嘉渊
2023-04-27 10e37a721fdd2ecfd8e17f7213688927c29343a1
update
6个文件已修改
71 ■■■■ 已修改文件
funasr/models/frontend/default.py 9 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/frontend/fused.py 5 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/frontend/s3prl.py 4 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/frontend/wav_frontend.py 31 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/frontend/wav_frontend_kaldifeat.py 12 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/frontend/windowing.py 10 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/frontend/default.py
@@ -11,13 +11,13 @@
from funasr.layers.log_mel import LogMel
from funasr.layers.stft import Stft
from funasr.models.frontend.abs_frontend import AbsFrontend
from funasr.modules.frontends.frontend import Frontend
from funasr.utils.get_default_kwargs import get_default_kwargs
class DefaultFrontend(torch.nn.Module):
class DefaultFrontend(AbsFrontend):
    """Conventional frontend structure for ASR.
    Stft -> WPE -> MVDR-Beamformer -> Power-spec -> Mel-Fbank -> CMVN
    """
@@ -134,9 +134,8 @@
class MultiChannelFrontend(torch.nn.Module):
class MultiChannelFrontend(AbsFrontend):
    """Conventional frontend structure for ASR.
    Stft -> WPE -> MVDR-Beamformer -> Power-spec -> Mel-Fbank -> CMVN
    """
@@ -254,4 +253,4 @@
        # Change torch.Tensor to ComplexTensor
        # input_stft: (..., F, 2) -> (..., F)
        input_stft = ComplexTensor(input_stft[..., 0], input_stft[..., 1])
        return input_stft, feats_lens
        return input_stft, feats_lens
funasr/models/frontend/fused.py
@@ -1,3 +1,4 @@
from funasr.models.frontend.abs_frontend import AbsFrontend
from funasr.models.frontend.default import DefaultFrontend
from funasr.models.frontend.s3prl import S3prlFrontend
import numpy as np
@@ -6,7 +7,7 @@
from typing import Tuple
class FusedFrontends(torch.nn.Module):
class FusedFrontends(AbsFrontend):
    def __init__(
        self, frontends=None, align_method="linear_projection", proj_dim=100, fs=16000
    ):
@@ -142,4 +143,4 @@
        else:
            raise NotImplementedError
        return input_feats, feats_lens
        return input_feats, feats_lens
funasr/models/frontend/s3prl.py
@@ -10,6 +10,7 @@
import torch
from typeguard import check_argument_types
from funasr.models.frontend.abs_frontend import AbsFrontend
from funasr.modules.frontends.frontend import Frontend
from funasr.modules.nets_utils import pad_list
from funasr.utils.get_default_kwargs import get_default_kwargs
@@ -26,7 +27,7 @@
    return args
