| | |
| | | |
| | | def export(self, |
| | | tag_name: str = 'damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch', |
| | | mode: str = 'paraformer', |
| | | mode: str = None, |
| | | ): |
| | | |
| | | model_dir = tag_name |
| | | if model_dir.startswith('damo/'): |
| | | if model_dir.startswith('damo'): |
| | | from modelscope.hub.snapshot_download import snapshot_download |
| | | model_dir = snapshot_download(model_dir, cache_dir=self.cache_dir) |
| | | asr_train_config = os.path.join(model_dir, 'config.yaml') |
| | | asr_model_file = os.path.join(model_dir, 'model.pb') |
| | | cmvn_file = os.path.join(model_dir, 'am.mvn') |
| | | json_file = os.path.join(model_dir, 'configuration.json') |
| | | |
| | | if mode is None: |
| | | import json |
| | | json_file = os.path.join(model_dir, 'configuration.json') |
| | | with open(json_file, 'r') as f: |
| | | config_data = json.load(f) |
| | | mode = config_data['model']['model_config']['mode'] |
| | | if config_data['task'] == "punctuation": |
| | | mode = config_data['model']['punc_model_config']['mode'] |
| | | else: |
| | | mode = config_data['model']['model_config']['mode'] |
| | | if mode.startswith('paraformer'): |
| | | from funasr.tasks.asr import ASRTaskParaformer as ASRTask |
| | | elif mode.startswith('uniasr'): |
| | | from funasr.tasks.asr import ASRTaskUniASR as ASRTask |
| | | config = os.path.join(model_dir, 'config.yaml') |
| | | model_file = os.path.join(model_dir, 'model.pb') |
| | | cmvn_file = os.path.join(model_dir, 'am.mvn') |
| | | model, asr_train_args = ASRTask.build_model_from_file( |
| | | config, model_file, cmvn_file, 'cpu' |
| | | ) |
| | | self.frontend = model.frontend |
| | | elif mode.startswith('offline'): |
| | | from funasr.tasks.vad import VADTask |
| | | config = os.path.join(model_dir, 'vad.yaml') |
| | | model_file = os.path.join(model_dir, 'vad.pb') |
| | | cmvn_file = os.path.join(model_dir, 'vad.mvn') |
| | | |
| | | model, asr_train_args = ASRTask.build_model_from_file( |
| | | asr_train_config, asr_model_file, cmvn_file, 'cpu' |
| | | ) |
| | | self.frontend = model.frontend |
| | | model, vad_infer_args = VADTask.build_model_from_file( |
| | | config, model_file, cmvn_file=cmvn_file, device='cpu' |
| | | ) |
| | | self.export_config["feats_dim"] = 400 |
| | | self.frontend = model.frontend |
| | | elif mode.startswith('punc'): |
| | | from funasr.tasks.punctuation import PunctuationTask as PUNCTask |
| | | punc_train_config = os.path.join(model_dir, 'config.yaml') |
| | | punc_model_file = os.path.join(model_dir, 'punc.pb') |
| | | model, punc_train_args = PUNCTask.build_model_from_file( |
| | | punc_train_config, punc_model_file, 'cpu' |
| | | ) |
| | | elif mode.startswith('punc_VadRealtime'): |
| | | from funasr.tasks.punctuation import PunctuationTask as PUNCTask |
| | | punc_train_config = os.path.join(model_dir, 'config.yaml') |
| | | punc_model_file = os.path.join(model_dir, 'punc.pb') |
| | | model, punc_train_args = PUNCTask.build_model_from_file( |
| | | punc_train_config, punc_model_file, 'cpu' |
| | | ) |
| | | self._export(model, tag_name) |
| | | |
| | | |