| | |
| | | |
| | | from funasr.datasets.collate_fn import CommonCollateFn |
| | | from funasr.datasets.preprocessor import PuncTrainTokenizerCommonPreprocessor |
| | | from funasr.punctuation.abs_model import AbsPunctuation |
| | | from funasr.punctuation.espnet_model import ESPnetPunctuationModel |
| | | from funasr.punctuation.target_delay_transformer import TargetDelayTransformer |
| | | from funasr.punctuation.vad_realtime_transformer import VadRealtimeTransformer |
| | | from funasr.train.abs_model import AbsPunctuation |
| | | from funasr.train.abs_model import PunctuationModel |
| | | from funasr.models.target_delay_transformer import TargetDelayTransformer |
| | | from funasr.models.vad_realtime_transformer import VadRealtimeTransformer |
| | | from funasr.tasks.abs_task import AbsTask |
| | | from funasr.text.phoneme_tokenizer import g2p_choices |
| | | from funasr.torch_utils.initialize import initialize |
| | |
| | | group.add_argument( |
| | | "--model_conf", |
| | | action=NestedDictAction, |
| | | default=get_default_kwargs(ESPnetPunctuationModel), |
| | | default=get_default_kwargs(PunctuationModel), |
| | | help="The keyword arguments for model class.", |
| | | ) |
| | | |
| | |
| | | return retval |
| | | |
| | | @classmethod |
| | | def build_model(cls, args: argparse.Namespace) -> ESPnetPunctuationModel: |
| | | def build_model(cls, args: argparse.Namespace) -> PunctuationModel: |
| | | assert check_argument_types() |
| | | if isinstance(args.token_list, str): |
| | | with open(args.token_list, encoding="utf-8") as f: |
| | |
| | | # Assume the last-id is sos_and_eos |
| | | if "punc_weight" in args.model_conf: |
| | | args.model_conf.pop("punc_weight") |
| | | model = ESPnetPunctuationModel(punc_model=punc, vocab_size=vocab_size, punc_weight=punc_weight_list, **args.model_conf) |
| | | model = PunctuationModel(punc_model=punc, vocab_size=vocab_size, punc_weight=punc_weight_list, **args.model_conf) |
| | | |
| | | # FIXME(kamo): Should be done in model? |
| | | # 3. Initialize |