| | |
| | | |
| | | from funasr.bin.asr_inference_paraformer import Speech2Text |
| | | from funasr.export.models import get_model |
| | | |
| | | |
| | | import numpy as np |
| | | import random |
| | | |
| | | class ASRModelExportParaformer: |
| | | def __init__(self, cache_dir: Union[Path, str] = None, onnx: bool = True): |
| | | assert check_argument_types() |
| | | self.set_all_random_seed(0) |
| | | if cache_dir is None: |
| | | cache_dir = Path.home() / "cache" / "export" |
| | | cache_dir = Path.home() / ".cache" / "export" |
| | | |
| | | self.cache_dir = Path(cache_dir) |
| | | self.export_config = dict( |
| | |
| | | ) |
| | | logging.info("output dir: {}".format(self.cache_dir)) |
| | | self.onnx = onnx |
| | | |
| | | |
| | | def _export( |
| | | self, |
| | |
| | | model_script = torch.jit.trace(model, dummy_input) |
| | | model_script.save(os.path.join(path, f'{model.model_name}.torchscripts')) |
| | | |
| | | def set_all_random_seed(self, seed: int): |
| | | random.seed(seed) |
| | | np.random.seed(seed) |
| | | torch.random.manual_seed(seed) |
| | | def export(self, |
| | | tag_name: str = 'damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch', |
| | | mode: str = 'paraformer', |
| | |
| | | ) |
| | | self._export(model, tag_name) |
| | | |
| | | # def export_from_modelscope( |
| | | # self, |
| | | # tag_name: str = 'damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch', |
| | | # ): |
| | | # |
| | | # from funasr.tasks.asr import ASRTaskParaformer as ASRTask |
| | | # from modelscope.hub.snapshot_download import snapshot_download |
| | | # |
| | | # model_dir = snapshot_download(tag_name, cache_dir=self.cache_dir) |
| | | # asr_train_config = os.path.join(model_dir, 'config.yaml') |
| | | # asr_model_file = os.path.join(model_dir, 'model.pb') |
| | | # cmvn_file = os.path.join(model_dir, 'am.mvn') |
| | | # model, asr_train_args = ASRTask.build_model_from_file( |
| | | # asr_train_config, asr_model_file, cmvn_file, 'cpu' |
| | | # ) |
| | | # self.export(model, tag_name) |
| | | # |
| | | # def export_from_local( |
| | | # self, |
| | | # tag_name: str = '/root/cache/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch', |
| | | # ): |
| | | # |
| | | # from funasr.tasks.asr import ASRTaskParaformer as ASRTask |
| | | # |
| | | # model_dir = tag_name |
| | | # asr_train_config = os.path.join(model_dir, 'config.yaml') |
| | | # asr_model_file = os.path.join(model_dir, 'model.pb') |
| | | # cmvn_file = os.path.join(model_dir, 'am.mvn') |
| | | # model, asr_train_args = ASRTask.build_model_from_file( |
| | | # asr_train_config, asr_model_file, cmvn_file, 'cpu' |
| | | # ) |
| | | # self.export(model, tag_name) |
| | | |
| | | def _export_onnx(self, model, verbose, path, enc_size=None): |
| | | if enc_size: |