| | |
| | | from funasr.utils.build_dataloader import build_dataloader |
| | | from funasr.utils.build_distributed import build_distributed |
| | | from funasr.utils.prepare_data import prepare_data |
| | | from funasr.utils.build_optimizer import build_optimizer |
| | | from funasr.utils.build_scheduler import build_scheduler |
| | | from funasr.utils.types import str2bool |
| | | |
| | | |
| | |
| | | distributed_option.dist_rank, |
| | | distributed_option.local_rank)) |
| | | |
| | | # optimizers = cls.build_optimizers(args, model=model) |
| | | # schedulers = [] |
| | | # for i, optim in enumerate(optimizers, 1): |
| | | # suf = "" if i == 1 else str(i) |
| | | # name = getattr(args, f"scheduler{suf}") |
| | | # conf = getattr(args, f"scheduler{suf}_conf") |
| | | # if name is not None: |
| | | # cls_ = scheduler_classes.get(name) |
| | | # if cls_ is None: |
| | | # raise ValueError( |
| | | # f"must be one of {list(scheduler_classes)}: {name}" |
| | | # ) |
| | | # scheduler = cls_(optim, **conf) |
| | | # else: |
| | | # scheduler = None |
| | | # |
| | | # schedulers.append(scheduler) |
| | | model = build_model(args) |
| | | optimizers = build_optimizer(args, model=model) |
| | | schedule = build_scheduler(args) |