zhifu gao
2024-04-24 861147c7308b91068ffa02724fdf74ee623a909e
funasr/models/mossformer/e2e_ss.py
@@ -48,7 +48,9 @@
        super(MossFormer, self).__init__()
        self.num_spks = num_spks
        # Encoding
        self.enc = MossFormerEncoder(kernel_size=kernel_size, out_channels=in_channels, in_channels=1)
        self.enc = MossFormerEncoder(
            kernel_size=kernel_size, out_channels=in_channels, in_channels=1
        )
        ##Compute Mask
        self.mask_net = MossFormer_MaskNet(
@@ -66,8 +68,9 @@
           out_channels=1,
           kernel_size=kernel_size,
           stride = kernel_size//2,
           bias=False
            bias=False,
        )
    def forward(self, input):
        x = self.enc(input)
        mask = self.mask_net(x)
@@ -76,10 +79,7 @@
        # Decoding
        est_source = torch.cat(
            [
                self.dec(sep_x[i]).unsqueeze(-1)
                for i in range(self.num_spks)
            ],
            [self.dec(sep_x[i]).unsqueeze(-1) for i in range(self.num_spks)],
            dim=-1,
        )
        T_origin = input.size(1)