From 0cf5dfec2c8313fc2ed2aab8d10bf3dc4b9c283f Mon Sep 17 00:00:00 2001
From: 雾聪 <wucong.lyb@alibaba-inc.com>
Date: 星期四, 14 三月 2024 14:41:49 +0800
Subject: [PATCH] update cmakelist

---
 funasr/models/seaco_paraformer/model.py |    8 ++++----
 1 files changed, 4 insertions(+), 4 deletions(-)

diff --git a/funasr/models/seaco_paraformer/model.py b/funasr/models/seaco_paraformer/model.py
index 5d0f602..92fc989 100644
--- a/funasr/models/seaco_paraformer/model.py
+++ b/funasr/models/seaco_paraformer/model.py
@@ -128,7 +128,7 @@
     
         hotword_pad = kwargs.get("hotword_pad")
         hotword_lengths = kwargs.get("hotword_lengths")
-        dha_pad = kwargs.get("dha_pad")
+        seaco_label_pad = kwargs.get("seaco_label_pad")
         
         batch_size = speech.shape[0]
         # for data-parallel
@@ -148,7 +148,7 @@
                                         ys_lengths, 
                                         hotword_pad, 
                                         hotword_lengths, 
-                                        dha_pad,
+                                        seaco_label_pad,
                                         )
         if self.train_decoder:
             loss_att, acc_att = self._calc_att_loss(
@@ -185,7 +185,7 @@
             ys_lengths: torch.Tensor,
             hotword_pad: torch.Tensor,
             hotword_lengths: torch.Tensor,
-            dha_pad: torch.Tensor,
+            seaco_label_pad: torch.Tensor,
     ):  
         # predictor forward
         encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
@@ -204,7 +204,7 @@
         dec_attended, _ = self.seaco_decoder(contextual_info, _contextual_length, decoder_out, ys_lengths)
         merged = self._merge(cif_attended, dec_attended)
         dha_output = self.hotword_output_layer(merged[:, :-1])  # remove the last token in loss calculation
-        loss_att = self.criterion_seaco(dha_output, dha_pad)
+        loss_att = self.criterion_seaco(dha_output, seaco_label_pad)
         return loss_att
 
     def _seaco_decode_with_ASF(self, 

--
Gitblit v1.9.1