From 4482bbcbb912f699a4faecaafd65aa15aec64a51 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 21 三月 2024 11:49:30 +0800
Subject: [PATCH] train (#1521)

---
 funasr/bin/train.py |  215 +++++++++++++++++++++++++++++------------------------
 1 files changed, 119 insertions(+), 96 deletions(-)

diff --git a/funasr/bin/train.py b/funasr/bin/train.py
index 3c93371..3f97f9e 100644
--- a/funasr/bin/train.py
+++ b/funasr/bin/train.py
@@ -4,15 +4,23 @@
 import os
 import sys
 import torch
+import torch.nn as nn
 import hydra
 import logging
+import time
 import argparse
 from io import BytesIO
+
+from contextlib import nullcontext
 import torch.distributed as dist
 from collections.abc import Sequence
 from omegaconf import DictConfig, OmegaConf
+from torch.cuda.amp import autocast, GradScaler
 from torch.nn.parallel import DistributedDataParallel as DDP
 from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
+from torch.distributed.algorithms.join import Join
+from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
+from funasr.train_utils.average_nbest_models import average_checkpoints
 
 from funasr.register import tables
 from funasr.optimizers import optim_classes
@@ -23,10 +31,8 @@
 from funasr.models.lora.utils import mark_only_lora_as_trainable
 from funasr.train_utils.set_all_random_seed import set_all_random_seed
 from funasr.train_utils.load_pretrained_model import load_pretrained_model
-# from funasr.tokenizer.build_tokenizer import build_tokenizer
-# from funasr.tokenizer.token_id_converter import TokenIDConverter
-# from funasr.tokenizer.funtoken import build_tokenizer
-
+from funasr.utils.misc import prepare_model_dir
+from funasr import AutoModel
 
 @hydra.main(config_name=None, version_base=None)
 def main_hydra(kwargs: DictConfig):
@@ -43,7 +49,6 @@
 
 
 def main(**kwargs):
-    print(kwargs)
     
     # set random seed
     set_all_random_seed(kwargs.get("seed", 0))
@@ -56,65 +61,29 @@
         tables.print()
     # Check if we are using DDP or FSDP
     use_ddp = 'WORLD_SIZE' in os.environ and int(os.environ["WORLD_SIZE"]) > 1
-    use_fsdp = kwargs.get("use_fsdp", None)
+    use_fsdp = kwargs.get("use_fsdp", False)
+    # use_ddp = False if use_fsdp else use_fsdp
     if use_ddp or use_fsdp:
         dist.init_process_group(backend=kwargs.get("backend", "nccl"), init_method='env://')
         torch.cuda.set_device(local_rank)
+
+    logging.info("Build model, frontend, tokenizer")
+    device = kwargs.get("device", "cuda")
+    kwargs["device"] = "cpu"
+    model = AutoModel(**kwargs)
+    
     
     # save config.yaml
     if (use_ddp or use_fsdp) and dist.get_rank() == 0 or not (use_ddp or use_fsdp) and local_rank == 0:
-        os.makedirs(kwargs.get("output_dir", "./"), exist_ok=True)
-        yaml_file = os.path.join(kwargs.get("output_dir", "./"), "config.yaml")
-        OmegaConf.save(config=kwargs, f=yaml_file)
-        logging.info("config.yaml is saved to: %s", yaml_file)
-
-    tokenizer = kwargs.get("tokenizer", None)
-    if tokenizer is not None:
-        tokenizer_class = tables.tokenizer_classes.get(tokenizer)
-        tokenizer = tokenizer_class(**kwargs["tokenizer_conf"])
-        kwargs["tokenizer"] = tokenizer
+        prepare_model_dir(**kwargs)
     
-    # build frontend if frontend is none None
-    frontend = kwargs.get("frontend", None)
-    if frontend is not None:
-        frontend_class = tables.frontend_classes.get(frontend)
-        frontend = frontend_class(**kwargs["frontend_conf"])
-        kwargs["frontend"] = frontend
-        kwargs["input_size"] = frontend.output_size()
-
-
-    # build model
-    model_class = tables.model_classes.get(kwargs["model"])
-    vocab_size = len(tokenizer.token_list) if hasattr(tokenizer, "token_list") else None
-    vocab_size = len(tokenizer.get_vocab()) if hasattr(tokenizer, "get_vocab") else vocab_size
-    model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=vocab_size)
-
-
-
-    # init_param
-    init_param = kwargs.get("init_param", None)
-    if init_param is not None:
-        if not isinstance(init_param, (list, tuple)):
-            init_param = (init_param,)
-        logging.info("init_param is not None: %s", init_param)
-        for p in init_param:
-            if os.path.exists(p):
-                logging.info(f"Loading pretrained params from {p}")
-                load_pretrained_model(
-                    model=model,
-                    path=p,
-                    ignore_init_mismatch=kwargs.get("ignore_init_mismatch", True),
-                    oss_bucket=kwargs.get("oss_bucket", None),
-                    scope_map=kwargs.get("scope_map", []),
-                    excludes=kwargs.get("excludes", None),
-                )
-            else:
-                logging.info(f"Checkpoint does not exist, init randomly: {p}")
-    elif kwargs.get("init", None):
-        initialize(model, kwargs.get("init", "kaiming_normal"))
-    else:
-        print("No initialize method")
-
+    # parse kwargs
+    kwargs = model.kwargs
+    kwargs["device"] = device
+    tokenizer = kwargs["tokenizer"]
+    frontend = kwargs["frontend"]
+    model = model.model
+    del kwargs["model"]
 
     # freeze_param
     freeze_param = kwargs.get("freeze_param", None)
