Merge pull request #162 from alibaba-damo-academy/dev_zly
gpu bug fix
| | |
| | | def forward(self, input: torch.Tensor, cache: torch.Tensor): |
| | | x = torch.unsqueeze(input, 1) |
| | | x_per = x.permute(0, 3, 2, 1) # B D T C |
| | | |
| | | |
| | | cache = cache.to(x_per.device) |
| | | y_left = torch.cat((cache, x_per), dim=2) |
| | | cache = y_left[:, :, -(self.lorder - 1) * self.lstride:, :] |
| | | y_left = self.conv_left(y_left) |
| | |
| | | print('input shape: {}'.format(x.shape)) |
| | | print('output shape: {}'.format(y.shape)) |
| | | |
| | | print(fsmn.to_kaldi_net()) |
| | | print(fsmn.to_kaldi_net()) |