From 4ace5a95b052d338947fc88809a440ccd55cf6b4 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 16 十一月 2023 16:39:52 +0800
Subject: [PATCH] funasr pages

---
 funasr/export/export_model.py |  149 ++++++++++++++++++++++++++++++++-----------------
 1 files changed, 97 insertions(+), 52 deletions(-)

diff --git a/funasr/export/export_model.py b/funasr/export/export_model.py
index b1161cb..6ab9408 100644
--- a/funasr/export/export_model.py
+++ b/funasr/export/export_model.py
@@ -1,16 +1,12 @@
-import json
-from typing import Union, Dict
-from pathlib import Path
-from typeguard import check_argument_types
-
 import os
-import logging
 import torch
-
-from funasr.export.models import get_model
-import numpy as np
 import random
-from funasr.utils.types import str2bool
+import logging
+import numpy as np
+from pathlib import Path
+from typing import Union, Dict, List
+from funasr.export.models import get_model
+from funasr.utils.types import str2bool, str2triple_str
 # torch_version = float(".".join(torch.__version__.split(".")[:2]))
 # assert torch_version > 1.9
 
@@ -19,28 +15,29 @@
         self,
         cache_dir: Union[Path, str] = None,
         onnx: bool = True,
+        device: str = "cpu",
         quant: bool = True,
         fallback_num: int = 0,
         audio_in: str = None,
         calib_num: int = 200,
+        model_revision: str = None,
     ):
-        assert check_argument_types()
         self.set_all_random_seed(0)
-        if cache_dir is None:
-            cache_dir = Path.home() / ".cache" / "export"
 
-        self.cache_dir = Path(cache_dir)
+        self.cache_dir = cache_dir
         self.export_config = dict(
             feats_dim=560,
             onnx=False,
         )
-        print("output dir: {}".format(self.cache_dir))
+        
         self.onnx = onnx
+        self.device = device
         self.quant = quant
         self.fallback_num = fallback_num
         self.frontend = None
         self.audio_in = audio_in
         self.calib_num = calib_num
+        self.model_revision = model_revision
         
 
     def _export(
@@ -50,7 +47,7 @@
         verbose: bool = False,
     ):
 
-        export_dir = self.cache_dir / tag_name.replace(' ', '-')
+        export_dir = self.cache_dir
         os.makedirs(export_dir, exist_ok=True)
 
         # export encoder1
@@ -59,14 +56,22 @@
             model,
             self.export_config,
         )
-        model.eval()
-        # self._export_onnx(model, verbose, export_dir)
-        if self.onnx:
-            self._export_onnx(model, verbose, export_dir)
+        if isinstance(model, List):
+            for m in model:
+                m.eval()
+                if self.onnx:
+                    self._export_onnx(m, verbose, export_dir)
+                else:
+                    self._export_torchscripts(m, verbose, export_dir)
+                print("output dir: {}".format(export_dir))
         else:
-            self._export_torchscripts(model, verbose, export_dir)
-
-        print("output dir: {}".format(export_dir))
+            model.eval()
+            # self._export_onnx(model, verbose, export_dir)
+            if self.onnx:
+                self._export_onnx(model, verbose, export_dir)
+            else:
+                self._export_torchscripts(model, verbose, export_dir)
+            print("output dir: {}".format(export_dir))
 
 
     def _torch_quantize(self, model):
@@ -111,6 +116,10 @@
             dummy_input = model.get_dummy_inputs(enc_size)
         else:
             dummy_input = model.get_dummy_inputs()
+
+        if self.device == 'cuda':
+            model = model.cuda()
+            dummy_input = tuple([i.cuda() for i in dummy_input])
 
         # model_script = torch.jit.script(model)
         model_script = torch.jit.trace(model, dummy_input)
@@ -161,31 +170,59 @@
     
     def export(self,
                tag_name: str = 'damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch',
-               mode: str = 'paraformer',
+               mode: str = None,
                ):
         
         model_dir = tag_name
-        if model_dir.startswith('damo/'):
+        if model_dir.startswith('damo'):
             from modelscope.hub.snapshot_download import snapshot_download
-            model_dir = snapshot_download(model_dir, cache_dir=self.cache_dir)
-        asr_train_config = os.path.join(model_dir, 'config.yaml')
-        asr_model_file = os.path.join(model_dir, 'model.pb')
-        cmvn_file = os.path.join(model_dir, 'am.mvn')
-        json_file = os.path.join(model_dir, 'configuration.json')
+            model_dir = snapshot_download(model_dir, cache_dir=self.cache_dir, revision=self.model_revision)
+        self.cache_dir = model_dir
+
         if mode is None:
             import json
+            json_file = os.path.join(model_dir, 'configuration.json')
             with open(json_file, 'r') as f:
                 config_data = json.load(f)
-                mode = config_data['model']['model_config']['mode']
+                if config_data['task'] == "punctuation":
+                    mode = config_data['model']['punc_model_config']['mode']
+                else:
+                    mode = config_data['model']['model_config']['mode']
         if mode.startswith('paraformer'):
             from funasr.tasks.asr import ASRTaskParaformer as ASRTask
