| | |
| | | from funasr.datasets.collate_fn import CommonCollateFn |
| | | from funasr.datasets.preprocessor import CommonPreprocessor |
| | | from funasr.lm.abs_model import AbsLM |
| | | from funasr.lm.espnet_model import ESPnetLanguageModel |
| | | from funasr.lm.espnet_model import LanguageModel |
| | | from funasr.lm.seq_rnn_lm import SequentialRNNLM |
| | | from funasr.lm.transformer_lm import TransformerLM |
| | | from funasr.tasks.abs_task import AbsTask |
| | |
| | | group.add_argument( |
| | | "--model_conf", |
| | | action=NestedDictAction, |
| | | default=get_default_kwargs(ESPnetLanguageModel), |
| | | default=get_default_kwargs(LanguageModel), |
| | | help="The keyword arguments for model class.", |
| | | ) |
| | | |
| | |
| | | return retval |
| | | |
| | | @classmethod |
| | | def build_model(cls, args: argparse.Namespace) -> ESPnetLanguageModel: |
| | | def build_model(cls, args: argparse.Namespace) -> LanguageModel: |
| | | assert check_argument_types() |
| | | if isinstance(args.token_list, str): |
| | | with open(args.token_list, encoding="utf-8") as f: |
| | |
| | | |
| | | # 2. Build ESPnetModel |
| | | # Assume the last-id is sos_and_eos |
| | | model = ESPnetLanguageModel(lm=lm, vocab_size=vocab_size, **args.model_conf) |
| | | model = LanguageModel(lm=lm, vocab_size=vocab_size, **args.model_conf) |
| | | |
| | | # 3. Initialize |
| | | if args.init is not None: |