From 3c632e4f1ac48714a449a39784ed955a3b536150 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期五, 24 二月 2023 11:37:10 +0800
Subject: [PATCH] Merge branch 'main' of github.com:alibaba-damo-academy/FunASR add

---
 funasr/bin/asr_inference_paraformer.py |    4 -
 funasr/bin/asr_inference_uniasr.py     |   25 ++++++------
 funasr/bin/asr_inference_uniasr_vad.py |   25 ++++++------
 docs_cn/build_task.md                  |   19 +++++++++
 docs_cn/modelscope_usages.md           |    2 
 docs/build_task.md                     |   21 ++++++++++
 6 files changed, 67 insertions(+), 29 deletions(-)

diff --git a/docs/build_task.md b/docs/build_task.md
index a45c820..be2d1af 100644
--- a/docs/build_task.md
+++ b/docs/build_task.md
@@ -103,4 +103,23 @@
         )
     return model
 ```
-This function defines the detail of the model. For different speech recognition models, the same speech recognition `Task` can usually be shared and the remaining thing needed to be done is to define a specific model in this function. For example, a speech recognition model with a standard encoder-decoder structure has been shown above. Specifically, it first defines each module of the model, including encoder, decoder, etc. and then combine these modules together to generate a complete model. In FunASR, the model needs to inherit `AbsESPnetModel` and the corresponding code can be seen in `funasr/train/abs_espnet_model.py`. The main function needed to be implemented is the `forward` function.
\ No newline at end of file
+This function defines the detail of the model. For different speech recognition models, the same speech recognition `Task` can usually be shared and the remaining thing needed to be done is to define a specific model in this function. For example, a speech recognition model with a standard encoder-decoder structure has been shown above. Specifically, it first defines each module of the model, including encoder, decoder, etc. and then combine these modules together to generate a complete model. In FunASR, the model needs to inherit `AbsESPnetModel` and the corresponding code can be seen in `funasr/train/abs_espnet_model.py`. The main function needed to be implemented is the `forward` function.
+
+Next, we take `SANMEncoder` as an example to introduce how to use a custom encoder as a part of the model when defining the specified model and the corresponding code can be seen in `funasr/models/encoder/sanm_encoder.py`. For a custom encoder, in addition to inheriting the common encoder class `AbsEncoder`, it is also necessary to define the `forward` function to achieve the forward computation of the `encoder`. After defining the `encoder`, it should also be registered in the `Task`. The corresponding code example can be seen as below:
+```python
+encoder_choices = ClassChoices(
+    "encoder",
+    classes=dict(
+        conformer=ConformerEncoder,
+        transformer=TransformerEncoder,
+        rnn=RNNEncoder,
+        sanm=SANMEncoder,
+        sanm_chunk_opt=SANMEncoderChunkOpt,
+        data2vec_encoder=Data2VecEncoder,
+        mfcca_enc=MFCCAEncoder,
+    ),
+    type_check=AbsEncoder,
+    default="rnn",
+)
+```
+In this code, `sanm=SANMEncoder` takes the newly defined `SANMEncoder` as an optional choice of the `encoder`. Once the user specifies the `encoder` as `sanm` in the configuration file, the `SANMEncoder` will be correspondingly employed as the `encoder` module of the model.
\ No newline at end of file
diff --git a/docs_cn/build_task.md b/docs_cn/build_task.md
index 5d78993..c23b19f 100644
--- a/docs_cn/build_task.md
+++ b/docs_cn/build_task.md
@@ -103,3 +103,22 @@
     return model
 ```
 璇ュ嚱鏁板畾涔変簡鍏蜂綋鐨勬ā鍨嬨�傚浜庝笉鍚岀殑璇煶璇嗗埆妯″瀷锛屽線寰�鍙互鍏辩敤鍚屼竴涓闊宠瘑鍒玚Task`锛岄澶栭渶瑕佸仛鐨勬槸鍦ㄦ鍑芥暟涓畾涔夌壒瀹氱殑妯″瀷銆備緥濡傦紝杩欓噷缁欏嚭鐨勬槸涓�涓爣鍑嗙殑encoder-decoder缁撴瀯鐨勮闊宠瘑鍒ā鍨嬨�傚叿浣撳湴锛屽厛瀹氫箟璇ユā鍨嬬殑鍚勪釜妯″潡锛屽寘鎷琫ncoder锛宒ecoder绛夛紝鐒跺悗鍦ㄥ皢杩欎簺妯″潡缁勫悎鍦ㄤ竴璧峰緱鍒颁竴涓畬鏁寸殑妯″瀷銆傚湪FunASR涓紝妯″瀷闇�瑕佺户鎵縛AbsESPnetModel`锛屽叾鍏蜂綋浠g爜瑙乣funasr/train/abs_espnet_model.py`锛屼富瑕侀渶瑕佸疄鐜扮殑鏄痐forward`鍑芥暟銆�
