From 09a28d19df5854bdd4bd4d3a05dcb6f502ec6b07 Mon Sep 17 00:00:00 2001
From: shixian.shi <shixian.shi@alibaba-inc.com>
Date: 星期五, 12 一月 2024 18:02:10 +0800
Subject: [PATCH] update

---
 funasr/models/fsmn_vad/model.py |   11 +++++++----
 1 files changed, 7 insertions(+), 4 deletions(-)

diff --git a/funasr/models/fsmn_vad/model.py b/funasr/models/fsmn_vad/model.py
index 1ed0773..15d2af5 100644
--- a/funasr/models/fsmn_vad/model.py
+++ b/funasr/models/fsmn_vad/model.py
@@ -12,6 +12,7 @@
 from funasr.utils.load_utils import load_audio_text_image_video,extract_fbank
 from funasr.utils.datadir_writer import DatadirWriter
 from torch.nn.utils.rnn import pad_sequence
+from funasr.train_utils.device_funcs import to_device
 
 class VadStateMachine(Enum):
     kVadInStateStartPointNotDetected = 1
@@ -579,7 +580,8 @@
                 "cache": cache
             }
 
-
+            
+            batch = to_device(batch, device=kwargs["device"])
             segments_part, cache = self.forward(**batch)
             if segments_part:
                 for batch_num in range(0, batch_size):
@@ -593,15 +595,16 @@
         results = []
         for i in range(batch_size):
             
-            if "MODELSCOPE_ENVIRONMENT" in os.environ and os.environ["MODELSCOPE_ENVIRONMENT"] == "eas":
-                results[i] = json.dumps(results[i])
                 
             if ibest_writer is not None:
                 ibest_writer["text"][key[i]] = segments[i]
 
             result_i = {"key": key[i], "value": segments[i]}
             results.append(result_i)
- 
+            
+            if "MODELSCOPE_ENVIRONMENT" in os.environ and os.environ["MODELSCOPE_ENVIRONMENT"] == "eas":
+                results[i] = json.dumps(results[i])
+                
         return results, meta_data
 
     def DetectCommonFrames(self) -> int:

--
Gitblit v1.9.1