From 09a28d19df5854bdd4bd4d3a05dcb6f502ec6b07 Mon Sep 17 00:00:00 2001
From: shixian.shi <shixian.shi@alibaba-inc.com>
Date: 星期五, 12 一月 2024 18:02:10 +0800
Subject: [PATCH] update
---
funasr/models/fsmn_vad/model.py | 11 +++++++----
1 files changed, 7 insertions(+), 4 deletions(-)
diff --git a/funasr/models/fsmn_vad/model.py b/funasr/models/fsmn_vad/model.py
index 1ed0773..15d2af5 100644
--- a/funasr/models/fsmn_vad/model.py
+++ b/funasr/models/fsmn_vad/model.py
@@ -12,6 +12,7 @@
from funasr.utils.load_utils import load_audio_text_image_video,extract_fbank
from funasr.utils.datadir_writer import DatadirWriter
from torch.nn.utils.rnn import pad_sequence
+from funasr.train_utils.device_funcs import to_device
class VadStateMachine(Enum):
kVadInStateStartPointNotDetected = 1
@@ -579,7 +580,8 @@
"cache": cache
}
-
+
+ batch = to_device(batch, device=kwargs["device"])
segments_part, cache = self.forward(**batch)
if segments_part:
for batch_num in range(0, batch_size):
@@ -593,15 +595,16 @@
results = []
for i in range(batch_size):
- if "MODELSCOPE_ENVIRONMENT" in os.environ and os.environ["MODELSCOPE_ENVIRONMENT"] == "eas":
- results[i] = json.dumps(results[i])
if ibest_writer is not None:
ibest_writer["text"][key[i]] = segments[i]
result_i = {"key": key[i], "value": segments[i]}
results.append(result_i)
-
+
+ if "MODELSCOPE_ENVIRONMENT" in os.environ and os.environ["MODELSCOPE_ENVIRONMENT"] == "eas":
+ results[i] = json.dumps(results[i])
+
return results, meta_data
def DetectCommonFrames(self) -> int:
--
Gitblit v1.9.1