From 1596f6f414f6f41da66506debb1dff19fffeb3ec Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 24 六月 2024 11:55:17 +0800
Subject: [PATCH] fixbug hotwords
---
funasr/models/sense_voice/model.py | 26 ++++++++++++++------------
1 files changed, 14 insertions(+), 12 deletions(-)
diff --git a/funasr/models/sense_voice/model.py b/funasr/models/sense_voice/model.py
index 697f50c..a9b2149 100644
--- a/funasr/models/sense_voice/model.py
+++ b/funasr/models/sense_voice/model.py
@@ -74,8 +74,6 @@
):
target_mask = kwargs.get("target_mask", None)
- # import pdb;
- # pdb.set_trace()
if len(text_lengths.size()) > 1:
text_lengths = text_lengths[:, 0]
if len(speech_lengths.size()) > 1:
@@ -304,8 +302,6 @@
):
target_mask = kwargs.get("target_mask", None)
- # import pdb;
- # pdb.set_trace()
if len(text_lengths.size()) > 1:
text_lengths = text_lengths[:, 0]
if len(speech_lengths.size()) > 1:
@@ -649,8 +645,6 @@
):
target_mask = kwargs.get("target_mask", None)
- # import pdb;
- # pdb.set_trace()
if len(text_lengths.size()) > 1:
text_lengths = text_lengths[:, 0]
if len(speech_lengths.size()) > 1:
@@ -1054,8 +1048,6 @@
):
target_mask = kwargs.get("target_mask", None)
- # import pdb;
- # pdb.set_trace()
if len(text_lengths.size()) > 1:
text_lengths = text_lengths[:, 0]
if len(speech_lengths.size()) > 1:
@@ -1594,15 +1586,25 @@
language = kwargs.get("language", None)
if language is not None:
- language_query = self.embed(torch.LongTensor([[self.lid_dict[language] if language in self.lid_dict else 0]]).to(speech.device)).repeat(speech.size(0), 1, 1)
+ language_query = self.embed(
+ torch.LongTensor(
+ [[self.lid_dict[language] if language in self.lid_dict else 0]]
+ ).to(speech.device)
+ ).repeat(speech.size(0), 1, 1)
else:
- language_query = self.embed(torch.LongTensor([[0]]).to(speech.device)).repeat(speech.size(0), 1, 1)
+ language_query = self.embed(torch.LongTensor([[0]]).to(speech.device)).repeat(
+ speech.size(0), 1, 1
+ )
textnorm = kwargs.get("text_norm", "wotextnorm")
- textnorm_query = self.embed(torch.LongTensor([[self.textnorm_dict[textnorm]]]).to(speech.device)).repeat(speech.size(0), 1, 1)
+ textnorm_query = self.embed(
+ torch.LongTensor([[self.textnorm_dict[textnorm]]]).to(speech.device)
+ ).repeat(speech.size(0), 1, 1)
speech = torch.cat((textnorm_query, speech), dim=1)
speech_lengths += 1
- event_emo_query = self.embed(torch.LongTensor([[1, 2]]).to(speech.device)).repeat(speech.size(0), 1, 1)
+ event_emo_query = self.embed(torch.LongTensor([[1, 2]]).to(speech.device)).repeat(
+ speech.size(0), 1, 1
+ )
input_query = torch.cat((language_query, event_emo_query), dim=1)
speech = torch.cat((input_query, speech), dim=1)
speech_lengths += 3
--
Gitblit v1.9.1