From 9fcb3cc06b4e324f0913d2f61b89becc2baeef1b Mon Sep 17 00:00:00 2001
From: hnluo <haoneng.lhn@alibaba-inc.com>
Date: 星期一, 11 九月 2023 17:40:03 +0800
Subject: [PATCH] Merge pull request #932 from alibaba-damo-academy/dev_lhn
---
funasr/bin/punc_infer.py | 8 ++++++++
1 files changed, 8 insertions(+), 0 deletions(-)
diff --git a/funasr/bin/punc_infer.py b/funasr/bin/punc_infer.py
index ac96811..7b61717 100644
--- a/funasr/bin/punc_infer.py
+++ b/funasr/bin/punc_infer.py
@@ -8,6 +8,7 @@
import numpy as np
import torch
+import os
from funasr.build_utils.build_model_from_file import build_model_from_file
from funasr.datasets.preprocessor import CodeMixTokenizerCommonPreprocessor
@@ -41,6 +42,11 @@
self.punc_list[i] = "锛�"
elif self.punc_list[i] == "銆�":
self.period = i
+ self.seg_dict_file = None
+ self.seg_jieba = False
+ if "seg_jieba" in train_args:
+ self.seg_jieba = train_args.seg_jieba
+ self.seg_dict_file = os.path.dirname(model_file)+"/"+ "jieba_usr_dict"
self.preprocessor = CodeMixTokenizerCommonPreprocessor(
train=False,
token_type=train_args.token_type,
@@ -50,6 +56,8 @@
g2p_type=train_args.g2p,
text_name="text",
non_linguistic_symbols=train_args.non_linguistic_symbols,
+ seg_jieba=self.seg_jieba,
+ seg_dict_file=self.seg_dict_file
)
@torch.no_grad()
--
Gitblit v1.9.1