From 3ac03e448b7673604eb86f619b27521fca55f34d Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 21 三月 2024 01:36:39 +0800
Subject: [PATCH] train & finetune llm-asr (#1519)

---
 funasr/bin/train_llm.py |   78 ++++++++++++++++++++++++++++----------
 1 files changed, 57 insertions(+), 21 deletions(-)

diff --git a/funasr/bin/train_llm.py b/funasr/bin/train_llm.py
index 8742bf1..89f5db3 100644
--- a/funasr/bin/train_llm.py
+++ b/funasr/bin/train_llm.py
@@ -4,18 +4,21 @@
 import os
 import sys
 import torch
+import torch.nn as nn
 import hydra
 import logging
 import time
 import argparse
 from io import BytesIO
 
+from contextlib import nullcontext
 import torch.distributed as dist
 from collections.abc import Sequence
 from omegaconf import DictConfig, OmegaConf
 from torch.cuda.amp import autocast, GradScaler
 from torch.nn.parallel import DistributedDataParallel as DDP
 from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
+from torch.distributed.algorithms.join import Join
 from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
 from funasr.train_utils.average_nbest_models import average_checkpoints
 
@@ -48,7 +51,6 @@
 
 
 def main(**kwargs):
-    print(kwargs)
     
     # set random seed
     set_all_random_seed(kwargs.get("seed", 0))
@@ -61,11 +63,13 @@
         tables.print()
     # Check if we are using DDP or FSDP
     use_ddp = 'WORLD_SIZE' in os.environ and int(os.environ["WORLD_SIZE"]) > 1
-    use_fsdp = kwargs.get("use_fsdp", None)
+    use_fsdp = kwargs.get("use_fsdp", False)
+    # use_ddp = False if use_fsdp else use_fsdp
     if use_ddp or use_fsdp:
         dist.init_process_group(backend=kwargs.get("backend", "nccl"), init_method='env://')
         torch.cuda.set_device(local_rank)
-        
+
+    logging.info("Build model, frontend, tokenizer")
     device = kwargs.get("device", "cuda")
     kwargs["device"] = "cpu"
     model = AutoModel(**kwargs)
@@ -76,6 +80,7 @@
         os.makedirs(kwargs.get("output_dir", "./"), exist_ok=True)
         yaml_file = os.path.join(kwargs.get("output_dir", "./"), "config.yaml")
         OmegaConf.save(config=kwargs, f=yaml_file)
+        print(kwargs)
         logging.info("config.yaml is saved to: %s", yaml_file)
     
     # parse kwargs
@@ -105,19 +110,42 @@
         model = DDP(model, device_ids=[local_rank],
                     find_unused_parameters=kwargs.get("train_conf", {}).get("find_unused_parameters", False))
     elif use_fsdp:
-        model = FSDP(model).cuda(local_rank)
+        # model = FSDP(model).cuda(local_rank)
+
+        def custom_auto_wrap_policy(
+            module: nn.Module,
+            recurse: bool,
+            nonwrapped_numel: int,
+            # Additional custom arguments
+            min_num_params: int = int(1e8),
+        ) -> bool:
+            # 鏍规嵁鑷畾涔夐�昏緫鍐冲畾鏄惁鍖呰妯″潡
+            is_large = unwrapped_params >= min_num_params
+            requires_grad_uniform = len({p.requires_grad for p in module.parameters()}) == 1
+            return is_large and requires_grad_uniform
+
+        # Configure a custom `min_num_params`
+        my_auto_wrap_policy = functools.partial(custom_auto_wrap_policy, min_num_params=int(1e5))
+        torch.cuda.set_device(local_rank)
+        model = FSDP(model,
+                     auto_wrap_policy=custom_auto_wrap_policy,
+                     mixed_precision=None,
+                     device_id=torch.cuda.current_device())
     else:
         model = model.to(device=kwargs.get("device", "cuda"))
 
+    logging.info(f"{model}")
     kwargs["device"] = next(model.parameters()).device
         
     # optim
+    logging.info("Build optim")
     optim = kwargs.get("optim", "adam")
     assert optim in optim_classes
     optim_class = optim_classes.get(optim)
     optim = optim_class(model.parameters(), **kwargs.get("optim_conf"))
     
     # scheduler
+    logging.info("Build scheduler")
     scheduler = kwargs.get("scheduler", "warmuplr")
     assert scheduler in scheduler_classes
     scheduler_class = scheduler_classes.get(scheduler)
@@ -125,6 +153,7 @@
 
 
     # dataset
+    logging.info("Build dataloader")
     dataset_class = tables.dataset_classes.get(kwargs.get("dataset", "AudioDataset"))
     dataset_tr = dataset_class(kwargs.get("train_data_set_list"), frontend=frontend, tokenizer=tokenizer, is_training=True, **kwargs.get("dataset_conf"))
     dataset_val = dataset_class(kwargs.get("valid_data_set_list"), frontend=frontend, tokenizer=tokenizer, is_training=False, **kwargs.get("dataset_conf"))
@@ -142,8 +171,9 @@
 
     trainer = Trainer(local_rank=local_rank,
                       use_ddp=use_ddp,
-                      resume=kwargs.get("resume", True),
+                      use_fsdp=use_fsdp,
                       device=kwargs["device"],
+                      output_dir=kwargs.get("output_dir", "./exp"),
                       **kwargs.get("train_conf"),
                       )
 
@@ -159,20 +189,27 @@
         writer = SummaryWriter(tensorboard_dir) if trainer.rank == 0 else None
     except:
         writer = None
-    
+
+    if use_ddp or use_fsdp:
+        context = Join([model])
+    else:
+        context = nullcontext()
+
     for epoch in range(trainer.start_epoch, trainer.max_epoch + 1):
         time1 = time.perf_counter()
-        trainer.train_epoch(
-                            model=model,
-                            optim=optim,
-                            scheduler=scheduler,
-                            scaler=scaler,
-                            dataloader_train=dataloader_tr,
-                            dataloader_val=dataloader_val,
-                            epoch=epoch,
-                            writer=writer
-                            )
-
+        with context:
+            
+            trainer.train_epoch(
+                                model=model,
+                                optim=optim,
+                                scheduler=scheduler,
+                                scaler=scaler,
+                                dataloader_train=dataloader_tr,
+                                dataloader_val=dataloader_val,
+                                epoch=epoch,
+                                writer=writer
+                                )
+        scheduler.step()
         trainer.validate_epoch(
             model=model,
             dataloader_val=dataloader_val,
@@ -180,21 +217,20 @@
             writer=writer
         )
 
+        
         trainer.save_checkpoint(epoch, model=model, optim=optim, scheduler=scheduler, scaler=scaler)
-
-        scheduler.step()
 
         time2 = time.perf_counter()
         time_escaped = (time2 - time1) / 3600.0
         logging.info(
-            f"\nrank: {local_rank}, "
+            f"rank: {local_rank}, "
             f"time_escaped_epoch: {time_escaped:.3f} hours, "
             f"estimated to finish {trainer.max_epoch} "
             f"epoch: {(trainer.max_epoch - epoch) * time_escaped:.3f} hours\n")
 
 
     if trainer.rank == 0:
-        average_checkpoints(trainer.output_dir, trainer.avg_nbest_model)
+        average_checkpoints(trainer.output_dir, trainer.avg_nbest_model, trainer.val_acc_list)
 
     trainer.close()
 

--
Gitblit v1.9.1