mengzhe.cmz
2023-07-18 b6b63936c7f4320b30b5f907514f6e8d39ed7239
funasr/bin/punc_infer.py
@@ -8,6 +8,7 @@
import numpy as np
import torch
import os
from funasr.build_utils.build_model_from_file import build_model_from_file
from funasr.datasets.preprocessor import CodeMixTokenizerCommonPreprocessor
@@ -41,6 +42,11 @@
                self.punc_list[i] = "?"
            elif self.punc_list[i] == "。":
                self.period = i
        self.seg_dict_file = None
        self.seg_jieba = False
        if "seg_jieba" in train_args:
            self.seg_jieba = train_args.seg_jieba
            self.seg_dict_file = os.path.dirname(model_file)+"/"+ "jieba_usr_dict"
        self.preprocessor = CodeMixTokenizerCommonPreprocessor(
            train=False,
            token_type=train_args.token_type,
@@ -50,6 +56,8 @@
            g2p_type=train_args.g2p,
            text_name="text",
            non_linguistic_symbols=train_args.non_linguistic_symbols,
            seg_jieba=self.seg_jieba,
            seg_dict_file=self.seg_dict_file
        )
    @torch.no_grad()