From 873cfae5c347b940e38e853d8579a6b4e85ada05 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期日, 24 三月 2024 00:45:45 +0800
Subject: [PATCH] update
---
funasr/models/seaco_paraformer/model.py | 96 ++++++++++++++++++++++++++++++------------------
1 files changed, 60 insertions(+), 36 deletions(-)
diff --git a/funasr/models/seaco_paraformer/model.py b/funasr/models/seaco_paraformer/model.py
index a8b1f1f..21b6aba 100644
--- a/funasr/models/seaco_paraformer/model.py
+++ b/funasr/models/seaco_paraformer/model.py
@@ -30,7 +30,7 @@
from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
-import pdb
+
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
from torch.cuda.amp import autocast
else:
@@ -99,6 +99,7 @@
)
self.train_decoder = kwargs.get("train_decoder", False)
self.NO_BIAS = kwargs.get("NO_BIAS", 8377)
+ self.predictor_name = kwargs.get("predictor")
def forward(
self,
@@ -116,6 +117,8 @@
text: (Batch, Length)
text_lengths: (Batch,)
"""
+ text_lengths = text_lengths.squeeze()
+ speech_lengths = speech_lengths.squeeze()
assert text_lengths.dim() == 1, text_lengths.shape
# Check that batch_size is unified
assert (
@@ -127,7 +130,7 @@
hotword_pad = kwargs.get("hotword_pad")
hotword_lengths = kwargs.get("hotword_lengths")
- dha_pad = kwargs.get("dha_pad")
+ seaco_label_pad = kwargs.get("seaco_label_pad")
batch_size = speech.shape[0]
# for data-parallel
@@ -147,7 +150,7 @@
ys_lengths,
hotword_pad,
hotword_lengths,
- dha_pad,
+ seaco_label_pad,
)
if self.train_decoder:
loss_att, acc_att = self._calc_att_loss(
@@ -163,12 +166,18 @@
# force_gatherable: to-device and to-tensor if scalar for DataParallel
if self.length_normalized_loss:
- batch_size = (text_lengths + self.predictor_bias).sum().type_as(batch_size)
+ batch_size = (text_lengths + self.predictor_bias).sum()
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
return loss, stats, weight
def _merge(self, cif_attended, dec_attended):
return cif_attended + dec_attended
+
+ def calc_predictor(self, encoder_out, encoder_out_lens):
+ encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
+ encoder_out.device)
+ predictor_outs = self.predictor(encoder_out, None, encoder_out_mask, ignore_id=self.ignore_id)
+ return predictor_outs[:4]
def _calc_seaco_loss(
self,
@@ -178,13 +187,12 @@
ys_lengths: torch.Tensor,
hotword_pad: torch.Tensor,
hotword_lengths: torch.Tensor,
- dha_pad: torch.Tensor,
+ seaco_label_pad: torch.Tensor,
):
# predictor forward
encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
encoder_out.device)
- pre_acoustic_embeds, _, _, _ = self.predictor(encoder_out, ys_pad, encoder_out_mask,
- ignore_id=self.ignore_id)
+ pre_acoustic_embeds = self.predictor(encoder_out, ys_pad, encoder_out_mask, ignore_id=self.ignore_id)[0]
# decoder forward
decoder_out, _ = self.decoder(encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_lengths, return_hidden=True)
selected = self._hotword_representation(hotword_pad,
@@ -197,7 +205,7 @@
dec_attended, _ = self.seaco_decoder(contextual_info, _contextual_length, decoder_out, ys_lengths)
merged = self._merge(cif_attended, dec_attended)
dha_output = self.hotword_output_layer(merged[:, :-1]) # remove the last token in loss calculation
- loss_att = self.criterion_seaco(dha_output, dha_pad)
+ loss_att = self.criterion_seaco(dha_output, seaco_label_pad)
return loss_att
def _seaco_decode_with_ASF(self,
@@ -248,7 +256,7 @@
def _merge_res(dec_output, dha_output):
lmbd = torch.Tensor([seaco_weight] * dha_output.shape[0])
dha_ids = dha_output.max(-1)[-1]# [0]
- dha_mask = (dha_ids == 8377).int().unsqueeze(-1)
+ dha_mask = (dha_ids == self.NO_BIAS).int().unsqueeze(-1)
a = (1 - lmbd) / lmbd
b = 1 / lmbd
a, b = a.to(dec_output.device), b.to(dec_output.device)
@@ -332,23 +340,28 @@
if isinstance(encoder_out, tuple):
encoder_out = encoder_out[0]
-
# predictor
predictor_outs = self.calc_predictor(encoder_out, encoder_out_lens)
- pre_acoustic_embeds, pre_token_length, _, _ = predictor_outs[0], predictor_outs[1], \
- predictor_outs[2], predictor_outs[3]
+ pre_acoustic_embeds, pre_token_length = predictor_outs[0], predictor_outs[1]
pre_token_length = pre_token_length.round().long()
if torch.max(pre_token_length) < 1:
return []
- decoder_out = self._seaco_decode_with_ASF(encoder_out, encoder_out_lens,
- pre_acoustic_embeds,
- pre_token_length,
- hw_list=self.hotword_list)
+ decoder_out = self._seaco_decode_with_ASF(encoder_out,
+ encoder_out_lens,
+ pre_acoustic_embeds,
+ pre_token_length,
+ hw_list=self.hotword_list
+ )
# decoder_out, _ = decoder_outs[0], decoder_outs[1]
- _, _, us_alphas, us_peaks = self.calc_predictor_timestamp(encoder_out, encoder_out_lens,
- pre_token_length)
+ if self.predictor_name == "CifPredictorV3":
+ _, _, us_alphas, us_peaks = self.calc_predictor_timestamp(encoder_out,
+ encoder_out_lens,
+ pre_token_length)
+ else:
+ us_alphas = None
+
results = []
b, n, d = decoder_out.size()
for i in range(b):
@@ -393,29 +406,30 @@
# Change integer-ids to tokens
token = tokenizer.ids2tokens(token_int)
text = tokenizer.tokens2text(token)
-
- _, timestamp = ts_prediction_lfr6_standard(us_alphas[i][:encoder_out_lens[i] * 3],
- us_peaks[i][:encoder_out_lens[i] * 3],
- copy.copy(token),
- vad_offset=kwargs.get("begin_time", 0))
-
- text_postprocessed, time_stamp_postprocessed, word_lists = postprocess_utils.sentence_postprocess(
- token, timestamp)
-
- result_i = {"key": key[i], "text": text_postprocessed,
- "timestamp": time_stamp_postprocessed
- }
-
- if ibest_writer is not None:
- ibest_writer["token"][key[i]] = " ".join(token)
- ibest_writer["timestamp"][key[i]] = time_stamp_postprocessed
- ibest_writer["text"][key[i]] = text_postprocessed
+ if us_alphas is not None:
+ _, timestamp = ts_prediction_lfr6_standard(us_alphas[i][:encoder_out_lens[i] * 3],
+ us_peaks[i][:encoder_out_lens[i] * 3],
+ copy.copy(token),
+ vad_offset=kwargs.get("begin_time", 0))
+ text_postprocessed, time_stamp_postprocessed, _ = \
+ postprocess_utils.sentence_postprocess(token, timestamp)
+ result_i = {"key": key[i], "text": text_postprocessed,
+ "timestamp": time_stamp_postprocessed}
+ if ibest_writer is not None:
+ ibest_writer["token"][key[i]] = " ".join(token)
+ ibest_writer["timestamp"][key[i]] = time_stamp_postprocessed
+ ibest_writer["text"][key[i]] = text_postprocessed
+ else:
+ text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
+ result_i = {"key": key[i], "text": text_postprocessed}
+ if ibest_writer is not None:
+ ibest_writer["token"][key[i]] = " ".join(token)
+ ibest_writer["text"][key[i]] = text_postprocessed
else:
result_i = {"key": key[i], "token_int": token_int}
results.append(result_i)
return results, meta_data
-
def generate_hotwords_list(self, hotword_list_or_file, tokenizer=None, frontend=None):
def load_seg_dict(seg_dict_file):
@@ -518,3 +532,13 @@
hotword_list = None
return hotword_list
+ def export(
+ self,
+ **kwargs,
+ ):
+ if 'max_seq_len' not in kwargs:
+ kwargs['max_seq_len'] = 512
+ from .export_meta import export_rebuild_model
+ models = export_rebuild_model(model=self, **kwargs)
+ return models
+
--
Gitblit v1.9.1