| | |
| | | from funasr.train_utils.set_all_random_seed import set_all_random_seed |
| | | from funasr.train_utils.load_pretrained_model import load_pretrained_model |
| | | from funasr.utils.misc import prepare_model_dir |
| | | from funasr.train_utils.model_summary import model_summary |
| | | from funasr import AutoModel |
| | | |
| | | |
| | |
| | | logging.info(f"Setting {k}.requires_grad = False") |
| | | p.requires_grad = False |
| | | |
| | | logging.info(f"model info: {model_summary(model)}") |
| | | if use_ddp: |
| | | model = model.cuda(local_rank) |
| | | model = DDP( |
| | |
| | | data_split_i=data_split_i, |
| | | data_split_num=dataloader.data_split_num, |
| | | ) |
| | | |
| | | torch.cuda.empty_cache() |
| | | |
| | | |
| | | trainer.validate_epoch( |
| | | model=model, dataloader_val=dataloader_val, epoch=epoch, writer=writer |