From b6b63936c7f4320b30b5f907514f6e8d39ed7239 Mon Sep 17 00:00:00 2001
From: mengzhe.cmz <mengzhe.cmz@alibaba-inc.com>
Date: 星期二, 18 七月 2023 17:32:38 +0800
Subject: [PATCH] add punc large model modelscope runtime; fix train bug

---
 funasr/bin/punc_infer.py |    8 ++++++++
 1 files changed, 8 insertions(+), 0 deletions(-)

diff --git a/funasr/bin/punc_infer.py b/funasr/bin/punc_infer.py
index ac96811..7b61717 100644
--- a/funasr/bin/punc_infer.py
+++ b/funasr/bin/punc_infer.py
@@ -8,6 +8,7 @@
 
 import numpy as np
 import torch
+import os
 
 from funasr.build_utils.build_model_from_file import build_model_from_file
 from funasr.datasets.preprocessor import CodeMixTokenizerCommonPreprocessor
@@ -41,6 +42,11 @@
                 self.punc_list[i] = "锛�"
             elif self.punc_list[i] == "銆�":
                 self.period = i
+        self.seg_dict_file = None
+        self.seg_jieba = False
+        if "seg_jieba" in train_args:
+            self.seg_jieba = train_args.seg_jieba
+            self.seg_dict_file = os.path.dirname(model_file)+"/"+ "jieba_usr_dict"
         self.preprocessor = CodeMixTokenizerCommonPreprocessor(
             train=False,
             token_type=train_args.token_type,
@@ -50,6 +56,8 @@
             g2p_type=train_args.g2p,
             text_name="text",
             non_linguistic_symbols=train_args.non_linguistic_symbols,
+            seg_jieba=self.seg_jieba,
+            seg_dict_file=self.seg_dict_file
         )
 
     @torch.no_grad()

--
Gitblit v1.9.1