From 33d3d2084403fd34b79c835d2f2fe04f6cd8f738 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 13 九月 2023 09:33:54 +0800
Subject: [PATCH] Merge branch 'main' of github.com:alibaba-damo-academy/FunASR add
---
funasr/models/frontend/wav_frontend.py | 68 ++++++++++++++++++----------------
1 files changed, 36 insertions(+), 32 deletions(-)
diff --git a/funasr/models/frontend/wav_frontend.py b/funasr/models/frontend/wav_frontend.py
index c4b7910..ca5aed6 100644
--- a/funasr/models/frontend/wav_frontend.py
+++ b/funasr/models/frontend/wav_frontend.py
@@ -1,15 +1,14 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# Part of the implementation is borrowed from espnet/espnet.
-from abc import ABC
from typing import Tuple
import numpy as np
import torch
import torchaudio.compliance.kaldi as kaldi
-from funasr.models.frontend.abs_frontend import AbsFrontend
-import funasr.models.frontend.eend_ola_feature as eend_ola_feature
-from typeguard import check_argument_types
from torch.nn.utils.rnn import pad_sequence
+
+import funasr.models.frontend.eend_ola_feature as eend_ola_feature
+from funasr.models.frontend.abs_frontend import AbsFrontend
def load_cmvn(cmvn_file):
@@ -31,14 +30,14 @@
rescale_line = line_item[3:(len(line_item) - 1)]
vars_list = list(rescale_line)
continue
- means = np.array(means_list).astype(np.float)
- vars = np.array(vars_list).astype(np.float)
+ means = np.array(means_list).astype(np.float32)
+ vars = np.array(vars_list).astype(np.float32)
cmvn = np.array([means, vars])
- cmvn = torch.as_tensor(cmvn)
+ cmvn = torch.as_tensor(cmvn, dtype=torch.float32)
return cmvn
-def apply_cmvn(inputs, cmvn_file): # noqa
+def apply_cmvn(inputs, cmvn): # noqa
"""
Apply CMVN with mvn data
"""
@@ -47,11 +46,10 @@
dtype = inputs.dtype
frame, dim = inputs.shape
- cmvn = load_cmvn(cmvn_file)
- means = np.tile(cmvn[0:1, :dim], (frame, 1))
- vars = np.tile(cmvn[1:2, :dim], (frame, 1))
- inputs += torch.from_numpy(means).type(dtype).to(device)
- inputs *= torch.from_numpy(vars).type(dtype).to(device)
+ means = cmvn[0:1, :dim]
+ vars = cmvn[1:2, :dim]
+ inputs += means.to(device)
+ inputs *= vars.to(device)
return inputs.type(torch.float32)
@@ -96,7 +94,6 @@
snip_edges: bool = True,
upsacle_samples: bool = True,
):
- assert check_argument_types()
super().__init__()
self.fs = fs
self.window = window
@@ -111,6 +108,7 @@
self.dither = dither
self.snip_edges = snip_edges
self.upsacle_samples = upsacle_samples
+ self.cmvn = None if self.cmvn_file is None else load_cmvn(self.cmvn_file)
def output_size(self) -> int:
return self.n_mels * self.lfr_m
@@ -140,8 +138,8 @@
if self.lfr_m != 1 or self.lfr_n != 1:
mat = apply_lfr(mat, self.lfr_m, self.lfr_n)
- if self.cmvn_file is not None:
- mat = apply_cmvn(mat, self.cmvn_file)
+ if self.cmvn is not None:
+ mat = apply_cmvn(mat, self.cmvn)
feat_length = mat.size(0)
feats.append(mat)
feats_lens.append(feat_length)
@@ -194,8 +192,8 @@
mat = input[i, :input_lengths[i], :]
if self.lfr_m != 1 or self.lfr_n != 1:
mat = apply_lfr(mat, self.lfr_m, self.lfr_n)
- if self.cmvn_file is not None:
- mat = apply_cmvn(mat, self.cmvn_file)
+ if self.cmvn is not None:
+ mat = apply_cmvn(mat, self.cmvn)
feat_length = mat.size(0)
feats.append(mat)
feats_lens.append(feat_length)
@@ -227,7 +225,6 @@
snip_edges: bool = True,
upsacle_samples: bool = True,
):
- assert check_argument_types()
super().__init__()
self.fs = fs
self.window = window
@@ -276,7 +273,8 @@
# inputs tensor has catted the cache tensor
# def apply_lfr(inputs: torch.Tensor, lfr_m: int, lfr_n: int, inputs_lfr_cache: torch.Tensor = None,
# is_final: bool = False) -> Tuple[torch.Tensor, torch.Tensor, int]:
- def apply_lfr(inputs: torch.Tensor, lfr_m: int, lfr_n: int, is_final: bool = False) -> Tuple[torch.Tensor, torch.Tensor, int]:
+ def apply_lfr(inputs: torch.Tensor, lfr_m: int, lfr_n: int, is_final: bool = False) -> Tuple[
+ torch.Tensor, torch.Tensor, int]:
"""
Apply lfr with data
"""
@@ -377,7 +375,8 @@
if self.lfr_m != 1 or self.lfr_n != 1:
# update self.lfr_splice_cache in self.apply_lfr
# mat, self.lfr_splice_cache[i], lfr_splice_frame_idx = self.apply_lfr(mat, self.lfr_m, self.lfr_n, self.lfr_splice_cache[i],
- mat, self.lfr_splice_cache[i], lfr_splice_frame_idx = self.apply_lfr(mat, self.lfr_m, self.lfr_n, is_final)
+ mat, self.lfr_splice_cache[i], lfr_splice_frame_idx = self.apply_lfr(mat, self.lfr_m, self.lfr_n,
+ is_final)
if self.cmvn_file is not None:
mat = self.apply_cmvn(mat, self.cmvn)
feat_length = mat.size(0)
@@ -393,15 +392,18 @@
return feats_pad, feats_lens, lfr_splice_frame_idxs
def forward(
- self, input: torch.Tensor, input_lengths: torch.Tensor, is_final: bool = False
+ self, input: torch.Tensor, input_lengths: torch.Tensor, is_final: bool = False, reset: bool = False
) -> Tuple[torch.Tensor, torch.Tensor]:
+ if reset:
+ self.cache_reset()
batch_size = input.shape[0]
assert batch_size == 1, 'we support to extract feature online only when the batch size is equal to 1 now'
waveforms, feats, feats_lengths = self.forward_fbank(input, input_lengths) # input shape: B T D
if feats.shape[0]:
- #if self.reserve_waveforms is None and self.lfr_m > 1:
+ # if self.reserve_waveforms is None and self.lfr_m > 1:
# self.reserve_waveforms = waveforms[:, :(self.lfr_m - 1) // 2 * self.frame_shift_sample_length]
- self.waveforms = waveforms if self.reserve_waveforms is None else torch.cat((self.reserve_waveforms, waveforms), dim=1)
+ self.waveforms = waveforms if self.reserve_waveforms is None else torch.cat(
+ (self.reserve_waveforms, waveforms), dim=1)
if not self.lfr_splice_cache: # 鍒濆鍖杝plice_cache
for i in range(batch_size):
self.lfr_splice_cache.append(feats[i][0, :].unsqueeze(dim=0).repeat((self.lfr_m - 1) // 2, 1))
@@ -410,7 +412,8 @@
lfr_splice_cache_tensor = torch.stack(self.lfr_splice_cache) # B T D
feats = torch.cat((lfr_splice_cache_tensor, feats), dim=1)
feats_lengths += lfr_splice_cache_tensor[0].shape[0]
- frame_from_waveforms = int((self.waveforms.shape[1] - self.frame_sample_length) / self.frame_shift_sample_length + 1)
+ frame_from_waveforms = int(
+ (self.waveforms.shape[1] - self.frame_sample_length) / self.frame_shift_sample_length + 1)
minus_frame = (self.lfr_m - 1) // 2 if self.reserve_waveforms is None else 0
feats, feats_lengths, lfr_splice_frame_idxs = self.forward_lfr_cmvn(feats, feats_lengths, is_final)
if self.lfr_m == 1:
@@ -424,14 +427,15 @@
self.waveforms = self.waveforms[:, :sample_length]
else:
# update self.reserve_waveforms and self.lfr_splice_cache
- self.reserve_waveforms = self.waveforms[:, :-(self.frame_sample_length - self.frame_shift_sample_length)]
+ self.reserve_waveforms = self.waveforms[:,
+ :-(self.frame_sample_length - self.frame_shift_sample_length)]
for i in range(batch_size):
self.lfr_splice_cache[i] = torch.cat((self.lfr_splice_cache[i], feats[i]), dim=0)
return torch.empty(0), feats_lengths
else:
if is_final:
self.waveforms = waveforms if self.reserve_waveforms is None else self.reserve_waveforms
- feats = torch.stack(self.lfr_splice_cache)
+ feats = torch.stack(self.lfr_splice_cache)
feats_lengths = torch.zeros(batch_size, dtype=torch.int) + feats.shape[1]
feats, feats_lengths, _ = self.forward_lfr_cmvn(feats, feats_lengths, is_final)
if is_final:
@@ -459,16 +463,16 @@
lfr_m: int = 1,
lfr_n: int = 1,
):
- assert check_argument_types()
super().__init__()
self.fs = fs
self.frame_length = frame_length
self.frame_shift = frame_shift
self.lfr_m = lfr_m
self.lfr_n = lfr_n
+ self.n_mels = 23
def output_size(self) -> int:
- return self.n_mels * self.lfr_m
+ return self.n_mels * (2 * self.lfr_m + 1)
def forward(
self,
@@ -480,10 +484,10 @@
for i in range(batch_size):
waveform_length = input_lengths[i]
waveform = input[i][:waveform_length]
- waveform = waveform.unsqueeze(0).numpy()
+ waveform = waveform.numpy()
mat = eend_ola_feature.stft(waveform, self.frame_length, self.frame_shift)
mat = eend_ola_feature.transform(mat)
- mat = mat.splice(mat, context_size=self.lfr_m)
+ mat = eend_ola_feature.splice(mat, context_size=self.lfr_m)
mat = mat[::self.lfr_n]
mat = torch.from_numpy(mat)
feat_length = mat.size(0)
@@ -494,4 +498,4 @@
feats_pad = pad_sequence(feats,
batch_first=True,
padding_value=0.0)
- return feats_pad, feats_lens
\ No newline at end of file
+ return feats_pad, feats_lens
--
Gitblit v1.9.1