From 1596f6f414f6f41da66506debb1dff19fffeb3ec Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 24 六月 2024 11:55:17 +0800
Subject: [PATCH] fixbug hotwords

---
 funasr/models/sense_voice/model.py |   26 ++++++++++++++------------
 1 files changed, 14 insertions(+), 12 deletions(-)

diff --git a/funasr/models/sense_voice/model.py b/funasr/models/sense_voice/model.py
index 697f50c..a9b2149 100644
--- a/funasr/models/sense_voice/model.py
+++ b/funasr/models/sense_voice/model.py
@@ -74,8 +74,6 @@
     ):
         target_mask = kwargs.get("target_mask", None)
 
-        # import pdb;
-        # pdb.set_trace()
         if len(text_lengths.size()) > 1:
             text_lengths = text_lengths[:, 0]
         if len(speech_lengths.size()) > 1:
@@ -304,8 +302,6 @@
     ):
         target_mask = kwargs.get("target_mask", None)
 
-        # import pdb;
-        # pdb.set_trace()
         if len(text_lengths.size()) > 1:
             text_lengths = text_lengths[:, 0]
         if len(speech_lengths.size()) > 1:
@@ -649,8 +645,6 @@
     ):
         target_mask = kwargs.get("target_mask", None)
 
-        # import pdb;
-        # pdb.set_trace()
         if len(text_lengths.size()) > 1:
             text_lengths = text_lengths[:, 0]
         if len(speech_lengths.size()) > 1:
@@ -1054,8 +1048,6 @@
     ):
         target_mask = kwargs.get("target_mask", None)
 
-        # import pdb;
-        # pdb.set_trace()
         if len(text_lengths.size()) > 1:
             text_lengths = text_lengths[:, 0]
         if len(speech_lengths.size()) > 1:
@@ -1594,15 +1586,25 @@
 
         language = kwargs.get("language", None)
         if language is not None:
-            language_query = self.embed(torch.LongTensor([[self.lid_dict[language] if language in self.lid_dict else 0]]).to(speech.device)).repeat(speech.size(0), 1, 1)
+            language_query = self.embed(
+                torch.LongTensor(
+                    [[self.lid_dict[language] if language in self.lid_dict else 0]]
+                ).to(speech.device)
+            ).repeat(speech.size(0), 1, 1)
         else:
-            language_query = self.embed(torch.LongTensor([[0]]).to(speech.device)).repeat(speech.size(0), 1, 1)
+            language_query = self.embed(torch.LongTensor([[0]]).to(speech.device)).repeat(
+                speech.size(0), 1, 1
+            )
         textnorm = kwargs.get("text_norm", "wotextnorm")
-        textnorm_query = self.embed(torch.LongTensor([[self.textnorm_dict[textnorm]]]).to(speech.device)).repeat(speech.size(0), 1, 1)
+        textnorm_query = self.embed(
+            torch.LongTensor([[self.textnorm_dict[textnorm]]]).to(speech.device)
+        ).repeat(speech.size(0), 1, 1)
         speech = torch.cat((textnorm_query, speech), dim=1)
         speech_lengths += 1
 
-        event_emo_query = self.embed(torch.LongTensor([[1, 2]]).to(speech.device)).repeat(speech.size(0), 1, 1)
+        event_emo_query = self.embed(torch.LongTensor([[1, 2]]).to(speech.device)).repeat(
+            speech.size(0), 1, 1
+        )
         input_query = torch.cat((language_query, event_emo_query), dim=1)
         speech = torch.cat((input_query, speech), dim=1)
         speech_lengths += 3

--
Gitblit v1.9.1