From 94de39dde2e616a01683c518023d0fab72b4e103 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 19 二月 2024 22:21:50 +0800
Subject: [PATCH] aishell example
---
funasr/train_utils/trainer.py | 170 ++++++++++++++++++++++++++++++++++++++++++++++++++------
1 files changed, 150 insertions(+), 20 deletions(-)
diff --git a/funasr/train_utils/trainer.py b/funasr/train_utils/trainer.py
index 0f0acc2..3cd61a1 100644
--- a/funasr/train_utils/trainer.py
+++ b/funasr/train_utils/trainer.py
@@ -5,10 +5,13 @@
from tqdm import tqdm
import torch.distributed as dist
from contextlib import nullcontext
+# from torch.utils.tensorboard import SummaryWriter
+from tensorboardX import SummaryWriter
+from pathlib import Path
from funasr.train_utils.device_funcs import to_device
from funasr.train_utils.recursive_op import recursive_average
-
+from funasr.train_utils.average_nbest_models import average_checkpoints
class Trainer:
"""
@@ -34,6 +37,7 @@
local_rank,
use_ddp=False,
use_fsdp=False,
+ output_dir: str="./",
**kwargs):
"""
Initializes the Trainer class with the model, optimizer, scheduler, dataloaders, and other settings.
@@ -55,7 +59,7 @@
self.scheduler = scheduler
self.dataloader_train = dataloader_train
self.dataloader_val = dataloader_val
- self.output_dir = kwargs.get('output_dir', './')
+ self.output_dir = output_dir
self.resume = kwargs.get('resume', True)
self.start_epoch = 0
self.max_epoch = kwargs.get('max_epoch', 100)
@@ -63,10 +67,9 @@
self.use_ddp = use_ddp
self.use_fsdp = use_fsdp
self.device = next(model.parameters()).device
+ self.avg_nbest_model = kwargs.get("avg_nbest_model", 5)
self.kwargs = kwargs
- if self.resume:
- self._resume_checkpoint(self.resume)
try:
rank = dist.get_rank()
@@ -77,6 +80,10 @@
logging.warning("distributed is not initialized, only single shard")
self.rank = rank
self.world_size = world_size
+
+ os.makedirs(os.path.join(self.output_dir, "tensorboard"), exist_ok=True)
+ self.writer = SummaryWriter(os.path.join(self.output_dir, "tensorboard")) if rank == 0 else None
+
def _save_checkpoint(self, epoch):
"""
@@ -95,9 +102,17 @@
}
# Create output directory if it does not exist
os.makedirs(self.output_dir, exist_ok=True)
- filename = os.path.join(self.output_dir, f'model.e{epoch}.pb')
+ filename = os.path.join(self.output_dir, f'model.pt.ep{epoch}')
torch.save(state, filename)
+
print(f'Checkpoint saved to {filename}')
+ latest = Path(os.path.join(self.output_dir, f'model.pt'))
+ try:
+ latest.unlink()
+ except:
+ pass
+
+ latest.symlink_to(filename)
def _resume_checkpoint(self, resume_path):
"""
@@ -107,27 +122,62 @@
Args:
resume_path (str): The file path to the checkpoint to resume from.
"""
- if os.path.isfile(resume_path):
- checkpoint = torch.load(resume_path)
+ ckpt = os.path.join(resume_path, "model.pt")
+ if os.path.isfile(ckpt):
+ checkpoint = torch.load(ckpt)
self.start_epoch = checkpoint['epoch'] + 1
self.model.load_state_dict(checkpoint['state_dict'])
self.optim.load_state_dict(checkpoint['optimizer'])
self.scheduler.load_state_dict(checkpoint['scheduler'])
- print(f"Checkpoint loaded successfully from '{resume_path}' at (epoch {checkpoint['epoch']})")
+ print(f"Checkpoint loaded successfully from '{ckpt}'")
else:
- print(f"No checkpoint found at '{resume_path}', starting from scratch")
+ print(f"No checkpoint found at '{ckpt}', starting from scratch")
+
+ if self.use_ddp or self.use_fsdp:
+ dist.barrier()
def run(self):
"""
Starts the training process, iterating over epochs, training the model,
and saving checkpoints at the end of each epoch.
"""
+ if self.resume:
+ self._resume_checkpoint(self.output_dir)
+
for epoch in range(self.start_epoch, self.max_epoch + 1):
+
self._train_epoch(epoch)
- # self._validate_epoch(epoch)
+
+
+
+ if self.use_ddp or self.use_fsdp:
+ dist.barrier()
+
+ self._validate_epoch(epoch)
+
+ if self.use_ddp or self.use_fsdp:
+ dist.barrier()
+
+
if self.rank == 0:
self._save_checkpoint(epoch)
+
+ if self.use_ddp or self.use_fsdp:
+ dist.barrier()
+
self.scheduler.step()
+
+
+ if self.rank == 0:
+ average_checkpoints(self.output_dir, self.avg_nbest_model)
+
+ if self.use_ddp or self.use_fsdp:
+ dist.barrier()
+
+
+ if self.writer:
+ self.writer.close()
+
def _train_epoch(self, epoch):
"""
@@ -148,14 +198,31 @@
for batch_idx, batch in enumerate(self.dataloader_train):
time1 = time.perf_counter()
speed_stats["data_load"] = f"{time1-time5:0.3f}"
- # import pdb;
- # pdb.set_trace()
+
batch = to_device(batch, self.device)
my_context = self.model.no_sync if batch_idx % accum_grad != 0 else nullcontext
with my_context():
time2 = time.perf_counter()
+ # print("before, GPU, memory: {:.3f} GB, "
+ # "{:.3f} GB, "
+ # "{:.3f} GB, "
+ # "{:.3f} GB".format(torch.cuda.memory_allocated()/1024/1024/1024,
+ # torch.cuda.max_memory_allocated()/1024/1024/1024,
+ # torch.cuda.memory_reserved()/1024/1024/1024,
+ # torch.cuda.max_memory_reserved()/1024/1024/1024,
+ # ))
+
retval = self.model(**batch)
+ torch.cuda.empty_cache()
+ # print("after, GPU, memory: {:.3f} GB, "
+ # "{:.3f} GB, "
+ # "{:.3f} GB, "
+ # "{:.3f} GB".format(torch.cuda.memory_allocated()/1024/1024/1024,
+ # torch.cuda.max_memory_allocated()/1024/1024/1024,
+ # torch.cuda.memory_reserved()/1024/1024/1024,
+ # torch.cuda.max_memory_reserved()/1024/1024/1024,
+ # ))
time3 = time.perf_counter()
speed_stats["forward_time"] = f"{time3 - time2:0.3f}"
loss, stats, weight = retval
@@ -193,6 +260,8 @@
continue
# Execute an optimization step (update model parameters)
+ if self.use_ddp or self.use_fsdp:
+ dist.barrier()
self.optim.step()
self.scheduler.step()
# Clear gradients for the next accumulation stage
@@ -202,20 +271,37 @@
speed_stats["optim_time"] = f"{time5 - time4:0.3f}"
speed_stats["total_time"] = total_time
-
- # import pdb;
- # pdb.set_trace()
+
+
pbar.update(1)
if self.local_rank == 0:
+ gpu_info = "GPU, memory: {:.3f} GB, " \
+ "{:.3f} GB, "\
+ "{:.3f} GB, "\
+ "{:.3f} GB".format(torch.cuda.memory_allocated()/1024/1024/1024,
+ torch.cuda.max_memory_allocated()/1024/1024/1024,
+ torch.cuda.memory_reserved()/1024/1024/1024,
+ torch.cuda.max_memory_reserved()/1024/1024/1024,
+ )
description = (
- f"Epoch: {epoch + 1}/{self.max_epoch}, "
+ f"Train epoch: {epoch}/{self.max_epoch}, "
f"step {batch_idx}/{len(self.dataloader_train)}, "
f"{speed_stats}, "
f"(loss: {loss.detach().cpu().item():.3f}), "
f"{[(k, round(v.cpu().item(), 3)) for k, v in stats.items()]}"
+ f"{gpu_info}"
)
pbar.set_description(description)
-
+ if self.writer:
+ self.writer.add_scalar('Loss/train', loss.item(),
+ epoch*len(self.dataloader_train) + batch_idx)
+ for key, var in stats.items():
+ self.writer.add_scalar(f'{key}/train', var.item(),
+ epoch * len(self.dataloader_train) + batch_idx)
+ for key, var in speed_stats.items():
+ self.writer.add_scalar(f'{key}/train', eval(var),
+ epoch * len(self.dataloader_train) + batch_idx)
+
# if batch_idx == 2:
# break
pbar.close()
@@ -230,6 +316,50 @@
"""
self.model.eval()
with torch.no_grad():
- for data, target in self.dataloader_val:
- # Implement the model validation steps here
- pass
+ pbar = tqdm(colour="red", desc=f"Training Epoch: {epoch + 1}", total=len(self.dataloader_val),
+ dynamic_ncols=True)
+ speed_stats = {}
+ time5 = time.perf_counter()
+ for batch_idx, batch in enumerate(self.dataloader_val):
+ time1 = time.perf_counter()
+ speed_stats["data_load"] = f"{time1 - time5:0.3f}"
+ batch = to_device(batch, self.device)
+ time2 = time.perf_counter()
+ retval = self.model(**batch)
+ time3 = time.perf_counter()
+ speed_stats["forward_time"] = f"{time3 - time2:0.3f}"
+ loss, stats, weight = retval
+ stats = {k: v for k, v in stats.items() if v is not None}
+ if self.use_ddp or self.use_fsdp:
+ # Apply weighted averaging for loss and stats
+ loss = (loss * weight.type(loss.dtype)).sum()
+ # if distributed, this method can also apply all_reduce()
+ stats, weight = recursive_average(stats, weight, distributed=True)
+ # Now weight is summation over all workers
+ loss /= weight
+ # Multiply world_size because DistributedDataParallel
+ # automatically normalizes the gradient by world_size.
+ loss *= self.world_size
+ # Scale the loss since we're not updating for every mini-batch
+ loss = loss
+ time4 = time.perf_counter()
+
+ pbar.update(1)
+ if self.local_rank == 0:
+ description = (
+ f"validation epoch: {epoch}/{self.max_epoch}, "
+ f"step {batch_idx}/{len(self.dataloader_train)}, "
+ f"{speed_stats}, "
+ f"(loss: {loss.detach().cpu().item():.3f}), "
+ f"{[(k, round(v.cpu().item(), 3)) for k, v in stats.items()]}"
+ )
+ pbar.set_description(description)
+ if self.writer:
+ self.writer.add_scalar('Loss/val', loss.item(),
+ epoch*len(self.dataloader_train) + batch_idx)
+ for key, var in stats.items():
+ self.writer.add_scalar(f'{key}/val', var.item(),
+ epoch * len(self.dataloader_train) + batch_idx)
+ for key, var in speed_stats.items():
+ self.writer.add_scalar(f'{key}/val', eval(var),
+ epoch * len(self.dataloader_train) + batch_idx)
\ No newline at end of file
--
Gitblit v1.9.1