+
+涓嬮潰鎴戜滑灏嗕互`SANMEncoder`涓轰緥锛屼粙缁嶅浣曞湪瀹氫箟妯″瀷鐨勬椂鍊欙紝浣跨敤鑷畾涔夌殑`encoder`鏉ヤ綔涓烘ā鍨嬬殑缁勬垚閮ㄥ垎锛屽叾鍏蜂綋鐨勪唬鐮佽`funasr/models/encoder/sanm_encoder.py`銆傚浜庤嚜瀹氫箟鐨刞encoder`锛岄櫎浜嗛渶瑕佺户鎵块�氱敤鐨刞encoder`绫籤AbsEncoder`澶栵紝杩橀渶瑕佽嚜瀹氫箟`forward`鍑芥暟锛屽疄鐜癭encoder`鐨勫墠鍚戣绠椼�傚湪瀹氫箟瀹宍encoder`鍚庯紝杩橀渶瑕佸湪`Task`涓鍏惰繘琛屾敞鍐岋紝涓嬮潰缁欏嚭浜嗙浉搴旂殑浠g爜绀轰緥锛�
+```python
+encoder_choices = ClassChoices(
+    "encoder",
+    classes=dict(
+        conformer=ConformerEncoder,
+        transformer=TransformerEncoder,
+        rnn=RNNEncoder,
+        sanm=SANMEncoder,
+        sanm_chunk_opt=SANMEncoderChunkOpt,
+        data2vec_encoder=Data2VecEncoder,
+        mfcca_enc=MFCCAEncoder,
+    ),
+    type_check=AbsEncoder,
+    default="rnn",
+)
+```
+鍙互鐪嬪埌锛宍sanm=SANMEncoder`灏嗘柊瀹氫箟鐨刞SANMEncoder`浣滀负浜哷encoder`鐨勪竴绉嶅彲閫夐」锛屽綋鐢ㄦ埛鍦ㄩ厤缃枃浠朵腑鎸囧畾`encoder`涓篳sanm`鏃讹紝鍗充細鐩稿簲鍦板皢`SANMEncoder`浣滀负妯″瀷鐨刞encoder`妯″潡銆�
\ No newline at end of file
diff --git a/docs_cn/modelscope_usages.md b/docs_cn/modelscope_usages.md
index e704e5e..c91de76 100644
--- a/docs_cn/modelscope_usages.md
+++ b/docs_cn/modelscope_usages.md
@@ -1,4 +1,4 @@
-# ModelScope Usage
+# ModelScope 浣跨敤璇存槑
 ModelScope鏄樋閲屽反宸存帹鍑虹殑寮�婧愭ā鍨嬪嵆鏈嶅姟鍏变韩骞冲彴锛屼负骞垮ぇ瀛︽湳鐣岀敤鎴峰拰宸ヤ笟鐣岀敤鎴锋彁渚涚伒娲汇�佷究鎹风殑妯″瀷搴旂敤鏀寔銆傚叿浣撶殑浣跨敤鏂规硶鍜屽紑婧愭ā鍨嬪彲浠ュ弬瑙乕ModelScope](https://www.modelscope.cn/models?page=1&tasks=auto-speech-recognition) 銆傚湪璇煶鏂瑰悜锛屾垜浠彁渚涗簡鑷洖褰�/闈炶嚜鍥炲綊璇煶璇嗗埆锛岃闊抽璁粌锛屾爣鐐归娴嬬瓑妯″瀷锛岀敤鎴峰彲浠ユ柟渚夸娇鐢ㄣ��
 
 ## 鏁翠綋浠嬬粛
diff --git a/funasr/bin/asr_inference_paraformer.py b/funasr/bin/asr_inference_paraformer.py
index 487f750..b807a34 100644
--- a/funasr/bin/asr_inference_paraformer.py
+++ b/funasr/bin/asr_inference_paraformer.py
@@ -660,11 +660,9 @@
         hotword_list_or_file = None
         if param_dict is not None:
             hotword_list_or_file = param_dict.get('hotword')
-
         if 'hotword' in kwargs:
             hotword_list_or_file = kwargs['hotword']
-
-        if speech2text.hotword_list is None:
+        if hotword_list_or_file is not None or 'hotword' in kwargs:
             speech2text.hotword_list = speech2text.generate_hotwords_list(hotword_list_or_file)
 
         # 3. Build data-iterator
diff --git a/funasr/bin/asr_inference_uniasr.py b/funasr/bin/asr_inference_uniasr.py
index c50bf17..8b31fad 100644
--- a/funasr/bin/asr_inference_uniasr.py
+++ b/funasr/bin/asr_inference_uniasr.py
@@ -398,6 +398,19 @@
     else:
         device = "cpu"
     
+    if param_dict is not None and "decoding_model" in param_dict:
+        if param_dict["decoding_model"] == "fast":
+            decoding_ind = 0
+            decoding_mode = "model1"
+        elif param_dict["decoding_model"] == "normal":
+            decoding_ind = 0
+            decoding_mode = "model2"
+        elif param_dict["decoding_model"] == "offline":
+            decoding_ind = 1
+            decoding_mode = "model2"
+        else:
+            raise NotImplementedError("unsupported decoding model {}".format(param_dict["decoding_model"]))
+
     # 1. Set random-seed
     set_all_random_seed(seed)
 
@@ -440,18 +453,6 @@
             if isinstance(raw_inputs, torch.Tensor):
                 raw_inputs = raw_inputs.numpy()
             data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
-        if param_dict is not None and "decoding_model" in param_dict:
-            if param_dict["decoding_model"] == "fast":
-                speech2text.decoding_ind = 0
-                speech2text.decoding_mode = "model1"
-            elif param_dict["decoding_model"] == "normal":
-                speech2text.decoding_ind = 0
-                speech2text.decoding_mode = "model2"
-            elif param_dict["decoding_model"] == "offline":
-                speech2text.decoding_ind = 1
-                speech2text.decoding_mode = "model2"
-            else:
-                raise NotImplementedError("unsupported decoding model {}".format(param_dict["decoding_model"]))
         loader = ASRTask.build_streaming_iterator(
             data_path_and_name_and_type,
             dtype=dtype,
diff --git a/funasr/bin/asr_inference_uniasr_vad.py b/funasr/bin/asr_inference_uniasr_vad.py
index ac3b4b6..e5815df 100644
--- a/funasr/bin/asr_inference_uniasr_vad.py
+++ b/funasr/bin/asr_inference_uniasr_vad.py
@@ -398,6 +398,19 @@
     else:
         device = "cpu"
 
+    if param_dict is not None and "decoding_model" in param_dict:
+        if param_dict["decoding_model"] == "fast":
+            decoding_ind = 0
+            decoding_mode = "model1"
+        elif param_dict["decoding_model"] == "normal":
+            decoding_ind = 0
+            decoding_mode = "model2"
+        elif param_dict["decoding_model"] == "offline":
+            decoding_ind = 1
+            decoding_mode = "model2"
+        else:
+            raise NotImplementedError("unsupported decoding model {}".format(param_dict["decoding_model"]))
+
     # 1. Set random-seed
     set_all_random_seed(seed)
 
@@ -440,18 +453,6 @@
             if isinstance(raw_inputs, torch.Tensor):
                 raw_inputs = raw_inputs.numpy()
             data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
-        if param_dict is not None and "decoding_model" in param_dict:
-            if param_dict["decoding_model"] == "fast":
-                speech2text.decoding_ind = 0
-                speech2text.decoding_mode = "model1"
-            elif param_dict["decoding_model"] == "normal":
-                speech2text.decoding_ind = 0
-                speech2text.decoding_mode = "model2"
-            elif param_dict["decoding_model"] == "offline":
-                speech2text.decoding_ind = 1
-                speech2text.decoding_mode = "model2"
-            else:
-                raise NotImplementedError("unsupported decoding model {}".format(param_dict["decoding_model"]))
         loader = ASRTask.build_streaming_iterator(
             data_path_and_name_and_type,
             dtype=dtype,

--
Gitblit v1.9.1