From abb33d6b2097e5b0643326bc1b376a63cdc2f967 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 24 六月 2024 17:06:21 +0800
Subject: [PATCH] Dev gzf deepspeed (#1844)

---
 funasr/models/sense_voice/model.py |   56 +++++++++++++++++++++++++++++++++++---------------------
 1 files changed, 35 insertions(+), 21 deletions(-)

diff --git a/funasr/models/sense_voice/model.py b/funasr/models/sense_voice/model.py
index 127d5a0..9db6539 100644
--- a/funasr/models/sense_voice/model.py
+++ b/funasr/models/sense_voice/model.py
@@ -10,12 +10,13 @@
 from torch import Tensor
 from torch import nn
 from torch.cuda.amp import autocast
-from funasr.metrics.compute_acc import compute_accuracy
+from funasr.metrics.compute_acc import compute_accuracy, th_accuracy
 from funasr.losses.label_smoothing_loss import LabelSmoothingLoss
 from funasr.train_utils.device_funcs import force_gatherable
 from . import whisper_lib as whisper
 from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
 from funasr.utils.datadir_writer import DatadirWriter
+from funasr.models.ctc.ctc import CTC
 
 from funasr.register import tables
 
@@ -73,8 +74,6 @@
     ):
         target_mask = kwargs.get("target_mask", None)
 
-        # import pdb;
-        # pdb.set_trace()
         if len(text_lengths.size()) > 1:
             text_lengths = text_lengths[:, 0]
         if len(speech_lengths.size()) > 1:
@@ -303,8 +302,6 @@
     ):
         target_mask = kwargs.get("target_mask", None)
 
-        # import pdb;
-        # pdb.set_trace()
         if len(text_lengths.size()) > 1:
             text_lengths = text_lengths[:, 0]
         if len(speech_lengths.size()) > 1:
@@ -648,8 +645,6 @@
     ):
         target_mask = kwargs.get("target_mask", None)
 
-        # import pdb;
-        # pdb.set_trace()
         if len(text_lengths.size()) > 1:
             text_lengths = text_lengths[:, 0]
         if len(speech_lengths.size()) > 1:
@@ -667,9 +662,11 @@
         else:
             encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
 
-        loss_att, acc_att, cer_att, wer_att = self._calc_att_loss(
-            encoder_out, encoder_out_lens, text, text_lengths, target_mask=target_mask
-        )
+        with autocast(False):
+            loss_att, acc_att, cer_att, wer_att = self._calc_att_loss(
+                encoder_out, encoder_out_lens, text, text_lengths, target_mask=target_mask
+            )
+
         loss = loss_att
         stats = {}
         stats["acc"] = acc_att
@@ -1041,6 +1038,7 @@
         self.length_normalized_loss = length_normalized_loss
         self.beam_search = None
         self.activation_checkpoint = kwargs.get("activation_checkpoint", False)
+        self.encoder_output_size = encoder_output_size
 
     def forward(
         self,
@@ -1052,8 +1050,6 @@
     ):
         target_mask = kwargs.get("target_mask", None)
 
-        # import pdb;
-        # pdb.set_trace()
         if len(text_lengths.size()) > 1:
             text_lengths = text_lengths[:, 0]
         if len(speech_lengths.size()) > 1:
@@ -1264,15 +1260,31 @@
         if isinstance(task, str):
             task = [task]
         task = "".join([f"<|{x}|>" for x in task])
-        initial_prompt = kwargs.get("initial_prompt", f"<|startoftranscript|>{task}")
 
-        language = DecodingOptions.get("language", None)
-        language = None if language == "auto" else language
+        sos = kwargs.get("model_conf").get("sos")
+        if isinstance(sos, str):
+            initial_prompt = kwargs.get("initial_prompt", f"<|startoftranscript|>{task}")
 
-        sos = f"{initial_prompt}<|{language}|>" if language is not None else initial_prompt
-        sos_int = tokenizer.encode(sos, allowed_special="all")
+            language = DecodingOptions.get("language", None)
+            language = None if language == "auto" else language
+
+            sos = f"{initial_prompt}<|{language}|>" if language is not None else initial_prompt
+            sos_int = tokenizer.encode(sos, allowed_special="all")
+        else:
+            language = DecodingOptions.get("language", None)
+            language = None if language == "auto" else language
+            initial_prompt = kwargs.get("initial_prompt", f"{task}")
+            initial_prompt_lid = (
+                f"{initial_prompt}<|{language}|>" if language is not None else initial_prompt
+            )
+            initial_prompt_lid_int = tokenizer.encode(initial_prompt_lid, allowed_special="all")
+            sos_int = [sos] + initial_prompt_lid_int
         eos = kwargs.get("model_conf").get("eos")
-        eos_int = tokenizer.encode(eos, allowed_special="all")
+        if isinstance(eos, str):
+            eos_int = tokenizer.encode(eos, allowed_special="all")
+        else:
+            eos_int = [eos]
+
         self.beam_search.sos = sos_int
         self.beam_search.eos = eos_int[0]
 
@@ -1297,9 +1309,7 @@
         )
         self.beam_search.event_score_ga = DecodingOptions.get("gain_tokens_score", [1, 1, 1, 1])
 
-        encoder_out, encoder_out_lens = self.encode(
-            speech[None, :, :].permute(0, 2, 1), speech_lengths
-        )
+        encoder_out, encoder_out_lens = self.encode(speech[None, :, :], speech_lengths)
 
         if text_token_int is not None:
             i = 0
@@ -1378,3 +1388,7 @@
                     ibest_writer["text"][key[i]] = text
 
         return results, meta_data
+
+
+from funasr.models.paraformer.search import Hypothesis
+from funasr.utils import postprocess_utils

--
Gitblit v1.9.1