From b7cb19b01a1454f7a1388e24dcd4e10fc654bd7c Mon Sep 17 00:00:00 2001
From: shixian.shi <shixian.shi@alibaba-inc.com>
Date: 星期二, 16 一月 2024 11:30:25 +0800
Subject: [PATCH] update demo, readme

---
 examples/industrial_data_pretraining/ct_transformer_streaming/demo.py |    2 
 examples/industrial_data_pretraining/ct_transformer/demo.py           |    4 
 examples/industrial_data_pretraining/paraformer-zh-spk/demo.py        |    4 
 funasr/auto/auto_model.py                                             |    8 +-
 examples/industrial_data_pretraining/fsmn_vad_streaming/demo.py       |    4 
 examples/industrial_data_pretraining/contextual_paraformer/demo.py    |    2 
 examples/industrial_data_pretraining/monotonic_aligner/demo.py        |    2 
 README_zh.md                                                          |   18 ++++--
 examples/industrial_data_pretraining/campplus_sv/demo.py              |    2 
 README.md                                                             |   24 +++++---
 funasr/bin/inference.py                                               |   21 ------
 examples/industrial_data_pretraining/emotion2vec/demo.py              |    2 
 examples/industrial_data_pretraining/paraformer/demo.py               |    4 
 examples/industrial_data_pretraining/bicif_paraformer/demo.py         |   16 ++--
 examples/industrial_data_pretraining/seaco_paraformer/demo.py         |    4 
 examples/industrial_data_pretraining/paraformer_streaming/demo.py     |   16 ++--
 16 files changed, 63 insertions(+), 70 deletions(-)

diff --git a/README.md b/README.md
index a53ce4d..2bd28e2 100644
--- a/README.md
+++ b/README.md
@@ -95,9 +95,9 @@
                   vad_model="fsmn-vad", vad_model_revision="v2.0.2", \
                   punc_model="ct-punc-c", punc_model_revision="v2.0.2", \
                   spk_model="cam++", spk_model_revision="v2.0.2")
-res = model(input=f"{model.model_path}/example/asr_example.wav", 
-            batch_size=64, 
-            hotword='榄旀惌')
+res = model.generate(input=f"{model.model_path}/example/asr_example.wav", 
+                     batch_size=64, 
+                     hotword='榄旀惌')
 print(res)
 ```
 Note: `model_hub`: represents the model repository, `ms` stands for selecting ModelScope download, `hf` stands for selecting Huggingface download.
@@ -124,7 +124,7 @@
 for i in range(total_chunk_num):
     speech_chunk = speech[i*chunk_stride:(i+1)*chunk_stride]
     is_final = i == total_chunk_num - 1
-    res = model(input=speech_chunk, cache=cache, is_final=is_final, chunk_size=chunk_size, encoder_chunk_look_back=encoder_chunk_look_back, decoder_chunk_look_back=decoder_chunk_look_back)
+    res = model.generate(input=speech_chunk, cache=cache, is_final=is_final, chunk_size=chunk_size, encoder_chunk_look_back=encoder_chunk_look_back, decoder_chunk_look_back=decoder_chunk_look_back)
     print(res)
 ```
 Note: `chunk_size` is the configuration for streaming latency.` [0,10,5]` indicates that the real-time display granularity is `10*60=600ms`, and the lookahead information is `5*60=300ms`. Each inference input is `600ms` (sample points are `16000*0.6=960`), and the output is the corresponding text. For the last speech segment input, `is_final=True` needs to be set to force the output of the last word.
@@ -135,7 +135,7 @@
 
 model = AutoModel(model="fsmn-vad", model_revision="v2.0.2")
 wav_file = f"{model.model_path}/example/asr_example.wav"
-res = model(input=wav_file)
+res = model.generate(input=wav_file)
 print(res)
 ```
 ### Voice Activity Detection (Non-streaming)
@@ -156,7 +156,7 @@
 for i in range(total_chunk_num):
     speech_chunk = speech[i*chunk_stride:(i+1)*chunk_stride]
     is_final = i == total_chunk_num - 1
-    res = model(input=speech_chunk, cache=cache, is_final=is_final, chunk_size=chunk_size)
+    res = model.generate(input=speech_chunk, cache=cache, is_final=is_final, chunk_size=chunk_size)
     if len(res[0]["value"]):
         print(res)
 ```
