| | |
| | | from funasr.train_utils.device_funcs import to_device |
| | | from funasr.train_utils.recursive_op import recursive_average |
| | | from funasr.train_utils.average_nbest_models import average_checkpoints |
| | | from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler |
| | | |
| | | @contextmanager |
| | | def maybe_autocast(enabled): |
| | |
| | | self.batch_total = 0 |
| | | self.use_fp16 = use_fp16 |
| | | self.disable_gpu_cache = kwargs.get("disable_gpu_cache", True) |
| | | self.scaler = GradScaler(enabled=use_fp16) if use_fp16 else None |
| | | scaler = GradScaler(enabled=use_fp16) if use_fp16 else None |
| | | scaler = ShardedGradScaler(enabled=use_fp16) if use_ddp else scaler |
| | | self.scaler = scaler |
| | | |
| | | |
| | | try: |
| | |
| | | self.scaler.load_state_dict(checkpoint['scaler_state']) |
| | | print(f"Checkpoint loaded successfully from '{ckpt}'") |
| | | else: |
| | | print(f"No checkpoint found at '{ckpt}', starting from scratch") |
| | | print(f"No checkpoint found at '{ckpt}', does not resume status!") |
| | | |
| | | if self.use_ddp or self.use_fsdp: |
| | | dist.barrier() |
| | |
| | | epoch * len(self.dataloader_val) + batch_idx) |
| | | for key, var in speed_stats.items(): |
| | | self.writer.add_scalar(f'rank{self.local_rank}_{key}/val', eval(var), |
| | | epoch * len(self.dataloader_val) + batch_idx) |
| | | epoch * len(self.dataloader_val) + batch_idx) |
| | | |
| | | self.model.train() |