From 7c5fdf30f428e22fd0fdb98055834e0d2616d308 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 29 三月 2023 00:27:11 +0800
Subject: [PATCH] export
---
funasr/export/export_model.py | 36 ++++++++++++++++++++++--------------
1 files changed, 22 insertions(+), 14 deletions(-)
diff --git a/funasr/export/export_model.py b/funasr/export/export_model.py
index f6ba616..cad3367 100644
--- a/funasr/export/export_model.py
+++ b/funasr/export/export_model.py
@@ -14,7 +14,7 @@
# torch_version = float(".".join(torch.__version__.split(".")[:2]))
# assert torch_version > 1.9
-class ASRModelExportParaformer:
+class ModelExport:
def __init__(
self,
cache_dir: Union[Path, str] = None,
@@ -161,31 +161,39 @@
def export(self,
tag_name: str = 'damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch',
- mode: str = 'paraformer',
+ mode: str = None,
):
model_dir = tag_name
- if model_dir.startswith('damo/'):
+ if model_dir.startswith('damo'):
from modelscope.hub.snapshot_download import snapshot_download
model_dir = snapshot_download(model_dir, cache_dir=self.cache_dir)
- asr_train_config = os.path.join(model_dir, 'config.yaml')
- asr_model_file = os.path.join(model_dir, 'model.pb')
- cmvn_file = os.path.join(model_dir, 'am.mvn')
- json_file = os.path.join(model_dir, 'configuration.json')
+
if mode is None:
import json
+ json_file = os.path.join(model_dir, 'configuration.json')
with open(json_file, 'r') as f:
config_data = json.load(f)
mode = config_data['model']['model_config']['mode']
if mode.startswith('paraformer'):
from funasr.tasks.asr import ASRTaskParaformer as ASRTask
- elif mode.startswith('uniasr'):
- from funasr.tasks.asr import ASRTaskUniASR as ASRTask
+ config = os.path.join(model_dir, 'config.yaml')
+ model_file = os.path.join(model_dir, 'model.pb')
+ cmvn_file = os.path.join(model_dir, 'am.mvn')
+ model, asr_train_args = ASRTask.build_model_from_file(
+ config, model_file, cmvn_file, 'cpu'
+ )
+ self.frontend = model.frontend
+ elif mode.startswith('offline'):
+ from funasr.tasks.vad import VADTask
+ config = os.path.join(model_dir, 'vad.yaml')
+ model_file = os.path.join(model_dir, 'vad.pb')
+ cmvn_file = os.path.join(model_dir, 'vad.mvn')
- model, asr_train_args = ASRTask.build_model_from_file(
- asr_train_config, asr_model_file, cmvn_file, 'cpu'
- )
- self.frontend = model.frontend
+ model, vad_infer_args = VADTask.build_model_from_file(
+ config, model_file, 'cpu'
+ )
+ self.export_config["feats_dim"] = 400
self._export(model, tag_name)
@@ -240,7 +248,7 @@
parser.add_argument('--calib_num', type=int, default=200, help='calib max num')
args = parser.parse_args()
- export_model = ASRModelExportParaformer(
+ export_model = ModelExport(
cache_dir=args.export_dir,
onnx=args.type == 'onnx',
quant=args.quantize,
--
Gitblit v1.9.1