From ada76b631223c51253840496d12423d8a2640eaf Mon Sep 17 00:00:00 2001
From: 北念 <lzr265946@alibaba-inc.com>
Date: 星期一, 17 六月 2024 13:36:22 +0800
Subject: [PATCH] sensevoice

---
 funasr/models/sense_voice/model.py |   26 ++++++++++++++++++++++----
 1 files changed, 22 insertions(+), 4 deletions(-)

diff --git a/funasr/models/sense_voice/model.py b/funasr/models/sense_voice/model.py
index c77930d..697f50c 100644
--- a/funasr/models/sense_voice/model.py
+++ b/funasr/models/sense_voice/model.py
@@ -1454,6 +1454,10 @@
         self.length_normalized_loss = length_normalized_loss
         self.encoder_output_size = encoder_output_size
 
+        self.lid_dict = {"zh": 3, "en": 4, "yue": 7, "ja": 11, "ko": 12, "nospeech": 13}
+        self.textnorm_dict = {"withtextnorm": 14, "wotextnorm": 15}
+        self.embed = torch.nn.Embedding(8 + len(self.lid_dict) + len(self.textnorm_dict), 560)
+
     def forward(
         self,
         speech: torch.Tensor,
@@ -1587,6 +1591,22 @@
 
         speech = speech.to(device=kwargs["device"])
         speech_lengths = speech_lengths.to(device=kwargs["device"])
+
+        language = kwargs.get("language", None)
+        if language is not None:
+            language_query = self.embed(torch.LongTensor([[self.lid_dict[language] if language in self.lid_dict else 0]]).to(speech.device)).repeat(speech.size(0), 1, 1)
+        else:
+            language_query = self.embed(torch.LongTensor([[0]]).to(speech.device)).repeat(speech.size(0), 1, 1)
+        textnorm = kwargs.get("text_norm", "wotextnorm")
+        textnorm_query = self.embed(torch.LongTensor([[self.textnorm_dict[textnorm]]]).to(speech.device)).repeat(speech.size(0), 1, 1)
+        speech = torch.cat((textnorm_query, speech), dim=1)
+        speech_lengths += 1
+
+        event_emo_query = self.embed(torch.LongTensor([[1, 2]]).to(speech.device)).repeat(speech.size(0), 1, 1)
+        input_query = torch.cat((language_query, event_emo_query), dim=1)
+        speech = torch.cat((input_query, speech), dim=1)
+        speech_lengths += 3
+
         # Encoder
         encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
         if isinstance(encoder_out, tuple):
@@ -1630,11 +1650,9 @@
                 )
 
                 # Change integer-ids to tokens
-                token = tokenizer.ids2tokens(token_int)
-                text = tokenizer.tokens2text(token)
+                text = tokenizer.decode(token_int)
 
-                text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
-                result_i = {"key": key[i], "token": token, "text": text_postprocessed}
+                result_i = {"key": key[i], "text": text}
                 results.append(result_i)
 
                 if ibest_writer is not None:

--
Gitblit v1.9.1