游雁
2023-12-19 0e622e694e6cb4459955f1e5942a7c53349ce640
funasr/frontends/wav_frontend.py
File was renamed from funasr/models/frontend/wav_frontend.py
@@ -4,11 +4,13 @@
import numpy as np
import torch
import torch.nn as nn
import torchaudio.compliance.kaldi as kaldi
from torch.nn.utils.rnn import pad_sequence
import funasr.models.frontend.eend_ola_feature as eend_ola_feature
from funasr.models.frontend.abs_frontend import AbsFrontend
import funasr.frontends.eend_ola_feature as eend_ola_feature
from funasr.utils.register import register_class
def load_cmvn(cmvn_file):
@@ -73,8 +75,8 @@
    LFR_outputs = torch.vstack(LFR_inputs)
    return LFR_outputs.type(torch.float32)
class WavFrontend(AbsFrontend):
@register_class("frontend_classes", "WavFrontend")
class WavFrontend(nn.Module):
    """Conventional frontend structure for ASR.
    """
@@ -93,6 +95,7 @@
            dither: float = 1.0,
            snip_edges: bool = True,
            upsacle_samples: bool = True,
            **kwargs,
    ):
        super().__init__()
        self.fs = fs
@@ -208,7 +211,8 @@
        return feats_pad, feats_lens
class WavFrontendOnline(AbsFrontend):
@register_class("frontend_classes", "WavFrontendOnline")
class WavFrontendOnline(nn.Module):
    """Conventional frontend structure for streaming ASR/VAD.
    """
@@ -227,6 +231,7 @@
            dither: float = 1.0,
            snip_edges: bool = True,
            upsacle_samples: bool = True,
            **kwargs,
    ):
        super().__init__()
        self.fs = fs
@@ -454,7 +459,7 @@
        self.lfr_splice_cache = []
class WavFrontendMel23(AbsFrontend):
class WavFrontendMel23(nn.Module):
    """Conventional frontend structure for ASR.
    """
@@ -465,6 +470,7 @@
            frame_shift: int = 10,
            lfr_m: int = 1,
            lfr_n: int = 1,
            **kwargs,
    ):
        super().__init__()
        self.fs = fs