From d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99 Mon Sep 17 00:00:00 2001
From: liugz18 <57401541+liugz18@users.noreply.github.com>
Date: 星期四, 18 七月 2024 21:34:55 +0800
Subject: [PATCH] Rename 'res' in line 514 to avoid with naming conflict with line 365

---
 funasr/bin/train.py |   52 +++++++++++++++++++++++++++++++++++++---------------
 1 files changed, 37 insertions(+), 15 deletions(-)

diff --git a/funasr/bin/train.py b/funasr/bin/train.py
index eb1611a..2729b80 100644
--- a/funasr/bin/train.py
+++ b/funasr/bin/train.py
@@ -13,7 +13,7 @@
 
 from contextlib import nullcontext
 import torch.distributed as dist
-from collections.abc import Sequence
+
 from omegaconf import DictConfig, OmegaConf
 from torch.cuda.amp import autocast, GradScaler
 from torch.nn.parallel import DistributedDataParallel as DDP
@@ -27,7 +27,7 @@
 from funasr.train_utils.trainer import Trainer
 from funasr.schedulers import scheduler_classes
 from funasr.train_utils.initialize import initialize
-from funasr.download.download_from_hub import download_model
+from funasr.download.download_model_from_hub import download_model
 from funasr.models.lora.utils import mark_only_lora_as_trainable
 from funasr.train_utils.set_all_random_seed import set_all_random_seed
 from funasr.train_utils.load_pretrained_model import load_pretrained_model
@@ -99,7 +99,7 @@
     if freeze_param is not None:
         if "," in freeze_param:
             freeze_param = eval(freeze_param)
-        if isinstance(freeze_param, Sequence):
+        if not isinstance(freeze_param, (list, tuple)):
             freeze_param = (freeze_param,)
         logging.info("freeze_param is not None: %s", freeze_param)
         for t in freeze_param:
@@ -107,8 +107,9 @@
                 if k.startswith(t + ".") or k == t:
                     logging.info(f"Setting {k}.requires_grad = False")
                     p.requires_grad = False
+    if local_rank == 0:
+        logging.info(f"{model_summary(model)}")
 
-    logging.info(f"model info: {model_summary(model)}")
     if use_ddp:
         model = model.cuda(local_rank)
         model = DDP(
@@ -145,8 +146,6 @@
     else:
         model = model.to(device=kwargs.get("device", "cuda"))
 
-    if local_rank == 0:
-        logging.info(f"{model}")
     kwargs["device"] = next(model.parameters()).device
 
     # optim
@@ -182,23 +181,32 @@
     scaler = GradScaler(enabled=trainer.use_fp16) if trainer.use_fp16 else None
     scaler = ShardedGradScaler(enabled=trainer.use_fp16) if trainer.use_fsdp else scaler
 
-    trainer.resume_checkpoint(model=model, optim=optim, scheduler=scheduler, scaler=scaler)
+    trainer.resume_checkpoint(
+        model=model,
+        optim=optim,
+        scheduler=scheduler,
+        scaler=scaler,
+    )
 
     tensorboard_dir = os.path.join(kwargs.get("output_dir"), "tensorboard")
     os.makedirs(tensorboard_dir, exist_ok=True)
     try:
         from tensorboardX import SummaryWriter
 
-        writer = SummaryWriter(tensorboard_dir) if trainer.rank == 0 else None
+        writer = SummaryWriter(tensorboard_dir)  # if trainer.rank == 0 else None
     except:
         writer = None
 
     dataloader_tr, dataloader_val = None, None
-    for epoch in range(trainer.start_epoch, trainer.max_epoch + 1):
+    for epoch in range(trainer.start_epoch, trainer.max_epoch):
         time1 = time.perf_counter()
 
-        for data_split_i in range(dataloader.data_split_num):
-            dataloader_tr, dataloader_val = dataloader.build_iter(epoch, data_split_i=data_split_i)
+        for data_split_i in range(trainer.start_data_split_i, dataloader.data_split_num):
+            time_slice_i = time.perf_counter()
+            dataloader_tr, dataloader_val = dataloader.build_iter(
+                epoch, data_split_i=data_split_i, start_step=trainer.start_step
+            )
+
             trainer.train_epoch(
                 model=model,
                 optim=optim,
@@ -210,17 +218,29 @@
                 writer=writer,
                 data_split_i=data_split_i,
                 data_split_num=dataloader.data_split_num,
+                start_step=trainer.start_step,
             )
-            
+            trainer.start_step = 0
+
             torch.cuda.empty_cache()
 
+            time_escaped = (time.perf_counter() - time_slice_i) / 3600.0
+            logging.info(
+                f"rank: {local_rank}, "
+                f"time_escaped_epoch: {time_escaped:.3f} hours, "
+                f"estimated to finish {dataloader.data_split_num} data_slices, remaining: {dataloader.data_split_num-data_split_i} slices, {(dataloader.data_split_num-data_split_i)*time_escaped:.3f} hours, "
+                f"epoch: {trainer.max_epoch - epoch} epochs, {((trainer.max_epoch - epoch - 1)*dataloader.data_split_num + dataloader.data_split_num-data_split_i)*time_escaped:.3f} hours\n"
+            )
 
+        trainer.start_data_split_i = 0
         trainer.validate_epoch(
-            model=model, dataloader_val=dataloader_val, epoch=epoch, writer=writer
+            model=model, dataloader_val=dataloader_val, epoch=epoch + 1, writer=writer
         )
         scheduler.step()
-
-        trainer.save_checkpoint(epoch, model=model, optim=optim, scheduler=scheduler, scaler=scaler)
+        trainer.step_in_epoch = 0
+        trainer.save_checkpoint(
+            epoch + 1, model=model, optim=optim, scheduler=scheduler, scaler=scaler
+        )
 
         time2 = time.perf_counter()
         time_escaped = (time2 - time1) / 3600.0
@@ -230,6 +250,8 @@
             f"estimated to finish {trainer.max_epoch} "
             f"epoch: {(trainer.max_epoch - epoch) * time_escaped:.3f} hours\n"
         )
+        trainer.train_acc_avg = 0.0
+        trainer.train_loss_avg = 0.0
 
     if trainer.rank == 0:
         average_checkpoints(trainer.output_dir, trainer.avg_nbest_model)

--
Gitblit v1.9.1