Binbin Gu
2023-09-22 c1d01605bf5c4dc383f4c397ae4f566ce91b214a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import logging
import os
 
import torch
 
from funasr.train.distributed_utils import DistributedOption
from funasr.utils.build_dataclass import build_dataclass
 
 
def build_distributed(args):
    distributed_option = build_dataclass(DistributedOption, args)
    if args.use_pai:
        distributed_option.init_options_pai()
        distributed_option.init_torch_distributed_pai(args)
    elif not args.simple_ddp:
        distributed_option.init_torch_distributed(args)
    elif args.distributed and args.simple_ddp:
        distributed_option.init_torch_distributed_pai(args)
        args.ngpu = torch.distributed.get_world_size()
 
    for handler in logging.root.handlers[:]:
        logging.root.removeHandler(handler)
    if not distributed_option.distributed or distributed_option.dist_rank == 0:
        logging.basicConfig(
            level="INFO",
            format=f"[{os.uname()[1].split('.')[0]}]"
                   f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
        )
    else:
        logging.basicConfig(
            level="ERROR",
            format=f"[{os.uname()[1].split('.')[0]}]"
                   f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
        )
    logging.info("world size: {}, rank: {}, local_rank: {}".format(distributed_option.dist_world_size,
                                                                   distributed_option.dist_rank,
                                                                   distributed_option.local_rank))
    return distributed_option