From adcee8828ef5d78b575043954deb662a35e318f7 Mon Sep 17 00:00:00 2001
From: huangmingming <huangmingming@deepscience.cn>
Date: 星期一, 30 一月 2023 16:02:54 +0800
Subject: [PATCH] update the minimum size of audio

---
 funasr/models/frontend/wav_frontend.py |  222 +++++++++++++++++++++++++++++++++---------------------
 1 files changed, 135 insertions(+), 87 deletions(-)

diff --git a/funasr/models/frontend/wav_frontend.py b/funasr/models/frontend/wav_frontend.py
index c0b28ff..57c5976 100644
--- a/funasr/models/frontend/wav_frontend.py
+++ b/funasr/models/frontend/wav_frontend.py
@@ -1,22 +1,43 @@
 # Copyright (c) Alibaba, Inc. and its affiliates.
 # Part of the implementation is borrowed from espnet/espnet.
 
-import copy
-from typing import Optional, Tuple, Union
+from typing import Tuple
 
-import humanfriendly
 import numpy as np
 import torch
 import torchaudio.compliance.kaldi as kaldi
 from funasr.models.frontend.abs_frontend import AbsFrontend
-from funasr.layers.log_mel import LogMel
-from funasr.layers.stft import Stft
-from funasr.utils.get_default_kwargs import get_default_kwargs
-from funasr.modules.frontends.frontend import Frontend
 from typeguard import check_argument_types
+from torch.nn.utils.rnn import pad_sequence
 
 
-def apply_cmvn(inputs, mvn):  # noqa
+def load_cmvn(cmvn_file):
+    with open(cmvn_file, 'r', encoding='utf-8') as f:
+        lines = f.readlines()
+    means_list = []
+    vars_list = []
+    for i in range(len(lines)):
+        line_item = lines[i].split()
+        if line_item[0] == '<AddShift>':
+            line_item = lines[i + 1].split()
+            if line_item[0] == '<LearnRateCoef>':
+                add_shift_line = line_item[3:(len(line_item) - 1)]
+                means_list = list(add_shift_line)
+                continue
+        elif line_item[0] == '<Rescale>':
+            line_item = lines[i + 1].split()
+            if line_item[0] == '<LearnRateCoef>':
+                rescale_line = line_item[3:(len(line_item) - 1)]
+                vars_list = list(rescale_line)
+                continue
+    means = np.array(means_list).astype(np.float)
+    vars = np.array(vars_list).astype(np.float)
+    cmvn = np.array([means, vars])
+    cmvn = torch.as_tensor(cmvn) 
+    return cmvn 
+          
+
+def apply_cmvn(inputs, cmvn_file):  # noqa
     """
     Apply CMVN with mvn data
     """
@@ -25,9 +46,10 @@
     dtype = inputs.dtype
     frame, dim = inputs.shape
 
-    meams = np.tile(mvn[0:1, :dim], (frame, 1))
-    vars = np.tile(mvn[1:2, :dim], (frame, 1))
-    inputs += torch.from_numpy(meams).type(dtype).to(device)
+    cmvn = load_cmvn(cmvn_file)
+    means = np.tile(cmvn[0:1, :dim], (frame, 1))
+    vars = np.tile(cmvn[1:2, :dim], (frame, 1))
+    inputs += torch.from_numpy(means).type(dtype).to(device)
     inputs *= torch.from_numpy(vars).type(dtype).to(device)
 
     return inputs.type(torch.float32)
