From 10e37a721fdd2ecfd8e17f7213688927c29343a1 Mon Sep 17 00:00:00 2001
From: 嘉渊 <wangjiaming.wjm@alibaba-inc.com>
Date: 星期四, 27 四月 2023 17:24:47 +0800
Subject: [PATCH] update

---
 funasr/models/frontend/wav_frontend_kaldifeat.py |   12 +++--
 funasr/models/frontend/default.py                |    9 ++--
 funasr/models/frontend/s3prl.py                  |    4 +-
 funasr/models/frontend/windowing.py              |   10 +---
 funasr/models/frontend/wav_frontend.py           |   31 ++++++++-------
 funasr/models/frontend/fused.py                  |    5 +-
 6 files changed, 35 insertions(+), 36 deletions(-)

diff --git a/funasr/models/frontend/default.py b/funasr/models/frontend/default.py
index 5b034cf..cf6441e 100644
--- a/funasr/models/frontend/default.py
+++ b/funasr/models/frontend/default.py
@@ -11,13 +11,13 @@
 
 from funasr.layers.log_mel import LogMel
 from funasr.layers.stft import Stft
+from funasr.models.frontend.abs_frontend import AbsFrontend
 from funasr.modules.frontends.frontend import Frontend
 from funasr.utils.get_default_kwargs import get_default_kwargs
 
 
-class DefaultFrontend(torch.nn.Module):
+class DefaultFrontend(AbsFrontend):
     """Conventional frontend structure for ASR.
-
     Stft -> WPE -> MVDR-Beamformer -> Power-spec -> Mel-Fbank -> CMVN
     """
 
@@ -134,9 +134,8 @@
 
 
 
-class MultiChannelFrontend(torch.nn.Module):
+class MultiChannelFrontend(AbsFrontend):
     """Conventional frontend structure for ASR.
-
     Stft -> WPE -> MVDR-Beamformer -> Power-spec -> Mel-Fbank -> CMVN
     """
 
@@ -254,4 +253,4 @@
         # Change torch.Tensor to ComplexTensor
         # input_stft: (..., F, 2) -> (..., F)
         input_stft = ComplexTensor(input_stft[..., 0], input_stft[..., 1])
-        return input_stft, feats_lens
+        return input_stft, feats_lens
\ No newline at end of file
diff --git a/funasr/models/frontend/fused.py b/funasr/models/frontend/fused.py
index 7cebde7..857486d 100644
--- a/funasr/models/frontend/fused.py
+++ b/funasr/models/frontend/fused.py
@@ -1,3 +1,4 @@
+from funasr.models.frontend.abs_frontend import AbsFrontend
 from funasr.models.frontend.default import DefaultFrontend
 from funasr.models.frontend.s3prl import S3prlFrontend
 import numpy as np
@@ -6,7 +7,7 @@
 from typing import Tuple
 
 
-class FusedFrontends(torch.nn.Module):
+class FusedFrontends(AbsFrontend):
     def __init__(
         self, frontends=None, align_method="linear_projection", proj_dim=100, fs=16000
     ):
@@ -142,4 +143,4 @@
         else:
             raise NotImplementedError
 
-        return input_feats, feats_lens
+        return input_feats, feats_lens
\ No newline at end of file
diff --git a/funasr/models/frontend/s3prl.py b/funasr/models/frontend/s3prl.py
index c0a526f..b03d2c9 100644
--- a/funasr/models/frontend/s3prl.py
+++ b/funasr/models/frontend/s3prl.py
@@ -10,6 +10,7 @@
 import torch
 from typeguard import check_argument_types
 
+from funasr.models.frontend.abs_frontend import AbsFrontend
 from funasr.modules.frontends.frontend import Frontend
 from funasr.modules.nets_utils import pad_list
 from funasr.utils.get_default_kwargs import get_default_kwargs
@@ -26,7 +27,7 @@
     return args
 
 
