From 3a4281f4959534b1bf5d01acf0085f4f8e6f2ec8 Mon Sep 17 00:00:00 2001
From: wuhongsheng <664116298@qq.com>
Date: 星期五, 05 七月 2024 00:55:32 +0800
Subject: [PATCH] 优化speakid和语句匹配逻辑,部分解决speakid不从0递增问题 (#1870)

---
 funasr/bin/train.py |  166 +++++++++++++++++++++++++++++++++---------------------
 1 files changed, 101 insertions(+), 65 deletions(-)

diff --git a/funasr/bin/train.py b/funasr/bin/train.py
index c02a66f..2729b80 100644
--- a/funasr/bin/train.py
+++ b/funasr/bin/train.py
@@ -13,7 +13,7 @@
 
 from contextlib import nullcontext
 import torch.distributed as dist
-from collections.abc import Sequence
+
 from omegaconf import DictConfig, OmegaConf
 from torch.cuda.amp import autocast, GradScaler
 from torch.nn.parallel import DistributedDataParallel as DDP
@@ -27,29 +27,32 @@
 from funasr.train_utils.trainer import Trainer
 from funasr.schedulers import scheduler_classes
 from funasr.train_utils.initialize import initialize
-from funasr.download.download_from_hub import download_model
+from funasr.download.download_model_from_hub import download_model
 from funasr.models.lora.utils import mark_only_lora_as_trainable
 from funasr.train_utils.set_all_random_seed import set_all_random_seed
 from funasr.train_utils.load_pretrained_model import load_pretrained_model
 from funasr.utils.misc import prepare_model_dir
+from funasr.train_utils.model_summary import model_summary
 from funasr import AutoModel
+
 
 @hydra.main(config_name=None, version_base=None)
 def main_hydra(kwargs: DictConfig):
     if kwargs.get("debug", False):
-        import pdb; pdb.set_trace()
+        import pdb
+
+        pdb.set_trace()
 
     assert "model" in kwargs
     if "model_conf" not in kwargs:
         logging.info("download models from model hub: {}".format(kwargs.get("hub", "ms")))
         kwargs = download_model(is_training=kwargs.get("is_training", True), **kwargs)
-    
 
     main(**kwargs)
 
 
 def main(**kwargs):
-    
+
     # set random seed
     set_all_random_seed(kwargs.get("seed", 0))
     torch.backends.cudnn.enabled = kwargs.get("cudnn_enabled", torch.backends.cudnn.enabled)
@@ -57,28 +60,32 @@
     torch.backends.cudnn.deterministic = kwargs.get("cudnn_deterministic", True)
     # open tf32
     torch.backends.cuda.matmul.allow_tf32 = kwargs.get("enable_tf32", True)
-    
-    local_rank = int(os.environ.get('LOCAL_RANK', 0))
+
+    local_rank = int(os.environ.get("LOCAL_RANK", 0))
     if local_rank == 0:
         tables.print()
     # Check if we are using DDP or FSDP
-    use_ddp = 'WORLD_SIZE' in os.environ and int(os.environ["WORLD_SIZE"]) > 1
+    use_ddp = "WORLD_SIZE" in os.environ and int(os.environ["WORLD_SIZE"]) > 1
     use_fsdp = kwargs.get("use_fsdp", False)
     # use_ddp = False if use_fsdp else use_fsdp
     if use_ddp or use_fsdp:
-        dist.init_process_group(backend=kwargs.get("backend", "nccl"), init_method='env://')
+        dist.init_process_group(backend=kwargs.get("backend", "nccl"), init_method="env://")
         torch.cuda.set_device(local_rank)
 
     logging.info("Build model, frontend, tokenizer")
     device = kwargs.get("device", "cuda")
     kwargs["device"] = "cpu"
     model = AutoModel(**kwargs)
-    
-    
+
     # save config.yaml
-    if (use_ddp or use_fsdp) and dist.get_rank() == 0 or not (use_ddp or use_fsdp) and local_rank == 0:
+    if (
+        (use_ddp or use_fsdp)
+        and dist.get_rank() == 0
+        or not (use_ddp or use_fsdp)
+        and local_rank == 0
+    ):
         prepare_model_dir(**kwargs)
-    
+
     # parse kwargs
     kwargs = model.kwargs
     kwargs["device"] = device
@@ -92,7 +99,7 @@
     if freeze_param is not None:
         if "," in freeze_param:
             freeze_param = eval(freeze_param)
-        if isinstance(freeze_param, Sequence):
+        if not isinstance(freeze_param, (list, tuple)):
             freeze_param = (freeze_param,)
         logging.info("freeze_param is not None: %s", freeze_param)
         for t in freeze_param:
@@ -100,12 +107,18 @@
                 if k.startswith(t + ".") or k == t:
                     logging.info(f"Setting {k}.requires_grad = False")
                     p.requires_grad = False
-    
+    if local_rank == 0:
+        logging.info(f"{model_summary(model)}")
 
     if use_ddp:
         model = model.cuda(local_rank)
-        model = DDP(model, device_ids=[local_rank],
-                    find_unused_parameters=kwargs.get("train_conf", {}).get("find_unused_parameters", False))
+        model = DDP(
+            model,
+            device_ids=[local_rank],
+            find_unused_parameters=kwargs.get("train_conf", {}).get(
+                "find_unused_parameters", False
+            ),
+        )
     elif use_fsdp:
         # model = FSDP(model).cuda(local_rank)
 
@@ -124,24 +137,24 @@
         # Configure a custom `min_num_params`
         my_auto_wrap_policy = functools.partial(custom_auto_wrap_policy, min_num_params=int(1e5))
         torch.cuda.set_device(local_rank)
-        model = FSDP(model,
-                     auto_wrap_policy=custom_auto_wrap_policy,
-                     mixed_precision=None,
-                     device_id=torch.cuda.current_device())
+        model = FSDP(
+            model,
+            auto_wrap_policy=custom_auto_wrap_policy,
+            mixed_precision=None,
+            device_id=torch.cuda.current_device(),
+        )
     else:
         model = model.to(device=kwargs.get("device", "cuda"))
 
-    if local_rank == 0:
-        logging.info(f"{model}")
     kwargs["device"] = next(model.parameters()).device
-        
+
     # optim
     logging.info("Build optim")
     optim = kwargs.get("optim", "adam")
     assert optim in optim_classes
     optim_class = optim_classes.get(optim)
     optim = optim_class(model.parameters(), **kwargs.get("optim_conf"))
-    
+
     # scheduler
     logging.info("Build scheduler")
     scheduler = kwargs.get("scheduler", "warmuplr")
@@ -149,62 +162,85 @@
     scheduler_class = scheduler_classes.get(scheduler)
     scheduler = scheduler_class(optim, **kwargs.get("scheduler_conf"))
 
-
     # dataset
     logging.info("Build dataloader")
-    dataloader_class = tables.dataloader_classes.get(kwargs["dataset_conf"].get("dataloader", "DataloaderMapStyle"))
+    dataloader_class = tables.dataloader_classes.get(
+        kwargs["dataset_conf"].get("dataloader", "DataloaderMapStyle")
+    )
     dataloader = dataloader_class(**kwargs)
     # dataloader_tr, dataloader_val = dataloader_class(**kwargs)
-    trainer = Trainer(local_rank=local_rank,
-                      use_ddp=use_ddp,
-                      use_fsdp=use_fsdp,
-                      device=kwargs["device"],
-                      output_dir=kwargs.get("output_dir", "./exp"),
-                      **kwargs.get("train_conf"),
-                      )
+    trainer = Trainer(
+        local_rank=local_rank,
+        use_ddp=use_ddp,
+        use_fsdp=use_fsdp,
+        device=kwargs["device"],
+        output_dir=kwargs.get("output_dir", "./exp"),
+        **kwargs.get("train_conf"),
+    )
 
     scaler = GradScaler(enabled=trainer.use_fp16) if trainer.use_fp16 else None
     scaler = ShardedGradScaler(enabled=trainer.use_fp16) if trainer.use_fsdp else scaler
 
-    trainer.resume_checkpoint(model=model, optim=optim, scheduler=scheduler, scaler=scaler)
+    trainer.resume_checkpoint(
+        model=model,
+        optim=optim,
+        scheduler=scheduler,
+        scaler=scaler,
+    )
 
     tensorboard_dir = os.path.join(kwargs.get("output_dir"), "tensorboard")
     os.makedirs(tensorboard_dir, exist_ok=True)
     try:
         from tensorboardX import SummaryWriter
-        writer = SummaryWriter(tensorboard_dir) if trainer.rank == 0 else None
+
+        writer = SummaryWriter(tensorboard_dir)  # if trainer.rank == 0 else None
     except:
         writer = None
 
-
-    for epoch in range(trainer.start_epoch, trainer.max_epoch + 1):
+    dataloader_tr, dataloader_val = None, None
+    for epoch in range(trainer.start_epoch, trainer.max_epoch):
         time1 = time.perf_counter()
-        
-        for data_split_i in range(dataloader.data_split_num):
-            dataloader_tr, dataloader_val = dataloader.build_iter(epoch, data_split_i=data_split_i)
+
+        for data_split_i in range(trainer.start_data_split_i, dataloader.data_split_num):
+            time_slice_i = time.perf_counter()
+            dataloader_tr, dataloader_val = dataloader.build_iter(
+                epoch, data_split_i=data_split_i, start_step=trainer.start_step
+            )
+
             trainer.train_epoch(
-                                model=model,
-                                optim=optim,
-                                scheduler=scheduler,
-                                scaler=scaler,
-                                dataloader_train=dataloader_tr,
-                                dataloader_val=dataloader_val,
-                                epoch=epoch,
-                                writer=writer,
-                                data_split_i=data_split_i,
-                                data_split_num=dataloader.data_split_num,
-                                )
-        
+                model=model,
+                optim=optim,
+                scheduler=scheduler,
+                scaler=scaler,
+                dataloader_train=dataloader_tr,
+                dataloader_val=dataloader_val,
+                epoch=epoch,
+                writer=writer,
+                data_split_i=data_split_i,
+                data_split_num=dataloader.data_split_num,
+                start_step=trainer.start_step,
+            )
+            trainer.start_step = 0
+
+            torch.cuda.empty_cache()
+
+            time_escaped = (time.perf_counter() - time_slice_i) / 3600.0
+            logging.info(
+                f"rank: {local_rank}, "
+                f"time_escaped_epoch: {time_escaped:.3f} hours, "
+                f"estimated to finish {dataloader.data_split_num} data_slices, remaining: {dataloader.data_split_num-data_split_i} slices, {(dataloader.data_split_num-data_split_i)*time_escaped:.3f} hours, "
+                f"epoch: {trainer.max_epoch - epoch} epochs, {((trainer.max_epoch - epoch - 1)*dataloader.data_split_num + dataloader.data_split_num-data_split_i)*time_escaped:.3f} hours\n"
+            )
+
+        trainer.start_data_split_i = 0
         trainer.validate_epoch(
-            model=model,
-            dataloader_val=dataloader_val,
-            epoch=epoch,
-            writer=writer
+            model=model, dataloader_val=dataloader_val, epoch=epoch + 1, writer=writer
         )
         scheduler.step()
-
-        
-        trainer.save_checkpoint(epoch, model=model, optim=optim, scheduler=scheduler, scaler=scaler)
+        trainer.step_in_epoch = 0
+        trainer.save_checkpoint(
+            epoch + 1, model=model, optim=optim, scheduler=scheduler, scaler=scaler
+        )
 
         time2 = time.perf_counter()
         time_escaped = (time2 - time1) / 3600.0
@@ -212,8 +248,10 @@
             f"rank: {local_rank}, "
             f"time_escaped_epoch: {time_escaped:.3f} hours, "
             f"estimated to finish {trainer.max_epoch} "
-            f"epoch: {(trainer.max_epoch - epoch) * time_escaped:.3f} hours\n")
-
+            f"epoch: {(trainer.max_epoch - epoch) * time_escaped:.3f} hours\n"
+        )
+        trainer.train_acc_avg = 0.0
+        trainer.train_loss_avg = 0.0
 
     if trainer.rank == 0:
         average_checkpoints(trainer.output_dir, trainer.avg_nbest_model)
@@ -221,7 +259,5 @@
     trainer.close()
 
 
-    
-
 if __name__ == "__main__":
-    main_hydra()
\ No newline at end of file
+    main_hydra()

--
Gitblit v1.9.1