游雁
2023-02-10 8bf1a6adbb9b412615a1d121ae2fa0c1776f1c48
exoprt model
3个文件已修改
180 ■■■■■ 已修改文件
funasr/export/README.md 15 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/export/export_model.py 81 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/export/models/__init__.py 84 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/export/README.md
@@ -16,17 +16,13 @@
output_dir = "../export"  # onnx/torchscripts model save path
export_model = ASRModelExportParaformer(cache_dir=output_dir, onnx=True)
export_model.export_from_modelscope('damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch')
export_model.export('damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch')
```
Export model from local path
```python
from funasr.export.export_model import ASRModelExportParaformer
output_dir = "../export"  # onnx/torchscripts model save path
export_model = ASRModelExportParaformer(cache_dir=output_dir, onnx=True)
export_model.export_from_local('/root/cache/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch')
export_model.export('/root/cache/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch')
```
## Export torchscripts format model
@@ -36,15 +32,12 @@
output_dir = "../export"  # onnx/torchscripts model save path
export_model = ASRModelExportParaformer(cache_dir=output_dir, onnx=False)
export_model.export_from_modelscope('damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch')
export_model.export('damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch')
```
Export model from local path
```python
from funasr.export.export_model import ASRModelExportParaformer
output_dir = "../export"  # onnx/torchscripts model save path
export_model = ASRModelExportParaformer(cache_dir=output_dir, onnx=False)
export_model.export_from_local('/root/cache/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch')
export_model.export('/root/cache/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch')
```
funasr/export/export_model.py
@@ -1,3 +1,4 @@
import json
from typing import Union, Dict
from pathlib import Path
from typeguard import check_argument_types
@@ -25,7 +26,7 @@
        logging.info("output dir: {}".format(self.cache_dir))
        self.onnx = onnx
    def export(
    def _export(
        self,
        model: Speech2Text,
        tag_name: str = None,
@@ -60,38 +61,66 @@
        model_script = torch.jit.trace(model, dummy_input)
        model_script.save(os.path.join(path, f'{model.model_name}.torchscripts'))
    def export_from_modelscope(
        self,
        tag_name: str = 'damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch',
    ):
    def export(self,
               tag_name: str = 'damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch',
               mode: str = 'paraformer',
               ):
        
        from funasr.tasks.asr import ASRTaskParaformer as ASRTask
        from modelscope.hub.snapshot_download import snapshot_download
        model_dir = snapshot_download(tag_name, cache_dir=self.cache_dir)
        asr_train_config = os.path.join(model_dir, 'config.yaml')
        asr_model_file = os.path.join(model_dir, 'model.pb')
        cmvn_file = os.path.join(model_dir, 'am.mvn')
        model, asr_train_args = ASRTask.build_model_from_file(
            asr_train_config, asr_model_file, cmvn_file, 'cpu'
        )
        self.export(model, tag_name)
    def export_from_local(
        self,
        tag_name: str = '/root/cache/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch',
    ):
        from funasr.tasks.asr import ASRTaskParaformer as ASRTask
        model_dir = tag_name
        if model_dir.startswith('damo/'):
            from modelscope.hub.snapshot_download import snapshot_download
            model_dir = snapshot_download(tag_name, cache_dir=self.cache_dir)
        asr_train_config = os.path.join(model_dir, 'config.yaml')
        asr_model_file = os.path.join(model_dir, 'model.pb')
        cmvn_file = os.path.join(model_dir, 'am.mvn')
        json_file = os.path.join(model_dir, 'configuration.json')
        if mode is None:
            import json
            with open(json_file, 'r') as f:
                config_data = json.load(f)
                mode = config_data['model']['model_config']['mode']
        if mode == 'paraformer':
            from funasr.tasks.asr import ASRTaskParaformer as ASRTask
        elif mode == 'uniasr':
            from funasr.tasks.asr import ASRTaskUniASR as ASRTask
        model, asr_train_args = ASRTask.build_model_from_file(
            asr_train_config, asr_model_file, cmvn_file, 'cpu'
        )
        self.export(model, tag_name)
    # def export_from_modelscope(
    #     self,
    #     tag_name: str = 'damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch',
    # ):
    #
    #     from funasr.tasks.asr import ASRTaskParaformer as ASRTask
    #     from modelscope.hub.snapshot_download import snapshot_download
    #
    #     model_dir = snapshot_download(tag_name, cache_dir=self.cache_dir)
    #     asr_train_config = os.path.join(model_dir, 'config.yaml')
    #     asr_model_file = os.path.join(model_dir, 'model.pb')
    #     cmvn_file = os.path.join(model_dir, 'am.mvn')
    #     model, asr_train_args = ASRTask.build_model_from_file(
    #         asr_train_config, asr_model_file, cmvn_file, 'cpu'
    #     )
    #     self.export(model, tag_name)
    #
    # def export_from_local(
    #     self,
    #     tag_name: str = '/root/cache/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch',
    # ):
    #
    #     from funasr.tasks.asr import ASRTaskParaformer as ASRTask
    #
    #     model_dir = tag_name
    #     asr_train_config = os.path.join(model_dir, 'config.yaml')
    #     asr_model_file = os.path.join(model_dir, 'model.pb')
    #     cmvn_file = os.path.join(model_dir, 'am.mvn')
    #     model, asr_train_args = ASRTask.build_model_from_file(
    #         asr_train_config, asr_model_file, cmvn_file, 'cpu'
    #     )
    #     self.export(model, tag_name)
    def _export_onnx(self, model, verbose, path, enc_size=None):
        if enc_size:
@@ -116,5 +145,5 @@
if __name__ == '__main__':
    output_dir = "../export"
    export_model = ASRModelExportParaformer(cache_dir=output_dir, onnx=False)
    export_model.export_from_modelscope('damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch')
    # export_model.export_from_local('/root/cache/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch')
    export_model.export('damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch')
    # export_model.export('/root/cache/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch')
funasr/export/models/__init__.py
@@ -1,42 +1,3 @@
# from .ctc import CTC
# from .joint_network import JointNetwork
#
# # encoder
# from espnet2.asr.encoder.rnn_encoder import RNNEncoder as espnetRNNEncoder
# from espnet2.asr.encoder.vgg_rnn_encoder import VGGRNNEncoder as espnetVGGRNNEncoder
# from espnet2.asr.encoder.contextual_block_transformer_encoder import ContextualBlockTransformerEncoder as espnetContextualTransformer
# from espnet2.asr.encoder.contextual_block_conformer_encoder import ContextualBlockConformerEncoder as espnetContextualConformer
# from espnet2.asr.encoder.transformer_encoder import TransformerEncoder as espnetTransformerEncoder
# from espnet2.asr.encoder.conformer_encoder import ConformerEncoder as espnetConformerEncoder
# from funasr.export.models.encoder.rnn import RNNEncoder
# from funasr.export.models.encoders import TransformerEncoder
# from funasr.export.models.encoders import ConformerEncoder
# from funasr.export.models.encoder.contextual_block_xformer import ContextualBlockXformerEncoder
#
# # decoder
# from espnet2.asr.decoder.rnn_decoder import RNNDecoder as espnetRNNDecoder
# from espnet2.asr.transducer.transducer_decoder import TransducerDecoder as espnetTransducerDecoder
# from funasr.export.models.decoder.rnn import (
#     RNNDecoder
# )
# from funasr.export.models.decoders import XformerDecoder
# from funasr.export.models.decoders import TransducerDecoder
#
# # lm
# from espnet2.lm.seq_rnn_lm import SequentialRNNLM as espnetSequentialRNNLM
# from espnet2.lm.transformer_lm import TransformerLM as espnetTransformerLM
# from .language_models.seq_rnn import SequentialRNNLM
# from .language_models.transformer import TransformerLM
#
# # frontend
# from espnet2.asr.frontend.s3prl import S3prlFrontend as espnetS3PRLModel
# from .frontends.s3prl import S3PRLModel
#
# from espnet2.asr.encoder.sanm_encoder import SANMEncoder_tf, SANMEncoderChunkOpt_tf
# from espnet_onnx.export.asr.models.encoders.transformer_sanm import TransformerEncoderSANM_tf
# from espnet2.asr.decoder.transformer_decoder import FsmnDecoderSCAMAOpt_tf
# from funasr.export.models.decoders import XformerDecoderSANM
from funasr.models.e2e_asr_paraformer import Paraformer
from funasr.export.models.e2e_asr_paraformer import Paraformer as Paraformer_export
@@ -45,47 +6,4 @@
    if isinstance(model, Paraformer):
        return Paraformer_export(model, **export_config)
    else:
        raise "The model is not exist!"
# def get_encoder(model, frontend, preencoder, predictor=None, export_config=None):
#     if isinstance(model, espnetRNNEncoder) or isinstance(model, espnetVGGRNNEncoder):
#         return RNNEncoder(model, frontend, preencoder, **export_config)
#     elif isinstance(model, espnetContextualTransformer) or isinstance(model, espnetContextualConformer):
#         return ContextualBlockXformerEncoder(model, **export_config)
#     elif isinstance(model, espnetTransformerEncoder):
#         return TransformerEncoder(model, frontend, preencoder, **export_config)
#     elif isinstance(model, espnetConformerEncoder):
#         return ConformerEncoder(model, frontend, preencoder, **export_config)
#     elif isinstance(model, SANMEncoder_tf) or isinstance(model, SANMEncoderChunkOpt_tf):
#         return TransformerEncoderSANM_tf(model, frontend, preencoder, predictor, **export_config)
#     else:
#         raise "The model is not exist!"
#
# def get_decoder(model, export_config):
#     if isinstance(model, espnetRNNDecoder):
#         return RNNDecoder(model, **export_config)
#     elif isinstance(model, espnetTransducerDecoder):
#         return TransducerDecoder(model, **export_config)
#     elif isinstance(model, FsmnDecoderSCAMAOpt_tf):
#         return XformerDecoderSANM(model, **export_config)
#     else:
#         return XformerDecoder(model, **export_config)
#
#
# def get_lm(model, export_config):
#     if isinstance(model, espnetSequentialRNNLM):
#         return SequentialRNNLM(model, **export_config)
#     elif isinstance(model, espnetTransformerLM):
#         return TransformerLM(model, **export_config)
#
#
# def get_frontend_models(model, export_config):
#     if isinstance(model, espnetS3PRLModel):
#         return S3PRLModel(model, **export_config)
#     else:
#         return None
#
        raise "The model is not exist!"