| | |
| | | def parse_args(): |
| | | parser = ASRTask.get_parser() |
| | | parser.add_argument( |
| | | "--mode", |
| | | type=str, |
| | | default="asr", |
| | | help=" ", |
| | | ) |
| | | parser.add_argument( |
| | | "--gpu_id", |
| | | type=int, |
| | | default=0, |
| | |
| | | |
| | | |
| | | def main(args=None, cmd=None): |
| | | |
| | | # for ASR Training |
| | | if args.mode == "asr": |
| | | from funasr.tasks.asr import ASRTask |
| | | if args.mode == "paraformer": |
| | | from funasr.tasks.asr import ASRTaskParaformer as ASRTask |
| | | if args.mode == "uniasr": |
| | | from funasr.tasks.asr import ASRTaskUniASR as ASRTask |
| | | if args.mode == "rnnt": |
| | | from funasr.tasks.asr import ASRTransducerTask as ASRTask |
| | | |
| | | ASRTask.main(args=args, cmd=cmd) |
| | | |
| | | |
| | |
| | | args.batch_bins = args.batch_bins * args.ngpu |
| | | |
| | | main(args=args) |
| | | |