| | |
| | | self, input: torch.Tensor, input_lengths: torch.Tensor, cache: dict = {}, **kwargs |
| | | ): |
| | | is_final = kwargs.get("is_final", False) |
| | | reset = kwargs.get("reset", False) |
| | | if len(cache) == 0 or reset: |
| | | if len(cache) == 0: |
| | | self.init_cache(cache) |
| | | |
| | | batch_size = input.shape[0] |
| | |
| | | feats = torch.stack(cache["lfr_splice_cache"]) |
| | | feats_lengths = torch.zeros(batch_size, dtype=torch.int) + feats.shape[1] |
| | | feats, feats_lengths, _ = self.forward_lfr_cmvn(feats, feats_lengths, is_final, cache=cache) |
| | | if is_final: |
| | | self.init_cache(cache) |
| | | # if is_final: |
| | | # self.init_cache(cache) |
| | | return feats, feats_lengths |
| | | |
| | | |