-        elif mode.startswith('uniasr'):
-            from funasr.tasks.asr import ASRTaskUniASR as ASRTask
+            config = os.path.join(model_dir, 'config.yaml')
+            model_file = os.path.join(model_dir, 'model.pb')
+            cmvn_file = os.path.join(model_dir, 'am.mvn')
+            model, asr_train_args = ASRTask.build_model_from_file(
+                config, model_file, cmvn_file, 'cpu'
+            )
+            self.frontend = model.frontend
+            self.export_config["feats_dim"] = 560
+        elif mode.startswith('offline'):
+            from funasr.tasks.vad import VADTask
+            config = os.path.join(model_dir, 'vad.yaml')
+            model_file = os.path.join(model_dir, 'vad.pb')
+            cmvn_file = os.path.join(model_dir, 'vad.mvn')
             
-        model, asr_train_args = ASRTask.build_model_from_file(
-            asr_train_config, asr_model_file, cmvn_file, 'cpu'
-        )
-        self.frontend = model.frontend
+            model, vad_infer_args = VADTask.build_model_from_file(
+                config, model_file, cmvn_file=cmvn_file, device='cpu'
+            )
+            self.export_config["feats_dim"] = 400
+            self.frontend = model.frontend
+        elif mode.startswith('punc'):
+            from funasr.tasks.punctuation import PunctuationTask as PUNCTask
+            punc_train_config = os.path.join(model_dir, 'config.yaml')
+            punc_model_file = os.path.join(model_dir, 'punc.pb')
+            model, punc_train_args = PUNCTask.build_model_from_file(
+                punc_train_config, punc_model_file, 'cpu'
+            )
+        elif mode.startswith('punc_VadRealtime'):
+            from funasr.tasks.punctuation import PunctuationTask as PUNCTask
+            punc_train_config = os.path.join(model_dir, 'config.yaml')
+            punc_model_file = os.path.join(model_dir, 'punc.pb')
+            model, punc_train_args = PUNCTask.build_model_from_file(
+                punc_train_config, punc_model_file, 'cpu'
+            )
         self._export(model, tag_name)
             
 
@@ -198,7 +235,7 @@
         # model_script = torch.jit.script(model)
         model_script = model #torch.jit.trace(model)
         model_path = os.path.join(path, f'{model.model_name}.onnx')
-
+        # if not os.path.exists(model_path):
         torch.onnx.export(
             model_script,
             dummy_input,
@@ -214,38 +251,46 @@
             from onnxruntime.quantization import QuantType, quantize_dynamic
             import onnx
             quant_model_path = os.path.join(path, f'{model.model_name}_quant.onnx')
-            onnx_model = onnx.load(model_path)
-            nodes = [n.name for n in onnx_model.graph.node]
-            nodes_to_exclude = [m for m in nodes if 'output' in m]
-            quantize_dynamic(
-                model_input=model_path,
-                model_output=quant_model_path,
-                op_types_to_quantize=['MatMul'],
-                per_channel=True,
-                reduce_range=False,
-                weight_type=QuantType.QUInt8,
-                nodes_to_exclude=nodes_to_exclude,
-            )
+            if not os.path.exists(quant_model_path):
+                onnx_model = onnx.load(model_path)
+                nodes = [n.name for n in onnx_model.graph.node]
+                nodes_to_exclude = [m for m in nodes if 'output' in m or 'bias_encoder' in m  or 'bias_decoder' in m]
+                quantize_dynamic(
+                    model_input=model_path,
+                    model_output=quant_model_path,
+                    op_types_to_quantize=['MatMul'],
+                    per_channel=True,
+                    reduce_range=False,
+                    weight_type=QuantType.QUInt8,
+                    nodes_to_exclude=nodes_to_exclude,
+                )
 
 
 if __name__ == '__main__':
     import argparse
     parser = argparse.ArgumentParser()
-    parser.add_argument('--model-name', type=str, required=True)
+    # parser.add_argument('--model-name', type=str, required=True)
+    parser.add_argument('--model-name', type=str, action="append", required=True, default=[])
     parser.add_argument('--export-dir', type=str, required=True)
     parser.add_argument('--type', type=str, default='onnx', help='["onnx", "torch"]')
+    parser.add_argument('--device', type=str, default='cpu', help='["cpu", "cuda"]')
     parser.add_argument('--quantize', type=str2bool, default=False, help='export quantized model')
     parser.add_argument('--fallback-num', type=int, default=0, help='amp fallback number')
     parser.add_argument('--audio_in', type=str, default=None, help='["wav", "wav.scp"]')
     parser.add_argument('--calib_num', type=int, default=200, help='calib max num')
+    parser.add_argument('--model_revision', type=str, default=None, help='model_revision')
     args = parser.parse_args()
 
     export_model = ModelExport(
         cache_dir=args.export_dir,
         onnx=args.type == 'onnx',
+        device=args.device,
         quant=args.quantize,
         fallback_num=args.fallback_num,
         audio_in=args.audio_in,
         calib_num=args.calib_num,
+        model_revision=args.model_revision,
     )
-    export_model.export(args.model_name)
+    for model_name in args.model_name:
+        print("export model: {}".format(model_name))
+        export_model.export(model_name)

--
Gitblit v1.9.1