From 3ac03e448b7673604eb86f619b27521fca55f34d Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 21 三月 2024 01:36:39 +0800
Subject: [PATCH] train & finetune llm-asr (#1519)
---
funasr/bin/train_llm.py | 78 ++++++++++++++++++++++++++++----------
1 files changed, 57 insertions(+), 21 deletions(-)
diff --git a/funasr/bin/train_llm.py b/funasr/bin/train_llm.py
index 8742bf1..89f5db3 100644
--- a/funasr/bin/train_llm.py
+++ b/funasr/bin/train_llm.py
@@ -4,18 +4,21 @@
import os
import sys
import torch
+import torch.nn as nn
import hydra
import logging
import time
import argparse
from io import BytesIO
+from contextlib import nullcontext
import torch.distributed as dist
from collections.abc import Sequence
from omegaconf import DictConfig, OmegaConf
from torch.cuda.amp import autocast, GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
+from torch.distributed.algorithms.join import Join
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
from funasr.train_utils.average_nbest_models import average_checkpoints
@@ -48,7 +51,6 @@
def main(**kwargs):
- print(kwargs)
# set random seed
set_all_random_seed(kwargs.get("seed", 0))
@@ -61,11 +63,13 @@
tables.print()
# Check if we are using DDP or FSDP
use_ddp = 'WORLD_SIZE' in os.environ and int(os.environ["WORLD_SIZE"]) > 1
- use_fsdp = kwargs.get("use_fsdp", None)
+ use_fsdp = kwargs.get("use_fsdp", False)
+ # use_ddp = False if use_fsdp else use_fsdp
if use_ddp or use_fsdp:
dist.init_process_group(backend=kwargs.get("backend", "nccl"), init_method='env://')
torch.cuda.set_device(local_rank)
-
+
+ logging.info("Build model, frontend, tokenizer")
device = kwargs.get("device", "cuda")
kwargs["device"] = "cpu"
model = AutoModel(**kwargs)
@@ -76,6 +80,7 @@
os.makedirs(kwargs.get("output_dir", "./"), exist_ok=True)
yaml_file = os.path.join(kwargs.get("output_dir", "./"), "config.yaml")
OmegaConf.save(config=kwargs, f=yaml_file)
+ print(kwargs)
logging.info("config.yaml is saved to: %s", yaml_file)
# parse kwargs
@@ -105,19 +110,42 @@
model = DDP(model, device_ids=[local_rank],
find_unused_parameters=kwargs.get("train_conf", {}).get("find_unused_parameters", False))
elif use_fsdp:
- model = FSDP(model).cuda(local_rank)
+ # model = FSDP(model).cuda(local_rank)
+
+ def custom_auto_wrap_policy(
+ module: nn.Module,
+ recurse: bool,
+ nonwrapped_numel: int,
+ # Additional custom arguments
+ min_num_params: int = int(1e8),
+ ) -> bool:
+ # 鏍规嵁鑷畾涔夐�昏緫鍐冲畾鏄惁鍖呰妯″潡
+ is_large = unwrapped_params >= min_num_params
+ requires_grad_uniform = len({p.requires_grad for p in module.parameters()}) == 1
+ return is_large and requires_grad_uniform
+
+ # Configure a custom `min_num_params`
+ my_auto_wrap_policy = functools.partial(custom_auto_wrap_policy, min_num_params=int(1e5))
+ torch.cuda.set_device(local_rank)
+ model = FSDP(model,
+ auto_wrap_policy=custom_auto_wrap_policy,
+ mixed_precision=None,
+ device_id=torch.cuda.current_device())
else:
model = model.to(device=kwargs.get("device", "cuda"))
+ logging.info(f"{model}")
kwargs["device"] = next(model.parameters()).device
# optim
+ logging.info("Build optim")
optim = kwargs.get("optim", "adam")
assert optim in optim_classes
optim_class = optim_classes.get(optim)
optim = optim_class(model.parameters(), **kwargs.get("optim_conf"))
# scheduler
+ logging.info("Build scheduler")
scheduler = kwargs.get("scheduler", "warmuplr")
assert scheduler in scheduler_classes
scheduler_class = scheduler_classes.get(scheduler)
@@ -125,6 +153,7 @@
# dataset
+ logging.info("Build dataloader")
dataset_class = tables.dataset_classes.get(kwargs.get("dataset", "AudioDataset"))
dataset_tr = dataset_class(kwargs.get("train_data_set_list"), frontend=frontend, tokenizer=tokenizer, is_training=True, **kwargs.get("dataset_conf"))
dataset_val = dataset_class(kwargs.get("valid_data_set_list"), frontend=frontend, tokenizer=tokenizer, is_training=False, **kwargs.get("dataset_conf"))
@@ -142,8 +171,9 @@
trainer = Trainer(local_rank=local_rank,
use_ddp=use_ddp,
- resume=kwargs.get("resume", True),
+ use_fsdp=use_fsdp,
device=kwargs["device"],
+ output_dir=kwargs.get("output_dir", "./exp"),
**kwargs.get("train_conf"),
)
@@ -159,20 +189,27 @@
writer = SummaryWriter(tensorboard_dir) if trainer.rank == 0 else None
except:
writer = None
-
+
+ if use_ddp or use_fsdp:
+ context = Join([model])
+ else:
+ context = nullcontext()
+
for epoch in range(trainer.start_epoch, trainer.max_epoch + 1):
time1 = time.perf_counter()
- trainer.train_epoch(
- model=model,
- optim=optim,
- scheduler=scheduler,
- scaler=scaler,
- dataloader_train=dataloader_tr,
- dataloader_val=dataloader_val,
- epoch=epoch,
- writer=writer
- )
-
+ with context:
+
+ trainer.train_epoch(
+ model=model,
+ optim=optim,
+ scheduler=scheduler,
+ scaler=scaler,
+ dataloader_train=dataloader_tr,
+ dataloader_val=dataloader_val,
+ epoch=epoch,
+ writer=writer
+ )
+ scheduler.step()
trainer.validate_epoch(
model=model,
dataloader_val=dataloader_val,
@@ -180,21 +217,20 @@
writer=writer
)
+
trainer.save_checkpoint(epoch, model=model, optim=optim, scheduler=scheduler, scaler=scaler)
-
- scheduler.step()
time2 = time.perf_counter()
time_escaped = (time2 - time1) / 3600.0
logging.info(
- f"\nrank: {local_rank}, "
+ f"rank: {local_rank}, "
f"time_escaped_epoch: {time_escaped:.3f} hours, "
f"estimated to finish {trainer.max_epoch} "
f"epoch: {(trainer.max_epoch - epoch) * time_escaped:.3f} hours\n")
if trainer.rank == 0:
- average_checkpoints(trainer.output_dir, trainer.avg_nbest_model)
+ average_checkpoints(trainer.output_dir, trainer.avg_nbest_model, trainer.val_acc_list)
trainer.close()
--
Gitblit v1.9.1