From e9d2cfc3a134b00f4e98271fbee3838d1ccecbcc Mon Sep 17 00:00:00 2001
From: VirtuosoQ <2416050435@qq.com>
Date: 星期五, 26 四月 2024 14:59:30 +0800
Subject: [PATCH] FunASR java http  client

---
 funasr/models/seaco_paraformer/model.py |   41 ++++++++++++++++++++++++++++-------------
 1 files changed, 28 insertions(+), 13 deletions(-)

diff --git a/funasr/models/seaco_paraformer/model.py b/funasr/models/seaco_paraformer/model.py
index 5d0f602..b28de94 100644
--- a/funasr/models/seaco_paraformer/model.py
+++ b/funasr/models/seaco_paraformer/model.py
@@ -97,7 +97,8 @@
             smoothing=seaco_lsm_weight,
             normalize_length=seaco_length_normalized_loss,
         )
-        self.train_decoder = kwargs.get("train_decoder", False)
+        self.train_decoder = kwargs.get("train_decoder", True)
+        self.seaco_weight = kwargs.get("seaco_weight", 0.01)
         self.NO_BIAS = kwargs.get("NO_BIAS", 8377)
         self.predictor_name = kwargs.get("predictor")
         
@@ -117,7 +118,10 @@
                 text: (Batch, Length)
                 text_lengths: (Batch,)
         """
-        assert text_lengths.dim() == 1, text_lengths.shape
+        if len(text_lengths.size()) > 1:
+            text_lengths = text_lengths[:, 0]
+        if len(speech_lengths.size()) > 1:
+            speech_lengths = speech_lengths[:, 0]
         # Check that batch_size is unified
         assert (
                 speech.shape[0]
@@ -128,7 +132,9 @@
     
         hotword_pad = kwargs.get("hotword_pad")
         hotword_lengths = kwargs.get("hotword_lengths")
-        dha_pad = kwargs.get("dha_pad")
+        seaco_label_pad = kwargs.get("seaco_label_pad")
+        if len(hotword_lengths.size()) > 1:
+            hotword_lengths = hotword_lengths[:, 0]
         
         batch_size = speech.shape[0]
         # for data-parallel
@@ -148,23 +154,24 @@
                                         ys_lengths, 
                                         hotword_pad, 
                                         hotword_lengths, 
-                                        dha_pad,
+                                        seaco_label_pad,
                                         )
         if self.train_decoder:
-            loss_att, acc_att = self._calc_att_loss(
+            loss_att, acc_att, _, _, _ = self._calc_att_loss(
                 encoder_out, encoder_out_lens, text, text_lengths
             )
-            loss = loss_seaco + loss_att
+            loss = loss_seaco + loss_att * self.seaco_weight
             stats["loss_att"] = torch.clone(loss_att.detach())
             stats["acc_att"] = acc_att
         else:
             loss = loss_seaco
+            
         stats["loss_seaco"] = torch.clone(loss_seaco.detach())
         stats["loss"] = torch.clone(loss.detach())
 
         # force_gatherable: to-device and to-tensor if scalar for DataParallel
         if self.length_normalized_loss:
-            batch_size = (text_lengths + self.predictor_bias).sum().type_as(batch_size)
+            batch_size = (text_lengths + self.predictor_bias).sum()
         loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
         return loss, stats, weight
 
@@ -185,13 +192,12 @@
             ys_lengths: torch.Tensor,
             hotword_pad: torch.Tensor,
             hotword_lengths: torch.Tensor,
-            dha_pad: torch.Tensor,
+            seaco_label_pad: torch.Tensor,
     ):  
         # predictor forward
         encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
             encoder_out.device)
-        pre_acoustic_embeds, _, _, _ = self.predictor(encoder_out, ys_pad, encoder_out_mask,
-                                                                                  ignore_id=self.ignore_id)
+        pre_acoustic_embeds = self.predictor(encoder_out, ys_pad, encoder_out_mask, ignore_id=self.ignore_id)[0]
         # decoder forward
         decoder_out, _ = self.decoder(encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_lengths, return_hidden=True)
         selected = self._hotword_representation(hotword_pad, 
@@ -204,7 +210,7 @@
         dec_attended, _ = self.seaco_decoder(contextual_info, _contextual_length, decoder_out, ys_lengths)
         merged = self._merge(cif_attended, dec_attended)
         dha_output = self.hotword_output_layer(merged[:, :-1])  # remove the last token in loss calculation
-        loss_att = self.criterion_seaco(dha_output, dha_pad)
+        loss_att = self.criterion_seaco(dha_output, seaco_label_pad)
         return loss_att
 
     def _seaco_decode_with_ASF(self, 
@@ -344,7 +350,7 @@
         pre_acoustic_embeds, pre_token_length = predictor_outs[0], predictor_outs[1]
         pre_token_length = pre_token_length.round().long()
         if torch.max(pre_token_length) < 1:
-            return []
+            return ([],)
 
         decoder_out = self._seaco_decode_with_ASF(encoder_out, 
                                                   encoder_out_lens,
@@ -429,7 +435,6 @@
                 results.append(result_i)
         
         return results, meta_data
-
 
     def generate_hotwords_list(self, hotword_list_or_file, tokenizer=None, frontend=None):
         def load_seg_dict(seg_dict_file):
@@ -532,3 +537,13 @@
             hotword_list = None
         return hotword_list
 
+    def export(
+        self,
+        **kwargs,
+    ):
+        if 'max_seq_len' not in kwargs:
+            kwargs['max_seq_len'] = 512
+        from .export_meta import export_rebuild_model
+        models = export_rebuild_model(model=self, **kwargs)
+        return models
+

--
Gitblit v1.9.1