From 149063ced4d2d5269f0472677228eadfcb4a4d8a Mon Sep 17 00:00:00 2001
From: 维石 <shixian.shi@alibaba-inc.com>
Date: 星期三, 17 四月 2024 14:33:24 +0800
Subject: [PATCH] update seaco finetune

---
 funasr/models/paraformer/model.py       |    3 ---
 funasr/models/seaco_paraformer/model.py |   15 ++++++++++-----
 2 files changed, 10 insertions(+), 8 deletions(-)

diff --git a/funasr/models/paraformer/model.py b/funasr/models/paraformer/model.py
index d47db11..6c7957c 100644
--- a/funasr/models/paraformer/model.py
+++ b/funasr/models/paraformer/model.py
@@ -181,15 +181,12 @@
                 text: (Batch, Length)
                 text_lengths: (Batch,)
         """
-        # import pdb;
-        # pdb.set_trace()
         if len(text_lengths.size()) > 1:
             text_lengths = text_lengths[:, 0]
         if len(speech_lengths.size()) > 1:
             speech_lengths = speech_lengths[:, 0]
         
         batch_size = speech.shape[0]
-        
         
         # Encoder
         encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
diff --git a/funasr/models/seaco_paraformer/model.py b/funasr/models/seaco_paraformer/model.py
index 21b6aba..8f87340 100644
--- a/funasr/models/seaco_paraformer/model.py
+++ b/funasr/models/seaco_paraformer/model.py
@@ -97,7 +97,8 @@
             smoothing=seaco_lsm_weight,
             normalize_length=seaco_length_normalized_loss,
         )
-        self.train_decoder = kwargs.get("train_decoder", False)
+        self.train_decoder = kwargs.get("train_decoder", True)
+        self.seaco_weight = kwargs.get("seaco_weight", 0.01)
         self.NO_BIAS = kwargs.get("NO_BIAS", 8377)
         self.predictor_name = kwargs.get("predictor")
         
@@ -117,9 +118,10 @@
                 text: (Batch, Length)
                 text_lengths: (Batch,)
         """
-        text_lengths = text_lengths.squeeze()
-        speech_lengths = speech_lengths.squeeze()
-        assert text_lengths.dim() == 1, text_lengths.shape
+        if len(text_lengths.size()) > 1:
+            text_lengths = text_lengths[:, 0]
+        if len(speech_lengths.size()) > 1:
+            speech_lengths = speech_lengths[:, 0]
         # Check that batch_size is unified
         assert (
                 speech.shape[0]
@@ -131,6 +133,8 @@
         hotword_pad = kwargs.get("hotword_pad")
         hotword_lengths = kwargs.get("hotword_lengths")
         seaco_label_pad = kwargs.get("seaco_label_pad")
+        if len(hotword_lengths.size()) > 1:
+            hotword_lengths = hotword_lengths[:, 0]
         
         batch_size = speech.shape[0]
         # for data-parallel
@@ -156,11 +160,12 @@
             loss_att, acc_att = self._calc_att_loss(
                 encoder_out, encoder_out_lens, text, text_lengths
             )
-            loss = loss_seaco + loss_att
+            loss = loss_seaco + loss_att * self.seaco_weight
             stats["loss_att"] = torch.clone(loss_att.detach())
             stats["acc_att"] = acc_att
         else:
             loss = loss_seaco
+            
         stats["loss_seaco"] = torch.clone(loss_seaco.detach())
         stats["loss"] = torch.clone(loss.detach())
 

--
Gitblit v1.9.1