From b15db52e4e67da8a133a67e8ffa415386de48b40 Mon Sep 17 00:00:00 2001
From: zhuyunfeng <10596244@qq.com>
Date: 星期二, 09 五月 2023 23:03:15 +0800
Subject: [PATCH] Add contributor

---
 funasr/bin/vad_inference_online.py |   27 +++++++++++++--------------
 1 files changed, 13 insertions(+), 14 deletions(-)

diff --git a/funasr/bin/vad_inference_online.py b/funasr/bin/vad_inference_online.py
index d18488e..a363309 100644
--- a/funasr/bin/vad_inference_online.py
+++ b/funasr/bin/vad_inference_online.py
@@ -1,5 +1,6 @@
 import argparse
 import logging
+import os
 import sys
 import json
 from pathlib import Path
@@ -32,12 +33,6 @@
 header_colors = '\033[95m'
 end_colors = '\033[0m'
 
-global_asr_language: str = 'zh-cn'
-global_sample_rate: Union[int, Dict[Any, int]] = {
-    'audio_fs': 16000,
-    'model_fs': 16000
-}
-
 
 class Speech2VadSegmentOnline(Speech2VadSegment):
     """Speech2VadSegmentOnline class
@@ -61,7 +56,7 @@
     @torch.no_grad()
     def __call__(
             self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None,
-            in_cache: Dict[str, torch.Tensor] = dict(), is_final: bool = False
+            in_cache: Dict[str, torch.Tensor] = dict(), is_final: bool = False, max_end_sil: int = 800
     ) -> Tuple[torch.Tensor, List[List[int]], torch.Tensor]:
         """Inference
 
@@ -92,7 +87,8 @@
                 "feats": feats,
                 "waveform": waveforms,
                 "in_cache": in_cache,
-                "is_final": is_final
+                "is_final": is_final,
+                "max_end_sil": max_end_sil
             }
             # a. To device
             batch = to_device(batch, device=self.device)
@@ -155,10 +151,11 @@
         **kwargs,
 ):
     assert check_argument_types()
+    ncpu = kwargs.get("ncpu", 1)
+    torch.set_num_threads(ncpu)
+    
     if batch_size > 1:
         raise NotImplementedError("batch decoding is not implemented")
-    if ngpu > 1:
-        raise NotImplementedError("only single GPU decoding is supported")
 
     logging.basicConfig(
         level=log_level,
@@ -169,7 +166,7 @@
         device = "cuda"
     else:
         device = "cpu"
-
+        batch_size = 1
     # 1. Set random-seed
     set_all_random_seed(seed)
 
@@ -222,7 +219,8 @@
 
         vad_results = []
         batch_in_cache = param_dict['in_cache'] if param_dict is not None else dict()
-        is_final = param_dict['is_final'] if param_dict is not None else False
+        is_final = param_dict.get('is_final', False) if param_dict is not None else False
+        max_end_sil = param_dict.get('max_end_sil', 800) if param_dict is not None else 800
         for keys, batch in loader:
             assert isinstance(batch, dict), type(batch)
             assert all(isinstance(s, str) for s in keys), keys
@@ -230,6 +228,7 @@
             assert len(keys) == _bs, f"{len(keys)} != {_bs}"
             batch['in_cache'] = batch_in_cache
             batch['is_final'] = is_final
+            batch['max_end_sil'] = max_end_sil
 
             # do vad segment
             _, results, param_dict['in_cache'] = speech2vadsegment(**batch)
@@ -237,11 +236,11 @@
             if results:
                 for i, _ in enumerate(keys):
                     if results[i]:
-                        results[i] = json.dumps(results[i])
+                        if "MODELSCOPE_ENVIRONMENT" in os.environ and os.environ["MODELSCOPE_ENVIRONMENT"] == "eas":
+                            results[i] = json.dumps(results[i])
                         item = {'key': keys[i], 'value': results[i]}
                         vad_results.append(item)
                         if writer is not None:
-                            results[i] = json.loads(results[i])
                             ibest_writer["text"][keys[i]] = "{}".format(results[i])
 
         return vad_results

--
Gitblit v1.9.1