-class S3prlFrontend(torch.nn.Module):
+class S3prlFrontend(AbsFrontend):
     """Speech Pretrained Representation frontend structure for ASR."""
 
     def __init__(
@@ -99,7 +100,6 @@
 
     def _tile_representations(self, feature):
         """Tile up the representations by `tile_factor`.
-
         Input - sequence of representations
                 shape: (batch_size, seq_len, feature_dim)
         Output - sequence of tiled representations
diff --git a/funasr/models/frontend/wav_frontend.py b/funasr/models/frontend/wav_frontend.py
index fc02dc9..35fab57 100644
--- a/funasr/models/frontend/wav_frontend.py
+++ b/funasr/models/frontend/wav_frontend.py
@@ -9,6 +9,7 @@
 from typeguard import check_argument_types
 
 import funasr.models.frontend.eend_ola_feature as eend_ola_feature
+from funasr.models.frontend.abs_frontend import AbsFrontend
 
 
 def load_cmvn(cmvn_file):
@@ -33,11 +34,11 @@
     means = np.array(means_list).astype(np.float)
     vars = np.array(vars_list).astype(np.float)
     cmvn = np.array([means, vars])
-    cmvn = torch.as_tensor(cmvn)
+    cmvn = torch.as_tensor(cmvn, dtype=torch.float32)
     return cmvn
 
 
-def apply_cmvn(inputs, cmvn_file):  # noqa
+def apply_cmvn(inputs, cmvn):  # noqa
     """
     Apply CMVN with mvn data
     """
@@ -46,11 +47,10 @@
     dtype = inputs.dtype
     frame, dim = inputs.shape
 
-    cmvn = load_cmvn(cmvn_file)
-    means = np.tile(cmvn[0:1, :dim], (frame, 1))
-    vars = np.tile(cmvn[1:2, :dim], (frame, 1))
-    inputs += torch.from_numpy(means).type(dtype).to(device)
-    inputs *= torch.from_numpy(vars).type(dtype).to(device)
+    means = cmvn[0:1, :dim]
+    vars = cmvn[1:2, :dim]
+    inputs += means.to(device)
+    inputs *= vars.to(device)
 
     return inputs.type(torch.float32)
 
@@ -75,7 +75,7 @@
     return LFR_outputs.type(torch.float32)
 
 
-class WavFrontend(torch.nn.Module):
+class WavFrontend(AbsFrontend):
     """Conventional frontend structure for ASR.
     """
 
@@ -110,6 +110,7 @@
         self.dither = dither
         self.snip_edges = snip_edges
         self.upsacle_samples = upsacle_samples
+        self.cmvn = None if self.cmvn_file is None else load_cmvn(self.cmvn_file)
 
     def output_size(self) -> int:
         return self.n_mels * self.lfr_m
@@ -139,8 +140,8 @@
 
             if self.lfr_m != 1 or self.lfr_n != 1:
                 mat = apply_lfr(mat, self.lfr_m, self.lfr_n)
-            if self.cmvn_file is not None:
-                mat = apply_cmvn(mat, self.cmvn_file)
+            if self.cmvn is not None:
+                mat = apply_cmvn(mat, self.cmvn)
             feat_length = mat.size(0)
             feats.append(mat)
             feats_lens.append(feat_length)
@@ -193,8 +194,8 @@
             mat = input[i, :input_lengths[i], :]
             if self.lfr_m != 1 or self.lfr_n != 1:
                 mat = apply_lfr(mat, self.lfr_m, self.lfr_n)
-            if self.cmvn_file is not None:
-                mat = apply_cmvn(mat, self.cmvn_file)
+            if self.cmvn is not None:
+                mat = apply_cmvn(mat, self.cmvn)
             feat_length = mat.size(0)
             feats.append(mat)
             feats_lens.append(feat_length)
@@ -206,7 +207,7 @@
         return feats_pad, feats_lens
 
 
-class WavFrontendOnline(torch.nn.Module):
+class WavFrontendOnline(AbsFrontend):
     """Conventional frontend structure for streaming ASR/VAD.
     """
 
@@ -451,7 +452,7 @@
         self.lfr_splice_cache = []
 
 
-class WavFrontendMel23(torch.nn.Module):
+class WavFrontendMel23(AbsFrontend):
     """Conventional frontend structure for ASR.
     """
 
@@ -499,4 +500,4 @@
         feats_pad = pad_sequence(feats,
                                  batch_first=True,
                                  padding_value=0.0)
-        return feats_pad, feats_lens
+        return feats_pad, feats_lens
\ No newline at end of file
diff --git a/funasr/models/frontend/wav_frontend_kaldifeat.py b/funasr/models/frontend/wav_frontend_kaldifeat.py
index d4e775e..85adbb7 100644
--- a/funasr/models/frontend/wav_frontend_kaldifeat.py
+++ b/funasr/models/frontend/wav_frontend_kaldifeat.py
@@ -6,8 +6,11 @@
 import numpy as np
 import torch
 import torchaudio.compliance.kaldi as kaldi
+from funasr.models.frontend.abs_frontend import AbsFrontend
 from typeguard import check_argument_types
 from torch.nn.utils.rnn import pad_sequence
+
+
 # import kaldifeat
 
 def load_cmvn(cmvn_file):
@@ -32,9 +35,9 @@
     means = np.array(means_list).astype(np.float)
     vars = np.array(vars_list).astype(np.float)
     cmvn = np.array([means, vars])
-    cmvn = torch.as_tensor(cmvn) 
-    return cmvn 
-          
+    cmvn = torch.as_tensor(cmvn)
+    return cmvn
+
 
 def apply_cmvn(inputs, cmvn_file):  # noqa
     """
@@ -72,7 +75,6 @@
             LFR_inputs.append(frame)
     LFR_outputs = torch.vstack(LFR_inputs)
     return LFR_outputs.type(torch.float32)
-
 
 # class WavFrontend_kaldifeat(AbsFrontend):
 #     """Conventional frontend structure for ASR.
@@ -176,4 +178,4 @@
 #         feats_pad = pad_sequence(feats,
 #                                  batch_first=True,
 #                                  padding_value=0.0)
-#         return feats_pad, feats_lens
+#         return feats_pad, feats_lens
\ No newline at end of file
diff --git a/funasr/models/frontend/windowing.py b/funasr/models/frontend/windowing.py
index f7f1dc1..a526758 100644
--- a/funasr/models/frontend/windowing.py
+++ b/funasr/models/frontend/windowing.py
@@ -4,19 +4,18 @@
 
 """Sliding Window for raw audio input data."""
 
+from funasr.models.frontend.abs_frontend import AbsFrontend
 import torch
 from typeguard import check_argument_types
 from typing import Tuple
 
 
-class SlidingWindow(torch.nn.Module):
+class SlidingWindow(AbsFrontend):
     """Sliding Window.
-
     Provides a sliding window over a batched continuous raw audio tensor.
     Optionally, provides padding (Currently not implemented).
     Combine this module with a pre-encoder compatible with raw audio data,
     for example Sinc convolutions.
-
     Known issues:
     Output length is calculated incorrectly if audio shorter than win_length.
     WARNING: trailing values are discarded - padding not implemented yet.
@@ -32,7 +31,6 @@
         fs=None,
     ):
         """Initialize.
-
         Args:
             win_length: Length of frame.
             hop_length: Relative starting point of next frame.
@@ -52,11 +50,9 @@
         self, input: torch.Tensor, input_lengths: torch.Tensor
     ) -> Tuple[torch.Tensor, torch.Tensor]:
         """Apply a sliding window on the input.
-
         Args:
             input: Input (B, T, C*D) or (B, T*C*D), with D=C=1.
             input_lengths: Input lengths within batch.
-
         Returns:
             Tensor: Output with dimensions (B, T, C, D), with D=win_length.
             Tensor: Output lengths within batch.
@@ -77,4 +73,4 @@
 
     def output_size(self) -> int:
         """Return output length of feature dimension D, i.e. the window length."""
-        return self.win_length
+        return self.win_length
\ No newline at end of file

--
Gitblit v1.9.1