From fc68b5ffe453235294a561737d8e84bb6c1689a4 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 25 四月 2024 21:43:47 +0800
Subject: [PATCH] Dev gzf exp (#1661)

---
 funasr/bin/train.py |   23 +++++++++++++++++------
 1 files changed, 17 insertions(+), 6 deletions(-)

diff --git a/funasr/bin/train.py b/funasr/bin/train.py
index 05942cd..448e464 100644
--- a/funasr/bin/train.py
+++ b/funasr/bin/train.py
@@ -32,6 +32,7 @@
 from funasr.train_utils.set_all_random_seed import set_all_random_seed
 from funasr.train_utils.load_pretrained_model import load_pretrained_model
 from funasr.utils.misc import prepare_model_dir
+from funasr.train_utils.model_summary import model_summary
 from funasr import AutoModel
 
 
@@ -98,7 +99,7 @@
     if freeze_param is not None:
         if "," in freeze_param:
             freeze_param = eval(freeze_param)
-        if isinstance(freeze_param, Sequence):
+        if not isinstance(freeze_param, Sequence):
             freeze_param = (freeze_param,)
         logging.info("freeze_param is not None: %s", freeze_param)
         for t in freeze_param:
@@ -106,6 +107,8 @@
                 if k.startswith(t + ".") or k == t:
                     logging.info(f"Setting {k}.requires_grad = False")
                     p.requires_grad = False
+    if local_rank == 0:
+        logging.info(f"{model_summary(model)}")
 
     if use_ddp:
         model = model.cuda(local_rank)
@@ -143,8 +146,6 @@
     else:
         model = model.to(device=kwargs.get("device", "cuda"))
 
-    if local_rank == 0:
-        logging.info(f"{model}")
     kwargs["device"] = next(model.parameters()).device
 
     # optim
@@ -180,7 +181,12 @@
     scaler = GradScaler(enabled=trainer.use_fp16) if trainer.use_fp16 else None
     scaler = ShardedGradScaler(enabled=trainer.use_fp16) if trainer.use_fsdp else scaler
 
-    trainer.resume_checkpoint(model=model, optim=optim, scheduler=scheduler, scaler=scaler)
+    trainer.resume_checkpoint(
+        model=model,
+        optim=optim,
+        scheduler=scheduler,
+        scaler=scaler,
+    )
 
     tensorboard_dir = os.path.join(kwargs.get("output_dir"), "tensorboard")
     os.makedirs(tensorboard_dir, exist_ok=True)
@@ -195,8 +201,11 @@
     for epoch in range(trainer.start_epoch, trainer.max_epoch + 1):
         time1 = time.perf_counter()
 
-        for data_split_i in range(dataloader.data_split_num):
-            dataloader_tr, dataloader_val = dataloader.build_iter(epoch, data_split_i=data_split_i)
+        for data_split_i in range(trainer.start_data_split_i, dataloader.data_split_num):
+            dataloader_tr, dataloader_val = dataloader.build_iter(
+                epoch, data_split_i=data_split_i, start_step=trainer.start_step
+            )
+            trainer.start_step = 0
             trainer.train_epoch(
                 model=model,
                 optim=optim,
@@ -210,6 +219,8 @@
                 data_split_num=dataloader.data_split_num,
             )
 
+            torch.cuda.empty_cache()
+
         trainer.validate_epoch(
             model=model, dataloader_val=dataloader_val, epoch=epoch, writer=writer
         )

--
Gitblit v1.9.1