From 26d642bfdf59a50365a9c8158acb223cae1004dc Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期二, 23 四月 2024 20:13:44 +0800
Subject: [PATCH] Dev gzf exp (#1651)
---
funasr/models/seaco_paraformer/model.py | 22 ++++++++++++++--------
1 files changed, 14 insertions(+), 8 deletions(-)
diff --git a/funasr/models/seaco_paraformer/model.py b/funasr/models/seaco_paraformer/model.py
index 27ff5d1..b28de94 100644
--- a/funasr/models/seaco_paraformer/model.py
+++ b/funasr/models/seaco_paraformer/model.py
@@ -97,7 +97,8 @@
smoothing=seaco_lsm_weight,
normalize_length=seaco_length_normalized_loss,
)
- self.train_decoder = kwargs.get("train_decoder", False)
+ self.train_decoder = kwargs.get("train_decoder", True)
+ self.seaco_weight = kwargs.get("seaco_weight", 0.01)
self.NO_BIAS = kwargs.get("NO_BIAS", 8377)
self.predictor_name = kwargs.get("predictor")
@@ -117,7 +118,10 @@
text: (Batch, Length)
text_lengths: (Batch,)
"""
- assert text_lengths.dim() == 1, text_lengths.shape
+ if len(text_lengths.size()) > 1:
+ text_lengths = text_lengths[:, 0]
+ if len(speech_lengths.size()) > 1:
+ speech_lengths = speech_lengths[:, 0]
# Check that batch_size is unified
assert (
speech.shape[0]
@@ -129,6 +133,8 @@
hotword_pad = kwargs.get("hotword_pad")
hotword_lengths = kwargs.get("hotword_lengths")
seaco_label_pad = kwargs.get("seaco_label_pad")
+ if len(hotword_lengths.size()) > 1:
+ hotword_lengths = hotword_lengths[:, 0]
batch_size = speech.shape[0]
# for data-parallel
@@ -151,20 +157,21 @@
seaco_label_pad,
)
if self.train_decoder:
- loss_att, acc_att = self._calc_att_loss(
+ loss_att, acc_att, _, _, _ = self._calc_att_loss(
encoder_out, encoder_out_lens, text, text_lengths
)
- loss = loss_seaco + loss_att
+ loss = loss_seaco + loss_att * self.seaco_weight
stats["loss_att"] = torch.clone(loss_att.detach())
stats["acc_att"] = acc_att
else:
loss = loss_seaco
+
stats["loss_seaco"] = torch.clone(loss_seaco.detach())
stats["loss"] = torch.clone(loss.detach())
# force_gatherable: to-device and to-tensor if scalar for DataParallel
if self.length_normalized_loss:
- batch_size = (text_lengths + self.predictor_bias).sum().type_as(batch_size)
+ batch_size = (text_lengths + self.predictor_bias).sum()
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
return loss, stats, weight
@@ -190,8 +197,7 @@
# predictor forward
encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
encoder_out.device)
- pre_acoustic_embeds, _, _, _ = self.predictor(encoder_out, ys_pad, encoder_out_mask,
- ignore_id=self.ignore_id)
+ pre_acoustic_embeds = self.predictor(encoder_out, ys_pad, encoder_out_mask, ignore_id=self.ignore_id)[0]
# decoder forward
decoder_out, _ = self.decoder(encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_lengths, return_hidden=True)
selected = self._hotword_representation(hotword_pad,
@@ -344,7 +350,7 @@
pre_acoustic_embeds, pre_token_length = predictor_outs[0], predictor_outs[1]
pre_token_length = pre_token_length.round().long()
if torch.max(pre_token_length) < 1:
- return []
+ return ([],)
decoder_out = self._seaco_decode_with_ASF(encoder_out,
encoder_out_lens,
--
Gitblit v1.9.1