| | |
| | | from tqdm import tqdm |
| | | import torch.distributed as dist |
| | | from contextlib import nullcontext |
| | | # from torch.utils.tensorboard import SummaryWriter |
| | | from tensorboardX import SummaryWriter |
| | | |
| | | from funasr.train_utils.device_funcs import to_device |
| | | from funasr.train_utils.recursive_op import recursive_average |
| | |
| | | local_rank, |
| | | use_ddp=False, |
| | | use_fsdp=False, |
| | | output_dir: str="./", |
| | | **kwargs): |
| | | """ |
| | | Initializes the Trainer class with the model, optimizer, scheduler, dataloaders, and other settings. |
| | |
| | | self.scheduler = scheduler |
| | | self.dataloader_train = dataloader_train |
| | | self.dataloader_val = dataloader_val |
| | | self.output_dir = kwargs.get('output_dir', './') |
| | | self.output_dir = output_dir |
| | | self.resume = kwargs.get('resume', True) |
| | | self.start_epoch = 0 |
| | | self.max_epoch = kwargs.get('max_epoch', 100) |
| | |
| | | logging.warning("distributed is not initialized, only single shard") |
| | | self.rank = rank |
| | | self.world_size = world_size |
| | | |
| | | os.makedirs(os.path.join(self.output_dir, "tensorboard"), exist_ok=True) |
| | | self.writer = SummaryWriter(os.path.join(self.output_dir, "tensorboard")) if rank == 0 else None |
| | | |
| | | |
| | | def _save_checkpoint(self, epoch): |
| | | """ |
| | |
| | | if self.rank == 0: |
| | | self._save_checkpoint(epoch) |
| | | self.scheduler.step() |
| | | |
| | | self.writer.close() |
| | | |
| | | def _train_epoch(self, epoch): |
| | | """ |
| | |
| | | f"{[(k, round(v.cpu().item(), 3)) for k, v in stats.items()]}" |
| | | ) |
| | | pbar.set_description(description) |
| | | |
| | | if self.writer: |
| | | self.writer.add_scalar('Loss/train', loss.item(), |
| | | epoch*len(self.dataloader_train) + batch_idx) |
| | | for key, var in stats.items(): |
| | | self.writer.add_scalar(f'{key}/train', var.item(), |
| | | epoch * len(self.dataloader_train) + batch_idx) |
| | | for key, var in speed_stats.items(): |
| | | self.writer.add_scalar(f'{key}/train', eval(var), |
| | | epoch * len(self.dataloader_train) + batch_idx) |
| | | |
| | | # if batch_idx == 2: |
| | | # break |
| | | pbar.close() |