From 80bd14e6bbb7bb282ff3832194648dc4a16157ca Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 25 四月 2024 10:41:14 +0800
Subject: [PATCH] Dev gzf exp (#1657)
---
funasr/bin/train.py | 5 +++++
1 files changed, 5 insertions(+), 0 deletions(-)
diff --git a/funasr/bin/train.py b/funasr/bin/train.py
index 05942cd..eb1611a 100644
--- a/funasr/bin/train.py
+++ b/funasr/bin/train.py
@@ -32,6 +32,7 @@
from funasr.train_utils.set_all_random_seed import set_all_random_seed
from funasr.train_utils.load_pretrained_model import load_pretrained_model
from funasr.utils.misc import prepare_model_dir
+from funasr.train_utils.model_summary import model_summary
from funasr import AutoModel
@@ -107,6 +108,7 @@
logging.info(f"Setting {k}.requires_grad = False")
p.requires_grad = False
+ logging.info(f"model info: {model_summary(model)}")
if use_ddp:
model = model.cuda(local_rank)
model = DDP(
@@ -209,6 +211,9 @@
data_split_i=data_split_i,
data_split_num=dataloader.data_split_num,
)
+
+ torch.cuda.empty_cache()
+
trainer.validate_epoch(
model=model, dataloader_val=dataloader_val, epoch=epoch, writer=writer
--
Gitblit v1.9.1