chenmengzheAAA
2023-09-14 30c40c643c19f6e2ac8679fa76d09d0f9ceccc65
funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py
@@ -10,7 +10,7 @@
from .utils.utils import (ONNXRuntimeError,
                          OrtInferSession, get_logger,
                          read_yaml)
from .utils.utils import (TokenIDConverter, split_to_mini_sentence,code_mix_split_words)
from .utils.utils import (TokenIDConverter, split_to_mini_sentence,code_mix_split_words,code_mix_split_words_jieba)
logging = get_logger()
@@ -29,7 +29,13 @@
                 ):
    
        if not Path(model_dir).exists():
            from modelscope.hub.snapshot_download import snapshot_download
            try:
                from modelscope.hub.snapshot_download import snapshot_download
            except:
                raise "You are exporting model from modelscope, please install modelscope and try it again. To install modelscope, you could:\n" \
                      "\npip3 install -U modelscope\n" \
                      "For the users in China, you could install with the command:\n" \
                      "\npip3 install -U modelscope -i https://mirror.sjtu.edu.cn/pypi/web/simple"
            try:
                model_dir = snapshot_download(model_dir, cache_dir=cache_dir)
            except:
@@ -41,7 +47,13 @@
            model_file = os.path.join(model_dir, 'model_quant.onnx')
        if not os.path.exists(model_file):
            print(".onnx is not exist, begin to export onnx")
            from funasr.export.export_model import ModelExport
            try:
                from funasr.export.export_model import ModelExport
            except:
                raise "You are exporting onnx, please install funasr and try it again. To install funasr, you could:\n" \
                      "\npip3 install -U funasr\n" \
                      "For the users in China, you could install with the command:\n" \
                      "\npip3 install -U funasr -i https://mirror.sjtu.edu.cn/pypi/web/simple"
            export_model = ModelExport(
                cache_dir=cache_dir,
                onnx=True,
@@ -65,9 +77,18 @@
                self.punc_list[i] = "?"
            elif self.punc_list[i] == "。":
                self.period = i
        if "seg_jieba" in config:
            self.seg_jieba = True
            self.jieba_usr_dict_path = os.path.join(model_dir, 'jieba_usr_dict')
            self.code_mix_split_words_jieba = code_mix_split_words_jieba(self.jieba_usr_dict_path)
        else:
            self.seg_jieba = False
    def __call__(self, text: Union[list, str], split_size=20):
        split_text = code_mix_split_words(text)
        if self.seg_jieba:
            split_text = self.code_mix_split_words_jieba(text)
        else:
            split_text = code_mix_split_words(text)
        split_text_id = self.converter.tokens2ids(split_text)
        mini_sentences = split_to_mini_sentence(split_text, split_size)
        mini_sentences_id = split_to_mini_sentence(split_text_id, split_size)
@@ -186,11 +207,12 @@
            mini_sentence = cache_sent + mini_sentence
            mini_sentence_id = np.concatenate((cache_sent_id, mini_sentence_id), axis=0,dtype='int32')
            text_length = len(mini_sentence_id)
            vad_mask = self.vad_mask(text_length, len(cache))[None, None, :, :].astype(np.float32)
            data = {
                "input": mini_sentence_id[None,:],
                "text_lengths": np.array([text_length], dtype='int32'),
                "vad_mask": self.vad_mask(text_length, len(cache))[None, None, :, :].astype(np.float32),
                "sub_masks": np.tril(np.ones((text_length, text_length), dtype=np.float32))[None, None, :, :].astype(np.float32)
                "vad_mask": vad_mask,
                "sub_masks": vad_mask
            }
            try:
                outputs = self.infer(data['input'], data['text_lengths'], data['vad_mask'], data["sub_masks"])