From fc68b5ffe453235294a561737d8e84bb6c1689a4 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 25 四月 2024 21:43:47 +0800
Subject: [PATCH] Dev gzf exp (#1661)

---
 funasr/bin/train.py |   24 +++++++++++++++---------
 1 files changed, 15 insertions(+), 9 deletions(-)

diff --git a/funasr/bin/train.py b/funasr/bin/train.py
index eb1611a..448e464 100644
--- a/funasr/bin/train.py
+++ b/funasr/bin/train.py
@@ -99,7 +99,7 @@
     if freeze_param is not None:
         if "," in freeze_param:
             freeze_param = eval(freeze_param)
-        if isinstance(freeze_param, Sequence):
+        if not isinstance(freeze_param, Sequence):
             freeze_param = (freeze_param,)
         logging.info("freeze_param is not None: %s", freeze_param)
         for t in freeze_param:
@@ -107,8 +107,9 @@
                 if k.startswith(t + ".") or k == t:
                     logging.info(f"Setting {k}.requires_grad = False")
                     p.requires_grad = False
+    if local_rank == 0:
+        logging.info(f"{model_summary(model)}")
 
-    logging.info(f"model info: {model_summary(model)}")
     if use_ddp:
         model = model.cuda(local_rank)
         model = DDP(
@@ -145,8 +146,6 @@
     else:
         model = model.to(device=kwargs.get("device", "cuda"))
 
-    if local_rank == 0:
-        logging.info(f"{model}")
     kwargs["device"] = next(model.parameters()).device
 
     # optim
@@ -182,7 +181,12 @@
     scaler = GradScaler(enabled=trainer.use_fp16) if trainer.use_fp16 else None
     scaler = ShardedGradScaler(enabled=trainer.use_fp16) if trainer.use_fsdp else scaler
 
-    trainer.resume_checkpoint(model=model, optim=optim, scheduler=scheduler, scaler=scaler)
+    trainer.resume_checkpoint(
+        model=model,
+        optim=optim,
+        scheduler=scheduler,
+        scaler=scaler,
+    )
 
     tensorboard_dir = os.path.join(kwargs.get("output_dir"), "tensorboard")
     os.makedirs(tensorboard_dir, exist_ok=True)
@@ -197,8 +201,11 @@
     for epoch in range(trainer.start_epoch, trainer.max_epoch + 1):
         time1 = time.perf_counter()
 
-        for data_split_i in range(dataloader.data_split_num):
-            dataloader_tr, dataloader_val = dataloader.build_iter(epoch, data_split_i=data_split_i)
+        for data_split_i in range(trainer.start_data_split_i, dataloader.data_split_num):
+            dataloader_tr, dataloader_val = dataloader.build_iter(
+                epoch, data_split_i=data_split_i, start_step=trainer.start_step
+            )
+            trainer.start_step = 0
             trainer.train_epoch(
                 model=model,
                 optim=optim,
@@ -211,9 +218,8 @@
                 data_split_i=data_split_i,
                 data_split_num=dataloader.data_split_num,
             )
-            
-            torch.cuda.empty_cache()
 
+            torch.cuda.empty_cache()
 
         trainer.validate_epoch(
             model=model, dataloader_val=dataloader_val, epoch=epoch, writer=writer

--
Gitblit v1.9.1