| | |
| | | import logging |
| | | from funasr.utils.build_asr_model import build_asr_model |
| | | from funasr.utils.build_pretrain_model import build_pretrain_model |
| | | |
| | | |
| | | def build_model(args): |
| | | if args.token_list is not None: |
| | | with open(args.token_list) as f: |
| | | token_list = [line.rstrip() for line in f] |
| | | args.token_list = list(token_list) |
| | | vocab_size = len(token_list) |
| | | logging.info(f"Vocabulary size: {vocab_size}") |
| | | if args.task_name == "asr": |
| | | model = build_asr_model(args) |
| | | elif args.task_name == "pretrain": |
| | | model = build_pretrain_model(args) |
| | | else: |
| | | raise NotImplementedError("Not supported task: {}".format(args.task_name)) |
| | | |
| | | |
| | | |
| | | |
| | | return model |