| | |
| | | json_file = os.path.join(model_dir, 'configuration.json') |
| | | with open(json_file, 'r') as f: |
| | | config_data = json.load(f) |
| | | mode = config_data['model']['model_config']['mode'] |
| | | if config_data['task'] == "punctuation": |
| | | mode = config_data['model']['punc_model_config']['mode'] |
| | | else: |
| | | mode = config_data['model']['model_config']['mode'] |
| | | if mode.startswith('paraformer'): |
| | | from funasr.tasks.asr import ASRTaskParaformer as ASRTask |
| | | config = os.path.join(model_dir, 'config.yaml') |
| | |
| | | ) |
| | | self.export_config["feats_dim"] = 400 |
| | | self.frontend = model.frontend |
| | | elif mode.startswith('punc'): |
| | | from funasr.tasks.punctuation import PunctuationTask as PUNCTask |
| | | punc_train_config = os.path.join(model_dir, 'config.yaml') |
| | | punc_model_file = os.path.join(model_dir, 'punc.pb') |
| | | model, punc_train_args = PUNCTask.build_model_from_file( |
| | | punc_train_config, punc_model_file, 'cpu' |
| | | ) |
| | | self._export(model, tag_name) |
| | | |
| | | |