From 4a7a984a5f3e3f894f86ce82e76ddd13d8a42a20 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 11 三月 2024 17:56:30 +0800
Subject: [PATCH] Dev gzf (#1465)
---
funasr/models/seaco_paraformer/model.py | 14 +++++---------
1 files changed, 5 insertions(+), 9 deletions(-)
diff --git a/funasr/models/seaco_paraformer/model.py b/funasr/models/seaco_paraformer/model.py
index f671db6..92fc989 100644
--- a/funasr/models/seaco_paraformer/model.py
+++ b/funasr/models/seaco_paraformer/model.py
@@ -128,7 +128,7 @@
hotword_pad = kwargs.get("hotword_pad")
hotword_lengths = kwargs.get("hotword_lengths")
- dha_pad = kwargs.get("dha_pad")
+ seaco_label_pad = kwargs.get("seaco_label_pad")
batch_size = speech.shape[0]
# for data-parallel
@@ -148,7 +148,7 @@
ys_lengths,
hotword_pad,
hotword_lengths,
- dha_pad,
+ seaco_label_pad,
)
if self.train_decoder:
loss_att, acc_att = self._calc_att_loss(
@@ -175,11 +175,7 @@
encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
encoder_out.device)
predictor_outs = self.predictor(encoder_out, None, encoder_out_mask, ignore_id=self.ignore_id)
- if len(predictor_outs) == 4:
- pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = predictor_outs
- else:
- pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index, pre_token_length2 = predictor_outs
- return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index
+ return predictor_outs[:4]
def _calc_seaco_loss(
self,
@@ -189,7 +185,7 @@
ys_lengths: torch.Tensor,
hotword_pad: torch.Tensor,
hotword_lengths: torch.Tensor,
- dha_pad: torch.Tensor,
+ seaco_label_pad: torch.Tensor,
):
# predictor forward
encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
@@ -208,7 +204,7 @@
dec_attended, _ = self.seaco_decoder(contextual_info, _contextual_length, decoder_out, ys_lengths)
merged = self._merge(cif_attended, dec_attended)
dha_output = self.hotword_output_layer(merged[:, :-1]) # remove the last token in loss calculation
- loss_att = self.criterion_seaco(dha_output, dha_pad)
+ loss_att = self.criterion_seaco(dha_output, seaco_label_pad)
return loss_att
def _seaco_decode_with_ASF(self,
--
Gitblit v1.9.1