@@ -58,98 +80,124 @@
     """
     def __init__(
         self,
-        fs: Union[int, str] = 16000,
-        n_fft: int = 512,
-        win_length: int = 400,
-        hop_length: int = 160,
-        window: Optional[str] = 'hamming',
-        center: bool = True,
-        normalized: bool = False,
-        onesided: bool = True,
+        cmvn_file: str = None,
+        fs: int = 16000,
+        window: str = 'hamming',
         n_mels: int = 80,
-        fmin: int = None,
-        fmax: int = None,
+        frame_length: int = 25,
+        frame_shift: int = 10,
+        filter_length_min: int = -1,
+        filter_length_max: int = -1,
         lfr_m: int = 1,
         lfr_n: int = 1,
-        htk: bool = False,
-        mvn_data=None,
-        frontend_conf: Optional[dict] = get_default_kwargs(Frontend),
-        apply_stft: bool = True,
+        dither: float = 1.0
     ):
         assert check_argument_types()
         super().__init__()
-        if isinstance(fs, str):
-            fs = humanfriendly.parse_size(fs)
-
-        # Deepcopy (In general, dict shouldn't be used as default arg)
-        frontend_conf = copy.deepcopy(frontend_conf)
-        self.hop_length = hop_length
-        self.win_length = win_length
-        self.window = window
         self.fs = fs
-        self.mvn_data = mvn_data
+        self.window = window
+        self.n_mels = n_mels
+        self.frame_length = frame_length
+        self.frame_shift = frame_shift
+        self.filter_length_min = filter_length_min
+        self.filter_length_max = filter_length_max
         self.lfr_m = lfr_m
         self.lfr_n = lfr_n
-
-        if apply_stft:
-            self.stft = Stft(
-                n_fft=n_fft,
-                win_length=win_length,
-                hop_length=hop_length,
-                center=center,
-                window=window,
-                normalized=normalized,
-                onesided=onesided,
-            )
-        else:
-            self.stft = None
-        self.apply_stft = apply_stft
-
-        if frontend_conf is not None:
-            self.frontend = Frontend(idim=n_fft // 2 + 1, **frontend_conf)
-        else:
-            self.frontend = None
-
-        self.logmel = LogMel(
-            fs=fs,
-            n_fft=n_fft,
-            n_mels=n_mels,
-            fmin=fmin,
-            fmax=fmax,
-            htk=htk,
-        )
-        self.n_mels = n_mels
-        self.frontend_type = 'default'
+        self.cmvn_file = cmvn_file
+        self.dither = dither
 
     def output_size(self) -> int:
-        return self.n_mels
+        return self.n_mels * self.lfr_m
 
     def forward(
-            self, input: torch.Tensor,
+            self,
+            input: torch.Tensor,
             input_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+        batch_size = input.size(0)
+        feats = []
+        feats_lens = []
+        for i in range(batch_size):
+            waveform_length = input_lengths[i]
+            waveform = input[i][:waveform_length]
+            waveform = waveform * (1 << 15)
+            waveform = waveform.unsqueeze(0)
+            mat = kaldi.fbank(waveform,
+                              num_mel_bins=self.n_mels,
+                              frame_length=self.frame_length,
+                              frame_shift=self.frame_shift,
+                              dither=self.dither,
+                              energy_floor=0.0,
+                              window_type=self.window,
+                              sample_frequency=self.fs)
+     
+            if self.lfr_m != 1 or self.lfr_n != 1:
+                mat = apply_lfr(mat, self.lfr_m, self.lfr_n)
+            if self.cmvn_file is not None:
+                mat = apply_cmvn(mat, self.cmvn_file) 
+            feat_length = mat.size(0)
+            feats.append(mat)
+            feats_lens.append(feat_length)
 
-        sample_frequency = self.fs
-        num_mel_bins = self.n_mels
-        frame_length = self.win_length * 1000 / sample_frequency
-        frame_shift = self.hop_length * 1000 / sample_frequency
+        feats_lens = torch.as_tensor(feats_lens)
+        feats_pad = pad_sequence(feats,
+                                 batch_first=True,
+                                 padding_value=0.0)
+        return feats_pad, feats_lens
 
-        waveform = input * (1 << 15)
+    def forward_fbank(
+            self,
+            input: torch.Tensor,
+            input_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+        batch_size = input.size(0)
+        feats = []
+        feats_lens = []
+        for i in range(batch_size):
+            waveform_length = input_lengths[i]
+            waveform = input[i][:waveform_length]
+            waveform = waveform * (1 << 15)
+            waveform = waveform.unsqueeze(0)
+            mat = kaldi.fbank(waveform,
+                              num_mel_bins=self.n_mels,
+                              frame_length=self.frame_length,
+                              frame_shift=self.frame_shift,
+                              dither=self.dither,
+                              energy_floor=0.0,
+                              window_type=self.window,
+                              sample_frequency=self.fs)
 
-        mat = kaldi.fbank(waveform,
-                          num_mel_bins=num_mel_bins,
-                          frame_length=frame_length,
-                          frame_shift=frame_shift,
-                          dither=1.0,
-                          energy_floor=0.0,
-                          window_type=self.window,
-                          sample_frequency=sample_frequency)
-        if self.lfr_m != 1 or self.lfr_n != 1:
-            mat = apply_lfr(mat, self.lfr_m, self.lfr_n)
-        if self.mvn_data is not None:
-            mat = apply_cmvn(mat, self.mvn_data)
+            # if self.lfr_m != 1 or self.lfr_n != 1:
+            #     mat = apply_lfr(mat, self.lfr_m, self.lfr_n)
+            # if self.cmvn_file is not None:
+            #     mat = apply_cmvn(mat, self.cmvn_file)
+            feat_length = mat.size(0)
+            feats.append(mat)
+            feats_lens.append(feat_length)
 
-        input_feats = mat[None, :]
-        feats_lens = torch.randn(1)
-        feats_lens.fill_(input_feats.shape[1])
+        feats_lens = torch.as_tensor(feats_lens)
+        feats_pad = pad_sequence(feats,
+                                 batch_first=True,
+                                 padding_value=0.0)
+        return feats_pad, feats_lens
 
-        return input_feats, feats_lens
+    def forward_lfr_cmvn(
+            self,
+            input: torch.Tensor,
+            input_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+        batch_size = input.size(0)
+        feats = []
+        feats_lens = []
+        for i in range(batch_size):
+            mat = input[i, :input_lengths[i], :]
+            if self.lfr_m != 1 or self.lfr_n != 1:
+                mat = apply_lfr(mat, self.lfr_m, self.lfr_n)
+            if self.cmvn_file is not None:
+                mat = apply_cmvn(mat, self.cmvn_file)
+            feat_length = mat.size(0)
+            feats.append(mat)
+            feats_lens.append(feat_length)
+
+        feats_lens = torch.as_tensor(feats_lens)
+        feats_pad = pad_sequence(feats,
+                                 batch_first=True,
+                                 padding_value=0.0)
+        return feats_pad, feats_lens

--
Gitblit v1.9.1