From ed22e34d654c47017962d3e5758d3a351d8826ab Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期日, 24 三月 2024 15:03:54 +0800
Subject: [PATCH] finetune

---
 funasr/bin/train.py |   18 ++++++++++--------
 1 files changed, 10 insertions(+), 8 deletions(-)

diff --git a/funasr/bin/train.py b/funasr/bin/train.py
index 0ff4ba1..6cb486b 100644
--- a/funasr/bin/train.py
+++ b/funasr/bin/train.py
@@ -128,7 +128,8 @@
     else:
         model = model.to(device=kwargs.get("device", "cuda"))
 
-    logging.info(f"{model}")
+    if local_rank == 0:
+        logging.info(f"{model}")
     kwargs["device"] = next(model.parameters()).device
         
     # optim
@@ -175,8 +176,8 @@
     # if use_ddp or use_fsdp:
     #     context = Join([model])
     # else:
+    #     context = nullcontext()
     context = nullcontext()
-
     for epoch in range(trainer.start_epoch, trainer.max_epoch + 1):
         time1 = time.perf_counter()
         with context:
@@ -191,13 +192,14 @@
                                 epoch=epoch,
                                 writer=writer
                                 )
+        with context:
+            trainer.validate_epoch(
+                model=model,
+                dataloader_val=dataloader_val,
+                epoch=epoch,
+                writer=writer
+            )
         scheduler.step()
-        trainer.validate_epoch(
-            model=model,
-            dataloader_val=dataloader_val,
-            epoch=epoch,
-            writer=writer
-        )
 
         
         trainer.save_checkpoint(epoch, model=model, optim=optim, scheduler=scheduler, scaler=scaler)

--
Gitblit v1.9.1