From 9ba0dbd98bf69c830dfcfde8f109a400cb65e4e5 Mon Sep 17 00:00:00 2001
From: 雾聪 <wucong.lyb@alibaba-inc.com>
Date: 星期五, 29 三月 2024 17:24:59 +0800
Subject: [PATCH] fix func Forward
---
funasr/train_utils/trainer.py | 237 ++++++++++++++++++++++++++++++++++++++++++++++++----------
1 files changed, 195 insertions(+), 42 deletions(-)
diff --git a/funasr/train_utils/trainer.py b/funasr/train_utils/trainer.py
index da346c3..723a149 100644
--- a/funasr/train_utils/trainer.py
+++ b/funasr/train_utils/trainer.py
@@ -3,14 +3,26 @@
import torch
import logging
from tqdm import tqdm
+from datetime import datetime
import torch.distributed as dist
-from contextlib import nullcontext
+from torch.cuda.amp import autocast, GradScaler
+from contextlib import nullcontext, contextmanager
# from torch.utils.tensorboard import SummaryWriter
from tensorboardX import SummaryWriter
+from pathlib import Path
from funasr.train_utils.device_funcs import to_device
from funasr.train_utils.recursive_op import recursive_average
+from funasr.train_utils.average_nbest_models import average_checkpoints
+from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
+@contextmanager
+def maybe_autocast(enabled):
+ if enabled:
+ with autocast():
+ yield
+ else:
+ yield
class Trainer:
"""
@@ -34,8 +46,9 @@
dataloader_train,
dataloader_val,
local_rank,
- use_ddp=False,
- use_fsdp=False,
+ use_ddp: bool = False,
+ use_fsdp: bool = False,
+ use_fp16: bool = False,
output_dir: str="./",
**kwargs):
"""
@@ -66,10 +79,16 @@
self.use_ddp = use_ddp
self.use_fsdp = use_fsdp
self.device = next(model.parameters()).device
+ self.avg_nbest_model = kwargs.get("avg_nbest_model", 5)
self.kwargs = kwargs
+ self.log_interval = kwargs.get("log_interval", 50)
+ self.batch_total = 0
+ self.use_fp16 = use_fp16
+ self.disable_gpu_cache = kwargs.get("disable_gpu_cache", True)
+ scaler = GradScaler(enabled=use_fp16) if use_fp16 else None
+ scaler = ShardedGradScaler(enabled=use_fp16) if use_ddp else scaler
+ self.scaler = scaler
- if self.resume:
- self._resume_checkpoint(self.resume)
try:
rank = dist.get_rank()
@@ -100,11 +119,17 @@
'optimizer': self.optim.state_dict(),
'scheduler': self.scheduler.state_dict(),
}
+ if self.scaler:
+ state["scaler_state"] = self.scaler.state_dict()
# Create output directory if it does not exist
os.makedirs(self.output_dir, exist_ok=True)
- filename = os.path.join(self.output_dir, f'model.e{epoch}.pb')
+ filename = os.path.join(self.output_dir, f'model.pt.ep{epoch}')
torch.save(state, filename)
- print(f'Checkpoint saved to {filename}')
+
+ print(f'\nCheckpoint saved to {filename}\n')
+ latest = Path(os.path.join(self.output_dir, f'model.pt'))
+ torch.save(state, latest)
+
def _resume_checkpoint(self, resume_path):
"""
@@ -114,29 +139,80 @@
Args:
resume_path (str): The file path to the checkpoint to resume from.
"""
- if os.path.isfile(resume_path):
- checkpoint = torch.load(resume_path)
+ ckpt = os.path.join(resume_path, "model.pt")
+ if os.path.isfile(ckpt):
+ checkpoint = torch.load(ckpt)
self.start_epoch = checkpoint['epoch'] + 1
- self.model.load_state_dict(checkpoint['state_dict'])
+ # self.model.load_state_dict(checkpoint['state_dict'])
+ src_state = checkpoint['state_dict']
+ dst_state = self.model.state_dict()
+ for k in dst_state.keys():
+ if not k.startswith("module.") and "module."+k in src_state.keys():
+ k_ddp = "module."+k
+ else:
+ k_ddp = k
+ if k_ddp in src_state.keys():
+ dst_state[k] = src_state[k_ddp]
+ else:
+ print(f"Miss key in ckpt: model: {k}, ckpt: {k_ddp}")
+
+ self.model.load_state_dict(dst_state)
self.optim.load_state_dict(checkpoint['optimizer'])
self.scheduler.load_state_dict(checkpoint['scheduler'])
- print(f"Checkpoint loaded successfully from '{resume_path}' at (epoch {checkpoint['epoch']})")
+ if self.scaler and 'scaler_state' in checkpoint:
+ self.scaler.load_state_dict(checkpoint['scaler_state'])
+ print(f"Checkpoint loaded successfully from '{ckpt}'")
else:
- print(f"No checkpoint found at '{resume_path}', starting from scratch")
+ print(f"No checkpoint found at '{ckpt}', does not resume status!")
+
+ if self.use_ddp or self.use_fsdp:
+ dist.barrier()
def run(self):
"""
Starts the training process, iterating over epochs, training the model,
and saving checkpoints at the end of each epoch.
"""
+ if self.resume:
+ self._resume_checkpoint(self.output_dir)
+
for epoch in range(self.start_epoch, self.max_epoch + 1):
+ time1 = time.perf_counter()
self._train_epoch(epoch)
- # self._validate_epoch(epoch)
+
+
+
+ if self.use_ddp or self.use_fsdp:
+ dist.barrier()
+
+ self._validate_epoch(epoch)
+
+ if self.use_ddp or self.use_fsdp:
+ dist.barrier()
+
+
if self.rank == 0:
self._save_checkpoint(epoch)
- self.scheduler.step()
- self.writer.close()
+ if self.use_ddp or self.use_fsdp:
+ dist.barrier()
+
+ self.scheduler.step()
+
+ time2 = time.perf_counter()
+ time_escaped = (time2 - time1)/3600.0
+ print(f"\nrank: {self.local_rank}, time_escaped_epoch: {time_escaped:.3f} hours, estimated to finish {self.max_epoch} epoch: {(self.max_epoch-epoch)*time_escaped:.3f} hours\n")
+
+ if self.rank == 0:
+ average_checkpoints(self.output_dir, self.avg_nbest_model)
+
+ if self.use_ddp or self.use_fsdp:
+ dist.barrier()
+
+
+ if self.writer:
+ self.writer.close()
+
def _train_epoch(self, epoch):
"""
@@ -145,7 +221,7 @@
epoch (int): The current epoch number.
"""
self.model.train()
- pbar = tqdm(colour="blue", desc=f"Training Epoch: {epoch + 1}", total=len(self.dataloader_train),
+ pbar = tqdm(colour="blue", desc=f"rank: {self.local_rank}, Training Epoch: {epoch + 1}", total=len(self.dataloader_train),
dynamic_ncols=True)
# Set the number of steps for gradient accumulation
@@ -154,17 +230,22 @@
self.optim.zero_grad()
speed_stats = {}
time5 = time.perf_counter()
+
for batch_idx, batch in enumerate(self.dataloader_train):
+ self.batch_total += 1
time1 = time.perf_counter()
speed_stats["data_load"] = f"{time1-time5:0.3f}"
- # import pdb;
- # pdb.set_trace()
+
batch = to_device(batch, self.device)
my_context = self.model.no_sync if batch_idx % accum_grad != 0 else nullcontext
with my_context():
time2 = time.perf_counter()
- retval = self.model(**batch)
+ with maybe_autocast(self.use_fp16):
+ retval = self.model(**batch)
+
+ if self.disable_gpu_cache: torch.cuda.empty_cache()
+
time3 = time.perf_counter()
speed_stats["forward_time"] = f"{time3 - time2:0.3f}"
loss, stats, weight = retval
@@ -181,7 +262,10 @@
loss *= self.world_size
# Scale the loss since we're not updating for every mini-batch
loss = loss / accum_grad
- loss.backward()
+ if self.use_fp16:
+ self.scaler.scale(loss).backward()
+ else:
+ loss.backward()
time4 = time.perf_counter()
speed_stats["backward_time"] = f"{time4 - time3:0.3f}"
@@ -202,40 +286,58 @@
continue
# Execute an optimization step (update model parameters)
- self.optim.step()
+ if self.use_ddp or self.use_fsdp:
+ dist.barrier()
+ if self.use_fp16:
+ self.scaler.step(self.optim)
+ self.scaler.update()
+ else:
+ self.optim.step()
self.scheduler.step()
# Clear gradients for the next accumulation stage
- self.optim.zero_grad()
+ self.optim.zero_grad(set_to_none=True)
total_time = f"{time.perf_counter() - time5:0.3f}"
time5 = time.perf_counter()
speed_stats["optim_time"] = f"{time5 - time4:0.3f}"
speed_stats["total_time"] = total_time
+
+
- # import pdb;
- # pdb.set_trace()
- pbar.update(1)
- if self.local_rank == 0:
+ if (batch_idx+1) % self.log_interval == 0 or (batch_idx+1) == len(self.dataloader_train):
+ pbar.update(self.log_interval)
+ gpu_info = "GPU, memory: {:.3f} GB, " \
+ "{:.3f} GB, "\
+ "{:.3f} GB, "\
+ "{:.3f} GB".format(torch.cuda.memory_allocated()/1024/1024/1024,
+ torch.cuda.max_memory_allocated()/1024/1024/1024,
+ torch.cuda.memory_reserved()/1024/1024/1024,
+ torch.cuda.max_memory_reserved()/1024/1024/1024,
+ )
+ lr = self.scheduler.get_last_lr()[0]
+ time_now = datetime.now()
+ time_now = time_now.strftime("%Y-%m-%d %H:%M:%S")
description = (
- f"Epoch: {epoch + 1}/{self.max_epoch}, "
- f"step {batch_idx}/{len(self.dataloader_train)}, "
- f"{speed_stats}, "
+ f"{time_now}, "
+ f"rank: {self.local_rank}, "
+ f"epoch: {epoch}/{self.max_epoch}, "
+ f"step: {batch_idx+1}/{len(self.dataloader_train)}, total step: {self.batch_total}, "
f"(loss: {loss.detach().cpu().item():.3f}), "
- f"{[(k, round(v.cpu().item(), 3)) for k, v in stats.items()]}"
+ f"(lr: {lr:.3e}), "
+ f"{[(k, round(v.cpu().item(), 3)) for k, v in stats.items()]}, "
+ f"{speed_stats}, "
+ f"{gpu_info}"
)
pbar.set_description(description)
if self.writer:
- self.writer.add_scalar('Loss/train', loss.item(),
- epoch*len(self.dataloader_train) + batch_idx)
+ self.writer.add_scalar(f'rank{self.local_rank}_Loss/train', loss.item(), self.batch_total)
+ self.writer.add_scalar(f'rank{self.local_rank}_lr/train', lr, self.batch_total)
for key, var in stats.items():
- self.writer.add_scalar(f'{key}/train', var.item(),
- epoch * len(self.dataloader_train) + batch_idx)
+ self.writer.add_scalar(f'rank{self.local_rank}_{key}/train', var.item(), self.batch_total)
for key, var in speed_stats.items():
- self.writer.add_scalar(f'{key}/train', eval(var),
- epoch * len(self.dataloader_train) + batch_idx)
-
- # if batch_idx == 2:
- # break
+ self.writer.add_scalar(f'rank{self.local_rank}_{key}/train', eval(var), self.batch_total)
+
+
pbar.close()
def _validate_epoch(self, epoch):
@@ -248,6 +350,57 @@
"""
self.model.eval()
with torch.no_grad():
- for data, target in self.dataloader_val:
- # Implement the model validation steps here
- pass
+ pbar = tqdm(colour="red", desc=f"rank: {self.local_rank}, Validation Epoch: {epoch + 1}", total=len(self.dataloader_val),
+ dynamic_ncols=True)
+ speed_stats = {}
+ time5 = time.perf_counter()
+ for batch_idx, batch in enumerate(self.dataloader_val):
+ time1 = time.perf_counter()
+ speed_stats["data_load"] = f"{time1 - time5:0.3f}"
+ batch = to_device(batch, self.device)
+ time2 = time.perf_counter()
+ retval = self.model(**batch)
+ time3 = time.perf_counter()
+ speed_stats["forward_time"] = f"{time3 - time2:0.3f}"
+ loss, stats, weight = retval
+ stats = {k: v for k, v in stats.items() if v is not None}
+ if self.use_ddp or self.use_fsdp:
+ # Apply weighted averaging for loss and stats
+ loss = (loss * weight.type(loss.dtype)).sum()
+ # if distributed, this method can also apply all_reduce()
+ stats, weight = recursive_average(stats, weight, distributed=True)
+ # Now weight is summation over all workers
+ loss /= weight
+ # Multiply world_size because DistributedDataParallel
+ # automatically normalizes the gradient by world_size.
+ loss *= self.world_size
+ # Scale the loss since we're not updating for every mini-batch
+ loss = loss
+ time4 = time.perf_counter()
+
+
+ if (batch_idx+1) % self.log_interval == 0 or (batch_idx+1) == len(self.dataloader_val):
+ pbar.update(self.log_interval)
+ time_now = datetime.now()
+ time_now = time_now.strftime("%Y-%m-%d %H:%M:%S")
+ description = (
+ f"{time_now}, "
+ f"rank: {self.local_rank}, "
+ f"validation epoch: {epoch}/{self.max_epoch}, "
+ f"step: {batch_idx+1}/{len(self.dataloader_val)}, "
+ f"(loss: {loss.detach().cpu().item():.3f}), "
+ f"{[(k, round(v.cpu().item(), 3)) for k, v in stats.items()]}, "
+ f"{speed_stats}, "
+ )
+ pbar.set_description(description)
+ if self.writer:
+ self.writer.add_scalar(f"rank{self.local_rank}_Loss/val", loss.item(),
+ epoch*len(self.dataloader_val) + batch_idx)
+ for key, var in stats.items():
+ self.writer.add_scalar(f'rank{self.local_rank}_{key}/val', var.item(),
+ epoch * len(self.dataloader_val) + batch_idx)
+ for key, var in speed_stats.items():
+ self.writer.add_scalar(f'rank{self.local_rank}_{key}/val', eval(var),
+ epoch * len(self.dataloader_val) + batch_idx)
+
+ self.model.train()
\ No newline at end of file
--
Gitblit v1.9.1