| | |
| | | super(MossFormer, self).__init__() |
| | | self.num_spks = num_spks |
| | | # Encoding |
| | | self.enc = MossFormerEncoder(kernel_size=kernel_size, out_channels=in_channels, in_channels=1) |
| | | self.enc = MossFormerEncoder( |
| | | kernel_size=kernel_size, out_channels=in_channels, in_channels=1 |
| | | ) |
| | | |
| | | ##Compute Mask |
| | | self.mask_net = MossFormer_MaskNet( |
| | |
| | | out_channels=1, |
| | | kernel_size=kernel_size, |
| | | stride = kernel_size//2, |
| | | bias=False |
| | | bias=False, |
| | | ) |
| | | |
| | | def forward(self, input): |
| | | x = self.enc(input) |
| | | mask = self.mask_net(x) |
| | |
| | | |
| | | # Decoding |
| | | est_source = torch.cat( |
| | | [ |
| | | self.dec(sep_x[i]).unsqueeze(-1) |
| | | for i in range(self.num_spks) |
| | | ], |
| | | [self.dec(sep_x[i]).unsqueeze(-1) for i in range(self.num_spks)], |
| | | dim=-1, |
| | | ) |
| | | T_origin = input.size(1) |