游雁
2023-03-31 d0cd484fdc21c06b8bc892bb2ab1c2a25fb1da8a
funasr/export/export_model.py
@@ -174,7 +174,10 @@
            json_file = os.path.join(model_dir, 'configuration.json')
            with open(json_file, 'r') as f:
                config_data = json.load(f)
                mode = config_data['model']['model_config']['mode']
                if config_data['task'] == "punctuation":
                    mode = config_data['model']['punc_model_config']['mode']
                else:
                    mode = config_data['model']['model_config']['mode']
        if mode.startswith('paraformer'):
            from funasr.tasks.asr import ASRTaskParaformer as ASRTask
            config = os.path.join(model_dir, 'config.yaml')
@@ -191,9 +194,24 @@
            cmvn_file = os.path.join(model_dir, 'vad.mvn')
            
            model, vad_infer_args = VADTask.build_model_from_file(
                config, model_file, 'cpu'
                config, model_file, cmvn_file=cmvn_file, device='cpu'
            )
            self.export_config["feats_dim"] = 400
            self.frontend = model.frontend
        elif mode.startswith('punc'):
            from funasr.tasks.punctuation import PunctuationTask as PUNCTask
            punc_train_config = os.path.join(model_dir, 'config.yaml')
            punc_model_file = os.path.join(model_dir, 'punc.pb')
            model, punc_train_args = PUNCTask.build_model_from_file(
                punc_train_config, punc_model_file, 'cpu'
            )
        elif mode.startswith('punc_VadRealtime'):
            from funasr.tasks.punctuation import PunctuationTask as PUNCTask
            punc_train_config = os.path.join(model_dir, 'config.yaml')
            punc_model_file = os.path.join(model_dir, 'punc.pb')
            model, punc_train_args = PUNCTask.build_model_from_file(
                punc_train_config, punc_model_file, 'cpu'
            )
        self._export(model, tag_name)