zhifu gao
2024-03-15 675b4605e8d1d9a406f5e6fc3bc989ddc932b04b
Dev gzf llm (#1506)

* update

* update

* update

* update onnx

* update with main (#1492)

* contextual&seaco ONNX export (#1481)

* contextual&seaco ONNX export

* update ContextualEmbedderExport2

* update ContextualEmbedderExport2

* update code

* onnx (#1482)

* qwenaudio qwenaudiochat

* qwenaudio qwenaudiochat

* whisper

* whisper

* llm

* llm

* llm

* llm

* llm

* llm

* llm

* llm

* export onnx

* export onnx

* export onnx

* dingding

* dingding

* llm

* doc

* onnx

* onnx

* onnx

* onnx

* onnx

* onnx

* v1.0.15

* qwenaudio

* qwenaudio

* issue doc

* update

* update

* bugfix

* onnx

* update export calling

* update codes

* remove useless code

* update code

---------

Co-authored-by: zhifu gao <zhifu.gzf@alibaba-inc.com>

* acknowledge

---------

Co-authored-by: Shi Xian <40013335+R1ckShi@users.noreply.github.com>

* update onnx

* update onnx

* train update

* train update

* train update

* train update

* punc update

---------

Co-authored-by: Shi Xian <40013335+R1ckShi@users.noreply.github.com>
29个文件已修改
2个文件已添加
388 ■■■■■ 已修改文件
README.md 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
README_zh.md 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/industrial_data_pretraining/bicif_paraformer/demo.py 8 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/industrial_data_pretraining/bicif_paraformer/demo.sh 11 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/industrial_data_pretraining/bicif_paraformer/export.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/industrial_data_pretraining/campplus_sv/demo.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/industrial_data_pretraining/contextual_paraformer/demo.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/industrial_data_pretraining/contextual_paraformer/demo.sh 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/industrial_data_pretraining/ct_transformer/demo.py 4 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/industrial_data_pretraining/ct_transformer/demo.sh 4 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/industrial_data_pretraining/ct_transformer_streaming/demo.sh 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/industrial_data_pretraining/emotion2vec/demo.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/industrial_data_pretraining/fsmn_vad_streaming/demo.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/industrial_data_pretraining/fsmn_vad_streaming/demo.sh 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/industrial_data_pretraining/monotonic_aligner/demo.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/industrial_data_pretraining/monotonic_aligner/demo.sh 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/industrial_data_pretraining/paraformer-zh-spk/demo.py 8 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/industrial_data_pretraining/paraformer-zh-spk/demo.sh 8 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/industrial_data_pretraining/paraformer/demo.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/industrial_data_pretraining/paraformer/finetune.sh 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/industrial_data_pretraining/paraformer/infer.sh 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/industrial_data_pretraining/scama/demo.sh 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/industrial_data_pretraining/seaco_paraformer/demo.py 6 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/industrial_data_pretraining/seaco_paraformer/demo.sh 6 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/industrial_data_pretraining/uniasr/demo.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/industrial_data_pretraining/uniasr/demo.sh 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/ct_transformer/export_meta.py 67 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/ct_transformer/model.py 78 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/ct_transformer_streaming/export_meta.py 77 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/ct_transformer_streaming/model.py 67 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
setup.py 8 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
README.md
@@ -112,7 +112,7 @@
from funasr import AutoModel
# paraformer-zh is a multi-functional asr model
# use vad, punc, spk or not as you need
model = AutoModel(model="paraformer-zh",  vad_model="fsmn-vad",  punc_model="ct-punc-c",
model = AutoModel(model="paraformer-zh",  vad_model="fsmn-vad",  punc_model="ct-punc",
                  # spk_model="cam++", 
                  )
res = model.generate(input=f"{model.model_path}/example/asr_example.wav", 
README_zh.md
@@ -106,7 +106,7 @@
from funasr import AutoModel
# paraformer-zh is a multi-functional asr model
# use vad, punc, spk or not as you need
model = AutoModel(model="paraformer-zh",  vad_model="fsmn-vad", punc_model="ct-punc-c",
model = AutoModel(model="paraformer-zh",  vad_model="fsmn-vad", punc_model="ct-punc",
                  # spk_model="cam++"
                  )
res = model.generate(input=f"{model.model_path}/example/asr_example.wav", 
examples/industrial_data_pretraining/bicif_paraformer/demo.py
@@ -5,13 +5,13 @@
from funasr import AutoModel
model = AutoModel(model="damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
model = AutoModel(model="iic/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
                  model_revision="v2.0.4",
                  vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
                  vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
                  vad_model_revision="v2.0.4",
                  punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
                  punc_model="iic/punc_ct-transformer_cn-en-common-vocab471067-large",
                  punc_model_revision="v2.0.4",
                  # spk_model="damo/speech_campplus_sv_zh-cn_16k-common",
                  # spk_model="iic/speech_campplus_sv_zh-cn_16k-common",
                  # spk_model_revision="v2.0.2",
                  )
examples/industrial_data_pretraining/bicif_paraformer/demo.sh
@@ -1,11 +1,12 @@
model="damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
model="iic/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
model_revision="v2.0.4"
vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch"
vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch"
vad_model_revision="v2.0.4"
punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch"
punc_model_revision="v2.0.3"
spk_model="damo/speech_campplus_sv_zh-cn_16k-common"
#punc_model="iic/punc_ct-transformer_zh-cn-common-vocab272727-pytorch"
punc_model="iic/punc_ct-transformer_cn-en-common-vocab471067-large"
punc_model_revision="v2.0.4"
spk_model="iic/speech_campplus_sv_zh-cn_16k-common"
spk_model_revision="v2.0.2"
python funasr/bin/inference.py \
examples/industrial_data_pretraining/bicif_paraformer/export.py
@@ -7,7 +7,7 @@
from funasr import AutoModel
model = AutoModel(model="damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
model = AutoModel(model="iic/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
                  model_revision="v2.0.4", device="cpu")
res = model.export(type="onnx", quantize=False)
examples/industrial_data_pretraining/campplus_sv/demo.py
@@ -5,7 +5,7 @@
from funasr import AutoModel
model = AutoModel(model="damo/speech_campplus_sv_zh-cn_16k-common",
model = AutoModel(model="iic/speech_campplus_sv_zh-cn_16k-common",
                  model_revision="v2.0.2",
                  )
examples/industrial_data_pretraining/contextual_paraformer/demo.py
@@ -5,7 +5,7 @@
from funasr import AutoModel
model = AutoModel(model="damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404", model_revision="v2.0.4")
model = AutoModel(model="iic/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404", model_revision="v2.0.4")
res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav",
            hotword='达摩院 魔搭')
examples/industrial_data_pretraining/contextual_paraformer/demo.sh
@@ -1,5 +1,5 @@
model="damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404"
model="iic/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404"
model_revision="v2.0.4"
python ../../../funasr/bin/inference.py \
examples/industrial_data_pretraining/ct_transformer/demo.py
@@ -5,7 +5,7 @@
from funasr import AutoModel
model = AutoModel(model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch", model_revision="v2.0.4")
model = AutoModel(model="iic/punc_ct-transformer_zh-cn-common-vocab272727-pytorch", model_revision="v2.0.4")
res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_text/punc_example.txt")
print(res)
@@ -13,7 +13,7 @@
from funasr import AutoModel
model = AutoModel(model="damo/punc_ct-transformer_cn-en-common-vocab471067-large", model_revision="v2.0.4")
model = AutoModel(model="iic/punc_ct-transformer_cn-en-common-vocab471067-large", model_revision="v2.0.4")
res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_text/punc_example.txt")
print(res)
examples/industrial_data_pretraining/ct_transformer/demo.sh
@@ -1,8 +1,8 @@
#model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch"
#model="iic/punc_ct-transformer_zh-cn-common-vocab272727-pytorch"
#model_revision="v2.0.4"
model="damo/punc_ct-transformer_cn-en-common-vocab471067-large"
model="iic/punc_ct-transformer_cn-en-common-vocab471067-large"
model_revision="v2.0.4"
python funasr/bin/inference.py \
examples/industrial_data_pretraining/ct_transformer_streaming/demo.sh
@@ -1,5 +1,5 @@
model="damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727"
model="iic/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727"
model_revision="v2.0.4"
python funasr/bin/inference.py \
examples/industrial_data_pretraining/emotion2vec/demo.py
@@ -5,7 +5,7 @@
from funasr import AutoModel
# model="damo/emotion2vec_base"
# model="iic/emotion2vec_base"
model = AutoModel(model="iic/emotion2vec_base_finetuned", model_revision="v2.0.4")
wav_file = f"{model.model_path}/example/test.wav"
examples/industrial_data_pretraining/fsmn_vad_streaming/demo.py
@@ -6,7 +6,7 @@
from funasr import AutoModel
wav_file = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav"
model = AutoModel(model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch", model_revision="v2.0.4")
model = AutoModel(model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch", model_revision="v2.0.4")
res = model.generate(input=wav_file)
print(res)
examples/industrial_data_pretraining/fsmn_vad_streaming/demo.sh
@@ -1,6 +1,6 @@
model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch"
model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch"
model_revision="v2.0.4"
python funasr/bin/inference.py \
examples/industrial_data_pretraining/monotonic_aligner/demo.py
@@ -5,7 +5,7 @@
from funasr import AutoModel
model = AutoModel(model="damo/speech_timestamp_prediction-v1-16k-offline", model_revision="v2.0.4")
model = AutoModel(model="iic/speech_timestamp_prediction-v1-16k-offline", model_revision="v2.0.4")
res = model.generate(input=("https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav",
                   "欢迎大家来到魔搭社区进行体验"),
examples/industrial_data_pretraining/monotonic_aligner/demo.sh
@@ -1,5 +1,5 @@
model="damo/speech_timestamp_prediction-v1-16k-offline"
model="iic/speech_timestamp_prediction-v1-16k-offline"
model_revision="v2.0.4"
python funasr/bin/inference.py \
examples/industrial_data_pretraining/paraformer-zh-spk/demo.py
@@ -5,13 +5,13 @@
from funasr import AutoModel
model = AutoModel(model="damo/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
model = AutoModel(model="iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
                  model_revision="v2.0.4",
                  vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
                  vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
                  vad_model_revision="v2.0.4",
                  punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
                  punc_model="iic/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
                  punc_model_revision="v2.0.4",
                  spk_model="damo/speech_campplus_sv_zh-cn_16k-common",
                  spk_model="iic/speech_campplus_sv_zh-cn_16k-common",
                  spk_model_revision="v2.0.2"
                  )
examples/industrial_data_pretraining/paraformer-zh-spk/demo.sh
@@ -1,11 +1,11 @@
model="damo/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
model="iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
model_revision="v2.0.4"
vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch"
vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch"
vad_model_revision="v2.0.4"
punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch"
punc_model="iic/punc_ct-transformer_zh-cn-common-vocab272727-pytorch"
punc_model_revision="v2.0.4"
spk_model="damo/speech_campplus_sv_zh-cn_16k-common"
spk_model="iic/speech_campplus_sv_zh-cn_16k-common"
spk_model_revision="v2.0.2"
python funasr/bin/inference.py \
examples/industrial_data_pretraining/paraformer/demo.py
@@ -22,7 +22,7 @@
''' can not use currently
from funasr import AutoFrontend
frontend = AutoFrontend(model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch", model_revision="v2.0.4")
frontend = AutoFrontend(model="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch", model_revision="v2.0.4")
fbanks = frontend(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav", batch_size=2)
examples/industrial_data_pretraining/paraformer/finetune.sh
@@ -32,7 +32,7 @@
--nnodes 1 \
--nproc_per_node ${gpu_num} \
funasr/bin/train.py \
++model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch" \
++model="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch" \
++model_revision="v2.0.4" \
++train_data_set_list="${train_data}" \
++valid_data_set_list="${val_data}" \
examples/industrial_data_pretraining/paraformer/infer.sh
@@ -8,7 +8,7 @@
output_dir="./outputs/debug"
model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
model="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
model_revision="v2.0.4"
device="cuda:0" # "cuda:0" for gpu0, "cuda:1" for gpu1, "cpu"
examples/industrial_data_pretraining/scama/demo.sh
@@ -1,5 +1,5 @@
model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online"
model="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online"
model_revision="v2.0.4"
python funasr/bin/inference.py \
examples/industrial_data_pretraining/seaco_paraformer/demo.py
@@ -7,11 +7,11 @@
model = AutoModel(model="iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
                  model_revision="v2.0.4",
                  # vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
                  # vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
                  # vad_model_revision="v2.0.4",
                  # punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
                  # punc_model="iic/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
                  # punc_model_revision="v2.0.4",
                  # spk_model="damo/speech_campplus_sv_zh-cn_16k-common",
                  # spk_model="iic/speech_campplus_sv_zh-cn_16k-common",
                  # spk_model_revision="v2.0.2",
                  )
examples/industrial_data_pretraining/seaco_paraformer/demo.sh
@@ -1,9 +1,9 @@
model="damo/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
model="iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
model_revision="v2.0.4"
vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch"
vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch"
vad_model_revision="v2.0.4"
punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch"
punc_model="iic/punc_ct-transformer_zh-cn-common-vocab272727-pytorch"
punc_model_revision="v2.0.4"
python funasr/bin/inference.py \
examples/industrial_data_pretraining/uniasr/demo.py
@@ -16,7 +16,7 @@
''' can not use currently
from funasr import AutoFrontend
frontend = AutoFrontend(model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch", model_revision="v2.0.4")
frontend = AutoFrontend(model="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch", model_revision="v2.0.4")
fbanks = frontend(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav", batch_size=2)
examples/industrial_data_pretraining/uniasr/demo.sh
@@ -1,5 +1,5 @@
model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
model="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
model_revision="v2.0.4"
python funasr/bin/inference.py \
funasr/models/ct_transformer/export_meta.py
New file
@@ -0,0 +1,67 @@
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
import types
import torch
from funasr.register import tables
def export_rebuild_model(model, **kwargs):
    is_onnx = kwargs.get("type", "onnx") == "onnx"
    encoder_class = tables.encoder_classes.get(kwargs["encoder"] + "Export")
    model.encoder = encoder_class(model.encoder, onnx=is_onnx)
    model.forward = types.MethodType(export_forward, model)
    model.export_dummy_inputs = types.MethodType(export_dummy_inputs, model)
    model.export_input_names = types.MethodType(export_input_names, model)
    model.export_output_names = types.MethodType(export_output_names, model)
    model.export_dynamic_axes = types.MethodType(export_dynamic_axes, model)
    model.export_name = types.MethodType(export_name, model)
    return model
def export_forward(self, inputs: torch.Tensor, text_lengths: torch.Tensor):
    """Compute loss value from buffer sequences.
    Args:
        input (torch.Tensor): Input ids. (batch, len)
        hidden (torch.Tensor): Target ids. (batch, len)
    """
    x = self.embed(inputs)
    h, _ = self.encoder(x, text_lengths)
    y = self.decoder(h)
    return y
def export_dummy_inputs(self):
    length = 120
    text_indexes = torch.randint(0, self.embed.num_embeddings, (2, length)).type(torch.int32)
    text_lengths = torch.tensor([length-20, length], dtype=torch.int32)
    return (text_indexes, text_lengths)
def export_input_names(self):
    return ['inputs', 'text_lengths']
def export_output_names(self):
    return ['logits']
def export_dynamic_axes(self):
    return {
        'inputs': {
            0: 'batch_size',
            1: 'feats_length'
        },
        'text_lengths': {
            0: 'batch_size',
        },
        'logits': {
            0: 'batch_size',
            1: 'logits_length'
        },
    }
def export_name(self):
    return "model.onnx"
funasr/models/ct_transformer/model.py
@@ -17,7 +17,10 @@
from funasr.utils.load_utils import load_audio_text_image_video
from funasr.models.transformer.utils.nets_utils import make_pad_mask
from funasr.models.ct_transformer.utils import split_to_mini_sentence, split_words
try:
    import jieba
except:
    pass
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
    from torch.cuda.amp import autocast
else:
@@ -69,6 +72,10 @@
        self.sos = sos
        self.eos = eos
        self.sentence_end_id = sentence_end_id
        self.jieba_usr_dict = None
        if kwargs.get("jieba_usr_dict", None) is not None:
            jieba.load_userdict(kwargs["jieba_usr_dict"])
            self.jieba_usr_dict = jieba
        
        
@@ -237,14 +244,8 @@
        # text = data_in[0]
        # text_lengths = data_lengths[0] if data_lengths is not None else None
        split_size = kwargs.get("split_size", 20)
        jieba_usr_dict = kwargs.get("jieba_usr_dict", None)
        if jieba_usr_dict and isinstance(jieba_usr_dict, str):
            import jieba
            jieba.load_userdict(jieba_usr_dict)
            jieba_usr_dict = jieba
            kwargs["jieba_usr_dict"] = "jieba_usr_dict"
        tokens = split_words(text, jieba_usr_dict=jieba_usr_dict)
        tokens = split_words(text, jieba_usr_dict=self.jieba_usr_dict)
        tokens_int = tokenizer.encode(tokens)
        mini_sentences = split_to_mini_sentence(tokens, split_size)
@@ -347,7 +348,7 @@
            else:
                punc_array = torch.cat([punc_array, punctuations], dim=0)
        # post processing when using word level punc model
        if jieba_usr_dict:
        if self.jieba_usr_dict is not None:
            len_tokens = len(tokens)
            new_punc_array = copy.copy(punc_array).tolist()
            # for i, (token, punc_id) in enumerate(zip(tokens[::-1], punc_array.tolist()[::-1])):
@@ -364,57 +365,10 @@
        results.append(result_i)
        return results, meta_data
    def export(
        self,
        **kwargs,
    ):
    def export(self, **kwargs):
        from .export_meta import export_rebuild_model
        models = export_rebuild_model(model=self, **kwargs)
        return models
        is_onnx = kwargs.get("type", "onnx") == "onnx"
        encoder_class = tables.encoder_classes.get(kwargs["encoder"]+"Export")
        self.encoder = encoder_class(self.encoder, onnx=is_onnx)
        self.forward = self.export_forward
        return self
    def export_forward(self, inputs: torch.Tensor, text_lengths: torch.Tensor):
        """Compute loss value from buffer sequences.
        Args:
            input (torch.Tensor): Input ids. (batch, len)
            hidden (torch.Tensor): Target ids. (batch, len)
        """
        x = self.embed(inputs)
        h, _ = self.encoder(x, text_lengths)
        y = self.decoder(h)
        return y
    def export_dummy_inputs(self):
        length = 120
        text_indexes = torch.randint(0, self.embed.num_embeddings, (2, length)).type(torch.int32)
        text_lengths = torch.tensor([length-20, length], dtype=torch.int32)
        return (text_indexes, text_lengths)
    def export_input_names(self):
        return ['inputs', 'text_lengths']
    def export_output_names(self):
        return ['logits']
    def export_dynamic_axes(self):
        return {
            'inputs': {
                0: 'batch_size',
                1: 'feats_length'
            },
            'text_lengths': {
                0: 'batch_size',
            },
            'logits': {
                0: 'batch_size',
                1: 'logits_length'
            },
        }
    def export_name(self):
        return "model.onnx"
funasr/models/ct_transformer_streaming/export_meta.py
New file
@@ -0,0 +1,77 @@
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
import types
import torch
from funasr.register import tables
def export_rebuild_model(model, **kwargs):
    is_onnx = kwargs.get("type", "onnx") == "onnx"
    encoder_class = tables.encoder_classes.get(kwargs["encoder"] + "Export")
    model.encoder = encoder_class(model.encoder, onnx=is_onnx)
    model.forward = types.MethodType(export_forward, model)
    model.export_dummy_inputs = types.MethodType(export_dummy_inputs, model)
    model.export_input_names = types.MethodType(export_input_names, model)
    model.export_output_names = types.MethodType(export_output_names, model)
    model.export_dynamic_axes = types.MethodType(export_dynamic_axes, model)
    model.export_name = types.MethodType(export_name, model)
    return model
def export_forward(self, inputs: torch.Tensor,
            text_lengths: torch.Tensor,
            vad_indexes: torch.Tensor,
            sub_masks: torch.Tensor,
            ):
    """Compute loss value from buffer sequences.
    Args:
        input (torch.Tensor): Input ids. (batch, len)
        hidden (torch.Tensor): Target ids. (batch, len)
    """
    x = self.embed(inputs)
    # mask = self._target_mask(input)
    h, _ = self.encoder(x, text_lengths, vad_indexes, sub_masks)
    y = self.decoder(h)
    return y
def export_dummy_inputs(self):
    length = 120
    text_indexes = torch.randint(0, self.embed.num_embeddings, (1, length)).type(torch.int32)
    text_lengths = torch.tensor([length], dtype=torch.int32)
    vad_mask = torch.ones(length, length, dtype=torch.float32)[None, None, :, :]
    sub_masks = torch.ones(length, length, dtype=torch.float32)
    sub_masks = torch.tril(sub_masks).type(torch.float32)
    return (text_indexes, text_lengths, vad_mask, sub_masks[None, None, :, :])
def export_input_names(self):
    return ['inputs', 'text_lengths', 'vad_masks', 'sub_masks']
def export_output_names(self):
    return ['logits']
def export_dynamic_axes(self):
    return {
        'inputs': {
            1: 'feats_length'
        },
        'vad_masks': {
            2: 'feats_length1',
            3: 'feats_length2'
        },
        'sub_masks': {
            2: 'feats_length1',
            3: 'feats_length2'
        },
        'logits': {
            1: 'logits_length'
        },
    }
def export_name(self):
    return "model.onnx"
funasr/models/ct_transformer_streaming/model.py
@@ -173,68 +173,9 @@
    
        return results, meta_data
    def export(
        self,
        **kwargs,
    ):
    def export(self, **kwargs):
    
        is_onnx = kwargs.get("type", "onnx") == "onnx"
        encoder_class = tables.encoder_classes.get(kwargs["encoder"] + "Export")
        self.encoder = encoder_class(self.encoder, onnx=is_onnx)
        self.forward = self.export_forward
        return self
        from .export_meta import export_rebuild_model
        models = export_rebuild_model(model=self, **kwargs)
        return models
    def export_forward(self, inputs: torch.Tensor,
                text_lengths: torch.Tensor,
                vad_indexes: torch.Tensor,
                sub_masks: torch.Tensor,
                ):
        """Compute loss value from buffer sequences.
        Args:
            input (torch.Tensor): Input ids. (batch, len)
            hidden (torch.Tensor): Target ids. (batch, len)
        """
        x = self.embed(inputs)
        # mask = self._target_mask(input)
        h, _ = self.encoder(x, text_lengths, vad_indexes, sub_masks)
        y = self.decoder(h)
        return y
    def export_dummy_inputs(self):
        length = 120
        text_indexes = torch.randint(0, self.embed.num_embeddings, (1, length)).type(torch.int32)
        text_lengths = torch.tensor([length], dtype=torch.int32)
        vad_mask = torch.ones(length, length, dtype=torch.float32)[None, None, :, :]
        sub_masks = torch.ones(length, length, dtype=torch.float32)
        sub_masks = torch.tril(sub_masks).type(torch.float32)
        return (text_indexes, text_lengths, vad_mask, sub_masks[None, None, :, :])
    def export_input_names(self):
        return ['inputs', 'text_lengths', 'vad_masks', 'sub_masks']
    def export_output_names(self):
        return ['logits']
    def export_dynamic_axes(self):
        return {
            'inputs': {
                1: 'feats_length'
            },
            'vad_masks': {
                2: 'feats_length1',
                3: 'feats_length2'
            },
            'sub_masks': {
                2: 'feats_length1',
                3: 'feats_length2'
            },
            'logits': {
                1: 'logits_length'
            },
        }
    def export_name(self):
        return "model.onnx"
setup.py
@@ -14,16 +14,14 @@
        "librosa",
        "jamo",  # For kss
        "PyYAML>=5.1.2",
        # "soundfile>=0.12.1",
        "soundfile>=0.12.1",
        "kaldiio>=2.17.0",
        "torch_complex",
        # "nltk>=3.4.5",
        # ASR
        "sentencepiece", # train
        "jieba",
        # "rotary_embedding_torch",
        "rotary_embedding_torch",
        # "ffmpeg-python",
        # TTS
        # "pypinyin>=0.44.0",
        # "espnet_tts_frontend",
        # ENH
@@ -54,6 +52,7 @@
        "torch_optimizer",
        "fairscale",
        "transformers",
        "openai-whisper"
    ],
    "setup": [
        "numpy",
@@ -96,6 +95,7 @@
    ],
}
requirements["all"].extend(requirements["train"])
requirements["all"].extend(requirements["llm"])
requirements["test"].extend(requirements["train"])
install_requires = requirements["install"]