zhifu gao
2024-04-25 80bd14e6bbb7bb282ff3832194648dc4a16157ca
funasr/bin/train.py
@@ -32,6 +32,7 @@
from funasr.train_utils.set_all_random_seed import set_all_random_seed
from funasr.train_utils.load_pretrained_model import load_pretrained_model
from funasr.utils.misc import prepare_model_dir
from funasr.train_utils.model_summary import model_summary
from funasr import AutoModel
@@ -107,6 +108,7 @@
                    logging.info(f"Setting {k}.requires_grad = False")
                    p.requires_grad = False
    logging.info(f"model info: {model_summary(model)}")
    if use_ddp:
        model = model.cuda(local_rank)
        model = DDP(
@@ -209,6 +211,9 @@
                data_split_i=data_split_i,
                data_split_num=dataloader.data_split_num,
            )
            torch.cuda.empty_cache()
        trainer.validate_epoch(
            model=model, dataloader_val=dataloader_val, epoch=epoch, writer=writer