| | |
| | | raise "unsupported encoder" |
| | | |
| | | |
| | | def forward(self, feats: torch.Tensor, |
| | | in_cache0: torch.Tensor, |
| | | in_cache1: torch.Tensor, |
| | | in_cache2: torch.Tensor, |
| | | in_cache3: torch.Tensor, |
| | | ): |
| | | def forward(self, feats: torch.Tensor, *args, ): |
| | | |
| | | scores, (cache0, cache1, cache2, cache3) = self.encoder(feats, |
| | | in_cache0, |
| | | in_cache1, |
| | | in_cache2, |
| | | in_cache3) # return B * T * D |
| | | return scores, cache0, cache1, cache2, cache3 |
| | | scores, out_caches = self.encoder(feats, *args) |
| | | return scores, out_caches |
| | | |
| | | def get_dummy_inputs(self, frame=30): |
| | | speech = torch.randn(1, frame, self.feats_dim) |