From 141a4737f779fcf435a0ece5434b9c73eda7d2a9 Mon Sep 17 00:00:00 2001
From: speech_asr <wangjiaming.wjm@alibaba-inc.com>
Date: 星期二, 14 三月 2023 15:54:28 +0800
Subject: [PATCH] update
---
funasr/models/frontend/wav_frontend.py | 106 ++++++++++++++++++++++++++++++++++++++++-------------
1 files changed, 80 insertions(+), 26 deletions(-)
diff --git a/funasr/models/frontend/wav_frontend.py b/funasr/models/frontend/wav_frontend.py
index 57c5976..6af7074 100644
--- a/funasr/models/frontend/wav_frontend.py
+++ b/funasr/models/frontend/wav_frontend.py
@@ -1,14 +1,15 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# Part of the implementation is borrowed from espnet/espnet.
-from typing import Tuple
-
+import funasr.models.frontend.eend_ola_feature
import numpy as np
import torch
import torchaudio.compliance.kaldi as kaldi
from funasr.models.frontend.abs_frontend import AbsFrontend
-from typeguard import check_argument_types
+import funasr.models.frontend.eend_ola_feature as eend_ola_feature
from torch.nn.utils.rnn import pad_sequence
+from typeguard import check_argument_types
+from typing import Tuple
def load_cmvn(cmvn_file):
@@ -33,9 +34,9 @@
means = np.array(means_list).astype(np.float)
vars = np.array(vars_list).astype(np.float)
cmvn = np.array([means, vars])
- cmvn = torch.as_tensor(cmvn)
- return cmvn
-
+ cmvn = torch.as_tensor(cmvn)
+ return cmvn
+
def apply_cmvn(inputs, cmvn_file): # noqa
"""
@@ -78,19 +79,22 @@
class WavFrontend(AbsFrontend):
"""Conventional frontend structure for ASR.
"""
+
def __init__(
- self,
- cmvn_file: str = None,
- fs: int = 16000,
- window: str = 'hamming',
- n_mels: int = 80,
- frame_length: int = 25,
- frame_shift: int = 10,
- filter_length_min: int = -1,
- filter_length_max: int = -1,
- lfr_m: int = 1,
- lfr_n: int = 1,
- dither: float = 1.0
+ self,
+ cmvn_file: str = None,
+ fs: int = 16000,
+ window: str = 'hamming',
+ n_mels: int = 80,
+ frame_length: int = 25,
+ frame_shift: int = 10,
+ filter_length_min: int = -1,
+ filter_length_max: int = -1,
+ lfr_m: int = 1,
+ lfr_n: int = 1,
+ dither: float = 1.0,
+ snip_edges: bool = True,
+ upsacle_samples: bool = True,
):
assert check_argument_types()
super().__init__()
@@ -105,6 +109,8 @@
self.lfr_n = lfr_n
self.cmvn_file = cmvn_file
self.dither = dither
+ self.snip_edges = snip_edges
+ self.upsacle_samples = upsacle_samples
def output_size(self) -> int:
return self.n_mels * self.lfr_m
@@ -119,7 +125,8 @@
for i in range(batch_size):
waveform_length = input_lengths[i]
waveform = input[i][:waveform_length]
- waveform = waveform * (1 << 15)
+ if self.upsacle_samples:
+ waveform = waveform * (1 << 15)
waveform = waveform.unsqueeze(0)
mat = kaldi.fbank(waveform,
num_mel_bins=self.n_mels,
@@ -128,12 +135,13 @@
dither=self.dither,
energy_floor=0.0,
window_type=self.window,
- sample_frequency=self.fs)
-
+ sample_frequency=self.fs,
+ snip_edges=self.snip_edges)
+
if self.lfr_m != 1 or self.lfr_n != 1:
mat = apply_lfr(mat, self.lfr_m, self.lfr_n)
if self.cmvn_file is not None:
- mat = apply_cmvn(mat, self.cmvn_file)
+ mat = apply_cmvn(mat, self.cmvn_file)
feat_length = mat.size(0)
feats.append(mat)
feats_lens.append(feat_length)
@@ -165,10 +173,6 @@
window_type=self.window,
sample_frequency=self.fs)
- # if self.lfr_m != 1 or self.lfr_n != 1:
- # mat = apply_lfr(mat, self.lfr_m, self.lfr_n)
- # if self.cmvn_file is not None:
- # mat = apply_cmvn(mat, self.cmvn_file)
feat_length = mat.size(0)
feats.append(mat)
feats_lens.append(feat_length)
@@ -201,3 +205,53 @@
batch_first=True,
padding_value=0.0)
return feats_pad, feats_lens
+
+
+class WavFrontendMel23(AbsFrontend):
+ """Conventional frontend structure for ASR.
+ """
+
+ def __init__(
+ self,
+ fs: int = 16000,
+ frame_length: int = 25,
+ frame_shift: int = 10,
+ lfr_m: int = 1,
+ lfr_n: int = 1,
+ ):
+ assert check_argument_types()
+ super().__init__()
+ self.fs = fs
+ self.frame_length = frame_length
+ self.frame_shift = frame_shift
+ self.lfr_m = lfr_m
+ self.lfr_n = lfr_n
+
+ def output_size(self) -> int:
+ return self.n_mels * self.lfr_m
+
+ def forward(
+ self,
+ input: torch.Tensor,
+ input_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ batch_size = input.size(0)
+ feats = []
+ feats_lens = []
+ for i in range(batch_size):
+ waveform_length = input_lengths[i]
+ waveform = input[i][:waveform_length]
+ waveform = waveform.unsqueeze(0).numpy()
+ mat = eend_ola_feature.stft(waveform, self.frame_length, self.frame_shift)
+ mat = eend_ola_feature.transform(mat)
+ mat = mat.splice(mat, context_size=self.lfr_m)
+ mat = mat[::self.lfr_n]
+ mat = torch.from_numpy(mat)
+ feat_length = mat.size(0)
+ feats.append(mat)
+ feats_lens.append(feat_length)
+
+ feats_lens = torch.as_tensor(feats_lens)
+ feats_pad = pad_sequence(feats,
+ batch_first=True,
+ padding_value=0.0)
+ return feats_pad, feats_lens
--
Gitblit v1.9.1