@@ -135,18 +104,42 @@
         model = DDP(model, device_ids=[local_rank],
                     find_unused_parameters=kwargs.get("train_conf", {}).get("find_unused_parameters", False))
     elif use_fsdp:
-        model = FSDP(model).cuda(local_rank)
+        # model = FSDP(model).cuda(local_rank)
+
+        def custom_auto_wrap_policy(
+            module: nn.Module,
+            recurse: bool,
+            nonwrapped_numel: int,
+            # Additional custom arguments
+            min_num_params: int = int(1e8),
+        ) -> bool:
+            # 鏍规嵁鑷畾涔夐�昏緫鍐冲畾鏄惁鍖呰妯″潡
+            is_large = unwrapped_params >= min_num_params
+            requires_grad_uniform = len({p.requires_grad for p in module.parameters()}) == 1
+            return is_large and requires_grad_uniform
+
+        # Configure a custom `min_num_params`
+        my_auto_wrap_policy = functools.partial(custom_auto_wrap_policy, min_num_params=int(1e5))
+        torch.cuda.set_device(local_rank)
+        model = FSDP(model,
+                     auto_wrap_policy=custom_auto_wrap_policy,
+                     mixed_precision=None,
+                     device_id=torch.cuda.current_device())
     else:
         model = model.to(device=kwargs.get("device", "cuda"))
-        
+
+    logging.info(f"{model}")
+    kwargs["device"] = next(model.parameters()).device
         
     # optim
+    logging.info("Build optim")
     optim = kwargs.get("optim", "adam")
     assert optim in optim_classes
     optim_class = optim_classes.get(optim)
     optim = optim_class(model.parameters(), **kwargs.get("optim_conf"))
     
     # scheduler
+    logging.info("Build scheduler")
     scheduler = kwargs.get("scheduler", "warmuplr")
     assert scheduler in scheduler_classes
     scheduler_class = scheduler_classes.get(scheduler)
@@ -154,45 +147,75 @@
 
 
     # dataset
-    dataset_class = tables.dataset_classes.get(kwargs.get("dataset", "AudioDataset"))
-    dataset_tr = dataset_class(kwargs.get("train_data_set_list"), frontend=frontend, tokenizer=tokenizer, is_training=True, **kwargs.get("dataset_conf"))
-    dataset_val = dataset_class(kwargs.get("valid_data_set_list"), frontend=frontend, tokenizer=tokenizer, is_training=False, **kwargs.get("dataset_conf"))
+    logging.info("Build dataloader")
+    dataloader_class = tables.dataloader_classes.get( kwargs["dataset_conf"].get("dataloader", "DataloaderMapStyle"))
+    dataloader_tr, dataloader_val = dataloader_class(**kwargs)
 
