From 502b6a5e97480d48a8f40ab198519660ed8ef557 Mon Sep 17 00:00:00 2001
From: hnluo <haoneng.lhn@alibaba-inc.com>
Date: 星期五, 05 五月 2023 11:50:41 +0800
Subject: [PATCH] Merge pull request #452 from alibaba-damo-academy/dev_lhn

---
 egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online/infer.py       |    4 ++--
 funasr/bin/asr_inference_paraformer_streaming.py                                                         |   22 +++++-----------------
 egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online/infer.py |    2 +-
 3 files changed, 8 insertions(+), 20 deletions(-)

diff --git a/egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online/infer.py b/egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online/infer.py
index 4fd4cdf..808084f 100644
--- a/egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online/infer.py
+++ b/egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online/infer.py
@@ -34,6 +34,6 @@
     rec_result = inference_pipeline(audio_in=speech[sample_offset: sample_offset + stride_size],
                                     param_dict=param_dict)
     if len(rec_result) != 0:
-        final_result += rec_result['text'][0]
+        final_result += " ".join(rec_result['text']) + " "
         print(rec_result)
 print(final_result)
diff --git a/egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online/infer.py b/egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online/infer.py
index 0066c7b..0ecf1ab 100644
--- a/egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online/infer.py
+++ b/egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online/infer.py
@@ -34,6 +34,6 @@
     rec_result = inference_pipeline(audio_in=speech[sample_offset: sample_offset + stride_size],
                                     param_dict=param_dict)
     if len(rec_result) != 0:
-        final_result += rec_result['text'][0]
+        final_result += " ".join(rec_result['text']) + " "
         print(rec_result)
-print(final_result)
+print(final_result.strip())
diff --git a/funasr/bin/asr_inference_paraformer_streaming.py b/funasr/bin/asr_inference_paraformer_streaming.py
index bf5590c..be0d752 100644
--- a/funasr/bin/asr_inference_paraformer_streaming.py
+++ b/funasr/bin/asr_inference_paraformer_streaming.py
@@ -239,7 +239,7 @@
                         feats_len = torch.tensor([feats_chunk2.shape[1]])
                         results_chunk2 = self.infer(feats_chunk2, feats_len, cache)
 
-                        return ["".join(results_chunk1 + results_chunk2)]
+                        return [" ".join(results_chunk1 + results_chunk2)]
 
                 results = self.infer(feats, feats_len, cache)
 
@@ -299,12 +299,9 @@
 
                 # Change integer-ids to tokens
                 token = self.converter.ids2tokens(token_int)
+                token = " ".join(token)
 
-                if self.tokenizer is not None:
-                    text = self.tokenizer.tokens2text(token)
-                else:
-                    text = None
-                results.append(text)
+                results.append(token)
 
         # assert check_return_type(results)
         return results
@@ -555,8 +552,8 @@
                 input_lens = torch.tensor([stride_size])
                 asr_result = speech2text(cache, raw_inputs[:, sample_offset: sample_offset + stride_size], input_lens)
                 if len(asr_result) != 0: 
-                    final_result += asr_result[0]
-            item = {'key': "utt", 'value': [final_result]}
+                    final_result += " ".join(asr_result) + " "
+            item = {'key': "utt", 'value': [final_result.strip()]}
         else:
             input_lens = torch.tensor([raw_inputs.shape[1]])
             cache["encoder"]["is_final"] = is_final
@@ -750,12 +747,3 @@
 if __name__ == "__main__":
     main()
 
-    # from modelscope.pipelines import pipeline
-    # from modelscope.utils.constant import Tasks
-    #
-    # inference_16k_pipline = pipeline(
-    #     task=Tasks.auto_speech_recognition,
-    #     model='damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch')
-    #
-    # rec_result = inference_16k_pipline(audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav')
-    # print(rec_result)

--
Gitblit v1.9.1