| funasr/models/frontend/default.py | ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史 | |
| funasr/models/frontend/fused.py | ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史 | |
| funasr/models/frontend/s3prl.py | ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史 | |
| funasr/models/frontend/wav_frontend.py | ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史 | |
| funasr/models/frontend/wav_frontend_kaldifeat.py | ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史 | |
| funasr/models/frontend/windowing.py | ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史 |
funasr/models/frontend/default.py
@@ -11,13 +11,13 @@ from funasr.layers.log_mel import LogMel from funasr.layers.stft import Stft from funasr.models.frontend.abs_frontend import AbsFrontend from funasr.modules.frontends.frontend import Frontend from funasr.utils.get_default_kwargs import get_default_kwargs class DefaultFrontend(torch.nn.Module): class DefaultFrontend(AbsFrontend): """Conventional frontend structure for ASR. Stft -> WPE -> MVDR-Beamformer -> Power-spec -> Mel-Fbank -> CMVN """ @@ -134,9 +134,8 @@ class MultiChannelFrontend(torch.nn.Module): class MultiChannelFrontend(AbsFrontend): """Conventional frontend structure for ASR. Stft -> WPE -> MVDR-Beamformer -> Power-spec -> Mel-Fbank -> CMVN """ @@ -254,4 +253,4 @@ # Change torch.Tensor to ComplexTensor # input_stft: (..., F, 2) -> (..., F) input_stft = ComplexTensor(input_stft[..., 0], input_stft[..., 1]) return input_stft, feats_lens return input_stft, feats_lens funasr/models/frontend/fused.py
@@ -1,3 +1,4 @@ from funasr.models.frontend.abs_frontend import AbsFrontend from funasr.models.frontend.default import DefaultFrontend from funasr.models.frontend.s3prl import S3prlFrontend import numpy as np @@ -6,7 +7,7 @@ from typing import Tuple class FusedFrontends(torch.nn.Module): class FusedFrontends(AbsFrontend): def __init__( self, frontends=None, align_method="linear_projection", proj_dim=100, fs=16000 ): @@ -142,4 +143,4 @@ else: raise NotImplementedError return input_feats, feats_lens return input_feats, feats_lens funasr/models/frontend/s3prl.py
@@ -10,6 +10,7 @@ import torch from typeguard import check_argument_types from funasr.models.frontend.abs_frontend import AbsFrontend from funasr.modules.frontends.frontend import Frontend from funasr.modules.nets_utils import pad_list from funasr.utils.get_default_kwargs import get_default_kwargs @@ -26,7 +27,7 @@ return args class S3prlFrontend(torch.nn.Module): class S3prlFrontend(AbsFrontend): """Speech Pretrained Representation frontend structure for ASR.""" def __init__( @@ -99,7 +100,6 @@ def _tile_representations(self, feature): """Tile up the representations by `tile_factor`. Input - sequence of representations shape: (batch_size, seq_len, feature_dim) Output - sequence of tiled representations funasr/models/frontend/wav_frontend.py
@@ -9,6 +9,7 @@ from typeguard import check_argument_types import funasr.models.frontend.eend_ola_feature as eend_ola_feature from funasr.models.frontend.abs_frontend import AbsFrontend def load_cmvn(cmvn_file): @@ -33,11 +34,11 @@ means = np.array(means_list).astype(np.float) vars = np.array(vars_list).astype(np.float) cmvn = np.array([means, vars]) cmvn = torch.as_tensor(cmvn) cmvn = torch.as_tensor(cmvn, dtype=torch.float32) return cmvn def apply_cmvn(inputs, cmvn_file): # noqa def apply_cmvn(inputs, cmvn): # noqa """ Apply CMVN with mvn data """ @@ -46,11 +47,10 @@ dtype = inputs.dtype frame, dim = inputs.shape cmvn = load_cmvn(cmvn_file) means = np.tile(cmvn[0:1, :dim], (frame, 1)) vars = np.tile(cmvn[1:2, :dim], (frame, 1)) inputs += torch.from_numpy(means).type(dtype).to(device) inputs *= torch.from_numpy(vars).type(dtype).to(device) means = cmvn[0:1, :dim] vars = cmvn[1:2, :dim] inputs += means.to(device) inputs *= vars.to(device) return inputs.type(torch.float32) @@ -75,7 +75,7 @@ return LFR_outputs.type(torch.float32) class WavFrontend(torch.nn.Module): class WavFrontend(AbsFrontend): """Conventional frontend structure for ASR. """ @@ -110,6 +110,7 @@ self.dither = dither self.snip_edges = snip_edges self.upsacle_samples = upsacle_samples self.cmvn = None if self.cmvn_file is None else load_cmvn(self.cmvn_file) def output_size(self) -> int: return self.n_mels * self.lfr_m @@ -139,8 +140,8 @@ if self.lfr_m != 1 or self.lfr_n != 1: mat = apply_lfr(mat, self.lfr_m, self.lfr_n) if self.cmvn_file is not None: mat = apply_cmvn(mat, self.cmvn_file) if self.cmvn is not None: mat = apply_cmvn(mat, self.cmvn) feat_length = mat.size(0) feats.append(mat) feats_lens.append(feat_length) @@ -193,8 +194,8 @@ mat = input[i, :input_lengths[i], :] if self.lfr_m != 1 or self.lfr_n != 1: mat = apply_lfr(mat, self.lfr_m, self.lfr_n) if self.cmvn_file is not None: mat = apply_cmvn(mat, self.cmvn_file) if self.cmvn is not None: mat = apply_cmvn(mat, self.cmvn) feat_length = mat.size(0) feats.append(mat) feats_lens.append(feat_length) @@ -206,7 +207,7 @@ return feats_pad, feats_lens class WavFrontendOnline(torch.nn.Module): class WavFrontendOnline(AbsFrontend): """Conventional frontend structure for streaming ASR/VAD. """ @@ -451,7 +452,7 @@ self.lfr_splice_cache = [] class WavFrontendMel23(torch.nn.Module): class WavFrontendMel23(AbsFrontend): """Conventional frontend structure for ASR. """ @@ -499,4 +500,4 @@ feats_pad = pad_sequence(feats, batch_first=True, padding_value=0.0) return feats_pad, feats_lens return feats_pad, feats_lens funasr/models/frontend/wav_frontend_kaldifeat.py
@@ -6,8 +6,11 @@ import numpy as np import torch import torchaudio.compliance.kaldi as kaldi from funasr.models.frontend.abs_frontend import AbsFrontend from typeguard import check_argument_types from torch.nn.utils.rnn import pad_sequence # import kaldifeat def load_cmvn(cmvn_file): @@ -32,9 +35,9 @@ means = np.array(means_list).astype(np.float) vars = np.array(vars_list).astype(np.float) cmvn = np.array([means, vars]) cmvn = torch.as_tensor(cmvn) return cmvn cmvn = torch.as_tensor(cmvn) return cmvn def apply_cmvn(inputs, cmvn_file): # noqa """ @@ -72,7 +75,6 @@ LFR_inputs.append(frame) LFR_outputs = torch.vstack(LFR_inputs) return LFR_outputs.type(torch.float32) # class WavFrontend_kaldifeat(AbsFrontend): # """Conventional frontend structure for ASR. @@ -176,4 +178,4 @@ # feats_pad = pad_sequence(feats, # batch_first=True, # padding_value=0.0) # return feats_pad, feats_lens # return feats_pad, feats_lens funasr/models/frontend/windowing.py
@@ -4,19 +4,18 @@ """Sliding Window for raw audio input data.""" from funasr.models.frontend.abs_frontend import AbsFrontend import torch from typeguard import check_argument_types from typing import Tuple class SlidingWindow(torch.nn.Module): class SlidingWindow(AbsFrontend): """Sliding Window. Provides a sliding window over a batched continuous raw audio tensor. Optionally, provides padding (Currently not implemented). Combine this module with a pre-encoder compatible with raw audio data, for example Sinc convolutions. Known issues: Output length is calculated incorrectly if audio shorter than win_length. WARNING: trailing values are discarded - padding not implemented yet. @@ -32,7 +31,6 @@ fs=None, ): """Initialize. Args: win_length: Length of frame. hop_length: Relative starting point of next frame. @@ -52,11 +50,9 @@ self, input: torch.Tensor, input_lengths: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: """Apply a sliding window on the input. Args: input: Input (B, T, C*D) or (B, T*C*D), with D=C=1. input_lengths: Input lengths within batch. Returns: Tensor: Output with dimensions (B, T, C, D), with D=win_length. Tensor: Output lengths within batch. @@ -77,4 +73,4 @@ def output_size(self) -> int: """Return output length of feature dimension D, i.e. the window length.""" return self.win_length return self.win_length