| | |
| | | import torch.nn.functional as F |
| | | import copy |
| | | from funasr.models.base_model import FunASRModel |
| | | from funasr.models.encoder.mossformer_encoder import MossFormerEncoder, MossFormer_MaskNet |
| | | from funasr.models.encoder.mossformer_encoder import MossFormerEncoder, MossFormer_MaskNet |
| | | from funasr.models.decoder.mossformer_decoder import MossFormerDecoder |
| | | |
| | | |
| | |
| | | super(MossFormer, self).__init__() |
| | | self.num_spks = num_spks |
| | | # Encoding |
| | | self.enc = MossFormerEncoder(kernel_size=kernel_size, out_channels=in_channels, in_channels=1) |
| | | self.enc = MossFormerEncoder( |
| | | kernel_size=kernel_size, out_channels=in_channels, in_channels=1 |
| | | ) |
| | | |
| | | ##Compute Mask |
| | | self.mask_net = MossFormer_MaskNet( |
| | |
| | | max_length=max_length, |
| | | ) |
| | | self.dec = MossFormerDecoder( |
| | | in_channels=out_channels, |
| | | out_channels=1, |
| | | kernel_size=kernel_size, |
| | | stride = kernel_size//2, |
| | | bias=False |
| | | in_channels=out_channels, |
| | | out_channels=1, |
| | | kernel_size=kernel_size, |
| | | stride=kernel_size // 2, |
| | | bias=False, |
| | | ) |
| | | |
| | | def forward(self, input): |
| | | x = self.enc(input) |
| | | mask = self.mask_net(x) |
| | |
| | | |
| | | # Decoding |
| | | est_source = torch.cat( |
| | | [ |
| | | self.dec(sep_x[i]).unsqueeze(-1) |
| | | for i in range(self.num_spks) |
| | | ], |
| | | [self.dec(sep_x[i]).unsqueeze(-1) for i in range(self.num_spks)], |
| | | dim=-1, |
| | | ) |
| | | T_origin = input.size(1) |
| | |
| | | |
| | | out = [] |
| | | for spk in range(self.num_spks): |
| | | out.append(est_source[:,:,spk]) |
| | | out.append(est_source[:, :, spk]) |
| | | return out |