From 0143122a4e2ee86cc27ba137b2bb0530577cbf12 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期五, 12 一月 2024 10:27:36 +0800
Subject: [PATCH] funasr1.0 streaming demo

---
 funasr/bin/inference.py                                           |    3 ++-
 funasr/bin/train.py                                               |    2 +-
 funasr/download/download_from_hub.py                              |    1 +
 examples/industrial_data_pretraining/paraformer_streaming/demo.py |    6 ++----
 4 files changed, 6 insertions(+), 6 deletions(-)

diff --git a/examples/industrial_data_pretraining/paraformer_streaming/demo.py b/examples/industrial_data_pretraining/paraformer_streaming/demo.py
index b62cc29..d4dd34e 100644
--- a/examples/industrial_data_pretraining/paraformer_streaming/demo.py
+++ b/examples/industrial_data_pretraining/paraformer_streaming/demo.py
@@ -22,10 +22,8 @@
 import soundfile
 import os
 
-speech, sample_rate = soundfile.read(os.path.expanduser('~')+
-                                     "/.cache/modelscope/hub/damo/"+
-                                     "speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online/"+
-                                     "example/asr_example.wav")
+wav_file = os.path.join(model.model_path, "example/asr_example.wav")
+speech, sample_rate = soundfile.read(wav_file)
 
 chunk_stride = chunk_size[1] * 960 # 600ms銆�480ms
 
diff --git a/funasr/bin/inference.py b/funasr/bin/inference.py
index e239747..7d9c1b9 100644
--- a/funasr/bin/inference.py
+++ b/funasr/bin/inference.py
@@ -83,7 +83,7 @@
     
     return key_list, data_list
 
-@hydra.main(config_name=None)
+@hydra.main(config_name=None, version_base=None)
 def main_hydra(cfg: DictConfig):
     def to_plain_list(cfg_item):
         if isinstance(cfg_item, ListConfig):
@@ -150,6 +150,7 @@
         self.punc_kwargs = punc_kwargs
         self.spk_model = spk_model
         self.spk_kwargs = spk_kwargs
+        self.model_path = kwargs["model_path"]
   
         
     def build_model(self, **kwargs):
diff --git a/funasr/bin/train.py b/funasr/bin/train.py
index 1f896b7..af3e8af 100644
--- a/funasr/bin/train.py
+++ b/funasr/bin/train.py
@@ -23,7 +23,7 @@
 from funasr.download.download_from_hub import download_model
 from funasr.register import tables
 
-@hydra.main(config_name=None)
+@hydra.main(config_name=None, version_base=None)
 def main_hydra(kwargs: DictConfig):
 	if kwargs.get("debug", False):
 		import pdb; pdb.set_trace()
diff --git a/funasr/download/download_from_hub.py b/funasr/download/download_from_hub.py
index 73578f2..946572f 100644
--- a/funasr/download/download_from_hub.py
+++ b/funasr/download/download_from_hub.py
@@ -18,6 +18,7 @@
 	model_revision = kwargs.get("model_revision")
 	if not os.path.exists(model_or_path):
 		model_or_path = get_or_download_model_dir(model_or_path, model_revision, is_training=kwargs.get("is_training"), check_latest=kwargs.get("kwargs", True))
+	kwargs["model_path"] = model_or_path
 	
 	config = os.path.join(model_or_path, "config.yaml")
 	if os.path.exists(config) and os.path.exists(os.path.join(model_or_path, "model.pb")):

--
Gitblit v1.9.1