| funasr/models/encoder/fsmn_encoder.py | ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史 |
funasr/models/encoder/fsmn_encoder.py
@@ -82,7 +82,8 @@ def forward(self, input: torch.Tensor, cache: torch.Tensor): x = torch.unsqueeze(input, 1) x_per = x.permute(0, 3, 2, 1) # B D T C cache = cache.to(x_per.device) y_left = torch.cat((cache, x_per), dim=2) cache = y_left[:, :, -(self.lorder - 1) * self.lstride:, :] y_left = self.conv_left(y_left) @@ -297,4 +298,4 @@ print('input shape: {}'.format(x.shape)) print('output shape: {}'.format(y.shape)) print(fsmn.to_kaldi_net()) print(fsmn.to_kaldi_net())