From b7cb19b01a1454f7a1388e24dcd4e10fc654bd7c Mon Sep 17 00:00:00 2001
From: shixian.shi <shixian.shi@alibaba-inc.com>
Date: 星期二, 16 一月 2024 11:30:25 +0800
Subject: [PATCH] update demo, readme
---
examples/industrial_data_pretraining/ct_transformer_streaming/demo.py | 2
examples/industrial_data_pretraining/ct_transformer/demo.py | 4
examples/industrial_data_pretraining/paraformer-zh-spk/demo.py | 4
funasr/auto/auto_model.py | 8 +-
examples/industrial_data_pretraining/fsmn_vad_streaming/demo.py | 4
examples/industrial_data_pretraining/contextual_paraformer/demo.py | 2
examples/industrial_data_pretraining/monotonic_aligner/demo.py | 2
README_zh.md | 18 ++++--
examples/industrial_data_pretraining/campplus_sv/demo.py | 2
README.md | 24 +++++---
funasr/bin/inference.py | 21 ------
examples/industrial_data_pretraining/emotion2vec/demo.py | 2
examples/industrial_data_pretraining/paraformer/demo.py | 4
examples/industrial_data_pretraining/bicif_paraformer/demo.py | 16 ++--
examples/industrial_data_pretraining/seaco_paraformer/demo.py | 4
examples/industrial_data_pretraining/paraformer_streaming/demo.py | 16 ++--
16 files changed, 63 insertions(+), 70 deletions(-)
diff --git a/README.md b/README.md
index a53ce4d..2bd28e2 100644
--- a/README.md
+++ b/README.md
@@ -95,9 +95,9 @@
vad_model="fsmn-vad", vad_model_revision="v2.0.2", \
punc_model="ct-punc-c", punc_model_revision="v2.0.2", \
spk_model="cam++", spk_model_revision="v2.0.2")
-res = model(input=f"{model.model_path}/example/asr_example.wav",
- batch_size=64,
- hotword='榄旀惌')
+res = model.generate(input=f"{model.model_path}/example/asr_example.wav",
+ batch_size=64,
+ hotword='榄旀惌')
print(res)
```
Note: `model_hub`: represents the model repository, `ms` stands for selecting ModelScope download, `hf` stands for selecting Huggingface download.
@@ -124,7 +124,7 @@
for i in range(total_chunk_num):
speech_chunk = speech[i*chunk_stride:(i+1)*chunk_stride]
is_final = i == total_chunk_num - 1
- res = model(input=speech_chunk, cache=cache, is_final=is_final, chunk_size=chunk_size, encoder_chunk_look_back=encoder_chunk_look_back, decoder_chunk_look_back=decoder_chunk_look_back)
+ res = model.generate(input=speech_chunk, cache=cache, is_final=is_final, chunk_size=chunk_size, encoder_chunk_look_back=encoder_chunk_look_back, decoder_chunk_look_back=decoder_chunk_look_back)
print(res)
```
Note: `chunk_size` is the configuration for streaming latency.` [0,10,5]` indicates that the real-time display granularity is `10*60=600ms`, and the lookahead information is `5*60=300ms`. Each inference input is `600ms` (sample points are `16000*0.6=960`), and the output is the corresponding text. For the last speech segment input, `is_final=True` needs to be set to force the output of the last word.
@@ -135,7 +135,7 @@
model = AutoModel(model="fsmn-vad", model_revision="v2.0.2")
wav_file = f"{model.model_path}/example/asr_example.wav"
-res = model(input=wav_file)
+res = model.generate(input=wav_file)
print(res)
```
### Voice Activity Detection (Non-streaming)
@@ -156,7 +156,7 @@
for i in range(total_chunk_num):
speech_chunk = speech[i*chunk_stride:(i+1)*chunk_stride]
is_final = i == total_chunk_num - 1
- res = model(input=speech_chunk, cache=cache, is_final=is_final, chunk_size=chunk_size)
+ res = model.generate(input=speech_chunk, cache=cache, is_final=is_final, chunk_size=chunk_size)
if len(res[0]["value"]):
print(res)
```
@@ -165,7 +165,7 @@
from funasr import AutoModel
model = AutoModel(model="ct-punc", model_revision="v2.0.2")
-res = model(input="閭d粖澶╃殑浼氬氨鍒拌繖閲屽惂 happy new year 鏄庡勾瑙�")
+res = model.generate(input="閭d粖澶╃殑浼氬氨鍒拌繖閲屽惂 happy new year 鏄庡勾瑙�")
print(res)
```
### Timestamp Prediction
@@ -175,7 +175,7 @@
model = AutoModel(model="fa-zh", model_revision="v2.0.2")
wav_file = f"{model.model_path}/example/asr_example.wav"
text_file = f"{model.model_path}/example/text.txt"
-res = model(input=(wav_file, text_file), data_type=("sound", "text"))
+res = model.generate(input=(wav_file, text_file), data_type=("sound", "text"))
print(res)
```
[//]: # (FunASR supports inference and fine-tuning of models trained on industrial datasets of tens of thousands of hours. For more details, please refer to ([modelscope_egs](https://alibaba-damo-academy.github.io/FunASR/en/modelscope_pipeline/quick_start.html)). It also supports training and fine-tuning of models on academic standard datasets. For more details, please refer to([egs](https://alibaba-damo-academy.github.io/FunASR/en/academic_recipe/asr_recipe.html)). The models include speech recognition (ASR), speech activity detection (VAD), punctuation recovery, language model, speaker verification, speaker separation, and multi-party conversation speech recognition. For a detailed list of models, please refer to the [Model Zoo](https://github.com/alibaba-damo-academy/FunASR/blob/main/docs/model_zoo/modelscope_models.md):)
@@ -229,10 +229,16 @@
}
@inproceedings{gao22b_interspeech,
author={Zhifu Gao and ShiLiang Zhang and Ian McLoughlin and Zhijie Yan},
- title={{Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition}},
+ title={Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition},
year=2022,
booktitle={Proc. Interspeech 2022},
pages={2063--2067},
doi={10.21437/Interspeech.2022-9996}
}
+@inproceedings{shi2023seaco,
+ author={Xian Shi and Yexin Yang and Zerui Li and Yanni Chen and Zhifu Gao and Shiliang Zhang},
+ title={SeACo-Paraformer: A Non-Autoregressive ASR System with Flexible and Effective Hotword Customization Ability},
+ year={2023},
+ booktitle={ICASSP2024}
+}
```
diff --git a/README_zh.md b/README_zh.md
index 861e61c..dc20302 100644
--- a/README_zh.md
+++ b/README_zh.md
@@ -91,7 +91,7 @@
vad_model="fsmn-vad", vad_model_revision="v2.0.2", \
punc_model="ct-punc-c", punc_model_revision="v2.0.2", \
spk_model="cam++", spk_model_revision="v2.0.2")
-res = model(input=f"{model.model_path}/example/asr_example.wav",
+res = model.generate(input=f"{model.model_path}/example/asr_example.wav",
batch_size=64,
hotword='榄旀惌')
print(res)
@@ -121,7 +121,7 @@
for i in range(total_chunk_num):
speech_chunk = speech[i*chunk_stride:(i+1)*chunk_stride]
is_final = i == total_chunk_num - 1
- res = model(input=speech_chunk, cache=cache, is_final=is_final, chunk_size=chunk_size, encoder_chunk_look_back=encoder_chunk_look_back, decoder_chunk_look_back=decoder_chunk_look_back)
+ res = model.generate(input=speech_chunk, cache=cache, is_final=is_final, chunk_size=chunk_size, encoder_chunk_look_back=encoder_chunk_look_back, decoder_chunk_look_back=decoder_chunk_look_back)
print(res)
```
@@ -134,7 +134,7 @@
model = AutoModel(model="fsmn-vad", model_revision="v2.0.2")
wav_file = f"{model.model_path}/example/asr_example.wav"
-res = model(input=wav_file)
+res = model.generate(input=wav_file)
print(res)
```
@@ -156,7 +156,7 @@
for i in range(total_chunk_num):
speech_chunk = speech[i*chunk_stride:(i+1)*chunk_stride]
is_final = i == total_chunk_num - 1
- res = model(input=speech_chunk, cache=cache, is_final=is_final, chunk_size=chunk_size)
+ res = model.generate(input=speech_chunk, cache=cache, is_final=is_final, chunk_size=chunk_size)
if len(res[0]["value"]):
print(res)
```
@@ -167,7 +167,7 @@
model = AutoModel(model="ct-punc", model_revision="v2.0.2")
-res = model(input="閭d粖澶╃殑浼氬氨鍒拌繖閲屽惂 happy new year 鏄庡勾瑙�")
+res = model.generate(input="閭d粖澶╃殑浼氬氨鍒拌繖閲屽惂 happy new year 鏄庡勾瑙�")
print(res)
```
@@ -179,7 +179,7 @@
wav_file = f"{model.model_path}/example/asr_example.wav"
text_file = f"{model.model_path}/example/text.txt"
-res = model(input=(wav_file, text_file), data_type=("sound", "text"))
+res = model.generate(input=(wav_file, text_file), data_type=("sound", "text"))
print(res)
```
鏇村璇︾粏鐢ㄦ硶锛圼绀轰緥](examples/industrial_data_pretraining)锛�
@@ -242,4 +242,10 @@
pages={2063--2067},
doi={10.21437/Interspeech.2022-9996}
}
+@article{shi2023seaco,
+ author={Xian Shi and Yexin Yang and Zerui Li and Yanni Chen and Zhifu Gao and Shiliang Zhang},
+ title={{SeACo-Paraformer: A Non-Autoregressive ASR System with Flexible and Effective Hotword Customization Ability}},
+ year=2023,
+ journal={arXiv preprint arXiv:2308.03266(accepted by ICASSP2024)},
+}
```
diff --git a/examples/industrial_data_pretraining/bicif_paraformer/demo.py b/examples/industrial_data_pretraining/bicif_paraformer/demo.py
index 60718de..a06b308 100644
--- a/examples/industrial_data_pretraining/bicif_paraformer/demo.py
+++ b/examples/industrial_data_pretraining/bicif_paraformer/demo.py
@@ -6,14 +6,14 @@
from funasr import AutoModel
model = AutoModel(model="damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
- model_revision="v2.0.2",
- vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
- vad_model_revision="v2.0.2",
- punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
- punc_model_revision="v2.0.2",
- spk_model="damo/speech_campplus_sv_zh-cn_16k-common",
- spk_model_revision="v2.0.2",
+ model_revision="v2.0.2",
+ vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
+ vad_model_revision="v2.0.2",
+ punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
+ punc_model_revision="v2.0.2",
+ spk_model="damo/speech_campplus_sv_zh-cn_16k-common",
+ spk_model_revision="v2.0.2",
)
-res = model(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_vad_punc_example.wav", batch_size_s=300, batch_size_threshold_s=60)
+res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_vad_punc_example.wav", batch_size_s=300, batch_size_threshold_s=60)
print(res)
diff --git a/examples/industrial_data_pretraining/campplus_sv/demo.py b/examples/industrial_data_pretraining/campplus_sv/demo.py
index 6a7f105..16d629b 100644
--- a/examples/industrial_data_pretraining/campplus_sv/demo.py
+++ b/examples/industrial_data_pretraining/campplus_sv/demo.py
@@ -9,5 +9,5 @@
model_revision="v2.0.2",
)
-res = model(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav")
+res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav")
print(res)
\ No newline at end of file
diff --git a/examples/industrial_data_pretraining/contextual_paraformer/demo.py b/examples/industrial_data_pretraining/contextual_paraformer/demo.py
index 78693c5..d1378ca 100644
--- a/examples/industrial_data_pretraining/contextual_paraformer/demo.py
+++ b/examples/industrial_data_pretraining/contextual_paraformer/demo.py
@@ -7,6 +7,6 @@
model = AutoModel(model="damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404", model_revision="v2.0.2")
-res = model(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav",
+res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav",
hotword='杈炬懇闄� 榄旀惌')
print(res)
\ No newline at end of file
diff --git a/examples/industrial_data_pretraining/ct_transformer/demo.py b/examples/industrial_data_pretraining/ct_transformer/demo.py
index d648f3d..f547f03 100644
--- a/examples/industrial_data_pretraining/ct_transformer/demo.py
+++ b/examples/industrial_data_pretraining/ct_transformer/demo.py
@@ -7,7 +7,7 @@
model = AutoModel(model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch", model_revision="v2.0.2")
-res = model(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_text/punc_example.txt")
+res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_text/punc_example.txt")
print(res)
@@ -15,5 +15,5 @@
model = AutoModel(model="damo/punc_ct-transformer_cn-en-common-vocab471067-large", model_revision="v2.0.2")
-res = model(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_text/punc_example.txt")
+res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_text/punc_example.txt")
print(res)
\ No newline at end of file
diff --git a/examples/industrial_data_pretraining/ct_transformer_streaming/demo.py b/examples/industrial_data_pretraining/ct_transformer_streaming/demo.py
index 5ef8381..081fd19 100644
--- a/examples/industrial_data_pretraining/ct_transformer_streaming/demo.py
+++ b/examples/industrial_data_pretraining/ct_transformer_streaming/demo.py
@@ -12,7 +12,7 @@
rec_result_all = "outputs: "
cache = {}
for vad in vads:
- rec_result = model(input=vad, cache=cache)
+ rec_result = model.generate(input=vad, cache=cache)
print(rec_result)
rec_result_all += rec_result[0]['text']
diff --git a/examples/industrial_data_pretraining/emotion2vec/demo.py b/examples/industrial_data_pretraining/emotion2vec/demo.py
index abaa9f4..ea8da99 100644
--- a/examples/industrial_data_pretraining/emotion2vec/demo.py
+++ b/examples/industrial_data_pretraining/emotion2vec/demo.py
@@ -7,5 +7,5 @@
model = AutoModel(model="damo/emotion2vec_base", model_revision="v2.0.1")
-res = model(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav", output_dir="./outputs")
+res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav", output_dir="./outputs")
print(res)
\ No newline at end of file
diff --git a/examples/industrial_data_pretraining/fsmn_vad_streaming/demo.py b/examples/industrial_data_pretraining/fsmn_vad_streaming/demo.py
index 459dfff..8084dec 100644
--- a/examples/industrial_data_pretraining/fsmn_vad_streaming/demo.py
+++ b/examples/industrial_data_pretraining/fsmn_vad_streaming/demo.py
@@ -9,7 +9,7 @@
chunk_size = 60000 # ms
model = AutoModel(model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch", model_revision="v2.0.2")
-res = model(input=wav_file, chunk_size=chunk_size, )
+res = model.generate(input=wav_file, chunk_size=chunk_size, )
print(res)
@@ -28,7 +28,7 @@
for i in range(total_chunk_num):
speech_chunk = speech[i*chunk_stride:(i+1)*chunk_stride]
is_final = i == total_chunk_num - 1
- res = model(input=speech_chunk,
+ res = model.generate(input=speech_chunk,
cache=cache,
is_final=is_final,
chunk_size=chunk_size,
diff --git a/examples/industrial_data_pretraining/monotonic_aligner/demo.py b/examples/industrial_data_pretraining/monotonic_aligner/demo.py
index def6b7d..cad9aab 100644
--- a/examples/industrial_data_pretraining/monotonic_aligner/demo.py
+++ b/examples/industrial_data_pretraining/monotonic_aligner/demo.py
@@ -7,7 +7,7 @@
model = AutoModel(model="damo/speech_timestamp_prediction-v1-16k-offline", model_revision="v2.0.2")
-res = model(input=("https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav",
+res = model.generate(input=("https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav",
"娆㈣繋澶у鏉ュ埌榄旀惌绀惧尯杩涜浣撻獙"),
data_type=("sound", "text"),
batch_size=2,
diff --git a/examples/industrial_data_pretraining/paraformer-zh-spk/demo.py b/examples/industrial_data_pretraining/paraformer-zh-spk/demo.py
index aa895eb..b4453e9 100644
--- a/examples/industrial_data_pretraining/paraformer-zh-spk/demo.py
+++ b/examples/industrial_data_pretraining/paraformer-zh-spk/demo.py
@@ -15,6 +15,6 @@
spk_model_revision="v2.0.2"
)
-res = model(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav",
- hotword='杈炬懇闄� 纾ㄦ惌')
+res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav",
+ hotword='杈炬懇闄� 纾ㄦ惌')
print(res)
\ No newline at end of file
diff --git a/examples/industrial_data_pretraining/paraformer/demo.py b/examples/industrial_data_pretraining/paraformer/demo.py
index 6dbe33d..ef33bf4 100644
--- a/examples/industrial_data_pretraining/paraformer/demo.py
+++ b/examples/industrial_data_pretraining/paraformer/demo.py
@@ -7,7 +7,7 @@
model = AutoModel(model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch", model_revision="v2.0.2")
-res = model(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav")
+res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav")
print(res)
@@ -18,5 +18,5 @@
fbanks = frontend(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav", batch_size=2)
for batch_idx, fbank_dict in enumerate(fbanks):
- res = model(**fbank_dict)
+ res = model.generate(**fbank_dict)
print(res)
\ No newline at end of file
diff --git a/examples/industrial_data_pretraining/paraformer_streaming/demo.py b/examples/industrial_data_pretraining/paraformer_streaming/demo.py
index 8f7eef3..07efde6 100644
--- a/examples/industrial_data_pretraining/paraformer_streaming/demo.py
+++ b/examples/industrial_data_pretraining/paraformer_streaming/demo.py
@@ -11,7 +11,7 @@
model = AutoModel(model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online", model_revision="v2.0.2")
cache = {}
-res = model(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav",
+res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav",
chunk_size=chunk_size,
encoder_chunk_look_back=encoder_chunk_look_back,
decoder_chunk_look_back=decoder_chunk_look_back,
@@ -32,11 +32,11 @@
for i in range(total_chunk_num):
speech_chunk = speech[i*chunk_stride:(i+1)*chunk_stride]
is_final = i == total_chunk_num - 1
- res = model(input=speech_chunk,
- cache=cache,
- is_final=is_final,
- chunk_size=chunk_size,
- encoder_chunk_look_back=encoder_chunk_look_back,
- decoder_chunk_look_back=decoder_chunk_look_back,
- )
+ res = model.generate(input=speech_chunk,
+ cache=cache,
+ is_final=is_final,
+ chunk_size=chunk_size,
+ encoder_chunk_look_back=encoder_chunk_look_back,
+ decoder_chunk_look_back=decoder_chunk_look_back,
+ )
print(res)
diff --git a/examples/industrial_data_pretraining/seaco_paraformer/demo.py b/examples/industrial_data_pretraining/seaco_paraformer/demo.py
index cf49e42..5f17252 100644
--- a/examples/industrial_data_pretraining/seaco_paraformer/demo.py
+++ b/examples/industrial_data_pretraining/seaco_paraformer/demo.py
@@ -15,6 +15,6 @@
spk_model_revision="v2.0.2",
)
-res = model(input=f"{model.model_path}/example/asr_example.wav",
- hotword='杈炬懇闄� 榄旀惌')
+res = model.generate(input=f"{model.model_path}/example/asr_example.wav",
+ hotword='杈炬懇闄� 榄旀惌')
print(res)
\ No newline at end of file
diff --git a/funasr/auto/auto_model.py b/funasr/auto/auto_model.py
index 25edeb7..580cca8 100644
--- a/funasr/auto/auto_model.py
+++ b/funasr/auto/auto_model.py
@@ -264,7 +264,7 @@
# step.1: compute the vad model
self.vad_kwargs.update(cfg)
beg_vad = time.time()
- res = self.generate(input, input_len=input_len, model=self.vad_model, kwargs=self.vad_kwargs, **cfg)
+ res = self.inference(input, input_len=input_len, model=self.vad_model, kwargs=self.vad_kwargs, **cfg)
end_vad = time.time()
print(f"time cost vad: {end_vad - beg_vad:0.3f}")
@@ -316,7 +316,7 @@
batch_size_ms_cum = 0
end_idx = j + 1
speech_j, speech_lengths_j = slice_padding_audio_samples(speech, speech_lengths, sorted_data[beg_idx:end_idx])
- results = self.generate(speech_j, input_len=None, model=model, kwargs=kwargs, **cfg)
+ results = self.inference(speech_j, input_len=None, model=model, kwargs=kwargs, **cfg)
if self.spk_model is not None:
all_segments = []
# compose vad segments: [[start_time_sec, end_time_sec, speech], [...]]
@@ -327,7 +327,7 @@
segments = sv_chunk(vad_segments)
all_segments.extend(segments)
speech_b = [i[2] for i in segments]
- spk_res = self.generate(speech_b, input_len=None, model=self.spk_model, kwargs=kwargs, **cfg)
+ spk_res = self.inference(speech_b, input_len=None, model=self.spk_model, kwargs=kwargs, **cfg)
results[_b]['spk_embedding'] = spk_res[0]['spk_embedding']
beg_idx = end_idx
if len(results) < 1:
@@ -378,7 +378,7 @@
# step.3 compute punc model
if self.punc_model is not None:
self.punc_kwargs.update(cfg)
- punc_res = self.generate(result["text"], model=self.punc_model, kwargs=self.punc_kwargs, **cfg)
+ punc_res = self.inference(result["text"], model=self.punc_model, kwargs=self.punc_kwargs, **cfg)
result["text_with_punc"] = punc_res[0]["text"]
# speaker embedding cluster after resorted
diff --git a/funasr/bin/inference.py b/funasr/bin/inference.py
index bc435c4..d2f0c14 100644
--- a/funasr/bin/inference.py
+++ b/funasr/bin/inference.py
@@ -1,25 +1,7 @@
-import json
-import time
-import torch
import hydra
-import random
-import string
import logging
-import os.path
-from tqdm import tqdm
from omegaconf import DictConfig, OmegaConf, ListConfig
-from funasr.register import tables
-from funasr.utils.load_utils import load_bytes
-from funasr.download.file import download_from_url
-from funasr.download.download_from_hub import download_model
-from funasr.utils.vad_utils import slice_padding_audio_samples
-from funasr.train_utils.set_all_random_seed import set_all_random_seed
-from funasr.train_utils.load_pretrained_model import load_pretrained_model
-from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
-from funasr.utils.timestamp_tools import timestamp_sentence
-from funasr.models.campplus.utils import sv_chunk, postprocess, distribute_spk
-from funasr.models.campplus.cluster_backend import ClusterBackend
from funasr.auto.auto_model import AutoModel
@@ -41,9 +23,8 @@
if kwargs.get("debug", False):
import pdb; pdb.set_trace()
model = AutoModel(**kwargs)
- res = model(input=kwargs["input"])
+ res = model.generate(input=kwargs["input"])
print(res)
-
if __name__ == '__main__':
--
Gitblit v1.9.1