From 80bd14e6bbb7bb282ff3832194648dc4a16157ca Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 25 四月 2024 10:41:14 +0800
Subject: [PATCH] Dev gzf exp (#1657)

---
 funasr/bin/train.py |    5 +++++
 1 files changed, 5 insertions(+), 0 deletions(-)

diff --git a/funasr/bin/train.py b/funasr/bin/train.py
index 05942cd..eb1611a 100644
--- a/funasr/bin/train.py
+++ b/funasr/bin/train.py
@@ -32,6 +32,7 @@
 from funasr.train_utils.set_all_random_seed import set_all_random_seed
 from funasr.train_utils.load_pretrained_model import load_pretrained_model
 from funasr.utils.misc import prepare_model_dir
+from funasr.train_utils.model_summary import model_summary
 from funasr import AutoModel
 
 
@@ -107,6 +108,7 @@
                     logging.info(f"Setting {k}.requires_grad = False")
                     p.requires_grad = False
 
+    logging.info(f"model info: {model_summary(model)}")
     if use_ddp:
         model = model.cuda(local_rank)
         model = DDP(
@@ -209,6 +211,9 @@
                 data_split_i=data_split_i,
                 data_split_num=dataloader.data_split_num,
             )
+            
+            torch.cuda.empty_cache()
+
 
         trainer.validate_epoch(
             model=model, dataloader_val=dataloader_val, epoch=epoch, writer=writer

--
Gitblit v1.9.1