From cf00b4a13f5fdedda19c3cae214943fc28df52ac Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 29 三月 2023 00:42:32 +0800
Subject: [PATCH] export

---
 funasr/runtime/python/onnxruntime/funasr_onnx/vad_bin.py |   13 ++++++++-----
 1 files changed, 8 insertions(+), 5 deletions(-)

diff --git a/funasr/runtime/python/onnxruntime/funasr_onnx/vad_bin.py b/funasr/runtime/python/onnxruntime/funasr_onnx/vad_bin.py
index 533b4b7..9568ac9 100644
--- a/funasr/runtime/python/onnxruntime/funasr_onnx/vad_bin.py
+++ b/funasr/runtime/python/onnxruntime/funasr_onnx/vad_bin.py
@@ -23,7 +23,7 @@
 	             device_id: Union[str, int] = "-1",
 	             quantize: bool = False,
 	             intra_op_num_threads: int = 4,
-	             max_end_sil: int = 800,
+	             max_end_sil: int = None,
 	             ):
 		
 		if not Path(model_dir).exists():
@@ -43,14 +43,17 @@
 		self.ort_infer = OrtInferSession(model_file, device_id, intra_op_num_threads=intra_op_num_threads)
 		self.batch_size = batch_size
 		self.vad_scorer = E2EVadModel(config["vad_post_conf"])
-		self.max_end_sil = max_end_sil
+		self.max_end_sil = max_end_sil if max_end_sil is not None else config["vad_post_conf"]["max_end_silence_time"]
+		self.encoder_conf = config["encoder_conf"]
 	
 	def prepare_cache(self, in_cache: list = []):
 		if len(in_cache) > 0:
 			return in_cache
-		
-		for i in range(4):
-			cache = np.random.rand(1, 128, 19, 1).astype(np.float32)
+		fsmn_layers = self.encoder_conf["fsmn_layers"]
+		proj_dim = self.encoder_conf["proj_dim"]
+		lorder = self.encoder_conf["lorder"]
+		for i in range(fsmn_layers):
+			cache = np.random.rand(1, proj_dim, lorder-1, 1).astype(np.float32)
 			in_cache.append(cache)
 		return in_cache
 		

--
Gitblit v1.9.1