VirtuosoQ
2024-04-26 e9d2cfc3a134b00f4e98271fbee3838d1ccecbcc
funasr/bin/train.py
@@ -55,6 +55,8 @@
    torch.backends.cudnn.enabled = kwargs.get("cudnn_enabled", torch.backends.cudnn.enabled)
    torch.backends.cudnn.benchmark = kwargs.get("cudnn_benchmark", torch.backends.cudnn.benchmark)
    torch.backends.cudnn.deterministic = kwargs.get("cudnn_deterministic", True)
    # open tf32
    torch.backends.cuda.matmul.allow_tf32 = kwargs.get("enable_tf32", True)
    
    local_rank = int(os.environ.get('LOCAL_RANK', 0))
    if local_rank == 0:
@@ -88,7 +90,8 @@
    # freeze_param
    freeze_param = kwargs.get("freeze_param", None)
    if freeze_param is not None:
        freeze_param = eval(freeze_param)
        if "," in freeze_param:
            freeze_param = eval(freeze_param)
        if isinstance(freeze_param, Sequence):
            freeze_param = (freeze_param,)
        logging.info("freeze_param is not None: %s", freeze_param)
@@ -173,15 +176,12 @@
    except:
        writer = None
    # if use_ddp or use_fsdp:
    #     context = Join([model])
    # else:
    #     context = nullcontext()
    context = nullcontext()
    for epoch in range(trainer.start_epoch, trainer.max_epoch + 1):
        time1 = time.perf_counter()
        with context:
            dataloader_tr, dataloader_val = dataloader.build_iter(epoch)
        for data_split_i in range(dataloader.data_split_num):
            dataloader_tr, dataloader_val = dataloader.build_iter(epoch, data_split_i=data_split_i)
            trainer.train_epoch(
                                model=model,
                                optim=optim,
@@ -190,15 +190,17 @@
                                dataloader_train=dataloader_tr,
                                dataloader_val=dataloader_val,
                                epoch=epoch,
                                writer=writer
                                writer=writer,
                                data_split_i=data_split_i,
                                data_split_num=dataloader.data_split_num,
                                )
        with context:
            trainer.validate_epoch(
                model=model,
                dataloader_val=dataloader_val,
                epoch=epoch,
                writer=writer
            )
        trainer.validate_epoch(
            model=model,
            dataloader_val=dataloader_val,
            epoch=epoch,
            writer=writer
        )
        scheduler.step()