From 10e37a721fdd2ecfd8e17f7213688927c29343a1 Mon Sep 17 00:00:00 2001
From: 嘉渊 <wangjiaming.wjm@alibaba-inc.com>
Date: 星期四, 27 四月 2023 17:24:47 +0800
Subject: [PATCH] update
---
funasr/models/frontend/wav_frontend_kaldifeat.py | 12 +++--
funasr/models/frontend/default.py | 9 ++--
funasr/models/frontend/s3prl.py | 4 +-
funasr/models/frontend/windowing.py | 10 +---
funasr/models/frontend/wav_frontend.py | 31 ++++++++-------
funasr/models/frontend/fused.py | 5 +-
6 files changed, 35 insertions(+), 36 deletions(-)
diff --git a/funasr/models/frontend/default.py b/funasr/models/frontend/default.py
index 5b034cf..cf6441e 100644
--- a/funasr/models/frontend/default.py
+++ b/funasr/models/frontend/default.py
@@ -11,13 +11,13 @@
from funasr.layers.log_mel import LogMel
from funasr.layers.stft import Stft
+from funasr.models.frontend.abs_frontend import AbsFrontend
from funasr.modules.frontends.frontend import Frontend
from funasr.utils.get_default_kwargs import get_default_kwargs
-class DefaultFrontend(torch.nn.Module):
+class DefaultFrontend(AbsFrontend):
"""Conventional frontend structure for ASR.
-
Stft -> WPE -> MVDR-Beamformer -> Power-spec -> Mel-Fbank -> CMVN
"""
@@ -134,9 +134,8 @@
-class MultiChannelFrontend(torch.nn.Module):
+class MultiChannelFrontend(AbsFrontend):
"""Conventional frontend structure for ASR.
-
Stft -> WPE -> MVDR-Beamformer -> Power-spec -> Mel-Fbank -> CMVN
"""
@@ -254,4 +253,4 @@
# Change torch.Tensor to ComplexTensor
# input_stft: (..., F, 2) -> (..., F)
input_stft = ComplexTensor(input_stft[..., 0], input_stft[..., 1])
- return input_stft, feats_lens
+ return input_stft, feats_lens
\ No newline at end of file
diff --git a/funasr/models/frontend/fused.py b/funasr/models/frontend/fused.py
index 7cebde7..857486d 100644
--- a/funasr/models/frontend/fused.py
+++ b/funasr/models/frontend/fused.py
@@ -1,3 +1,4 @@
+from funasr.models.frontend.abs_frontend import AbsFrontend
from funasr.models.frontend.default import DefaultFrontend
from funasr.models.frontend.s3prl import S3prlFrontend
import numpy as np
@@ -6,7 +7,7 @@
from typing import Tuple
-class FusedFrontends(torch.nn.Module):
+class FusedFrontends(AbsFrontend):
def __init__(
self, frontends=None, align_method="linear_projection", proj_dim=100, fs=16000
):
@@ -142,4 +143,4 @@
else:
raise NotImplementedError
- return input_feats, feats_lens
+ return input_feats, feats_lens
\ No newline at end of file
diff --git a/funasr/models/frontend/s3prl.py b/funasr/models/frontend/s3prl.py
index c0a526f..b03d2c9 100644
--- a/funasr/models/frontend/s3prl.py
+++ b/funasr/models/frontend/s3prl.py
@@ -10,6 +10,7 @@
import torch
from typeguard import check_argument_types
+from funasr.models.frontend.abs_frontend import AbsFrontend
from funasr.modules.frontends.frontend import Frontend
from funasr.modules.nets_utils import pad_list
from funasr.utils.get_default_kwargs import get_default_kwargs
@@ -26,7 +27,7 @@
return args
-class S3prlFrontend(torch.nn.Module):
+class S3prlFrontend(AbsFrontend):
"""Speech Pretrained Representation frontend structure for ASR."""
def __init__(
@@ -99,7 +100,6 @@
def _tile_representations(self, feature):
"""Tile up the representations by `tile_factor`.
-
Input - sequence of representations
shape: (batch_size, seq_len, feature_dim)
Output - sequence of tiled representations
diff --git a/funasr/models/frontend/wav_frontend.py b/funasr/models/frontend/wav_frontend.py
index fc02dc9..35fab57 100644
--- a/funasr/models/frontend/wav_frontend.py
+++ b/funasr/models/frontend/wav_frontend.py
@@ -9,6 +9,7 @@
from typeguard import check_argument_types
import funasr.models.frontend.eend_ola_feature as eend_ola_feature
+from funasr.models.frontend.abs_frontend import AbsFrontend
def load_cmvn(cmvn_file):
@@ -33,11 +34,11 @@
means = np.array(means_list).astype(np.float)
vars = np.array(vars_list).astype(np.float)
cmvn = np.array([means, vars])
- cmvn = torch.as_tensor(cmvn)
+ cmvn = torch.as_tensor(cmvn, dtype=torch.float32)
return cmvn
-def apply_cmvn(inputs, cmvn_file): # noqa
+def apply_cmvn(inputs, cmvn): # noqa
"""
Apply CMVN with mvn data
"""
@@ -46,11 +47,10 @@
dtype = inputs.dtype
frame, dim = inputs.shape
- cmvn = load_cmvn(cmvn_file)
- means = np.tile(cmvn[0:1, :dim], (frame, 1))
- vars = np.tile(cmvn[1:2, :dim], (frame, 1))
- inputs += torch.from_numpy(means).type(dtype).to(device)
- inputs *= torch.from_numpy(vars).type(dtype).to(device)
+ means = cmvn[0:1, :dim]
+ vars = cmvn[1:2, :dim]
+ inputs += means.to(device)
+ inputs *= vars.to(device)
return inputs.type(torch.float32)
@@ -75,7 +75,7 @@
return LFR_outputs.type(torch.float32)
-class WavFrontend(torch.nn.Module):
+class WavFrontend(AbsFrontend):
"""Conventional frontend structure for ASR.
"""
@@ -110,6 +110,7 @@
self.dither = dither
self.snip_edges = snip_edges
self.upsacle_samples = upsacle_samples
+ self.cmvn = None if self.cmvn_file is None else load_cmvn(self.cmvn_file)
def output_size(self) -> int:
return self.n_mels * self.lfr_m
@@ -139,8 +140,8 @@
if self.lfr_m != 1 or self.lfr_n != 1:
mat = apply_lfr(mat, self.lfr_m, self.lfr_n)
- if self.cmvn_file is not None:
- mat = apply_cmvn(mat, self.cmvn_file)
+ if self.cmvn is not None:
+ mat = apply_cmvn(mat, self.cmvn)
feat_length = mat.size(0)
feats.append(mat)
feats_lens.append(feat_length)
@@ -193,8 +194,8 @@
mat = input[i, :input_lengths[i], :]
if self.lfr_m != 1 or self.lfr_n != 1:
mat = apply_lfr(mat, self.lfr_m, self.lfr_n)
- if self.cmvn_file is not None:
- mat = apply_cmvn(mat, self.cmvn_file)
+ if self.cmvn is not None:
+ mat = apply_cmvn(mat, self.cmvn)
feat_length = mat.size(0)
feats.append(mat)
feats_lens.append(feat_length)
@@ -206,7 +207,7 @@
return feats_pad, feats_lens
-class WavFrontendOnline(torch.nn.Module):
+class WavFrontendOnline(AbsFrontend):
"""Conventional frontend structure for streaming ASR/VAD.
"""
@@ -451,7 +452,7 @@
self.lfr_splice_cache = []
-class WavFrontendMel23(torch.nn.Module):
+class WavFrontendMel23(AbsFrontend):
"""Conventional frontend structure for ASR.
"""
@@ -499,4 +500,4 @@
feats_pad = pad_sequence(feats,
batch_first=True,
padding_value=0.0)
- return feats_pad, feats_lens
+ return feats_pad, feats_lens
\ No newline at end of file
diff --git a/funasr/models/frontend/wav_frontend_kaldifeat.py b/funasr/models/frontend/wav_frontend_kaldifeat.py
index d4e775e..85adbb7 100644
--- a/funasr/models/frontend/wav_frontend_kaldifeat.py
+++ b/funasr/models/frontend/wav_frontend_kaldifeat.py
@@ -6,8 +6,11 @@
import numpy as np
import torch
import torchaudio.compliance.kaldi as kaldi
+from funasr.models.frontend.abs_frontend import AbsFrontend
from typeguard import check_argument_types
from torch.nn.utils.rnn import pad_sequence
+
+
# import kaldifeat
def load_cmvn(cmvn_file):
@@ -32,9 +35,9 @@
means = np.array(means_list).astype(np.float)
vars = np.array(vars_list).astype(np.float)
cmvn = np.array([means, vars])
- cmvn = torch.as_tensor(cmvn)
- return cmvn
-
+ cmvn = torch.as_tensor(cmvn)
+ return cmvn
+
def apply_cmvn(inputs, cmvn_file): # noqa
"""
@@ -72,7 +75,6 @@
LFR_inputs.append(frame)
LFR_outputs = torch.vstack(LFR_inputs)
return LFR_outputs.type(torch.float32)
-
# class WavFrontend_kaldifeat(AbsFrontend):
# """Conventional frontend structure for ASR.
@@ -176,4 +178,4 @@
# feats_pad = pad_sequence(feats,
# batch_first=True,
# padding_value=0.0)
-# return feats_pad, feats_lens
+# return feats_pad, feats_lens
\ No newline at end of file
diff --git a/funasr/models/frontend/windowing.py b/funasr/models/frontend/windowing.py
index f7f1dc1..a526758 100644
--- a/funasr/models/frontend/windowing.py
+++ b/funasr/models/frontend/windowing.py
@@ -4,19 +4,18 @@
"""Sliding Window for raw audio input data."""
+from funasr.models.frontend.abs_frontend import AbsFrontend
import torch
from typeguard import check_argument_types
from typing import Tuple
-class SlidingWindow(torch.nn.Module):
+class SlidingWindow(AbsFrontend):
"""Sliding Window.
-
Provides a sliding window over a batched continuous raw audio tensor.
Optionally, provides padding (Currently not implemented).
Combine this module with a pre-encoder compatible with raw audio data,
for example Sinc convolutions.
-
Known issues:
Output length is calculated incorrectly if audio shorter than win_length.
WARNING: trailing values are discarded - padding not implemented yet.
@@ -32,7 +31,6 @@
fs=None,
):
"""Initialize.
-
Args:
win_length: Length of frame.
hop_length: Relative starting point of next frame.
@@ -52,11 +50,9 @@
self, input: torch.Tensor, input_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Apply a sliding window on the input.
-
Args:
input: Input (B, T, C*D) or (B, T*C*D), with D=C=1.
input_lengths: Input lengths within batch.
-
Returns:
Tensor: Output with dimensions (B, T, C, D), with D=win_length.
Tensor: Output lengths within batch.
@@ -77,4 +73,4 @@
def output_size(self) -> int:
"""Return output length of feature dimension D, i.e. the window length."""
- return self.win_length
+ return self.win_length
\ No newline at end of file
--
Gitblit v1.9.1