From 3ac03e448b7673604eb86f619b27521fca55f34d Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 21 三月 2024 01:36:39 +0800
Subject: [PATCH] train & finetune llm-asr (#1519)
---
funasr/models/paraformer/model.py | 1
funasr/train_utils/trainer_llm.py | 178 ++++++++++++-------
funasr/schedulers/lambdalr_cus.py | 17 +
funasr/train_utils/trainer.py | 1
funasr/train_utils/average_nbest_models.py | 2
funasr/datasets/llm_datasets_vicuna/samplers.py | 197 +++++++++++++++++++++
funasr/bin/train_llm.py | 78 ++++++--
examples/industrial_data_pretraining/paraformer_streaming/demo.py | 7
8 files changed, 386 insertions(+), 95 deletions(-)
diff --git a/examples/industrial_data_pretraining/paraformer_streaming/demo.py b/examples/industrial_data_pretraining/paraformer_streaming/demo.py
index 9885c0b..601a531 100644
--- a/examples/industrial_data_pretraining/paraformer_streaming/demo.py
+++ b/examples/industrial_data_pretraining/paraformer_streaming/demo.py
@@ -8,20 +8,21 @@
chunk_size = [5, 10, 5] #[0, 10, 5] 600ms, [0, 8, 4] 480ms
encoder_chunk_look_back = 0 #number of chunks to lookback for encoder self-attention
decoder_chunk_look_back = 0 #number of encoder chunks to lookback for decoder cross-attention
-
+wav_file="/Users/zhifu/Downloads/NCYzUhAtZNI_0015.wav"
model = AutoModel(model="iic/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online", model_revision="v2.0.4")
-res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav",
+res = model.generate(input=wav_file,
chunk_size=chunk_size,
encoder_chunk_look_back=encoder_chunk_look_back,
decoder_chunk_look_back=decoder_chunk_look_back,
)
print(res)
+# exit()
import soundfile
import os
-wav_file = os.path.join(model.model_path, "example/asr_example.wav")
+# wav_file = os.path.join(model.model_path, "example/asr_example.wav")
speech, sample_rate = soundfile.read(wav_file)
chunk_stride = chunk_size[1] * 960 # 600ms銆�480ms
diff --git a/funasr/bin/train_llm.py b/funasr/bin/train_llm.py
index 8742bf1..89f5db3 100644
--- a/funasr/bin/train_llm.py
+++ b/funasr/bin/train_llm.py
@@ -4,18 +4,21 @@
import os
import sys
import torch
+import torch.nn as nn
import hydra
import logging
import time
import argparse
from io import BytesIO
+from contextlib import nullcontext
import torch.distributed as dist
from collections.abc import Sequence
from omegaconf import DictConfig, OmegaConf
from torch.cuda.amp import autocast, GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
+from torch.distributed.algorithms.join import Join
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
from funasr.train_utils.average_nbest_models import average_checkpoints
@@ -48,7 +51,6 @@
def main(**kwargs):
- print(kwargs)
# set random seed
set_all_random_seed(kwargs.get("seed", 0))
@@ -61,11 +63,13 @@
tables.print()
# Check if we are using DDP or FSDP
use_ddp = 'WORLD_SIZE' in os.environ and int(os.environ["WORLD_SIZE"]) > 1
- use_fsdp = kwargs.get("use_fsdp", None)
+ use_fsdp = kwargs.get("use_fsdp", False)
+ # use_ddp = False if use_fsdp else use_fsdp
if use_ddp or use_fsdp:
dist.init_process_group(backend=kwargs.get("backend", "nccl"), init_method='env://')
torch.cuda.set_device(local_rank)
-
+
+ logging.info("Build model, frontend, tokenizer")
device = kwargs.get("device", "cuda")
kwargs["device"] = "cpu"
model = AutoModel(**kwargs)
@@ -76,6 +80,7 @@
os.makedirs(kwargs.get("output_dir", "./"), exist_ok=True)
yaml_file = os.path.join(kwargs.get("output_dir", "./"), "config.yaml")
OmegaConf.save(config=kwargs, f=yaml_file)
+ print(kwargs)
logging.info("config.yaml is saved to: %s", yaml_file)
# parse kwargs
@@ -105,19 +110,42 @@
model = DDP(model, device_ids=[local_rank],
find_unused_parameters=kwargs.get("train_conf", {}).get("find_unused_parameters", False))
elif use_fsdp:
- model = FSDP(model).cuda(local_rank)
+ # model = FSDP(model).cuda(local_rank)
+
+ def custom_auto_wrap_policy(
+ module: nn.Module,
+ recurse: bool,
+ nonwrapped_numel: int,
+ # Additional custom arguments
+ min_num_params: int = int(1e8),
+ ) -> bool:
+ # 鏍规嵁鑷畾涔夐�昏緫鍐冲畾鏄惁鍖呰妯″潡
+ is_large = unwrapped_params >= min_num_params
+ requires_grad_uniform = len({p.requires_grad for p in module.parameters()}) == 1
+ return is_large and requires_grad_uniform
+
+ # Configure a custom `min_num_params`
+ my_auto_wrap_policy = functools.partial(custom_auto_wrap_policy, min_num_params=int(1e5))
+ torch.cuda.set_device(local_rank)
+ model = FSDP(model,
+ auto_wrap_policy=custom_auto_wrap_policy,
+ mixed_precision=None,
+ device_id=torch.cuda.current_device())
else:
model = model.to(device=kwargs.get("device", "cuda"))
+ logging.info(f"{model}")
kwargs["device"] = next(model.parameters()).device
# optim
+ logging.info("Build optim")
optim = kwargs.get("optim", "adam")
assert optim in optim_classes
optim_class = optim_classes.get(optim)
optim = optim_class(model.parameters(), **kwargs.get("optim_conf"))
# scheduler
+ logging.info("Build scheduler")
scheduler = kwargs.get("scheduler", "warmuplr")
assert scheduler in scheduler_classes
scheduler_class = scheduler_classes.get(scheduler)
@@ -125,6 +153,7 @@
# dataset
+ logging.info("Build dataloader")
dataset_class = tables.dataset_classes.get(kwargs.get("dataset", "AudioDataset"))
dataset_tr = dataset_class(kwargs.get("train_data_set_list"), frontend=frontend, tokenizer=tokenizer, is_training=True, **kwargs.get("dataset_conf"))
dataset_val = dataset_class(kwargs.get("valid_data_set_list"), frontend=frontend, tokenizer=tokenizer, is_training=False, **kwargs.get("dataset_conf"))
@@ -142,8 +171,9 @@
trainer = Trainer(local_rank=local_rank,
use_ddp=use_ddp,
- resume=kwargs.get("resume", True),
+ use_fsdp=use_fsdp,
device=kwargs["device"],
+ output_dir=kwargs.get("output_dir", "./exp"),
**kwargs.get("train_conf"),
)
@@ -159,20 +189,27 @@
writer = SummaryWriter(tensorboard_dir) if trainer.rank == 0 else None
except:
writer = None
-
+
+ if use_ddp or use_fsdp:
+ context = Join([model])
+ else:
+ context = nullcontext()
+
for epoch in range(trainer.start_epoch, trainer.max_epoch + 1):
time1 = time.perf_counter()
- trainer.train_epoch(
- model=model,
- optim=optim,
- scheduler=scheduler,
- scaler=scaler,
- dataloader_train=dataloader_tr,
- dataloader_val=dataloader_val,
- epoch=epoch,
- writer=writer
- )
-
+ with context:
+
+ trainer.train_epoch(
+ model=model,
+ optim=optim,
+ scheduler=scheduler,
+ scaler=scaler,
+ dataloader_train=dataloader_tr,
+ dataloader_val=dataloader_val,
+ epoch=epoch,
+ writer=writer
+ )
+ scheduler.step()
trainer.validate_epoch(
model=model,
dataloader_val=dataloader_val,
@@ -180,21 +217,20 @@
writer=writer
)
+
trainer.save_checkpoint(epoch, model=model, optim=optim, scheduler=scheduler, scaler=scaler)
-
- scheduler.step()
time2 = time.perf_counter()
time_escaped = (time2 - time1) / 3600.0
logging.info(
- f"\nrank: {local_rank}, "
+ f"rank: {local_rank}, "
f"time_escaped_epoch: {time_escaped:.3f} hours, "
f"estimated to finish {trainer.max_epoch} "
f"epoch: {(trainer.max_epoch - epoch) * time_escaped:.3f} hours\n")
if trainer.rank == 0:
- average_checkpoints(trainer.output_dir, trainer.avg_nbest_model)
+ average_checkpoints(trainer.output_dir, trainer.avg_nbest_model, trainer.val_acc_list)
trainer.close()
diff --git a/funasr/datasets/llm_datasets_vicuna/samplers.py b/funasr/datasets/llm_datasets_vicuna/samplers.py
index c728d9c..61f7d94 100644
--- a/funasr/datasets/llm_datasets_vicuna/samplers.py
+++ b/funasr/datasets/llm_datasets_vicuna/samplers.py
@@ -232,3 +232,200 @@
def set_epoch(self, epoch):
self.epoch = epoch
+
+
+@tables.register("batch_sampler_classes", "CustomDistributedBufferBatchSampler_fn")
+def CustomDistributedBatchSampler_fn(dataset, **kwargs):
+ dataloader_args = {}
+ dataloader_args["batch_sampler"] = CustomDistributedBufferBatchSampler(dataset, **kwargs)
+ dataloader_args["num_workers"] = kwargs.get("num_workers", 4)
+ dataloader_args["pin_memory"] = kwargs.get("pin_memory", True)
+
+ return dataloader_args
+
+
+@tables.register("batch_sampler_classes", "CustomDistributedBufferBatchSampler")
+class CustomDistributedBatchSampler(Sampler):
+ def __init__(self, dataset,
+ batch_size,
+ num_replicas=None,
+ rank=None,
+ shuffle=True,
+ drop_last=False,
+ is_training: bool = True,
+ sort_size: int = 1024,
+ **kwargs,
+ ):
+
+ try:
+ rank = dist.get_rank()
+ num_replicas = dist.get_world_size()
+ except:
+ rank = 0
+ num_replicas = 1
+ self.rank = rank
+ self.num_replicas = num_replicas
+ self.dataset = dataset
+ self.batch_size = batch_size
+ self.is_training = is_training
+ self.shuffle = shuffle and is_training
+ self.drop_last = drop_last
+ # self.total_size = len(dataset)
+ if self.drop_last:
+ self.total_size = (len(self.dataset) // (batch_size * num_replicas)) * (batch_size * num_replicas)
+ else:
+ self.total_size = math.ceil(len(self.dataset) / (batch_size * num_replicas)) * (batch_size * num_replicas)
+ self.num_samples = int(self.total_size // self.num_replicas)
+ self.epoch = 0
+ self.max_token_length = kwargs.get("max_token_length", None)
+ self.length_scale_source = kwargs.get("length_scale_source", 1.0)
+ self.sort_size = sort_size
+
+ def __iter__(self):
+ # Generate a list of indices
+ if self.shuffle:
+ g = torch.Generator()
+ g.manual_seed(self.epoch)
+ indices = torch.randperm(len(self.dataset), generator=g).tolist()
+ else:
+ indices = list(range(len(self.dataset)))
+
+ # Add extra samples to make it evenly divisible
+ padding_size = self.total_size - len(indices)
+ if padding_size <= len(indices):
+ indices += indices[:padding_size]
+ else:
+ indices += (indices * (padding_size // len(indices)) + indices[:padding_size % len(indices)])
+
+ assert len(indices) == self.total_size
+
+ # Subsample
+ indices = indices[self.rank:self.total_size:self.num_replicas]
+ assert len(indices) == self.num_samples
+
+ # Filter out indices with length greater than the max length, if provided
+ if self.max_token_length is not None:
+ filtered_indices = []
+ for idx in indices:
+ source_len = self.dataset.get_source_len(idx) / self.length_scale_source
+ if source_len <= self.max_token_length:
+ filtered_indices.append(idx)
+ indices = filtered_indices
+
+ # Buffer sorting logic
+ sorted_batches = []
+ buffer = []
+
+ for idx in indices:
+ buffer.append(idx)
+ if len(buffer) >= self.sort_size:
+ # Sort the buffer based on some criteria, e.g., dataset sample length
+ buffer.sort(key=lambda x: self.dataset.get_source_len(x))
+ sorted_batches.extend(self._create_batches_from_buffer(buffer))
+ buffer = []
+
+ # Handle the remaining items in the buffer
+ if buffer:
+ buffer.sort(key=lambda x: self.dataset.get_source_len(x))
+ sorted_batches.extend(self._create_batches_from_buffer(buffer))
+
+ return iter(sorted_batches)
+
+ def _create_batches_from_buffer(self, buffer):
+ # Function to convert the sorted buffer into batches
+ batched_buffer = [buffer[i:i + self.batch_size] for i in range(0, len(buffer), self.batch_size)]
+ if self.drop_last and len(batched_buffer[-1]) != self.batch_size:
+ batched_buffer = batched_buffer[:-1]
+ return batched_buffer
+
+ def __len__(self):
+
+ return self.num_samples // self.batch_size
+
+ def set_epoch(self, epoch):
+ self.epoch = epoch
+
+
+@tables.register("batch_sampler_classes", "CustomDistributedDynamicBatchSampler_fn")
+def CustomDistributedBatchSampler_fn(dataset, **kwargs):
+ dataloader_args = {}
+ dataloader_args["batch_sampler"] = CustomDistributedDynamicBatchSampler(dataset, **kwargs)
+ dataloader_args["num_workers"] = kwargs.get("num_workers", 4)
+ dataloader_args["pin_memory"] = kwargs.get("pin_memory", True)
+
+ return dataloader_args
+
+
+@tables.register("batch_sampler_classes", "CustomDistributedDynamicBatchSampler")
+class CustomDistributedDynamicBatchSampler(Sampler):
+ def __init__(self, dataset,
+ batch_size,
+ num_replicas=None,
+ rank=None,
+ shuffle=True,
+ drop_last=False,
+ is_training: bool = True,
+ **kwargs,
+ ):
+
+ try:
+ rank = dist.get_rank()
+ num_replicas = dist.get_world_size()
+ except:
+ rank = 0
+ num_replicas = 1
+ self.rank = rank
+ self.num_replicas = num_replicas
+ self.dataset = dataset
+ self.batch_size = batch_size
+ self.is_training = is_training
+ self.shuffle = shuffle and is_training
+ self.drop_last = drop_last
+
+ self.total_size = len(self.dataset)
+ # self.num_samples = int(math.ceil(self.total_size / self.num_replicas))
+ self.epoch = 0
+
+ def __iter__(self):
+ if self.shuffle:
+ g = torch.Generator()
+ g.manual_seed(self.epoch)
+ indices = torch.randperm(len(self.dataset), generator=g).tolist()
+ else:
+ indices = list(range(len(self.dataset)))
+
+ indices = indices[self.rank:self.total_size:self.num_replicas]
+
+ batches = []
+ batch = []
+ max_len_in_batch = 0
+ current_batch_length = 0
+
+ for idx in indices:
+ sample_length = self.dataset.get_source_len(idx)
+ potential_batch_length = (max_len_in_batch if sample_length < max_len_in_batch else sample_length) * (
+ len(batch) + 1)
+
+ if potential_batch_length <= self.batch_size:
+ batch.append(idx)
+ if sample_length > max_len_in_batch:
+ max_len_in_batch = sample_length
+ current_batch_length = max_len_in_batch * len(batch)
+ else:
+ batches.append(batch)
+ batch = [idx]
+ max_len_in_batch = sample_length
+ current_batch_length = max_len_in_batch
+
+ # Add the last batch if it's not empty and we're not dropping it
+ if batch and (not self.drop_last or len(batch) * max_len_in_batch == self.batch_size):
+ batches.append(batch)
+
+ return iter(batches)
+
+ def __len__(self):
+
+ return -1
+
+ def set_epoch(self, epoch):
+ self.epoch = epoch
diff --git a/funasr/models/paraformer/model.py b/funasr/models/paraformer/model.py
index 316255d..bd85df0 100644
--- a/funasr/models/paraformer/model.py
+++ b/funasr/models/paraformer/model.py
@@ -231,6 +231,7 @@
stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
stats["loss"] = torch.clone(loss.detach())
+ stats["batch_size"] = batch_size
# force_gatherable: to-device and to-tensor if scalar for DataParallel
if self.length_normalized_loss:
diff --git a/funasr/schedulers/lambdalr_cus.py b/funasr/schedulers/lambdalr_cus.py
index 0123cc2..5aad049 100644
--- a/funasr/schedulers/lambdalr_cus.py
+++ b/funasr/schedulers/lambdalr_cus.py
@@ -14,4 +14,19 @@
for base_lr in self.base_lrs
]
else:
- return [base_lr for base_lr in self.base_lrs]
\ No newline at end of file
+ return [base_lr for base_lr in self.base_lrs]
+
+
+class CustomLambdaLR(_LRScheduler):
+ def __init__(self, optimizer, train_config, last_epoch=-1, verbose=False):
+ self.warmup_steps = train_config.warmup_steps
+ self.total_steps = train_config.total_steps
+ super(CustomLambdaLR, self).__init__(optimizer, last_epoch, verbose)
+
+ def get_lr(self):
+ step = self._step_count
+ if step < self.warmup_steps:
+ lr_scale = step / self.warmup_steps
+ else:
+ lr_scale = max(0.0, 1 - (step - self.warmup_steps) / (self.total_steps - self.warmup_steps))
+ return [base_lr * lr_scale for base_lr in self.base_lrs]
diff --git a/funasr/train_utils/average_nbest_models.py b/funasr/train_utils/average_nbest_models.py
index f117804..3603a44 100644
--- a/funasr/train_utils/average_nbest_models.py
+++ b/funasr/train_utils/average_nbest_models.py
@@ -143,7 +143,7 @@
return checkpoint_paths
@torch.no_grad()
-def average_checkpoints(output_dir: str, last_n: int=5):
+def average_checkpoints(output_dir: str, last_n: int=5, val_acc_list=[]):
"""
Average the last 'last_n' checkpoints' model state_dicts.
If a tensor is of type torch.int, perform sum instead of average.
diff --git a/funasr/train_utils/trainer.py b/funasr/train_utils/trainer.py
index 14abd6c..aae4513 100644
--- a/funasr/train_utils/trainer.py
+++ b/funasr/train_utils/trainer.py
@@ -103,6 +103,7 @@
os.makedirs(os.path.join(self.output_dir, "tensorboard"), exist_ok=True)
self.writer = SummaryWriter(os.path.join(self.output_dir, "tensorboard")) if rank == 0 else None
+
def _save_checkpoint(self, epoch, step=None):
diff --git a/funasr/train_utils/trainer_llm.py b/funasr/train_utils/trainer_llm.py
index 6a3b83b..5f13b5a 100644
--- a/funasr/train_utils/trainer_llm.py
+++ b/funasr/train_utils/trainer_llm.py
@@ -1,3 +1,4 @@
+import math
import os
import time
import torch
@@ -61,6 +62,8 @@
"""
self.output_dir = output_dir
+ if not os.path.exists(self.output_dir):
+ os.makedirs(self.output_dir, exist_ok=True)
self.resume = kwargs.get('resume', True)
self.start_epoch = 0
self.max_epoch = kwargs.get('max_epoch', 100)
@@ -78,6 +81,7 @@
# scaler = ShardedGradScaler(enabled=use_fp16) if use_fsdp else scaler
# self.scaler = scaler
self.save_checkpoint_interval = kwargs.get("save_checkpoint_interval", 5000)
+ self.keep_nbest_models = kwargs.get("keep_nbest_models", -1)
self.accum_grad = kwargs.get("accum_grad", 1)
self.grad_clip = kwargs.get("grad_clip", 10.0)
self.grad_clip_type = kwargs.get("grad_clip_type", 2.0)
@@ -93,6 +97,15 @@
logging.warning("distributed is not initialized, only single shard")
self.rank = rank
self.world_size = world_size
+ self.train_acc_avg = 0.0
+ self.train_loss_avg = 0.0
+ self.val_acc_avg = 0.0
+ self.val_loss_avg = 0.0
+ self.best_acc_idx = 0
+ self.saved_ckpts = {}
+ self.val_acc_list = []
+ self.step_or_epoch = -1
+
@@ -112,28 +125,56 @@
Args:
epoch (int): The epoch number at which the checkpoint is being saved.
"""
+
if self.rank == 0:
+ logging.info(f"Save checkpoint: {epoch}, rank: {self.local_rank}\n")
+ self.step_or_epoch += 1
state = {
'epoch': epoch,
'state_dict': model.state_dict(),
'optimizer': optim.state_dict(),
'scheduler': scheduler.state_dict(),
+ "acc": self.val_acc_list,
+ "step_or_epoch": self.step_or_epoch,
}
+ if hasattr(model, "module"):
+ state["state_dict"] = model.module.state_dict()
+
if scaler:
state["scaler_state"] = scaler.state_dict()
# Create output directory if it does not exist
os.makedirs(self.output_dir, exist_ok=True)
if step is None:
- filename = os.path.join(self.output_dir, f'model.pt.ep{epoch}')
+ ckpt_name = f'model.pt.ep{epoch}'
else:
- filename = os.path.join(self.output_dir, f'model.pt.ep{epoch}.{step}')
-
+ ckpt_name = f'model.pt.ep{epoch}.{step}'
+ filename = os.path.join(self.output_dir, ckpt_name)
torch.save(state, filename)
- print(f'\nCheckpoint saved to {filename}\n')
+ logging.info(f'\nCheckpoint saved to {filename}\n')
latest = Path(os.path.join(self.output_dir, f'model.pt'))
torch.save(state, latest)
-
+
+ if self.val_acc_list[self.step_or_epoch] >= self.val_acc_list[self.best_acc_idx]:
+ self.best_acc_idx = self.step_or_epoch
+ best_ckpt = Path(os.path.join(self.output_dir, f'model.pt.best'))
+ torch.save(state, best_ckpt)
+ logging.info(f"Update best acc: {self.val_acc_list[self.best_acc_idx]}, {best_ckpt}")
+ else:
+ logging.info(f"No improvement in acc: {self.val_acc_list[self.best_acc_idx]}")
+
+ if self.keep_nbest_models > 0:
+ self.saved_ckpts[ckpt_name] = self.val_acc_list[-1]
+ if len(self.saved_ckpts) > self.keep_nbest_models:
+
+ min_key = min(self.saved_ckpts, key=self.saved_ckpts.get)
+ if min_key in self.saved_ckpts:
+ del self.saved_ckpts[min_key]
+ filename = os.path.join(self.output_dir, min_key)
+ logging.info(f"Delete: {filename}")
+ if os.path.exists(filename):
+ os.remove(filename)
+
if self.use_ddp or self.use_fsdp:
dist.barrier()
@@ -173,6 +214,10 @@
scheduler.load_state_dict(checkpoint['scheduler'])
if scaler is not None and 'scaler_state' in checkpoint:
scaler.load_state_dict(checkpoint['scaler_state'])
+
+ self.val_acc_list = checkpoint["acc"]
+ self.step_or_epoch = checkpoint["step_or_epoch"]
+
print(f"Checkpoint loaded successfully from '{ckpt}'")
else:
print(f"No checkpoint found at '{ckpt}', does not resume status!")
@@ -180,52 +225,7 @@
if self.use_ddp or self.use_fsdp:
dist.barrier()
- # def train(self):
- # """
- # Starts the training process, iterating over epochs, training the model,
- # and saving checkpoints at the end of each epoch.
- # """
- # if self.resume:
- # self.resume_checkpoint(self.output_dir)
- #
- # for epoch in range(self.start_epoch, self.max_epoch + 1):
- # time1 = time.perf_counter()
- # self.train_epoch(epoch)
- #
- #
- #
- # if self.use_ddp or self.use_fsdp:
- # dist.barrier()
- #
- # self._validate_epoch(epoch)
- #
- # if self.use_ddp or self.use_fsdp:
- # dist.barrier()
- #
- #
- # if self.rank == 0:
- # self._save_checkpoint(epoch)
- #
- # if self.use_ddp or self.use_fsdp:
- # dist.barrier()
- #
- # self.scheduler.step()
- #
- # time2 = time.perf_counter()
- # time_escaped = (time2 - time1)/3600.0
- # print(f"\nrank: {self.local_rank}, time_escaped_epoch: {time_escaped:.3f} hours, estimated to finish {self.max_epoch} epoch: {(self.max_epoch-epoch)*time_escaped:.3f} hours\n")
- #
- # if self.rank == 0:
- # average_checkpoints(self.output_dir, self.avg_nbest_model)
- #
- # if self.use_ddp or self.use_fsdp:
- # dist.barrier()
- #
- #
- # if writer:
- # writer.close()
- #
-
+
def train_epoch(self,
model=None,
optim=None,
@@ -241,9 +241,9 @@
Args:
epoch (int): The current epoch number.
"""
+ logging.info(f"Train epoch: {epoch}, rank: {self.local_rank}\n")
model.train()
-
# Set the number of steps for gradient accumulation
accum_grad = self.accum_grad
# Initialize the gradient accumulation
@@ -288,6 +288,18 @@
loss.backward()
time4 = time.perf_counter()
speed_stats["backward_time"] = f"{time4 - time3:0.3f}"
+
+ self.train_loss_avg = (self.train_loss_avg*batch_idx + loss.detach().cpu().item())/(batch_idx+1)
+ if "acc" in stats:
+ self.train_acc_avg = (self.train_acc_avg * batch_idx + stats["acc"].detach().cpu().item()) / (batch_idx + 1)
+ if self.use_ddp or self.use_fsdp:
+ train_loss_avg = torch.tensor(self.train_loss_avg, dtype=torch.float32).to(self.device)
+ train_acc_avg = torch.tensor(self.train_acc_avg, dtype=torch.float32).to(self.device)
+ dist.all_reduce(train_loss_avg, op=dist.ReduceOp.SUM)
+ dist.all_reduce(train_acc_avg, op=dist.ReduceOp.SUM)
+ self.train_loss_avg = train_loss_avg.detach().cpu().item() / self.world_size
+ self.train_acc_avg = train_acc_avg.detach().cpu().item() / self.world_size
+
# Perform an optimizer step only after accumulating enough gradients
if (batch_idx + 1) % accum_grad == 0:
@@ -322,9 +334,11 @@
speed_stats["total_time"] = total_time
lr = scheduler.get_last_lr()[0]
-
+ batch_num_epoch = -1
+ if hasattr(dataloader_train, "__len__"):
+ batch_num_epoch = len(dataloader_train)
self.log(epoch, batch_idx,
- batch_num_epoch=len(dataloader_train),
+ batch_num_epoch=batch_num_epoch,
lr=lr,
loss=loss.detach().cpu().item(),
speed_stats=speed_stats,
@@ -341,7 +355,7 @@
writer=writer
)
- if (batch_idx+1) % self.save_checkpoint_interval == 0 and self.rank == 0:
+ if (batch_idx+1) % self.save_checkpoint_interval == 0:
self.save_checkpoint(epoch, model=model, optim=optim, scheduler=scheduler, scaler=scaler, step=batch_idx+1)
@@ -364,6 +378,7 @@
Args:
epoch (int): The current epoch number.
"""
+ logging.info(f"Validate epoch: {epoch}, rank: {self.local_rank}\n")
model.eval()
with torch.no_grad():
@@ -394,18 +409,35 @@
loss = loss
time4 = time.perf_counter()
+ self.val_loss_avg = (self.val_loss_avg*batch_idx + loss.detach().cpu().item())/(batch_idx+1)
+ if "acc" in stats:
+ self.val_acc_avg = (self.val_acc_avg * batch_idx + stats["acc"].detach().cpu().item()) / (batch_idx + 1)
+ if self.use_ddp or self.use_fsdp:
+ val_loss_avg = torch.tensor(self.val_loss_avg, dtype=torch.float32).to(self.device)
+ val_acc_avg = torch.tensor(self.val_acc_avg, dtype=torch.float32).to(self.device)
+ dist.all_reduce(val_loss_avg, op=dist.ReduceOp.SUM)
+ dist.all_reduce(val_acc_avg, op=dist.ReduceOp.SUM)
+ self.val_loss_avg = val_loss_avg.detach().cpu().item() / self.world_size
+ self.val_acc_avg = val_acc_avg.detach().cpu().item() / self.world_size
+ batch_num_epoch = -1
+ if hasattr(dataloader_val, "__len__"):
+ batch_num_epoch = len(dataloader_val)
self.log(epoch, batch_idx,
- batch_num_epoch=len(dataloader_val),
+ batch_num_epoch=batch_num_epoch,
lr=0.0,
loss=loss.detach().cpu().item(),
speed_stats=speed_stats,
stats=stats,
writer=writer,
- tag="train",
+ tag="val",
)
+ self.val_acc_list.append(self.val_acc_avg)
model.train()
+
+ if self.use_ddp or self.use_fsdp:
+ dist.barrier()
def log(self,
@@ -422,39 +454,47 @@
if (batch_idx + 1) % self.log_interval == 0:
- gpu_info = "GPU, memory: {:.3f} GB, " \
- "{:.3f} GB, " \
- "{:.3f} GB, " \
- "{:.3f} GB".format(torch.cuda.memory_allocated() / 1024 / 1024 / 1024,
+ gpu_info = "GPU, memory: usage: {:.3f} GB, " \
+ "peak: {:.3f} GB, " \
+ "cache: {:.3f} GB, " \
+ "cache_peak: {:.3f} GB".format(torch.cuda.memory_allocated() / 1024 / 1024 / 1024,
torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024,
torch.cuda.memory_reserved() / 1024 / 1024 / 1024,
torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024,
)
- time_now = datetime.now()
- time_now = time_now.strftime("%Y-%m-%d %H:%M:%S")
+ loss_avg_epoch = getattr(self, f"{tag}_loss_avg")
+ acc_avg_epoch = getattr(self, f"{tag}_acc_avg")
description = (
- f"{time_now}, "
+ f"{tag}, "
f"rank: {self.local_rank}, "
f"epoch: {epoch}/{self.max_epoch}, "
f"step: {batch_idx + 1}/{batch_num_epoch}, total step: {self.batch_total}, "
- f"(loss: {loss:.3f}), "
+ f"(loss_avg_rank: {loss:.3f}), "
+ f"(loss_avg_epoch: {loss_avg_epoch:.3f}), "
+ f"(ppl_avg_epoch: {math.exp(loss_avg_epoch):.3f}), "
+ f"(acc_avg_epoch: {acc_avg_epoch:.3f}), "
f"(lr: {lr:.3e}), "
- f"{[(k, round(v.cpu().item(), 3)) for k, v in stats.items()]}, "
+ f"{[(k, round(v.detach().cpu().item(), 3)) for k, v in stats.items()]}, "
f"{speed_stats}, "
f"{gpu_info}"
)
logging.info(description)
if writer is not None:
- writer.add_scalar(f'rank{self.local_rank}_Loss/{tag}', loss, self.batch_total)
+ writer.add_scalar(f'rank{self.local_rank}_loss/{tag}', loss, self.batch_total)
+ writer.add_scalar(f'rank{self.local_rank}_lr/{tag}', lr, self.batch_total)
writer.add_scalar(f'rank{self.local_rank}_lr/{tag}', lr, self.batch_total)
for key, var in stats.items():
- writer.add_scalar(f'rank{self.local_rank}_{key}/{tag}', var.item(), self.batch_total)
+ writer.add_scalar(f'stats_rank{self.local_rank}_{key}/{tag}', var.item(), self.batch_total)
for key, var in speed_stats.items():
- writer.add_scalar(f'rank{self.local_rank}_{key}/{tag}', eval(var), self.batch_total)
+ writer.add_scalar(f'stats_rank{self.local_rank}_{key}/{tag}', eval(var), self.batch_total)
def close(self, writer=None):
+
+ if self.use_ddp or self.use_fsdp:
+ dist.barrier()
+
if writer is not None:
writer.close()
--
Gitblit v1.9.1