From d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99 Mon Sep 17 00:00:00 2001
From: liugz18 <57401541+liugz18@users.noreply.github.com>
Date: 星期四, 18 七月 2024 21:34:55 +0800
Subject: [PATCH] Rename 'res' in line 514 to avoid with naming conflict with line 365
---
funasr/frontends/wav_frontend.py | 353 ++++++++++++++++++++++++++++++++--------------------------
1 files changed, 192 insertions(+), 161 deletions(-)
diff --git a/funasr/frontends/wav_frontend.py b/funasr/frontends/wav_frontend.py
index fe22335..a4002df 100644
--- a/funasr/frontends/wav_frontend.py
+++ b/funasr/frontends/wav_frontend.py
@@ -12,24 +12,23 @@
from funasr.register import tables
-
def load_cmvn(cmvn_file):
- with open(cmvn_file, 'r', encoding='utf-8') as f:
+ with open(cmvn_file, "r", encoding="utf-8") as f:
lines = f.readlines()
means_list = []
vars_list = []
for i in range(len(lines)):
line_item = lines[i].split()
- if line_item[0] == '<AddShift>':
+ if line_item[0] == "<AddShift>":
line_item = lines[i + 1].split()
- if line_item[0] == '<LearnRateCoef>':
- add_shift_line = line_item[3:(len(line_item) - 1)]
+ if line_item[0] == "<LearnRateCoef>":
+ add_shift_line = line_item[3 : (len(line_item) - 1)]
means_list = list(add_shift_line)
continue
- elif line_item[0] == '<Rescale>':
+ elif line_item[0] == "<Rescale>":
line_item = lines[i + 1].split()
- if line_item[0] == '<LearnRateCoef>':
- rescale_line = line_item[3:(len(line_item) - 1)]
+ if line_item[0] == "<LearnRateCoef>":
+ rescale_line = line_item[3 : (len(line_item) - 1)]
vars_list = list(rescale_line)
continue
means = np.array(means_list).astype(np.float32)
@@ -65,37 +64,38 @@
T = T + (lfr_m - 1) // 2
for i in range(T_lfr):
if lfr_m <= T - i * lfr_n:
- LFR_inputs.append((inputs[i * lfr_n:i * lfr_n + lfr_m]).view(1, -1))
+ LFR_inputs.append((inputs[i * lfr_n : i * lfr_n + lfr_m]).view(1, -1))
else: # process last LFR frame
num_padding = lfr_m - (T - i * lfr_n)
- frame = (inputs[i * lfr_n:]).view(-1)
+ frame = (inputs[i * lfr_n :]).view(-1)
for _ in range(num_padding):
frame = torch.hstack((frame, inputs[-1]))
LFR_inputs.append(frame)
LFR_outputs = torch.vstack(LFR_inputs)
return LFR_outputs.type(torch.float32)
+
+@tables.register("frontend_classes", "wav_frontend")
@tables.register("frontend_classes", "WavFrontend")
class WavFrontend(nn.Module):
- """Conventional frontend structure for ASR.
- """
+ """Conventional frontend structure for ASR."""
def __init__(
- self,
- cmvn_file: str = None,
- fs: int = 16000,
- window: str = 'hamming',
- n_mels: int = 80,
- frame_length: int = 25,
- frame_shift: int = 10,
- filter_length_min: int = -1,
- filter_length_max: int = -1,
- lfr_m: int = 1,
- lfr_n: int = 1,
- dither: float = 1.0,
- snip_edges: bool = True,
- upsacle_samples: bool = True,
- **kwargs,
+ self,
+ cmvn_file: str = None,
+ fs: int = 16000,
+ window: str = "hamming",
+ n_mels: int = 80,
+ frame_length: int = 25,
+ frame_shift: int = 10,
+ filter_length_min: int = -1,
+ filter_length_max: int = -1,
+ lfr_m: int = 1,
+ lfr_n: int = 1,
+ dither: float = 1.0,
+ snip_edges: bool = True,
+ upsacle_samples: bool = True,
+ **kwargs,
):
super().__init__()
self.fs = fs
@@ -117,10 +117,10 @@
return self.n_mels * self.lfr_m
def forward(
- self,
- input: torch.Tensor,
- input_lengths,
- **kwargs,
+ self,
+ input: torch.Tensor,
+ input_lengths,
+ **kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
batch_size = input.size(0)
feats = []
@@ -131,15 +131,17 @@
if self.upsacle_samples:
waveform = waveform * (1 << 15)
waveform = waveform.unsqueeze(0)
- mat = kaldi.fbank(waveform,
- num_mel_bins=self.n_mels,
- frame_length=self.frame_length,
- frame_shift=self.frame_shift,
- dither=self.dither,
- energy_floor=0.0,
- window_type=self.window,
- sample_frequency=self.fs,
- snip_edges=self.snip_edges)
+ mat = kaldi.fbank(
+ waveform,
+ num_mel_bins=self.n_mels,
+ frame_length=self.frame_length,
+ frame_shift=self.frame_shift,
+ dither=self.dither,
+ energy_floor=0.0,
+ window_type=self.window,
+ sample_frequency=self.fs,
+ snip_edges=self.snip_edges,
+ )
if self.lfr_m != 1 or self.lfr_n != 1:
mat = apply_lfr(mat, self.lfr_m, self.lfr_n)
@@ -153,15 +155,12 @@
if batch_size == 1:
feats_pad = feats[0][None, :, :]
else:
- feats_pad = pad_sequence(feats,
- batch_first=True,
- padding_value=0.0)
+ feats_pad = pad_sequence(feats, batch_first=True, padding_value=0.0)
return feats_pad, feats_lens
def forward_fbank(
- self,
- input: torch.Tensor,
- input_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ self, input: torch.Tensor, input_lengths: torch.Tensor
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
batch_size = input.size(0)
feats = []
feats_lens = []
@@ -170,34 +169,33 @@
waveform = input[i][:waveform_length]
waveform = waveform * (1 << 15)
waveform = waveform.unsqueeze(0)
- mat = kaldi.fbank(waveform,
- num_mel_bins=self.n_mels,
- frame_length=self.frame_length,
- frame_shift=self.frame_shift,
- dither=self.dither,
- energy_floor=0.0,
- window_type=self.window,
- sample_frequency=self.fs)
+ mat = kaldi.fbank(
+ waveform,
+ num_mel_bins=self.n_mels,
+ frame_length=self.frame_length,
+ frame_shift=self.frame_shift,
+ dither=self.dither,
+ energy_floor=0.0,
+ window_type=self.window,
+ sample_frequency=self.fs,
+ )
feat_length = mat.size(0)
feats.append(mat)
feats_lens.append(feat_length)
feats_lens = torch.as_tensor(feats_lens)
- feats_pad = pad_sequence(feats,
- batch_first=True,
- padding_value=0.0)
+ feats_pad = pad_sequence(feats, batch_first=True, padding_value=0.0)
return feats_pad, feats_lens
def forward_lfr_cmvn(
- self,
- input: torch.Tensor,
- input_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ self, input: torch.Tensor, input_lengths: torch.Tensor
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
batch_size = input.size(0)
feats = []
feats_lens = []
for i in range(batch_size):
- mat = input[i, :input_lengths[i], :]
+ mat = input[i, : input_lengths[i], :]
if self.lfr_m != 1 or self.lfr_n != 1:
mat = apply_lfr(mat, self.lfr_m, self.lfr_n)
if self.cmvn is not None:
@@ -207,33 +205,30 @@
feats_lens.append(feat_length)
feats_lens = torch.as_tensor(feats_lens)
- feats_pad = pad_sequence(feats,
- batch_first=True,
- padding_value=0.0)
+ feats_pad = pad_sequence(feats, batch_first=True, padding_value=0.0)
return feats_pad, feats_lens
@tables.register("frontend_classes", "WavFrontendOnline")
class WavFrontendOnline(nn.Module):
- """Conventional frontend structure for streaming ASR/VAD.
- """
+ """Conventional frontend structure for streaming ASR/VAD."""
def __init__(
- self,
- cmvn_file: str = None,
- fs: int = 16000,
- window: str = 'hamming',
- n_mels: int = 80,
- frame_length: int = 25,
- frame_shift: int = 10,
- filter_length_min: int = -1,
- filter_length_max: int = -1,
- lfr_m: int = 1,
- lfr_n: int = 1,
- dither: float = 1.0,
- snip_edges: bool = True,
- upsacle_samples: bool = True,
- **kwargs,
+ self,
+ cmvn_file: str = None,
+ fs: int = 16000,
+ window: str = "hamming",
+ n_mels: int = 80,
+ frame_length: int = 25,
+ frame_shift: int = 10,
+ filter_length_min: int = -1,
+ filter_length_max: int = -1,
+ lfr_m: int = 1,
+ lfr_n: int = 1,
+ dither: float = 1.0,
+ snip_edges: bool = True,
+ upsacle_samples: bool = True,
+ **kwargs,
):
super().__init__()
self.fs = fs
@@ -280,8 +275,9 @@
return inputs.type(torch.float32)
@staticmethod
- def apply_lfr(inputs: torch.Tensor, lfr_m: int, lfr_n: int, is_final: bool = False) -> Tuple[
- torch.Tensor, torch.Tensor, int]:
+ def apply_lfr(
+ inputs: torch.Tensor, lfr_m: int, lfr_n: int, is_final: bool = False
+ ) -> Tuple[torch.Tensor, torch.Tensor, int]:
"""
Apply lfr with data
"""
@@ -289,15 +285,17 @@
LFR_inputs = []
# inputs = torch.vstack((inputs_lfr_cache, inputs))
T = inputs.shape[0] # include the right context
- T_lfr = int(np.ceil((T - (lfr_m - 1) // 2) / lfr_n)) # minus the right context: (lfr_m - 1) // 2
+ T_lfr = int(
+ np.ceil((T - (lfr_m - 1) // 2) / lfr_n)
+ ) # minus the right context: (lfr_m - 1) // 2
splice_idx = T_lfr
for i in range(T_lfr):
if lfr_m <= T - i * lfr_n:
- LFR_inputs.append((inputs[i * lfr_n:i * lfr_n + lfr_m]).view(1, -1))
+ LFR_inputs.append((inputs[i * lfr_n : i * lfr_n + lfr_m]).view(1, -1))
else: # process last LFR frame
if is_final:
num_padding = lfr_m - (T - i * lfr_n)
- frame = (inputs[i * lfr_n:]).view(-1)
+ frame = (inputs[i * lfr_n :]).view(-1)
for _ in range(num_padding):
frame = torch.hstack((frame, inputs[-1]))
LFR_inputs.append(frame)
@@ -311,23 +309,29 @@
return LFR_outputs.type(torch.float32), lfr_splice_cache, splice_idx
@staticmethod
- def compute_frame_num(sample_length: int, frame_sample_length: int, frame_shift_sample_length: int) -> int:
+ def compute_frame_num(
+ sample_length: int, frame_sample_length: int, frame_shift_sample_length: int
+ ) -> int:
frame_num = int((sample_length - frame_sample_length) / frame_shift_sample_length + 1)
return frame_num if frame_num >= 1 and sample_length >= frame_sample_length else 0
def forward_fbank(
- self,
- input: torch.Tensor,
- input_lengths: torch.Tensor,
- cache: dict = {},
- **kwargs,
+ self,
+ input: torch.Tensor,
+ input_lengths: torch.Tensor,
+ cache: dict = {},
+ **kwargs,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
batch_size = input.size(0)
input = torch.cat((cache["input_cache"], input), dim=1)
- frame_num = self.compute_frame_num(input.shape[-1], self.frame_sample_length, self.frame_shift_sample_length)
+ frame_num = self.compute_frame_num(
+ input.shape[-1], self.frame_sample_length, self.frame_shift_sample_length
+ )
# update self.in_cache
- cache["input_cache"] = input[:, -(input.shape[-1] - frame_num * self.frame_shift_sample_length):]
+ cache["input_cache"] = input[
+ :, -(input.shape[-1] - frame_num * self.frame_shift_sample_length) :
+ ]
waveforms = torch.empty(0)
feats_pad = torch.empty(0)
feats_lens = torch.empty(0)
@@ -339,17 +343,25 @@
waveform = input[i]
# we need accurate wave samples that used for fbank extracting
waveforms.append(
- waveform[:((frame_num - 1) * self.frame_shift_sample_length + self.frame_sample_length)])
+ waveform[
+ : (
+ (frame_num - 1) * self.frame_shift_sample_length
+ + self.frame_sample_length
+ )
+ ]
+ )
waveform = waveform * (1 << 15)
waveform = waveform.unsqueeze(0)
- mat = kaldi.fbank(waveform,
- num_mel_bins=self.n_mels,
- frame_length=self.frame_length,
- frame_shift=self.frame_shift,
- dither=self.dither,
- energy_floor=0.0,
- window_type=self.window,
- sample_frequency=self.fs)
+ mat = kaldi.fbank(
+ waveform,
+ num_mel_bins=self.n_mels,
+ frame_length=self.frame_length,
+ frame_shift=self.frame_shift,
+ dither=self.dither,
+ energy_floor=0.0,
+ window_type=self.window,
+ sample_frequency=self.fs,
+ )
feat_length = mat.size(0)
feats.append(mat)
@@ -357,33 +369,31 @@
waveforms = torch.stack(waveforms)
feats_lens = torch.as_tensor(feats_lens)
- feats_pad = pad_sequence(feats,
- batch_first=True,
- padding_value=0.0)
+ feats_pad = pad_sequence(feats, batch_first=True, padding_value=0.0)
cache["fbanks"] = feats_pad
- cache["fbanks_lens"]= copy.deepcopy(feats_lens)
+ cache["fbanks_lens"] = copy.deepcopy(feats_lens)
return waveforms, feats_pad, feats_lens
-
def forward_lfr_cmvn(
- self,
- input: torch.Tensor,
- input_lengths: torch.Tensor,
- is_final: bool = False,
- cache: dict = {},
- **kwargs,
+ self,
+ input: torch.Tensor,
+ input_lengths: torch.Tensor,
+ is_final: bool = False,
+ cache: dict = {},
+ **kwargs,
):
batch_size = input.size(0)
feats = []
feats_lens = []
lfr_splice_frame_idxs = []
for i in range(batch_size):
- mat = input[i, :input_lengths[i], :]
+ mat = input[i, : input_lengths[i], :]
if self.lfr_m != 1 or self.lfr_n != 1:
# update self.lfr_splice_cache in self.apply_lfr
# mat, self.lfr_splice_cache[i], lfr_splice_frame_idx = self.apply_lfr(mat, self.lfr_m, self.lfr_n, self.lfr_splice_cache[i],
- mat, cache["lfr_splice_cache"][i], lfr_splice_frame_idx = self.apply_lfr(mat, self.lfr_m, self.lfr_n,
- is_final)
+ mat, cache["lfr_splice_cache"][i], lfr_splice_frame_idx = self.apply_lfr(
+ mat, self.lfr_m, self.lfr_n, is_final
+ )
if self.cmvn_file is not None:
mat = self.apply_cmvn(mat, self.cmvn)
feat_length = mat.size(0)
@@ -392,68 +402,93 @@
lfr_splice_frame_idxs.append(lfr_splice_frame_idx)
feats_lens = torch.as_tensor(feats_lens)
- feats_pad = pad_sequence(feats,
- batch_first=True,
- padding_value=0.0)
+ feats_pad = pad_sequence(feats, batch_first=True, padding_value=0.0)
lfr_splice_frame_idxs = torch.as_tensor(lfr_splice_frame_idxs)
return feats_pad, feats_lens, lfr_splice_frame_idxs
- def forward(
- self, input: torch.Tensor, input_lengths: torch.Tensor, cache: dict = {}, **kwargs
- ):
+ def forward(self, input: torch.Tensor, input_lengths: torch.Tensor, **kwargs):
is_final = kwargs.get("is_final", False)
- reset = kwargs.get("reset", False)
- if len(cache) == 0 or reset:
+ cache = kwargs.get("cache", {})
+ if len(cache) == 0:
self.init_cache(cache)
-
+
batch_size = input.shape[0]
- assert batch_size == 1, 'we support to extract feature online only when the batch size is equal to 1 now'
-
- waveforms, feats, feats_lengths = self.forward_fbank(input, input_lengths, cache=cache) # input shape: B T D
-
+ assert (
+ batch_size == 1
+ ), "we support to extract feature online only when the batch size is equal to 1 now"
+
+ waveforms, feats, feats_lengths = self.forward_fbank(
+ input, input_lengths, cache=cache
+ ) # input shape: B T D
+
if feats.shape[0]:
cache["waveforms"] = torch.cat((cache["reserve_waveforms"], waveforms), dim=1)
-
+
if not cache["lfr_splice_cache"]: # 鍒濆鍖杝plice_cache
for i in range(batch_size):
- cache["lfr_splice_cache"].append(feats[i][0, :].unsqueeze(dim=0).repeat((self.lfr_m - 1) // 2, 1))
+ cache["lfr_splice_cache"].append(
+ feats[i][0, :].unsqueeze(dim=0).repeat((self.lfr_m - 1) // 2, 1)
+ )
# need the number of the input frames + self.lfr_splice_cache[0].shape[0] is greater than self.lfr_m
if feats_lengths[0] + cache["lfr_splice_cache"][0].shape[0] >= self.lfr_m:
lfr_splice_cache_tensor = torch.stack(cache["lfr_splice_cache"]) # B T D
feats = torch.cat((lfr_splice_cache_tensor, feats), dim=1)
feats_lengths += lfr_splice_cache_tensor[0].shape[0]
frame_from_waveforms = int(
- (cache["waveforms"].shape[1] - self.frame_sample_length) / self.frame_shift_sample_length + 1)
- minus_frame = (self.lfr_m - 1) // 2 if cache["reserve_waveforms"].numel() == 0 else 0
- feats, feats_lengths, lfr_splice_frame_idxs = self.forward_lfr_cmvn(feats, feats_lengths, is_final, cache=cache)
+ (cache["waveforms"].shape[1] - self.frame_sample_length)
+ / self.frame_shift_sample_length
+ + 1
+ )
+ minus_frame = (
+ (self.lfr_m - 1) // 2 if cache["reserve_waveforms"].numel() == 0 else 0
+ )
+ feats, feats_lengths, lfr_splice_frame_idxs = self.forward_lfr_cmvn(
+ feats, feats_lengths, is_final, cache=cache
+ )
if self.lfr_m == 1:
cache["reserve_waveforms"] = torch.empty(0)
else:
reserve_frame_idx = lfr_splice_frame_idxs[0] - minus_frame
# print('reserve_frame_idx: ' + str(reserve_frame_idx))
# print('frame_frame: ' + str(frame_from_waveforms))
- cache["reserve_waveforms"] = cache["waveforms"][:, reserve_frame_idx * self.frame_shift_sample_length:frame_from_waveforms * self.frame_shift_sample_length]
- sample_length = (frame_from_waveforms - 1) * self.frame_shift_sample_length + self.frame_sample_length
+ cache["reserve_waveforms"] = cache["waveforms"][
+ :,
+ reserve_frame_idx
+ * self.frame_shift_sample_length : frame_from_waveforms
+ * self.frame_shift_sample_length,
+ ]
+ sample_length = (
+ frame_from_waveforms - 1
+ ) * self.frame_shift_sample_length + self.frame_sample_length
cache["waveforms"] = cache["waveforms"][:, :sample_length]
else:
# update self.reserve_waveforms and self.lfr_splice_cache
- cache["reserve_waveforms"] = cache["waveforms"][:, :-(self.frame_sample_length - self.frame_shift_sample_length)]
+ cache["reserve_waveforms"] = cache["waveforms"][
+ :, : -(self.frame_sample_length - self.frame_shift_sample_length)
+ ]
for i in range(batch_size):
- cache["lfr_splice_cache"][i] = torch.cat((cache["lfr_splice_cache"][i], feats[i]), dim=0)
+ cache["lfr_splice_cache"][i] = torch.cat(
+ (cache["lfr_splice_cache"][i], feats[i]), dim=0
+ )
return torch.empty(0), feats_lengths
else:
if is_final:
- cache["waveforms"] = waveforms if cache["reserve_waveforms"].numel() == 0 else cache["reserve_waveforms"]
+ cache["waveforms"] = (
+ waveforms
+ if cache["reserve_waveforms"].numel() == 0
+ else cache["reserve_waveforms"]
+ )
feats = torch.stack(cache["lfr_splice_cache"])
feats_lengths = torch.zeros(batch_size, dtype=torch.int) + feats.shape[1]
- feats, feats_lengths, _ = self.forward_lfr_cmvn(feats, feats_lengths, is_final, cache=cache)
- if is_final:
- self.init_cache(cache)
+ feats, feats_lengths, _ = self.forward_lfr_cmvn(
+ feats, feats_lengths, is_final, cache=cache
+ )
+ # if is_final:
+ # self.init_cache(cache)
return feats, feats_lengths
-
- def init_cache(self, cache: dict = {}):
+ def init_cache(self, cache: dict = {}):
cache["reserve_waveforms"] = torch.empty(0)
cache["input_cache"] = torch.empty(0)
cache["lfr_splice_cache"] = []
@@ -464,17 +499,16 @@
class WavFrontendMel23(nn.Module):
- """Conventional frontend structure for ASR.
- """
+ """Conventional frontend structure for ASR."""
def __init__(
- self,
- fs: int = 16000,
- frame_length: int = 25,
- frame_shift: int = 10,
- lfr_m: int = 1,
- lfr_n: int = 1,
- **kwargs,
+ self,
+ fs: int = 16000,
+ frame_length: int = 25,
+ frame_shift: int = 10,
+ lfr_m: int = 1,
+ lfr_n: int = 1,
+ **kwargs,
):
super().__init__()
self.fs = fs
@@ -488,9 +522,8 @@
return self.n_mels * (2 * self.lfr_m + 1)
def forward(
- self,
- input: torch.Tensor,
- input_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ self, input: torch.Tensor, input_lengths: torch.Tensor
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
batch_size = input.size(0)
feats = []
feats_lens = []
@@ -501,14 +534,12 @@
mat = eend_ola_feature.stft(waveform, self.frame_length, self.frame_shift)
mat = eend_ola_feature.transform(mat)
mat = eend_ola_feature.splice(mat, context_size=self.lfr_m)
- mat = mat[::self.lfr_n]
+ mat = mat[:: self.lfr_n]
mat = torch.from_numpy(mat)
feat_length = mat.size(0)
feats.append(mat)
feats_lens.append(feat_length)
feats_lens = torch.as_tensor(feats_lens)
- feats_pad = pad_sequence(feats,
- batch_first=True,
- padding_value=0.0)
+ feats_pad = pad_sequence(feats, batch_first=True, padding_value=0.0)
return feats_pad, feats_lens
--
Gitblit v1.9.1