From a750595594321833b48dc19798eed66876a100b4 Mon Sep 17 00:00:00 2001
From: ming030890 <67713085+ming030890@users.noreply.github.com>
Date: 星期五, 04 七月 2025 14:25:54 +0800
Subject: [PATCH] Fix a few issues found during fine-tuning (#2582)
---
funasr/bin/train_ds.py | 20 ++++++++++++++++++--
1 files changed, 18 insertions(+), 2 deletions(-)
diff --git a/funasr/bin/train_ds.py b/funasr/bin/train_ds.py
index 10a5d08..2241b0c 100644
--- a/funasr/bin/train_ds.py
+++ b/funasr/bin/train_ds.py
@@ -149,7 +149,7 @@
dataloader = dataloader_class(**kwargs)
# dataloader_tr, dataloader_val = dataloader_class(**kwargs)
- scaler = GradScaler(enabled=True) if trainer.use_fp16 or trainer.use_bf16 else None
+ scaler = GradScaler(enabled=True) if trainer.use_fp16 else None
scaler = ShardedGradScaler(enabled=trainer.use_fp16) if trainer.use_fsdp else scaler
trainer.resume_checkpoint(
@@ -158,6 +158,10 @@
scheduler=scheduler,
scaler=scaler,
)
+
+ early_stopping_patience = kwargs.get("train_conf", {}).get("early_stopping_patience", 0)
+ best_val_loss = float("inf")
+ epochs_no_improve = 0
dataloader_tr, dataloader_val = None, None
for epoch in range(trainer.start_epoch, trainer.max_epoch):
@@ -199,7 +203,19 @@
trainer.start_data_split_i = 0
trainer.validate_epoch(model=model, dataloader_val=dataloader_val, epoch=epoch + 1)
- scheduler.step()
+ current_val = trainer.val_loss_avg
+
+ if current_val < best_val_loss:
+ logging.info(f"current_val: {current_val}, best_val_loss: {best_val_loss}")
+ best_val_loss = current_val
+ epochs_no_improve = 0
+ else:
+ epochs_no_improve += 1
+ logging.info(f"No val_loss improvement for {epochs_no_improve}/{early_stopping_patience} epochs")
+ if early_stopping_patience > 0 and epochs_no_improve >= early_stopping_patience:
+ logging.info(f"Early stopping triggered at epoch {epoch+1}")
+ break
+
trainer.step_in_epoch = 0
trainer.save_checkpoint(
epoch + 1, model=model, optim=optim, scheduler=scheduler, scaler=scaler
--
Gitblit v1.9.1