| | |
| | | |
| | | if __name__ == '__main__': |
| | | parser = get_parser() |
| | | args = parser.parse_args() |
| | | task_args = build_args(args) |
| | | args = argparse.Namespace(**vars(args), **vars(task_args)) |
| | | args, extra_task_params = parser.parse_known_args() |
| | | if extra_task_params: |
| | | task_args = build_args(args, extra_task_params) |
| | | args = argparse.Namespace(**vars(args), **vars(task_args)) |
| | | |
| | | # set random seed |
| | | set_all_random_seed(args.seed) |
| | |
| | | from funasr.utils.types import str_or_none |
| | | |
| | | |
| | | def build_args(args): |
| | | def build_args(args, extra_task_params): |
| | | parser = argparse.ArgumentParser("Task related config") |
| | | if args.task_name == "asr": |
| | | from funasr.build_utils.build_asr_model import class_choices_list |
| | |
| | | else: |
| | | raise NotImplementedError("Not supported task: {}".format(args.task_name)) |
| | | |
| | | task_args = parser.parse_args() |
| | | task_args = parser.parse_args(extra_task_params) |
| | | return task_args |