shixian.shi
2024-01-15 55c09aeaa25b4bb88a50e09ba68fa6ff00a6d676
update readme, fix seaco bug
5个文件已修改
41 ■■■■ 已修改文件
README.md 21 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/industrial_data_pretraining/seaco_paraformer/demo.py 6 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/download/name_maps_from_hub.py 7 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/ct_transformer/model.py 1 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/seaco_paraformer/model.py 6 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
README.md
@@ -90,12 +90,15 @@
### Speech Recognition (Non-streaming)
```python
from funasr import AutoModel
model = AutoModel(model="paraformer-zh")
# for the long duration wav, you could add vad model
# model = AutoModel(model="paraformer-zh", vad_model="fsmn-vad", punc_model="ct-punc")
res = model(input="asr_example_zh.wav", batch_size=64)
# paraformer-zh is a multi-functional asr model
# use vad, punc, spk or not as you need
model = AutoModel(model="paraformer-zh", model_revision="v2.0.2", \
                  vad_model="fsmn-vad", vad_model_revision="v2.0.2", \
                  punc_model="ct-punc-c", punc_model_revision="v2.0.2", \
                  spk_model="cam++", spk_model_revision="v2.0.2")
res = model(input=f"{model.model_path}/example/asr_example.wav",
            batch_size=16,
            hotword='魔搭')
print(res)
```
Note: `model_hub`: represents the model repository, `ms` stands for selecting ModelScope download, `hf` stands for selecting Huggingface download.
@@ -108,7 +111,7 @@
encoder_chunk_look_back = 4 #number of chunks to lookback for encoder self-attention
decoder_chunk_look_back = 1 #number of encoder chunks to lookback for decoder cross-attention
model = AutoModel(model="paraformer-zh-streaming", model_revision="v2.0.0")
model = AutoModel(model="paraformer-zh-streaming", model_revision="v2.0.2")
import soundfile
import os
@@ -163,7 +166,7 @@
```python
from funasr import AutoModel
model = AutoModel(model="ct-punc", model_revision="v2.0.1")
model = AutoModel(model="ct-punc", model_revision="v2.0.2")
res = model(input="那今天的会就到这里吧 happy new year 明年见")
print(res)
@@ -172,7 +175,7 @@
```python
from funasr import AutoModel
model = AutoModel(model="fa-zh", model_revision="v2.0.0")
model = AutoModel(model="fa-zh", model_revision="v2.0.2")
wav_file = f"{model.model_path}/example/asr_example.wav"
text_file = f"{model.model_path}/example/asr_example.wav"
examples/industrial_data_pretraining/seaco_paraformer/demo.py
@@ -11,8 +11,10 @@
                  vad_model_revision="v2.0.2",
                  punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
                  punc_model_revision="v2.0.2",
                  spk_model="damo/speech_campplus_sv_zh-cn_16k-common",
                  spk_model="v2.0.2",
                  )
res = model(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav",
            hotword='达摩院 磨搭')
res = model(input=f"{model.model_path}/example/asr_example.wav",
            hotword='达摩院 魔搭')
print(res)
funasr/download/name_maps_from_hub.py
@@ -1,14 +1,13 @@
name_maps_ms = {
    "paraformer-zh": "damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
    "paraformer-zh-spk": "damo/speech_paraformer-large-vad-punc-spk_asr_nat-zh-cn",
    "paraformer-zh": "damo/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
    "paraformer-en": "damo/speech_paraformer-large-vad-punc_asr_nat-en-16k-common-vocab10020",
    "paraformer-en-spk": "damo/speech_paraformer-large-vad-punc_asr_nat-en-16k-common-vocab10020",
    "paraformer-zh-streaming": "damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online",
    "fsmn-vad": "damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
    "ct-punc": "damo/punc_ct-transformer_cn-en-common-vocab471067-large",
    "ct-punc-c": "damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
    "fa-zh": "damo/speech_timestamp_prediction-v1-16k-offline",
    "cam++": "damo/speech_campplus_sv_zh-cn_16k-common",
}
name_maps_hf = {
funasr/models/ct_transformer/model.py
@@ -344,7 +344,6 @@
                punc_array = punctuations
            else:
                punc_array = torch.cat([punc_array, punctuations], dim=0)
        result_i = {"key": key[0], "text": new_mini_sentence_out, "punc_array": punc_array}
        results.append(result_i)
    
funasr/models/seaco_paraformer/model.py
@@ -212,7 +212,7 @@
                               ys_pad_lens, 
                               hw_list,
                               nfilter=50,
                                 seaco_weight=1.0):
                               seaco_weight=1.0):
        # decoder forward
        decoder_out, decoder_hidden, _ = self.decoder(encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, return_hidden=True, return_both=True)
        decoder_pred = torch.log_softmax(decoder_out, dim=-1)
@@ -254,10 +254,9 @@
            
            dha_output = self.hotword_output_layer(merged)  # remove the last token in loss calculation
            dha_pred = torch.log_softmax(dha_output, dim=-1)
            # import pdb; pdb.set_trace()
            def _merge_res(dec_output, dha_output):
                lmbd = torch.Tensor([seaco_weight] * dha_output.shape[0])
                dha_ids = dha_output.max(-1)[-1][0]
                dha_ids = dha_output.max(-1)[-1]# [0]
                dha_mask = (dha_ids == 8377).int().unsqueeze(-1)
                a = (1 - lmbd) / lmbd
                b = 1 / lmbd
@@ -267,6 +266,7 @@
                logits = dec_output * dha_mask + dha_output[:,:,:] * (1-dha_mask)
                return logits
            merged_pred = _merge_res(decoder_pred, dha_pred)
            # import pdb; pdb.set_trace()
            return merged_pred
        else:
            return decoder_pred