zhifu gao
2023-02-28 a016617c7ec98ab9c7475ff7d3b6150b98d5beeb
funasr/tasks/punctuation.py
@@ -13,10 +13,11 @@
from typeguard import check_return_type
from funasr.datasets.collate_fn import CommonCollateFn
from funasr.datasets.preprocessor import MutliTokenizerCommonPreprocessor
from funasr.datasets.preprocessor import PuncTrainTokenizerCommonPreprocessor
from funasr.punctuation.abs_model import AbsPunctuation
from funasr.punctuation.espnet_model import ESPnetPunctuationModel
from funasr.punctuation.target_delay_transformer import TargetDelayTransformer
from funasr.punctuation.vad_realtime_transformer import VadRealtimeTransformer
from funasr.tasks.abs_task import AbsTask
from funasr.text.phoneme_tokenizer import g2p_choices
from funasr.torch_utils.initialize import initialize
@@ -29,11 +30,9 @@
punc_choices = ClassChoices(
    "punctuation",
    classes=dict(
        target_delay=TargetDelayTransformer,
    ),
    classes=dict(target_delay=TargetDelayTransformer, vad_realtime=VadRealtimeTransformer),
    type_check=AbsPunctuation,
    default="TargetDelayTransformer",
    default="target_delay",
)
@@ -56,8 +55,6 @@
        # NOTE(kamo): add_arguments(..., required=True) can't be used
        # to provide --print_config mode. Instead of it, do as
        required = parser.get_default("required")
        #import pdb;pdb.set_trace()
        #required += ["token_list"]
        group.add_argument(
            "--token_list",
@@ -154,7 +151,7 @@
        bpemodels = [args.bpemodel, args.bpemodel]
        text_names = ["text", "punc"]
        if args.use_preprocessor:
            retval = MutliTokenizerCommonPreprocessor(
            retval = PuncTrainTokenizerCommonPreprocessor(
                train=train,
                token_type=token_types,
                token_list=token_lists,
@@ -182,7 +179,7 @@
    def optional_data_names(
            cls, train: bool = True, inference: bool = False
    ) -> Tuple[str, ...]:
        retval = ()
        retval = ("vad",)
        return retval
    @classmethod
@@ -197,11 +194,13 @@
            args.token_list = token_list.copy()
        if isinstance(args.punc_list, str):
            with open(args.punc_list, encoding="utf-8") as f2:
                punc_list = [line.rstrip() for line in f2]
                pairs = [line.rstrip().split(":") for line in f2]
            punc_list = [pair[0] for pair in pairs]
            punc_weight_list = [float(pair[1]) for pair in pairs]
            args.punc_list = punc_list.copy()
        elif isinstance(args.punc_list, list):
            # This is in the inference code path.
            punc_list = args.punc_list.copy()
            punc_weight_list = [1] * len(punc_list)
        if isinstance(args.token_list, (tuple, list)):
            token_list = args.token_list.copy()
        else:
@@ -217,7 +216,9 @@
        # 2. Build ESPnetModel
        # Assume the last-id is sos_and_eos
        model = ESPnetPunctuationModel(punc_model=punc, vocab_size=vocab_size, **args.model_conf)
        if "punc_weight" in args.model_conf:
            args.model_conf.pop("punc_weight")
        model = ESPnetPunctuationModel(punc_model=punc, vocab_size=vocab_size, punc_weight=punc_weight_list, **args.model_conf)
        # FIXME(kamo): Should be done in model?
        # 3. Initialize