| funasr/export/README.md | ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史 | |
| funasr/export/export_model.py | ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史 | |
| funasr/export/models/__init__.py | ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史 |
funasr/export/README.md
@@ -16,17 +16,13 @@ output_dir = "../export" # onnx/torchscripts model save path export_model = ASRModelExportParaformer(cache_dir=output_dir, onnx=True) export_model.export_from_modelscope('damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch') export_model.export('damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch') ``` Export model from local path ```python from funasr.export.export_model import ASRModelExportParaformer output_dir = "../export" # onnx/torchscripts model save path export_model = ASRModelExportParaformer(cache_dir=output_dir, onnx=True) export_model.export_from_local('/root/cache/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch') export_model.export('/root/cache/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch') ``` ## Export torchscripts format model @@ -36,15 +32,12 @@ output_dir = "../export" # onnx/torchscripts model save path export_model = ASRModelExportParaformer(cache_dir=output_dir, onnx=False) export_model.export_from_modelscope('damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch') export_model.export('damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch') ``` Export model from local path ```python from funasr.export.export_model import ASRModelExportParaformer output_dir = "../export" # onnx/torchscripts model save path export_model = ASRModelExportParaformer(cache_dir=output_dir, onnx=False) export_model.export_from_local('/root/cache/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch') export_model.export('/root/cache/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch') ``` funasr/export/export_model.py
@@ -1,3 +1,4 @@ import json from typing import Union, Dict from pathlib import Path from typeguard import check_argument_types @@ -8,14 +9,15 @@ from funasr.bin.asr_inference_paraformer import Speech2Text from funasr.export.models import get_model import numpy as np import random class ASRModelExportParaformer: def __init__(self, cache_dir: Union[Path, str] = None, onnx: bool = True): assert check_argument_types() self.set_all_random_seed(0) if cache_dir is None: cache_dir = Path.home() / "cache" / "export" cache_dir = Path.home() / ".cache" / "export" self.cache_dir = Path(cache_dir) self.export_config = dict( @@ -24,8 +26,9 @@ ) logging.info("output dir: {}".format(self.cache_dir)) self.onnx = onnx def export( def _export( self, model: Speech2Text, tag_name: str = None, @@ -60,38 +63,38 @@ model_script = torch.jit.trace(model, dummy_input) model_script.save(os.path.join(path, f'{model.model_name}.torchscripts')) def export_from_modelscope( self, tag_name: str = 'damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch', ): def set_all_random_seed(self, seed: int): random.seed(seed) np.random.seed(seed) torch.random.manual_seed(seed) def export(self, tag_name: str = 'damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch', mode: str = 'paraformer', ): from funasr.tasks.asr import ASRTaskParaformer as ASRTask from modelscope.hub.snapshot_download import snapshot_download model_dir = snapshot_download(tag_name, cache_dir=self.cache_dir) asr_train_config = os.path.join(model_dir, 'config.yaml') asr_model_file = os.path.join(model_dir, 'model.pb') cmvn_file = os.path.join(model_dir, 'am.mvn') model, asr_train_args = ASRTask.build_model_from_file( asr_train_config, asr_model_file, cmvn_file, 'cpu' ) self.export(model, tag_name) def export_from_local( self, tag_name: str = '/root/cache/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch', ): from funasr.tasks.asr import ASRTaskParaformer as ASRTask model_dir = tag_name if model_dir.startswith('damo/'): from modelscope.hub.snapshot_download import snapshot_download model_dir = snapshot_download(model_dir, cache_dir=self.cache_dir) asr_train_config = os.path.join(model_dir, 'config.yaml') asr_model_file = os.path.join(model_dir, 'model.pb') cmvn_file = os.path.join(model_dir, 'am.mvn') json_file = os.path.join(model_dir, 'configuration.json') if mode is None: import json with open(json_file, 'r') as f: config_data = json.load(f) mode = config_data['model']['model_config']['mode'] if mode == 'paraformer': from funasr.tasks.asr import ASRTaskParaformer as ASRTask elif mode == 'uniasr': from funasr.tasks.asr import ASRTaskUniASR as ASRTask model, asr_train_args = ASRTask.build_model_from_file( asr_train_config, asr_model_file, cmvn_file, 'cpu' ) self.export(model, tag_name) self._export(model, tag_name) def _export_onnx(self, model, verbose, path, enc_size=None): if enc_size: @@ -116,5 +119,5 @@ if __name__ == '__main__': output_dir = "../export" export_model = ASRModelExportParaformer(cache_dir=output_dir, onnx=False) export_model.export_from_modelscope('damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch') # export_model.export_from_local('/root/cache/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch') export_model.export('damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch') # export_model.export('/root/cache/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch') funasr/export/models/__init__.py
@@ -1,42 +1,3 @@ # from .ctc import CTC # from .joint_network import JointNetwork # # # encoder # from espnet2.asr.encoder.rnn_encoder import RNNEncoder as espnetRNNEncoder # from espnet2.asr.encoder.vgg_rnn_encoder import VGGRNNEncoder as espnetVGGRNNEncoder # from espnet2.asr.encoder.contextual_block_transformer_encoder import ContextualBlockTransformerEncoder as espnetContextualTransformer # from espnet2.asr.encoder.contextual_block_conformer_encoder import ContextualBlockConformerEncoder as espnetContextualConformer # from espnet2.asr.encoder.transformer_encoder import TransformerEncoder as espnetTransformerEncoder # from espnet2.asr.encoder.conformer_encoder import ConformerEncoder as espnetConformerEncoder # from funasr.export.models.encoder.rnn import RNNEncoder # from funasr.export.models.encoders import TransformerEncoder # from funasr.export.models.encoders import ConformerEncoder # from funasr.export.models.encoder.contextual_block_xformer import ContextualBlockXformerEncoder # # # decoder # from espnet2.asr.decoder.rnn_decoder import RNNDecoder as espnetRNNDecoder # from espnet2.asr.transducer.transducer_decoder import TransducerDecoder as espnetTransducerDecoder # from funasr.export.models.decoder.rnn import ( # RNNDecoder # ) # from funasr.export.models.decoders import XformerDecoder # from funasr.export.models.decoders import TransducerDecoder # # # lm # from espnet2.lm.seq_rnn_lm import SequentialRNNLM as espnetSequentialRNNLM # from espnet2.lm.transformer_lm import TransformerLM as espnetTransformerLM # from .language_models.seq_rnn import SequentialRNNLM # from .language_models.transformer import TransformerLM # # # frontend # from espnet2.asr.frontend.s3prl import S3prlFrontend as espnetS3PRLModel # from .frontends.s3prl import S3PRLModel # # from espnet2.asr.encoder.sanm_encoder import SANMEncoder_tf, SANMEncoderChunkOpt_tf # from espnet_onnx.export.asr.models.encoders.transformer_sanm import TransformerEncoderSANM_tf # from espnet2.asr.decoder.transformer_decoder import FsmnDecoderSCAMAOpt_tf # from funasr.export.models.decoders import XformerDecoderSANM from funasr.models.e2e_asr_paraformer import Paraformer from funasr.export.models.e2e_asr_paraformer import Paraformer as Paraformer_export @@ -45,47 +6,4 @@ if isinstance(model, Paraformer): return Paraformer_export(model, **export_config) else: raise "The model is not exist!" # def get_encoder(model, frontend, preencoder, predictor=None, export_config=None): # if isinstance(model, espnetRNNEncoder) or isinstance(model, espnetVGGRNNEncoder): # return RNNEncoder(model, frontend, preencoder, **export_config) # elif isinstance(model, espnetContextualTransformer) or isinstance(model, espnetContextualConformer): # return ContextualBlockXformerEncoder(model, **export_config) # elif isinstance(model, espnetTransformerEncoder): # return TransformerEncoder(model, frontend, preencoder, **export_config) # elif isinstance(model, espnetConformerEncoder): # return ConformerEncoder(model, frontend, preencoder, **export_config) # elif isinstance(model, SANMEncoder_tf) or isinstance(model, SANMEncoderChunkOpt_tf): # return TransformerEncoderSANM_tf(model, frontend, preencoder, predictor, **export_config) # else: # raise "The model is not exist!" # # def get_decoder(model, export_config): # if isinstance(model, espnetRNNDecoder): # return RNNDecoder(model, **export_config) # elif isinstance(model, espnetTransducerDecoder): # return TransducerDecoder(model, **export_config) # elif isinstance(model, FsmnDecoderSCAMAOpt_tf): # return XformerDecoderSANM(model, **export_config) # else: # return XformerDecoder(model, **export_config) # # # def get_lm(model, export_config): # if isinstance(model, espnetSequentialRNNLM): # return SequentialRNNLM(model, **export_config) # elif isinstance(model, espnetTransformerLM): # return TransformerLM(model, **export_config) # # # def get_frontend_models(model, export_config): # if isinstance(model, espnetS3PRLModel): # return S3PRLModel(model, **export_config) # else: # return None # raise "The model is not exist!"