| | |
| | | LFR_outputs = torch.vstack(LFR_inputs) |
| | | return LFR_outputs.type(torch.float32) |
| | | |
| | | @tables.register("frontend_classes", "wav_frontend") |
| | | @tables.register("frontend_classes", "WavFrontend") |
| | | class WavFrontend(nn.Module): |
| | | """Conventional frontend structure for ASR. |
| | |
| | | return feats_pad, feats_lens, lfr_splice_frame_idxs |
| | | |
| | | def forward( |
| | | self, input: torch.Tensor, input_lengths: torch.Tensor, cache: dict = {}, **kwargs |
| | | self, input: torch.Tensor, input_lengths: torch.Tensor, **kwargs |
| | | ): |
| | | is_final = kwargs.get("is_final", False) |
| | | reset = kwargs.get("reset", False) |
| | | if len(cache) == 0 or reset: |
| | | cache = kwargs.get("cache", {}) |
| | | if len(cache) == 0: |
| | | self.init_cache(cache) |
| | | |
| | | batch_size = input.shape[0] |
| | |
| | | feats = torch.stack(cache["lfr_splice_cache"]) |
| | | feats_lengths = torch.zeros(batch_size, dtype=torch.int) + feats.shape[1] |
| | | feats, feats_lengths, _ = self.forward_lfr_cmvn(feats, feats_lengths, is_final, cache=cache) |
| | | if is_final: |
| | | self.init_cache(cache) |
| | | # if is_final: |
| | | # self.init_cache(cache) |
| | | return feats, feats_lengths |
| | | |
| | | |