From 4137f5cf26e7c4b40853959cd2574edfde03aa60 Mon Sep 17 00:00:00 2001
From: 志浩 <neo.dzh@alibaba-inc.com>
Date: 星期五, 07 四月 2023 21:03:34 +0800
Subject: [PATCH] Merge branch 'main' of github.com:alibaba-damo-academy/FunASR into dev_dzh
---
funasr/bin/vad_inference_online.py | 19 +++++++++----------
1 files changed, 9 insertions(+), 10 deletions(-)
diff --git a/funasr/bin/vad_inference_online.py b/funasr/bin/vad_inference_online.py
index d18488e..9ed0721 100644
--- a/funasr/bin/vad_inference_online.py
+++ b/funasr/bin/vad_inference_online.py
@@ -1,5 +1,6 @@
import argparse
import logging
+import os
import sys
import json
from pathlib import Path
@@ -32,12 +33,6 @@
header_colors = '\033[95m'
end_colors = '\033[0m'
-global_asr_language: str = 'zh-cn'
-global_sample_rate: Union[int, Dict[Any, int]] = {
- 'audio_fs': 16000,
- 'model_fs': 16000
-}
-
class Speech2VadSegmentOnline(Speech2VadSegment):
"""Speech2VadSegmentOnline class
@@ -61,7 +56,7 @@
@torch.no_grad()
def __call__(
self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None,
- in_cache: Dict[str, torch.Tensor] = dict(), is_final: bool = False
+ in_cache: Dict[str, torch.Tensor] = dict(), is_final: bool = False, max_end_sil: int = 800
) -> Tuple[torch.Tensor, List[List[int]], torch.Tensor]:
"""Inference
@@ -92,7 +87,8 @@
"feats": feats,
"waveform": waveforms,
"in_cache": in_cache,
- "is_final": is_final
+ "is_final": is_final,
+ "max_end_sil": max_end_sil
}
# a. To device
batch = to_device(batch, device=self.device)
@@ -222,7 +218,8 @@
vad_results = []
batch_in_cache = param_dict['in_cache'] if param_dict is not None else dict()
- is_final = param_dict['is_final'] if param_dict is not None else False
+ is_final = param_dict.get('is_final', False) if param_dict is not None else False
+ max_end_sil = param_dict.get('max_end_sil', 800) if param_dict is not None else 800
for keys, batch in loader:
assert isinstance(batch, dict), type(batch)
assert all(isinstance(s, str) for s in keys), keys
@@ -230,6 +227,7 @@
assert len(keys) == _bs, f"{len(keys)} != {_bs}"
batch['in_cache'] = batch_in_cache
batch['is_final'] = is_final
+ batch['max_end_sil'] = max_end_sil
# do vad segment
_, results, param_dict['in_cache'] = speech2vadsegment(**batch)
@@ -237,7 +235,8 @@
if results:
for i, _ in enumerate(keys):
if results[i]:
- results[i] = json.dumps(results[i])
+ if "MODELSCOPE_ENVIRONMENT" in os.environ and os.environ["MODELSCOPE_ENVIRONMENT"] == "eas":
+ results[i] = json.dumps(results[i])
item = {'key': keys[i], 'value': results[i]}
vad_results.append(item)
if writer is not None:
--
Gitblit v1.9.1