| | |
| | | from tqdm import tqdm |
| | | from datetime import datetime |
| | | import torch.distributed as dist |
| | | from contextlib import nullcontext |
| | | from torch.cuda.amp import autocast, GradScaler |
| | | from contextlib import nullcontext, contextmanager |
| | | # from torch.utils.tensorboard import SummaryWriter |
| | | from tensorboardX import SummaryWriter |
| | | from pathlib import Path |
| | |
| | | from funasr.train_utils.device_funcs import to_device |
| | | from funasr.train_utils.recursive_op import recursive_average |
| | | from funasr.train_utils.average_nbest_models import average_checkpoints |
| | | |
| | | @contextmanager |
| | | def maybe_autocast(enabled): |
| | | if enabled: |
| | | with autocast(): |
| | | yield |
| | | else: |
| | | yield |
| | | |
| | | class Trainer: |
| | | """ |
| | |
| | | dataloader_train, |
| | | dataloader_val, |
| | | local_rank, |
| | | use_ddp=False, |
| | | use_fsdp=False, |
| | | use_ddp: bool = False, |
| | | use_fsdp: bool = False, |
| | | use_fp16: bool = False, |
| | | output_dir: str="./", |
| | | **kwargs): |
| | | """ |
| | |
| | | self.kwargs = kwargs |
| | | self.log_interval = kwargs.get("log_interval", 50) |
| | | self.batch_total = 0 |
| | | self.use_fp16 = use_fp16 |
| | | self.disable_gpu_cache = kwargs.get("disable_gpu_cache", True) |
| | | self.scaler = GradScaler(enabled=use_fp16) if use_fp16 else None |
| | | |
| | | |
| | | try: |
| | |
| | | 'optimizer': self.optim.state_dict(), |
| | | 'scheduler': self.scheduler.state_dict(), |
| | | } |
| | | if self.scaler: |
| | | state["scaler_state"] = self.scaler.state_dict() |
| | | # Create output directory if it does not exist |
| | | os.makedirs(self.output_dir, exist_ok=True) |
| | | filename = os.path.join(self.output_dir, f'model.pt.ep{epoch}') |
| | |
| | | self.model.load_state_dict(dst_state) |
| | | self.optim.load_state_dict(checkpoint['optimizer']) |
| | | self.scheduler.load_state_dict(checkpoint['scheduler']) |
| | | if self.scaler and 'scaler_state' in checkpoint: |
| | | self.scaler.load_state_dict(checkpoint['scaler_state']) |
| | | print(f"Checkpoint loaded successfully from '{ckpt}'") |
| | | else: |
| | | print(f"No checkpoint found at '{ckpt}', starting from scratch") |
| | |
| | | my_context = self.model.no_sync if batch_idx % accum_grad != 0 else nullcontext |
| | | with my_context(): |
| | | time2 = time.perf_counter() |
| | | |
| | | retval = self.model(**batch) |
| | | torch.cuda.empty_cache() |
| | | with maybe_autocast(self.use_fp16): |
| | | retval = self.model(**batch) |
| | | |
| | | if self.disable_gpu_cache: torch.cuda.empty_cache() |
| | | |
| | | time3 = time.perf_counter() |
| | | speed_stats["forward_time"] = f"{time3 - time2:0.3f}" |
| | |
| | | loss *= self.world_size |
| | | # Scale the loss since we're not updating for every mini-batch |
| | | loss = loss / accum_grad |
| | | loss.backward() |
| | | if self.use_fp16: |
| | | self.scaler.scale(loss).backward() |
| | | else: |
| | | loss.backward() |
| | | time4 = time.perf_counter() |
| | | speed_stats["backward_time"] = f"{time4 - time3:0.3f}" |
| | | |
| | |
| | | # Execute an optimization step (update model parameters) |
| | | if self.use_ddp or self.use_fsdp: |
| | | dist.barrier() |
| | | self.optim.step() |
| | | if self.use_fp16: |
| | | self.scaler.step(self.optim) |
| | | self.scaler.update() |
| | | else: |
| | | self.optim.step() |
| | | self.scheduler.step() |
| | | # Clear gradients for the next accumulation stage |
| | | self.optim.zero_grad() |
| | | self.optim.zero_grad(set_to_none=True) |
| | | total_time = f"{time.perf_counter() - time5:0.3f}" |
| | | time5 = time.perf_counter() |
| | | speed_stats["optim_time"] = f"{time5 - time4:0.3f}" |