@@ -165,7 +165,7 @@
 from funasr import AutoModel
 
 model = AutoModel(model="ct-punc", model_revision="v2.0.2")
-res = model(input="閭d粖澶╃殑浼氬氨鍒拌繖閲屽惂 happy new year 鏄庡勾瑙�")
+res = model.generate(input="閭d粖澶╃殑浼氬氨鍒拌繖閲屽惂 happy new year 鏄庡勾瑙�")
 print(res)
 ```
 ### Timestamp Prediction
@@ -175,7 +175,7 @@
 model = AutoModel(model="fa-zh", model_revision="v2.0.2")
 wav_file = f"{model.model_path}/example/asr_example.wav"
 text_file = f"{model.model_path}/example/text.txt"
-res = model(input=(wav_file, text_file), data_type=("sound", "text"))
+res = model.generate(input=(wav_file, text_file), data_type=("sound", "text"))
 print(res)
 ```
 [//]: # (FunASR supports inference and fine-tuning of models trained on industrial datasets of tens of thousands of hours. For more details, please refer to &#40;[modelscope_egs]&#40;https://alibaba-damo-academy.github.io/FunASR/en/modelscope_pipeline/quick_start.html&#41;&#41;. It also supports training and fine-tuning of models on academic standard datasets. For more details, please refer to&#40;[egs]&#40;https://alibaba-damo-academy.github.io/FunASR/en/academic_recipe/asr_recipe.html&#41;&#41;. The models include speech recognition &#40;ASR&#41;, speech activity detection &#40;VAD&#41;, punctuation recovery, language model, speaker verification, speaker separation, and multi-party conversation speech recognition. For a detailed list of models, please refer to the [Model Zoo]&#40;https://github.com/alibaba-damo-academy/FunASR/blob/main/docs/model_zoo/modelscope_models.md&#41;:)
@@ -229,10 +229,16 @@
 }
 @inproceedings{gao22b_interspeech,
   author={Zhifu Gao and ShiLiang Zhang and Ian McLoughlin and Zhijie Yan},
-  title={{Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition}},
+  title={Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition},
   year=2022,
   booktitle={Proc. Interspeech 2022},
   pages={2063--2067},
   doi={10.21437/Interspeech.2022-9996}
 }
+@inproceedings{shi2023seaco,
+  author={Xian Shi and Yexin Yang and Zerui Li and Yanni Chen and Zhifu Gao and Shiliang Zhang},
+  title={SeACo-Paraformer: A Non-Autoregressive ASR System with Flexible and Effective Hotword Customization Ability},
+  year={2023},
+  booktitle={ICASSP2024}
+}
 ```
diff --git a/README_zh.md b/README_zh.md
index 861e61c..dc20302 100644
--- a/README_zh.md
+++ b/README_zh.md
@@ -91,7 +91,7 @@
                   vad_model="fsmn-vad", vad_model_revision="v2.0.2", \
                   punc_model="ct-punc-c", punc_model_revision="v2.0.2", \
                   spk_model="cam++", spk_model_revision="v2.0.2")
-res = model(input=f"{model.model_path}/example/asr_example.wav", 
+res = model.generate(input=f"{model.model_path}/example/asr_example.wav", 
             batch_size=64, 
             hotword='榄旀惌')
 print(res)
@@ -121,7 +121,7 @@
 for i in range(total_chunk_num):
     speech_chunk = speech[i*chunk_stride:(i+1)*chunk_stride]
     is_final = i == total_chunk_num - 1
-    res = model(input=speech_chunk, cache=cache, is_final=is_final, chunk_size=chunk_size, encoder_chunk_look_back=encoder_chunk_look_back, decoder_chunk_look_back=decoder_chunk_look_back)
+    res = model.generate(input=speech_chunk, cache=cache, is_final=is_final, chunk_size=chunk_size, encoder_chunk_look_back=encoder_chunk_look_back, decoder_chunk_look_back=decoder_chunk_look_back)
     print(res)
 ```
 
