From 851e3e3ef83d0769d9bde172d8841f6b20e3e377 Mon Sep 17 00:00:00 2001
From: gaochangfeng <54253717+gaochangfeng@users.noreply.github.com>
Date: 星期三, 10 四月 2024 14:37:35 +0800
Subject: [PATCH] Gcf (#1605)

---
 funasr/models/sense_voice/whisper_lib/decoding.py |    6 ++++++
 1 files changed, 6 insertions(+), 0 deletions(-)

diff --git a/funasr/models/sense_voice/whisper_lib/decoding.py b/funasr/models/sense_voice/whisper_lib/decoding.py
index b3fce7e..2239b64 100644
--- a/funasr/models/sense_voice/whisper_lib/decoding.py
+++ b/funasr/models/sense_voice/whisper_lib/decoding.py
@@ -10,6 +10,8 @@
 from .audio import CHUNK_LENGTH
 from .tokenizer import Tokenizer, get_tokenizer
 from .utils import compression_ratio
+from funasr.models.transformer.utils.nets_utils import to_device
+
 
 if TYPE_CHECKING:
     from .model import Whisper
@@ -58,6 +60,10 @@
     # x = torch.tensor([[tokenizer.sot]] * n_audio).to(mel.device)  # [n_audio, 1]
     if x is None:
         x = torch.tensor([tokenizer.encode(initial_prompt, allowed_special="all")] * n_audio).to(mel.device)  # [n_audio, 1]
+
+    else:
+        x = x.to(mel.device)
+
     logits = model.logits(x[:,:-1], mel)[:, -1]
     # collect detected languages; suppress all non-language tokens
     mask = torch.ones(logits.shape[-1], dtype=torch.bool)

--
Gitblit v1.9.1