From c2e4e3c2e9be855277d9f4fa9cd0544892ff829a Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 30 八月 2023 09:57:30 +0800
Subject: [PATCH] Merge branch 'main' of github.com:alibaba-damo-academy/FunASR add

---
 funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py |   48 ++++++++++++++++++++++++++++++++++++++----------
 1 files changed, 38 insertions(+), 10 deletions(-)

diff --git a/funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py b/funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py
index 6fd01e4..cc5daa8 100644
--- a/funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py
+++ b/funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py
@@ -10,7 +10,7 @@
 from .utils.utils import (ONNXRuntimeError,
                           OrtInferSession, get_logger,
                           read_yaml)
-from .utils.utils import (TokenIDConverter, split_to_mini_sentence,code_mix_split_words)
+from .utils.utils import (TokenIDConverter, split_to_mini_sentence,code_mix_split_words,code_mix_split_words_jieba)
 logging = get_logger()
 
 
@@ -24,15 +24,32 @@
                  batch_size: int = 1,
                  device_id: Union[str, int] = "-1",
                  quantize: bool = False,
-                 intra_op_num_threads: int = 4
+                 intra_op_num_threads: int = 4,
+                 cache_dir: str = None,
                  ):
-
+    
         if not Path(model_dir).exists():
-            raise FileNotFoundError(f'{model_dir} does not exist.')
-
+            from modelscope.hub.snapshot_download import snapshot_download
+            try:
+                model_dir = snapshot_download(model_dir, cache_dir=cache_dir)
+            except:
+                raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format(
+                    model_dir)
+    
         model_file = os.path.join(model_dir, 'model.onnx')
         if quantize:
             model_file = os.path.join(model_dir, 'model_quant.onnx')
+        if not os.path.exists(model_file):
+            print(".onnx is not exist, begin to export onnx")
+            from funasr.export.export_model import ModelExport
+            export_model = ModelExport(
+                cache_dir=cache_dir,
+                onnx=True,
+                device="cpu",
+                quant=quantize,
+            )
+            export_model.export(model_dir)
+            
         config_file = os.path.join(model_dir, 'punc.yaml')
         config = read_yaml(config_file)
 
@@ -48,9 +65,18 @@
                 self.punc_list[i] = "锛�"
             elif self.punc_list[i] == "銆�":
                 self.period = i
+        if "seg_jieba" in config:
+            self.seg_jieba = True
+            self.jieba_usr_dict_path = os.path.join(model_dir, 'jieba_usr_dict')
+            self.code_mix_split_words_jieba = code_mix_split_words_jieba(self.jieba_usr_dict_path)
+        else:
+            self.seg_jieba = False
 
     def __call__(self, text: Union[list, str], split_size=20):
-        split_text = code_mix_split_words(text)
+        if self.seg_jieba:
+            split_text = self.code_mix_split_words_jieba(text)
+        else:
+            split_text = code_mix_split_words(text)
         split_text_id = self.converter.tokens2ids(split_text)
         mini_sentences = split_to_mini_sentence(split_text, split_size)
         mini_sentences_id = split_to_mini_sentence(split_text_id, split_size)
@@ -135,9 +161,10 @@
                  batch_size: int = 1,
                  device_id: Union[str, int] = "-1",
                  quantize: bool = False,
-                 intra_op_num_threads: int = 4
+                 intra_op_num_threads: int = 4,
+                 cache_dir: str = None
                  ):
-        super(CT_Transformer_VadRealtime, self).__init__(model_dir, batch_size, device_id, quantize, intra_op_num_threads)
+        super(CT_Transformer_VadRealtime, self).__init__(model_dir, batch_size, device_id, quantize, intra_op_num_threads, cache_dir=cache_dir)
 
     def __call__(self, text: str, param_dict: map, split_size=20):
         cache_key = "cache"
@@ -168,11 +195,12 @@
             mini_sentence = cache_sent + mini_sentence
             mini_sentence_id = np.concatenate((cache_sent_id, mini_sentence_id), axis=0,dtype='int32')
             text_length = len(mini_sentence_id)
+            vad_mask = self.vad_mask(text_length, len(cache))[None, None, :, :].astype(np.float32)
             data = {
                 "input": mini_sentence_id[None,:],
                 "text_lengths": np.array([text_length], dtype='int32'),
-                "vad_mask": self.vad_mask(text_length, len(cache))[None, None, :, :].astype(np.float32),
-                "sub_masks": np.tril(np.ones((text_length, text_length), dtype=np.float32))[None, None, :, :].astype(np.float32)
+                "vad_mask": vad_mask,
+                "sub_masks": vad_mask
             }
             try:
                 outputs = self.infer(data['input'], data['text_lengths'], data['vad_mask'], data["sub_masks"])

--
Gitblit v1.9.1