游雁
2024-03-26 abf5af40e934216b397c5331e0a68dc92f0a4f4e
funasr/frontends/wav_frontend.py
@@ -75,6 +75,7 @@
    LFR_outputs = torch.vstack(LFR_inputs)
    return LFR_outputs.type(torch.float32)
@tables.register("frontend_classes", "wav_frontend")
@tables.register("frontend_classes", "WavFrontend")
class WavFrontend(nn.Module):
    """Conventional frontend structure for ASR.
@@ -399,9 +400,10 @@
        return feats_pad, feats_lens, lfr_splice_frame_idxs
    def forward(
        self, input: torch.Tensor, input_lengths: torch.Tensor, cache: dict = {}, **kwargs
        self, input: torch.Tensor, input_lengths: torch.Tensor, **kwargs
    ):
        is_final = kwargs.get("is_final", False)
        cache = kwargs.get("cache", {})
        if len(cache) == 0:
            self.init_cache(cache)