雾聪
2023-08-10 ffb05b9ae7eccc47416e9e7fae9dea54d400a245
funasr/bin/punc_train.py
@@ -1,4 +1,8 @@
# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
import os
from funasr.tasks.punctuation import PunctuationTask
@@ -40,4 +44,10 @@
    else:
        args.distributed = False
    if args.dataset_type == "small":
        if args.batch_size is not None:
            args.batch_size = args.batch_size * args.ngpu * args.num_worker_count
        if args.batch_bins is not None:
            args.batch_bins = args.batch_bins * args.ngpu * args.num_worker_count
    main(args=args)