wuhongsheng
2024-07-05 3a4281f4959534b1bf5d01acf0085f4f8e6f2ec8
funasr/models/mossformer/e2e_ss.py
@@ -4,7 +4,7 @@
import torch.nn.functional as F
import copy
from funasr.models.base_model import FunASRModel
from funasr.models.encoder.mossformer_encoder import MossFormerEncoder, MossFormer_MaskNet
from funasr.models.encoder.mossformer_encoder import MossFormerEncoder, MossFormer_MaskNet
from funasr.models.decoder.mossformer_decoder import MossFormerDecoder
@@ -48,7 +48,9 @@
        super(MossFormer, self).__init__()
        self.num_spks = num_spks
        # Encoding
        self.enc = MossFormerEncoder(kernel_size=kernel_size, out_channels=in_channels, in_channels=1)
        self.enc = MossFormerEncoder(
            kernel_size=kernel_size, out_channels=in_channels, in_channels=1
        )
        ##Compute Mask
        self.mask_net = MossFormer_MaskNet(
@@ -62,12 +64,13 @@
            max_length=max_length,
        )
        self.dec = MossFormerDecoder(
           in_channels=out_channels,
           out_channels=1,
           kernel_size=kernel_size,
           stride = kernel_size//2,
           bias=False
            in_channels=out_channels,
            out_channels=1,
            kernel_size=kernel_size,
            stride=kernel_size // 2,
            bias=False,
        )
    def forward(self, input):
        x = self.enc(input)
        mask = self.mask_net(x)
@@ -76,10 +79,7 @@
        # Decoding
        est_source = torch.cat(
            [
                self.dec(sep_x[i]).unsqueeze(-1)
                for i in range(self.num_spks)
            ],
            [self.dec(sep_x[i]).unsqueeze(-1) for i in range(self.num_spks)],
            dim=-1,
        )
        T_origin = input.size(1)
@@ -91,5 +91,5 @@
        out = []
        for spk in range(self.num_spks):
            out.append(est_source[:,:,spk])
            out.append(est_source[:, :, spk])
        return out