From d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99 Mon Sep 17 00:00:00 2001
From: liugz18 <57401541+liugz18@users.noreply.github.com>
Date: 星期四, 18 七月 2024 21:34:55 +0800
Subject: [PATCH] Rename 'res' in line 514 to avoid with naming conflict with line 365
---
funasr/models/monotonic_aligner/model.py | 128 +++++++++++++++++++++++++-----------------
1 files changed, 76 insertions(+), 52 deletions(-)
diff --git a/funasr/models/monotonic_aligner/model.py b/funasr/models/monotonic_aligner/model.py
index 718923b..4754d1f 100644
--- a/funasr/models/monotonic_aligner/model.py
+++ b/funasr/models/monotonic_aligner/model.py
@@ -28,6 +28,7 @@
Achieving timestamp prediction while recognizing with non-autoregressive end-to-end ASR model
https://arxiv.org/abs/2301.12343
"""
+
def __init__(
self,
input_size: int = 80,
@@ -64,11 +65,11 @@
self.predictor_bias = predictor_bias
def forward(
- self,
- speech: torch.Tensor,
- speech_lengths: torch.Tensor,
- text: torch.Tensor,
- text_lengths: torch.Tensor,
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ text: torch.Tensor,
+ text_lengths: torch.Tensor,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Frontend + Encoder + Decoder + Calc loss
Args:
@@ -80,25 +81,25 @@
assert text_lengths.dim() == 1, text_lengths.shape
# Check that batch_size is unified
assert (
- speech.shape[0]
- == speech_lengths.shape[0]
- == text.shape[0]
- == text_lengths.shape[0]
+ speech.shape[0] == speech_lengths.shape[0] == text.shape[0] == text_lengths.shape[0]
), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
batch_size = speech.shape[0]
# for data-parallel
text = text[:, : text_lengths.max()]
- speech = speech[:, :speech_lengths.max()]
+ speech = speech[:, : speech_lengths.max()]
# 1. Encoder
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
- encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
- encoder_out.device)
+ encoder_out_mask = (
+ ~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]
+ ).to(encoder_out.device)
if self.predictor_bias == 1:
_, text = add_sos_eos(text, 1, 2, -1)
text_lengths = text_lengths + self.predictor_bias
- _, _, _, _, pre_token_length2 = self.predictor(encoder_out, text, encoder_out_mask, ignore_id=-1)
+ _, _, _, _, pre_token_length2 = self.predictor(
+ encoder_out, text, encoder_out_mask, ignore_id=-1
+ )
# loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
loss_pre = self.criterion_pre(text_lengths.type_as(pre_token_length2), pre_token_length2)
@@ -115,15 +116,19 @@
return loss, stats, weight
def calc_predictor_timestamp(self, encoder_out, encoder_out_lens, token_num):
- encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
- encoder_out.device)
- ds_alphas, ds_cif_peak, us_alphas, us_peaks = self.predictor.get_upsample_timestamp(encoder_out,
- encoder_out_mask,
- token_num)
+ encoder_out_mask = (
+ ~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]
+ ).to(encoder_out.device)
+ ds_alphas, ds_cif_peak, us_alphas, us_peaks = self.predictor.get_upsample_timestamp(
+ encoder_out, encoder_out_mask, token_num
+ )
return ds_alphas, ds_cif_peak, us_alphas, us_peaks
def encode(
- self, speech: torch.Tensor, speech_lengths: torch.Tensor, **kwargs,
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ **kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Encoder. Note that this method is used by asr_inference.py
Args:
@@ -136,51 +141,62 @@
# Data augmentation
if self.specaug is not None and self.training:
speech, speech_lengths = self.specaug(speech, speech_lengths)
-
+
# Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
if self.normalize is not None:
speech, speech_lengths = self.normalize(speech, speech_lengths)
-
+
# Forward encoder
encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
if isinstance(encoder_out, tuple):
encoder_out = encoder_out[0]
return encoder_out, encoder_out_lens
-
- def inference(self,
- data_in,
- data_lengths=None,
- key: list=None,
- tokenizer=None,
- frontend=None,
- **kwargs,
- ):
+
+ def inference(
+ self,
+ data_in,
+ data_lengths=None,
+ key: list = None,
+ tokenizer=None,
+ frontend=None,
+ **kwargs,
+ ):
meta_data = {}
# extract fbank feats
time1 = time.perf_counter()
- audio_list, text_token_int_list = load_audio_text_image_video(data_in,
- fs=frontend.fs,
- audio_fs=kwargs.get("fs", 16000),
- data_type=kwargs.get("data_type", "sound"),
- tokenizer=tokenizer)
+ audio_list, text_token_int_list = load_audio_text_image_video(
+ data_in,
+ fs=frontend.fs,
+ audio_fs=kwargs.get("fs", 16000),
+ data_type=kwargs.get("data_type", "sound"),
+ tokenizer=tokenizer,
+ )
time2 = time.perf_counter()
meta_data["load_data"] = f"{time2 - time1:0.3f}"
- speech, speech_lengths = extract_fbank(audio_list, data_type=kwargs.get("data_type", "sound"), frontend=frontend)
+ speech, speech_lengths = extract_fbank(
+ audio_list, data_type=kwargs.get("data_type", "sound"), frontend=frontend
+ )
time3 = time.perf_counter()
meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
- meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
-
+ meta_data["batch_data_time"] = (
+ speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
+ )
+
speech = speech.to(device=kwargs["device"])
speech_lengths = speech_lengths.to(device=kwargs["device"])
# Encoder
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
-
+
# predictor
- text_lengths = torch.tensor([len(i)+1 for i in text_token_int_list]).to(encoder_out.device)
- _, _, us_alphas, us_peaks = self.calc_predictor_timestamp(encoder_out, encoder_out_lens, token_num=text_lengths)
-
+ text_lengths = torch.tensor([len(i) + 1 for i in text_token_int_list]).to(
+ encoder_out.device
+ )
+ _, _, us_alphas, us_peaks = self.calc_predictor_timestamp(
+ encoder_out, encoder_out_lens, token_num=text_lengths
+ )
+
results = []
ibest_writer = None
if kwargs.get("output_dir") is not None:
@@ -188,20 +204,28 @@
self.writer = DatadirWriter(kwargs.get("output_dir"))
ibest_writer = self.writer["tp_res"]
- for i, (us_alpha, us_peak, token_int) in enumerate(zip(us_alphas, us_peaks, text_token_int_list)):
+ for i, (us_alpha, us_peak, token_int) in enumerate(
+ zip(us_alphas, us_peaks, text_token_int_list)
+ ):
token = tokenizer.ids2tokens(token_int)
- timestamp_str, timestamp = ts_prediction_lfr6_standard(us_alpha[:encoder_out_lens[i] * 3],
- us_peak[:encoder_out_lens[i] * 3],
- copy.copy(token))
- text_postprocessed, time_stamp_postprocessed, _ = postprocess_utils.sentence_postprocess(token, timestamp)
- result_i = {"key": key[i], "text": text_postprocessed,
- "timestamp": time_stamp_postprocessed,
- }
+ timestamp_str, timestamp = ts_prediction_lfr6_standard(
+ us_alpha[: encoder_out_lens[i] * 3],
+ us_peak[: encoder_out_lens[i] * 3],
+ copy.copy(token),
+ )
+ text_postprocessed, time_stamp_postprocessed, _ = (
+ postprocess_utils.sentence_postprocess(token, timestamp)
+ )
+ result_i = {
+ "key": key[i],
+ "text": text_postprocessed,
+ "timestamp": time_stamp_postprocessed,
+ }
results.append(result_i)
if ibest_writer:
# ibest_writer["token"][key[i]] = " ".join(token)
ibest_writer["timestamp_list"][key[i]] = time_stamp_postprocessed
ibest_writer["timestamp_str"][key[i]] = timestamp_str
-
- return results, meta_data
\ No newline at end of file
+
+ return results, meta_data
--
Gitblit v1.9.1