class S3prlFrontend(torch.nn.Module):
class S3prlFrontend(AbsFrontend):
    """Speech Pretrained Representation frontend structure for ASR."""
    def __init__(
@@ -99,7 +100,6 @@
    def _tile_representations(self, feature):
        """Tile up the representations by `tile_factor`.
        Input - sequence of representations
                shape: (batch_size, seq_len, feature_dim)
        Output - sequence of tiled representations
funasr/models/frontend/wav_frontend.py
@@ -9,6 +9,7 @@
from typeguard import check_argument_types
import funasr.models.frontend.eend_ola_feature as eend_ola_feature
from funasr.models.frontend.abs_frontend import AbsFrontend
def load_cmvn(cmvn_file):
@@ -33,11 +34,11 @@
    means = np.array(means_list).astype(np.float)
    vars = np.array(vars_list).astype(np.float)
    cmvn = np.array([means, vars])
    cmvn = torch.as_tensor(cmvn)
    cmvn = torch.as_tensor(cmvn, dtype=torch.float32)
    return cmvn
def apply_cmvn(inputs, cmvn_file):  # noqa
def apply_cmvn(inputs, cmvn):  # noqa
    """
    Apply CMVN with mvn data
    """
@@ -46,11 +47,10 @@
    dtype = inputs.dtype
    frame, dim = inputs.shape
    cmvn = load_cmvn(cmvn_file)
    means = np.tile(cmvn[0:1, :dim], (frame, 1))
    vars = np.tile(cmvn[1:2, :dim], (frame, 1))
    inputs += torch.from_numpy(means).type(dtype).to(device)
    inputs *= torch.from_numpy(vars).type(dtype).to(device)
    means = cmvn[0:1, :dim]
    vars = cmvn[1:2, :dim]
    inputs += means.to(device)
    inputs *= vars.to(device)
    return inputs.type(torch.float32)
@@ -75,7 +75,7 @@
    return LFR_outputs.type(torch.float32)
class WavFrontend(torch.nn.Module):
class WavFrontend(AbsFrontend):
    """Conventional frontend structure for ASR.
    """
@@ -110,6 +110,7 @@
        self.dither = dither
        self.snip_edges = snip_edges
        self.upsacle_samples = upsacle_samples
        self.cmvn = None if self.cmvn_file is None else load_cmvn(self.cmvn_file)
    def output_size(self) -> int:
        return self.n_mels * self.lfr_m
@@ -139,8 +140,8 @@
            if self.lfr_m != 1 or self.lfr_n != 1:
                mat = apply_lfr(mat, self.lfr_m, self.lfr_n)
            if self.cmvn_file is not None:
                mat = apply_cmvn(mat, self.cmvn_file)
            if self.cmvn is not None:
                mat = apply_cmvn(mat, self.cmvn)
            feat_length = mat.size(0)
            feats.append(mat)
            feats_lens.append(feat_length)
@@ -193,8 +194,8 @@
            mat = input[i, :input_lengths[i], :]
            if self.lfr_m != 1 or self.lfr_n != 1:
                mat = apply_lfr(mat, self.lfr_m, self.lfr_n)
            if self.cmvn_file is not None:
                mat = apply_cmvn(mat, self.cmvn_file)
            if self.cmvn is not None:
                mat = apply_cmvn(mat, self.cmvn)
            feat_length = mat.size(0)
            feats.append(mat)
            feats_lens.append(feat_length)
@@ -206,7 +207,7 @@
        return feats_pad, feats_lens
class WavFrontendOnline(torch.nn.Module):
class WavFrontendOnline(AbsFrontend):
    """Conventional frontend structure for streaming ASR/VAD.
    """
@@ -451,7 +452,7 @@
        self.lfr_splice_cache = []
class WavFrontendMel23(torch.nn.Module):
class WavFrontendMel23(AbsFrontend):
    """Conventional frontend structure for ASR.
    """
@@ -499,4 +500,4 @@
        feats_pad = pad_sequence(feats,
                                 batch_first=True,
                                 padding_value=0.0)
        return feats_pad, feats_lens
        return feats_pad, feats_lens
funasr/models/frontend/wav_frontend_kaldifeat.py
@@ -6,8 +6,11 @@
import numpy as np
import torch
import torchaudio.compliance.kaldi as kaldi
from funasr.models.frontend.abs_frontend import AbsFrontend
from typeguard import check_argument_types
from torch.nn.utils.rnn import pad_sequence
# import kaldifeat
def load_cmvn(cmvn_file):
@@ -32,9 +35,9 @@
    means = np.array(means_list).astype(np.float)
    vars = np.array(vars_list).astype(np.float)
    cmvn = np.array([means, vars])
    cmvn = torch.as_tensor(cmvn)
    return cmvn
    cmvn = torch.as_tensor(cmvn)
    return cmvn
def apply_cmvn(inputs, cmvn_file):  # noqa
    """
@@ -72,7 +75,6 @@
            LFR_inputs.append(frame)
    LFR_outputs = torch.vstack(LFR_inputs)
    return LFR_outputs.type(torch.float32)
# class WavFrontend_kaldifeat(AbsFrontend):
#     """Conventional frontend structure for ASR.
@@ -176,4 +178,4 @@
#         feats_pad = pad_sequence(feats,
#                                  batch_first=True,
#                                  padding_value=0.0)
#         return feats_pad, feats_lens
#         return feats_pad, feats_lens
funasr/models/frontend/windowing.py
@@ -4,19 +4,18 @@
"""Sliding Window for raw audio input data."""
from funasr.models.frontend.abs_frontend import AbsFrontend
import torch
from typeguard import check_argument_types
from typing import Tuple
class SlidingWindow(torch.nn.Module):
class SlidingWindow(AbsFrontend):
    """Sliding Window.
    Provides a sliding window over a batched continuous raw audio tensor.
    Optionally, provides padding (Currently not implemented).
    Combine this module with a pre-encoder compatible with raw audio data,
    for example Sinc convolutions.
    Known issues:
    Output length is calculated incorrectly if audio shorter than win_length.
    WARNING: trailing values are discarded - padding not implemented yet.
@@ -32,7 +31,6 @@
        fs=None,
    ):
        """Initialize.
        Args:
            win_length: Length of frame.
            hop_length: Relative starting point of next frame.
@@ -52,11 +50,9 @@
        self, input: torch.Tensor, input_lengths: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Apply a sliding window on the input.
        Args:
            input: Input (B, T, C*D) or (B, T*C*D), with D=C=1.
            input_lengths: Input lengths within batch.
        Returns:
            Tensor: Output with dimensions (B, T, C, D), with D=win_length.
            Tensor: Output lengths within batch.
@@ -77,4 +73,4 @@
    def output_size(self) -> int:
        """Return output length of feature dimension D, i.e. the window length."""
        return self.win_length
        return self.win_length