From 33d3d2084403fd34b79c835d2f2fe04f6cd8f738 Mon Sep 17 00:00:00 2001 From: 游雁 <zhifu.gzf@alibaba-inc.com> Date: 星期三, 13 九月 2023 09:33:54 +0800 Subject: [PATCH] Merge branch 'main' of github.com:alibaba-damo-academy/FunASR add --- funasr/bin/sa_asr_train.py | 50 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 files changed, 50 insertions(+), 0 deletions(-) diff --git a/funasr/bin/sa_asr_train.py b/funasr/bin/sa_asr_train.py new file mode 100755 index 0000000..67106cf --- /dev/null +++ b/funasr/bin/sa_asr_train.py @@ -0,0 +1,50 @@ +# -*- encoding: utf-8 -*- +#!/usr/bin/env python3 +# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. +# MIT License (https://opensource.org/licenses/MIT) + +import os + +from funasr.tasks.sa_asr import ASRTask + + +# for ASR Training +def parse_args(): + parser = ASRTask.get_parser() + parser.add_argument( + "--gpu_id", + type=int, + default=0, + help="local gpu id.", + ) + args = parser.parse_args() + return args + + +def main(args=None, cmd=None): + # for ASR Training + ASRTask.main(args=args, cmd=cmd) + + +if __name__ == '__main__': + args = parse_args() + + # setup local gpu_id + if args.ngpu > 0: + os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id) + + # DDP settings + if args.ngpu > 1: + args.distributed = True + else: + args.distributed = False + assert args.num_worker_count == 1 + + # re-compute batch size: when dataset type is small + if args.dataset_type == "small": + if args.batch_size is not None and args.ngpu > 0: + args.batch_size = args.batch_size * args.ngpu + if args.batch_bins is not None and args.ngpu > 0: + args.batch_bins = args.batch_bins * args.ngpu + + main(args=args) -- Gitblit v1.9.1