-    # dataloader
-    batch_sampler = kwargs["dataset_conf"].get("batch_sampler", "DynamicBatchLocalShuffleSampler")
-    batch_sampler_val = None
-    if batch_sampler is not None:
-        batch_sampler_class = tables.batch_sampler_classes.get(batch_sampler)
-        batch_sampler = batch_sampler_class(dataset_tr, **kwargs.get("dataset_conf"))
-        batch_sampler_val = batch_sampler_class(dataset_val, is_training=False, **kwargs.get("dataset_conf"))
-    dataloader_tr = torch.utils.data.DataLoader(dataset_tr,
-                                                collate_fn=dataset_tr.collator,
-                                                batch_sampler=batch_sampler,
-                                                num_workers=kwargs.get("dataset_conf").get("num_workers", 4),
-                                                pin_memory=True)
-    
-    dataloader_val = torch.utils.data.DataLoader(dataset_val,
-                                                collate_fn=dataset_val.collator,
-                                                batch_sampler=batch_sampler_val,
-                                                num_workers=kwargs.get("dataset_conf").get("num_workers", 4),
-                                                pin_memory=True)
-    trainer = Trainer(
-        model=model,
-        optim=optim,
-        scheduler=scheduler,
-        dataloader_train=dataloader_tr,
-        dataloader_val=dataloader_val,
-        local_rank=local_rank,
-        use_ddp=use_ddp,
-        use_fsdp=use_fsdp,
-        output_dir=kwargs.get("output_dir", "./exp"),
-        resume=kwargs.get("resume", True),
-        **kwargs.get("train_conf"),
-    )
-    trainer.run()
-    
+    trainer = Trainer(local_rank=local_rank,
+                      use_ddp=use_ddp,
+                      use_fsdp=use_fsdp,
+                      device=kwargs["device"],
+                      output_dir=kwargs.get("output_dir", "./exp"),
+                      **kwargs.get("train_conf"),
+                      )
+
+    scaler = GradScaler(enabled=trainer.use_fp16) if trainer.use_fp16 else None
+    scaler = ShardedGradScaler(enabled=trainer.use_fp16) if trainer.use_fsdp else scaler
+
+    trainer.resume_checkpoint(model=model, optim=optim, scheduler=scheduler, scaler=scaler)
+
+    tensorboard_dir = os.path.join(kwargs.get("output_dir"), "tensorboard")
+    os.makedirs(tensorboard_dir, exist_ok=True)
+    try:
+        from tensorboardX import SummaryWriter
+        writer = SummaryWriter(tensorboard_dir) if trainer.rank == 0 else None
+    except:
+        writer = None
+
     if use_ddp or use_fsdp:
-        torch.distributed.destroy_process_group()
+        context = Join([model])
+    else:
+        context = nullcontext()
+
+    for epoch in range(trainer.start_epoch, trainer.max_epoch + 1):
+        time1 = time.perf_counter()
+        with context:
+            
+            trainer.train_epoch(
+                                model=model,
+                                optim=optim,
+                                scheduler=scheduler,
+                                scaler=scaler,
+                                dataloader_train=dataloader_tr,
+                                dataloader_val=dataloader_val,
+                                epoch=epoch,
+                                writer=writer
+                                )
+        scheduler.step()
+        trainer.validate_epoch(
+            model=model,
+            dataloader_val=dataloader_val,
+            epoch=epoch,
+            writer=writer
+        )
+
+        
+        trainer.save_checkpoint(epoch, model=model, optim=optim, scheduler=scheduler, scaler=scaler)
+
+        time2 = time.perf_counter()
+        time_escaped = (time2 - time1) / 3600.0
+        logging.info(
+            f"rank: {local_rank}, "
+            f"time_escaped_epoch: {time_escaped:.3f} hours, "
+            f"estimated to finish {trainer.max_epoch} "
+            f"epoch: {(trainer.max_epoch - epoch) * time_escaped:.3f} hours\n")
+
+
+    if trainer.rank == 0:
+        average_checkpoints(trainer.output_dir, trainer.avg_nbest_model, trainer.val_acc_list)
+
+    trainer.close()
+
 
     
 

--
Gitblit v1.9.1