From db9ec58cb430fa10b273a7a365457bf33dc46adc Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 05 六月 2024 17:30:57 +0800
Subject: [PATCH] Dev gzf exp (#1785)

---
 funasr/models/sense_voice/model.py                |   28 +++++++++++++++++++++-------
 funasr/models/sense_voice/decoder.py              |    1 +
 funasr/datasets/audio_datasets/espnet_samplers.py |    2 ++
 3 files changed, 24 insertions(+), 7 deletions(-)

diff --git a/funasr/datasets/audio_datasets/espnet_samplers.py b/funasr/datasets/audio_datasets/espnet_samplers.py
index b358fa3..004201e 100644
--- a/funasr/datasets/audio_datasets/espnet_samplers.py
+++ b/funasr/datasets/audio_datasets/espnet_samplers.py
@@ -147,7 +147,9 @@
         start_idx = self.rank * batches_per_rank
         end_idx = start_idx + batches_per_rank
         rank_batches = buffer_batches[start_idx + self.start_step : end_idx]
+
         self.batch_num = len(rank_batches)
+
         logging.info(
             f"rank: {self.rank}, dataloader start from step: {self.start_step}, batch_num: {end_idx-start_idx}, batch_num_after_step: {len(rank_batches)}"
         )
diff --git a/funasr/models/sense_voice/decoder.py b/funasr/models/sense_voice/decoder.py
index 60af29a..ff933d7 100644
--- a/funasr/models/sense_voice/decoder.py
+++ b/funasr/models/sense_voice/decoder.py
@@ -360,6 +360,7 @@
         """Score."""
         ys_mask = subsequent_mask(len(ys), device=x.device).unsqueeze(0)
         logp = self.forward(ys.unsqueeze(0), x.unsqueeze(0), cache=state)
+        logp = torch.log_softmax(logp, dim=-1)
         return logp.squeeze(0)[-1, :], state
 
 
diff --git a/funasr/models/sense_voice/model.py b/funasr/models/sense_voice/model.py
index 127d5a0..22272ee 100644
--- a/funasr/models/sense_voice/model.py
+++ b/funasr/models/sense_voice/model.py
@@ -1264,15 +1264,29 @@
         if isinstance(task, str):
             task = [task]
         task = "".join([f"<|{x}|>" for x in task])
-        initial_prompt = kwargs.get("initial_prompt", f"<|startoftranscript|>{task}")
+        
+        sos = kwargs.get("model_conf").get("sos")
+        if isinstance(sos, str):
+            initial_prompt = kwargs.get("initial_prompt", f"<|startoftranscript|>{task}")
 
-        language = DecodingOptions.get("language", None)
-        language = None if language == "auto" else language
+            language = DecodingOptions.get("language", None)
+            language = None if language == "auto" else language
 
-        sos = f"{initial_prompt}<|{language}|>" if language is not None else initial_prompt
-        sos_int = tokenizer.encode(sos, allowed_special="all")
+            sos = f"{initial_prompt}<|{language}|>" if language is not None else initial_prompt
+            sos_int = tokenizer.encode(sos, allowed_special="all")
+        else:
+            language = DecodingOptions.get("language", None)
+            language = None if language == "auto" else language
+            initial_prompt = kwargs.get("initial_prompt", f"{task}")
+            initial_prompt_lid = f"{initial_prompt}<|{language}|>" if language is not None else initial_prompt
+            initial_prompt_lid_int = tokenizer.encode(initial_prompt_lid, allowed_special="all")
+            sos_int = [sos] + initial_prompt_lid_int
         eos = kwargs.get("model_conf").get("eos")
-        eos_int = tokenizer.encode(eos, allowed_special="all")
+        if isinstance(eos, str):
+            eos_int = tokenizer.encode(eos, allowed_special="all")
+        else:
+            eos_int = [eos]
+
         self.beam_search.sos = sos_int
         self.beam_search.eos = eos_int[0]
 
@@ -1298,7 +1312,7 @@
         self.beam_search.event_score_ga = DecodingOptions.get("gain_tokens_score", [1, 1, 1, 1])
 
         encoder_out, encoder_out_lens = self.encode(
-            speech[None, :, :].permute(0, 2, 1), speech_lengths
+            speech[None, :, :], speech_lengths
         )
 
         if text_token_int is not None:

--
Gitblit v1.9.1