@@ -134,7 +134,7 @@
 model = AutoModel(model="fsmn-vad", model_revision="v2.0.2")
 
 wav_file = f"{model.model_path}/example/asr_example.wav"
-res = model(input=wav_file)
+res = model.generate(input=wav_file)
 print(res)
 ```
 
@@ -156,7 +156,7 @@
 for i in range(total_chunk_num):
     speech_chunk = speech[i*chunk_stride:(i+1)*chunk_stride]
     is_final = i == total_chunk_num - 1
-    res = model(input=speech_chunk, cache=cache, is_final=is_final, chunk_size=chunk_size)
+    res = model.generate(input=speech_chunk, cache=cache, is_final=is_final, chunk_size=chunk_size)
     if len(res[0]["value"]):
         print(res)
 ```
@@ -167,7 +167,7 @@
 
 model = AutoModel(model="ct-punc", model_revision="v2.0.2")
 
-res = model(input="閭d粖澶╃殑浼氬氨鍒拌繖閲屽惂 happy new year 鏄庡勾瑙�")
+res = model.generate(input="閭d粖澶╃殑浼氬氨鍒拌繖閲屽惂 happy new year 鏄庡勾瑙�")
 print(res)
 ```
 
@@ -179,7 +179,7 @@
 
 wav_file = f"{model.model_path}/example/asr_example.wav"
 text_file = f"{model.model_path}/example/text.txt"
-res = model(input=(wav_file, text_file), data_type=("sound", "text"))
+res = model.generate(input=(wav_file, text_file), data_type=("sound", "text"))
 print(res)
 ```
 鏇村璇︾粏鐢ㄦ硶锛圼绀轰緥](examples/industrial_data_pretraining)锛�
@@ -242,4 +242,10 @@
   pages={2063--2067},
   doi={10.21437/Interspeech.2022-9996}
 }
