From 0cf5dfec2c8313fc2ed2aab8d10bf3dc4b9c283f Mon Sep 17 00:00:00 2001
From: 雾聪 <wucong.lyb@alibaba-inc.com>
Date: 星期四, 14 三月 2024 14:41:49 +0800
Subject: [PATCH] update cmakelist
---
funasr/train_utils/trainer.py | 116 +++++++++++++++++++++++++++++++++++++++++-----------------
1 files changed, 82 insertions(+), 34 deletions(-)
diff --git a/funasr/train_utils/trainer.py b/funasr/train_utils/trainer.py
index 10f7f80..723a149 100644
--- a/funasr/train_utils/trainer.py
+++ b/funasr/train_utils/trainer.py
@@ -3,8 +3,10 @@
import torch
import logging
from tqdm import tqdm
+from datetime import datetime
import torch.distributed as dist
-from contextlib import nullcontext
+from torch.cuda.amp import autocast, GradScaler
+from contextlib import nullcontext, contextmanager
# from torch.utils.tensorboard import SummaryWriter
from tensorboardX import SummaryWriter
from pathlib import Path
@@ -12,6 +14,15 @@
from funasr.train_utils.device_funcs import to_device
from funasr.train_utils.recursive_op import recursive_average
from funasr.train_utils.average_nbest_models import average_checkpoints
+from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
+
+@contextmanager
+def maybe_autocast(enabled):
+ if enabled:
+ with autocast():
+ yield
+ else:
+ yield
class Trainer:
"""
@@ -35,8 +46,9 @@
dataloader_train,
dataloader_val,
local_rank,
- use_ddp=False,
- use_fsdp=False,
+ use_ddp: bool = False,
+ use_fsdp: bool = False,
+ use_fp16: bool = False,
output_dir: str="./",
**kwargs):
"""
@@ -71,6 +83,11 @@
self.kwargs = kwargs
self.log_interval = kwargs.get("log_interval", 50)
self.batch_total = 0
+ self.use_fp16 = use_fp16
+ self.disable_gpu_cache = kwargs.get("disable_gpu_cache", True)
+ scaler = GradScaler(enabled=use_fp16) if use_fp16 else None
+ scaler = ShardedGradScaler(enabled=use_fp16) if use_ddp else scaler
+ self.scaler = scaler
try:
@@ -102,19 +119,17 @@
'optimizer': self.optim.state_dict(),
'scheduler': self.scheduler.state_dict(),
}
+ if self.scaler:
+ state["scaler_state"] = self.scaler.state_dict()
# Create output directory if it does not exist
os.makedirs(self.output_dir, exist_ok=True)
filename = os.path.join(self.output_dir, f'model.pt.ep{epoch}')
torch.save(state, filename)
- print(f'Checkpoint saved to {filename}')
+ print(f'\nCheckpoint saved to {filename}\n')
latest = Path(os.path.join(self.output_dir, f'model.pt'))
- try:
- latest.unlink()
- except:
- pass
+ torch.save(state, latest)
- latest.symlink_to(filename)
def _resume_checkpoint(self, resume_path):
"""
@@ -128,12 +143,27 @@
if os.path.isfile(ckpt):
checkpoint = torch.load(ckpt)
self.start_epoch = checkpoint['epoch'] + 1
- self.model.load_state_dict(checkpoint['state_dict'])
+ # self.model.load_state_dict(checkpoint['state_dict'])
+ src_state = checkpoint['state_dict']
+ dst_state = self.model.state_dict()
+ for k in dst_state.keys():
+ if not k.startswith("module.") and "module."+k in src_state.keys():
+ k_ddp = "module."+k
+ else:
+ k_ddp = k
+ if k_ddp in src_state.keys():
+ dst_state[k] = src_state[k_ddp]
+ else:
+ print(f"Miss key in ckpt: model: {k}, ckpt: {k_ddp}")
+
+ self.model.load_state_dict(dst_state)
self.optim.load_state_dict(checkpoint['optimizer'])
self.scheduler.load_state_dict(checkpoint['scheduler'])
+ if self.scaler and 'scaler_state' in checkpoint:
+ self.scaler.load_state_dict(checkpoint['scaler_state'])
print(f"Checkpoint loaded successfully from '{ckpt}'")
else:
- print(f"No checkpoint found at '{ckpt}', starting from scratch")
+ print(f"No checkpoint found at '{ckpt}', does not resume status!")
if self.use_ddp or self.use_fsdp:
dist.barrier()
@@ -147,7 +177,7 @@
self._resume_checkpoint(self.output_dir)
for epoch in range(self.start_epoch, self.max_epoch + 1):
-
+ time1 = time.perf_counter()
self._train_epoch(epoch)
@@ -169,6 +199,9 @@
self.scheduler.step()
+ time2 = time.perf_counter()
+ time_escaped = (time2 - time1)/3600.0
+ print(f"\nrank: {self.local_rank}, time_escaped_epoch: {time_escaped:.3f} hours, estimated to finish {self.max_epoch} epoch: {(self.max_epoch-epoch)*time_escaped:.3f} hours\n")
if self.rank == 0:
average_checkpoints(self.output_dir, self.avg_nbest_model)
@@ -188,7 +221,7 @@
epoch (int): The current epoch number.
"""
self.model.train()
- pbar = tqdm(colour="blue", desc=f"Training Epoch: {epoch + 1}", total=len(self.dataloader_train),
+ pbar = tqdm(colour="blue", desc=f"rank: {self.local_rank}, Training Epoch: {epoch + 1}", total=len(self.dataloader_train),
dynamic_ncols=True)
# Set the number of steps for gradient accumulation
@@ -208,9 +241,10 @@
my_context = self.model.no_sync if batch_idx % accum_grad != 0 else nullcontext
with my_context():
time2 = time.perf_counter()
-
- retval = self.model(**batch)
- torch.cuda.empty_cache()
+ with maybe_autocast(self.use_fp16):
+ retval = self.model(**batch)
+
+ if self.disable_gpu_cache: torch.cuda.empty_cache()
time3 = time.perf_counter()
speed_stats["forward_time"] = f"{time3 - time2:0.3f}"
@@ -228,7 +262,10 @@
loss *= self.world_size
# Scale the loss since we're not updating for every mini-batch
loss = loss / accum_grad
- loss.backward()
+ if self.use_fp16:
+ self.scaler.scale(loss).backward()
+ else:
+ loss.backward()
time4 = time.perf_counter()
speed_stats["backward_time"] = f"{time4 - time3:0.3f}"
@@ -251,10 +288,14 @@
# Execute an optimization step (update model parameters)
if self.use_ddp or self.use_fsdp:
dist.barrier()
- self.optim.step()
+ if self.use_fp16:
+ self.scaler.step(self.optim)
+ self.scaler.update()
+ else:
+ self.optim.step()
self.scheduler.step()
# Clear gradients for the next accumulation stage
- self.optim.zero_grad()
+ self.optim.zero_grad(set_to_none=True)
total_time = f"{time.perf_counter() - time5:0.3f}"
time5 = time.perf_counter()
speed_stats["optim_time"] = f"{time5 - time4:0.3f}"
@@ -273,28 +314,30 @@
torch.cuda.memory_reserved()/1024/1024/1024,
torch.cuda.max_memory_reserved()/1024/1024/1024,
)
+ lr = self.scheduler.get_last_lr()[0]
+ time_now = datetime.now()
+ time_now = time_now.strftime("%Y-%m-%d %H:%M:%S")
description = (
+ f"{time_now}, "
f"rank: {self.local_rank}, "
f"epoch: {epoch}/{self.max_epoch}, "
- f"step: {batch_idx}/{len(self.dataloader_train)}, total: {self.batch_total}, "
+ f"step: {batch_idx+1}/{len(self.dataloader_train)}, total step: {self.batch_total}, "
f"(loss: {loss.detach().cpu().item():.3f}), "
- f"{[(k, round(v.cpu().item(), 3)) for k, v in stats.items()]}"
+ f"(lr: {lr:.3e}), "
+ f"{[(k, round(v.cpu().item(), 3)) for k, v in stats.items()]}, "
f"{speed_stats}, "
f"{gpu_info}"
)
pbar.set_description(description)
if self.writer:
- self.writer.add_scalar(f'rank{self.local_rank}_Loss/train', loss.item(),
- epoch*len(self.dataloader_train) + batch_idx)
+ self.writer.add_scalar(f'rank{self.local_rank}_Loss/train', loss.item(), self.batch_total)
+ self.writer.add_scalar(f'rank{self.local_rank}_lr/train', lr, self.batch_total)
for key, var in stats.items():
- self.writer.add_scalar(f'rank{self.local_rank}_{key}/train', var.item(),
- epoch * len(self.dataloader_train) + batch_idx)
+ self.writer.add_scalar(f'rank{self.local_rank}_{key}/train', var.item(), self.batch_total)
for key, var in speed_stats.items():
- self.writer.add_scalar(f'rank{self.local_rank}_{key}/train', eval(var),
- epoch * len(self.dataloader_train) + batch_idx)
-
- # if batch_idx == 2:
- # break
+ self.writer.add_scalar(f'rank{self.local_rank}_{key}/train', eval(var), self.batch_total)
+
+
pbar.close()
def _validate_epoch(self, epoch):
@@ -307,7 +350,7 @@
"""
self.model.eval()
with torch.no_grad():
- pbar = tqdm(colour="red", desc=f"Training Epoch: {epoch + 1}", total=len(self.dataloader_val),
+ pbar = tqdm(colour="red", desc=f"rank: {self.local_rank}, Validation Epoch: {epoch + 1}", total=len(self.dataloader_val),
dynamic_ncols=True)
speed_stats = {}
time5 = time.perf_counter()
@@ -338,12 +381,15 @@
if (batch_idx+1) % self.log_interval == 0 or (batch_idx+1) == len(self.dataloader_val):
pbar.update(self.log_interval)
+ time_now = datetime.now()
+ time_now = time_now.strftime("%Y-%m-%d %H:%M:%S")
description = (
+ f"{time_now}, "
f"rank: {self.local_rank}, "
f"validation epoch: {epoch}/{self.max_epoch}, "
- f"step: {batch_idx}/{len(self.dataloader_val)}, "
+ f"step: {batch_idx+1}/{len(self.dataloader_val)}, "
f"(loss: {loss.detach().cpu().item():.3f}), "
- f"{[(k, round(v.cpu().item(), 3)) for k, v in stats.items()]}"
+ f"{[(k, round(v.cpu().item(), 3)) for k, v in stats.items()]}, "
f"{speed_stats}, "
)
pbar.set_description(description)
@@ -355,4 +401,6 @@
epoch * len(self.dataloader_val) + batch_idx)
for key, var in speed_stats.items():
self.writer.add_scalar(f'rank{self.local_rank}_{key}/val', eval(var),
- epoch * len(self.dataloader_val) + batch_idx)
\ No newline at end of file
+ epoch * len(self.dataloader_val) + batch_idx)
+
+ self.model.train()
\ No newline at end of file
--
Gitblit v1.9.1