+@article{shi2023seaco,
+  author={Xian Shi and Yexin Yang and Zerui Li and Yanni Chen and Zhifu Gao and Shiliang Zhang},
+  title={{SeACo-Paraformer: A Non-Autoregressive ASR System with Flexible and Effective Hotword Customization Ability}},
+  year=2023,
+  journal={arXiv preprint arXiv:2308.03266(accepted by ICASSP2024)},
+}
 ```
diff --git a/examples/industrial_data_pretraining/bicif_paraformer/demo.py b/examples/industrial_data_pretraining/bicif_paraformer/demo.py
index 60718de..a06b308 100644
--- a/examples/industrial_data_pretraining/bicif_paraformer/demo.py
+++ b/examples/industrial_data_pretraining/bicif_paraformer/demo.py
@@ -6,14 +6,14 @@
 from funasr import AutoModel
 
 model = AutoModel(model="damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
-                    model_revision="v2.0.2",
-                    vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
-                    vad_model_revision="v2.0.2",
-                    punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
-                    punc_model_revision="v2.0.2",
-                    spk_model="damo/speech_campplus_sv_zh-cn_16k-common",
-                    spk_model_revision="v2.0.2",
+                  model_revision="v2.0.2",
+                  vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
+                  vad_model_revision="v2.0.2",
+                  punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
+                  punc_model_revision="v2.0.2",
+                  spk_model="damo/speech_campplus_sv_zh-cn_16k-common",
+                  spk_model_revision="v2.0.2",
                   )
 
-res = model(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_vad_punc_example.wav", batch_size_s=300, batch_size_threshold_s=60)
+res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_vad_punc_example.wav", batch_size_s=300, batch_size_threshold_s=60)
 print(res)
diff --git a/examples/industrial_data_pretraining/campplus_sv/demo.py b/examples/industrial_data_pretraining/campplus_sv/demo.py
index 6a7f105..16d629b 100644
--- a/examples/industrial_data_pretraining/campplus_sv/demo.py
+++ b/examples/industrial_data_pretraining/campplus_sv/demo.py
@@ -9,5 +9,5 @@
                   model_revision="v2.0.2",
                   )
 
-res = model(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav")
+res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav")
 print(res)
\ No newline at end of file
diff --git a/examples/industrial_data_pretraining/contextual_paraformer/demo.py b/examples/industrial_data_pretraining/contextual_paraformer/demo.py
index 78693c5..d1378ca 100644
--- a/examples/industrial_data_pretraining/contextual_paraformer/demo.py
+++ b/examples/industrial_data_pretraining/contextual_paraformer/demo.py
@@ -7,6 +7,6 @@
 
 model = AutoModel(model="damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404", model_revision="v2.0.2")
 
-res = model(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav",
+res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav",
             hotword='杈炬懇闄� 榄旀惌')
 print(res)
\ No newline at end of file
diff --git a/examples/industrial_data_pretraining/ct_transformer/demo.py b/examples/industrial_data_pretraining/ct_transformer/demo.py
index d648f3d..f547f03 100644
--- a/examples/industrial_data_pretraining/ct_transformer/demo.py
+++ b/examples/industrial_data_pretraining/ct_transformer/demo.py
@@ -7,7 +7,7 @@
 
 model = AutoModel(model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch", model_revision="v2.0.2")
 
-res = model(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_text/punc_example.txt")
+res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_text/punc_example.txt")
 print(res)
 
 
@@ -15,5 +15,5 @@
 
 model = AutoModel(model="damo/punc_ct-transformer_cn-en-common-vocab471067-large", model_revision="v2.0.2")
 
-res = model(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_text/punc_example.txt")
+res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_text/punc_example.txt")
 print(res)
\ No newline at end of file
diff --git a/examples/industrial_data_pretraining/ct_transformer_streaming/demo.py b/examples/industrial_data_pretraining/ct_transformer_streaming/demo.py
index 5ef8381..081fd19 100644
--- a/examples/industrial_data_pretraining/ct_transformer_streaming/demo.py
+++ b/examples/industrial_data_pretraining/ct_transformer_streaming/demo.py
@@ -12,7 +12,7 @@
 rec_result_all = "outputs: "
 cache = {}
 for vad in vads:
-    rec_result = model(input=vad, cache=cache)
+    rec_result = model.generate(input=vad, cache=cache)
     print(rec_result)
     rec_result_all += rec_result[0]['text']
 
diff --git a/examples/industrial_data_pretraining/emotion2vec/demo.py b/examples/industrial_data_pretraining/emotion2vec/demo.py
index abaa9f4..ea8da99 100644
--- a/examples/industrial_data_pretraining/emotion2vec/demo.py
+++ b/examples/industrial_data_pretraining/emotion2vec/demo.py
@@ -7,5 +7,5 @@
 
 model = AutoModel(model="damo/emotion2vec_base", model_revision="v2.0.1")
 
-res = model(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav", output_dir="./outputs")
+res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav", output_dir="./outputs")
 print(res)
\ No newline at end of file
diff --git a/examples/industrial_data_pretraining/fsmn_vad_streaming/demo.py b/examples/industrial_data_pretraining/fsmn_vad_streaming/demo.py
index 459dfff..8084dec 100644
--- a/examples/industrial_data_pretraining/fsmn_vad_streaming/demo.py
+++ b/examples/industrial_data_pretraining/fsmn_vad_streaming/demo.py
@@ -9,7 +9,7 @@
 chunk_size = 60000 # ms
 model = AutoModel(model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch", model_revision="v2.0.2")
 
-res = model(input=wav_file, chunk_size=chunk_size, )
+res = model.generate(input=wav_file, chunk_size=chunk_size, )
 print(res)
 
 
@@ -28,7 +28,7 @@
 for i in range(total_chunk_num):
     speech_chunk = speech[i*chunk_stride:(i+1)*chunk_stride]
     is_final = i == total_chunk_num - 1
-    res = model(input=speech_chunk,
+    res = model.generate(input=speech_chunk,
                 cache=cache,
                 is_final=is_final,
                 chunk_size=chunk_size,
diff --git a/examples/industrial_data_pretraining/monotonic_aligner/demo.py b/examples/industrial_data_pretraining/monotonic_aligner/demo.py
index def6b7d..cad9aab 100644
--- a/examples/industrial_data_pretraining/monotonic_aligner/demo.py
+++ b/examples/industrial_data_pretraining/monotonic_aligner/demo.py
@@ -7,7 +7,7 @@
 
 model = AutoModel(model="damo/speech_timestamp_prediction-v1-16k-offline", model_revision="v2.0.2")
 
-res = model(input=("https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav",
+res = model.generate(input=("https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav",
                    "娆㈣繋澶у鏉ュ埌榄旀惌绀惧尯杩涜浣撻獙"),
             data_type=("sound", "text"),
             batch_size=2,
diff --git a/examples/industrial_data_pretraining/paraformer-zh-spk/demo.py b/examples/industrial_data_pretraining/paraformer-zh-spk/demo.py
index aa895eb..b4453e9 100644
--- a/examples/industrial_data_pretraining/paraformer-zh-spk/demo.py
+++ b/examples/industrial_data_pretraining/paraformer-zh-spk/demo.py
@@ -15,6 +15,6 @@
                   spk_model_revision="v2.0.2"
                   )
 
-res = model(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav",
-            hotword='杈炬懇闄� 纾ㄦ惌')
+res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav",
+                     hotword='杈炬懇闄� 纾ㄦ惌')
 print(res)
\ No newline at end of file
diff --git a/examples/industrial_data_pretraining/paraformer/demo.py b/examples/industrial_data_pretraining/paraformer/demo.py
index 6dbe33d..ef33bf4 100644
--- a/examples/industrial_data_pretraining/paraformer/demo.py
+++ b/examples/industrial_data_pretraining/paraformer/demo.py
@@ -7,7 +7,7 @@
 
 model = AutoModel(model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch", model_revision="v2.0.2")
 
-res = model(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav")
+res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav")
 print(res)
 
 
@@ -18,5 +18,5 @@
 fbanks = frontend(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav", batch_size=2)
 
 for batch_idx, fbank_dict in enumerate(fbanks):
-    res = model(**fbank_dict)
+    res = model.generate(**fbank_dict)
     print(res)
\ No newline at end of file
diff --git a/examples/industrial_data_pretraining/paraformer_streaming/demo.py b/examples/industrial_data_pretraining/paraformer_streaming/demo.py
index 8f7eef3..07efde6 100644
--- a/examples/industrial_data_pretraining/paraformer_streaming/demo.py
+++ b/examples/industrial_data_pretraining/paraformer_streaming/demo.py
@@ -11,7 +11,7 @@
 
 model = AutoModel(model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online", model_revision="v2.0.2")
 cache = {}
-res = model(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav",
+res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav",
             chunk_size=chunk_size,
             encoder_chunk_look_back=encoder_chunk_look_back,
             decoder_chunk_look_back=decoder_chunk_look_back,
@@ -32,11 +32,11 @@
 for i in range(total_chunk_num):
     speech_chunk = speech[i*chunk_stride:(i+1)*chunk_stride]
     is_final = i == total_chunk_num - 1
-    res = model(input=speech_chunk,
-                cache=cache,
-                is_final=is_final,
-                chunk_size=chunk_size,
-                encoder_chunk_look_back=encoder_chunk_look_back,
-                decoder_chunk_look_back=decoder_chunk_look_back,
-                )
+    res = model.generate(input=speech_chunk,
+                         cache=cache,
+                         is_final=is_final,
+                         chunk_size=chunk_size,
+                         encoder_chunk_look_back=encoder_chunk_look_back,
+                         decoder_chunk_look_back=decoder_chunk_look_back,
+                         )
     print(res)
diff --git a/examples/industrial_data_pretraining/seaco_paraformer/demo.py b/examples/industrial_data_pretraining/seaco_paraformer/demo.py
index cf49e42..5f17252 100644
--- a/examples/industrial_data_pretraining/seaco_paraformer/demo.py
+++ b/examples/industrial_data_pretraining/seaco_paraformer/demo.py
@@ -15,6 +15,6 @@
                   spk_model_revision="v2.0.2",
                   )
 
-res = model(input=f"{model.model_path}/example/asr_example.wav",
-            hotword='杈炬懇闄� 榄旀惌')
+res = model.generate(input=f"{model.model_path}/example/asr_example.wav",
+                     hotword='杈炬懇闄� 榄旀惌')
 print(res)
\ No newline at end of file
diff --git a/funasr/auto/auto_model.py b/funasr/auto/auto_model.py
index 25edeb7..580cca8 100644
--- a/funasr/auto/auto_model.py
+++ b/funasr/auto/auto_model.py
@@ -264,7 +264,7 @@
         # step.1: compute the vad model
         self.vad_kwargs.update(cfg)
         beg_vad = time.time()
-        res = self.generate(input, input_len=input_len, model=self.vad_model, kwargs=self.vad_kwargs, **cfg)
+        res = self.inference(input, input_len=input_len, model=self.vad_model, kwargs=self.vad_kwargs, **cfg)
         end_vad = time.time()
         print(f"time cost vad: {end_vad - beg_vad:0.3f}")
 
@@ -316,7 +316,7 @@
                 batch_size_ms_cum = 0
                 end_idx = j + 1
                 speech_j, speech_lengths_j = slice_padding_audio_samples(speech, speech_lengths, sorted_data[beg_idx:end_idx])       
-                results = self.generate(speech_j, input_len=None, model=model, kwargs=kwargs, **cfg)
+                results = self.inference(speech_j, input_len=None, model=model, kwargs=kwargs, **cfg)
                 if self.spk_model is not None:
                     all_segments = []
                     # compose vad segments: [[start_time_sec, end_time_sec, speech], [...]]
@@ -327,7 +327,7 @@
                         segments = sv_chunk(vad_segments)
                         all_segments.extend(segments)
                         speech_b = [i[2] for i in segments]
-                        spk_res = self.generate(speech_b, input_len=None, model=self.spk_model, kwargs=kwargs, **cfg)
+                        spk_res = self.inference(speech_b, input_len=None, model=self.spk_model, kwargs=kwargs, **cfg)
                         results[_b]['spk_embedding'] = spk_res[0]['spk_embedding']
                 beg_idx = end_idx
                 if len(results) < 1:
@@ -378,7 +378,7 @@
             # step.3 compute punc model
             if self.punc_model is not None:
                 self.punc_kwargs.update(cfg)
-                punc_res = self.generate(result["text"], model=self.punc_model, kwargs=self.punc_kwargs, **cfg)
+                punc_res = self.inference(result["text"], model=self.punc_model, kwargs=self.punc_kwargs, **cfg)
                 result["text_with_punc"] = punc_res[0]["text"]
                      
             # speaker embedding cluster after resorted
diff --git a/funasr/bin/inference.py b/funasr/bin/inference.py
index bc435c4..d2f0c14 100644
--- a/funasr/bin/inference.py
+++ b/funasr/bin/inference.py
@@ -1,25 +1,7 @@
-import json
-import time
-import torch
 import hydra
-import random
-import string
 import logging
-import os.path
-from tqdm import tqdm
 from omegaconf import DictConfig, OmegaConf, ListConfig
 
-from funasr.register import tables
-from funasr.utils.load_utils import load_bytes
-from funasr.download.file import download_from_url
-from funasr.download.download_from_hub import download_model
-from funasr.utils.vad_utils import slice_padding_audio_samples
-from funasr.train_utils.set_all_random_seed import set_all_random_seed
-from funasr.train_utils.load_pretrained_model import load_pretrained_model
-from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
-from funasr.utils.timestamp_tools import timestamp_sentence
-from funasr.models.campplus.utils import sv_chunk, postprocess, distribute_spk
-from funasr.models.campplus.cluster_backend import ClusterBackend
 from funasr.auto.auto_model import AutoModel
 
 
@@ -41,9 +23,8 @@
     if kwargs.get("debug", False):
         import pdb; pdb.set_trace()
     model = AutoModel(**kwargs)
-    res = model(input=kwargs["input"])
+    res = model.generate(input=kwargs["input"])
     print(res)
-
 
 
 if __name__ == '__main__':

--
Gitblit v1.9.1