From 28ccfbfc51068a663a80764e14074df5edf2b5ba Mon Sep 17 00:00:00 2001
From: kongdeqiang <kongdeqiang960204@163.com>
Date: 星期五, 13 三月 2026 17:41:41 +0800
Subject: [PATCH] 提交
---
funasr/bin/train_ds.py | 45 +++++++++++++++++++++++++++++++++++----------
1 files changed, 35 insertions(+), 10 deletions(-)
diff --git a/funasr/bin/train_ds.py b/funasr/bin/train_ds.py
index d9b7679..2241b0c 100644
--- a/funasr/bin/train_ds.py
+++ b/funasr/bin/train_ds.py
@@ -27,7 +27,7 @@
from funasr.train_utils.trainer_ds import Trainer
from funasr.schedulers import scheduler_classes
from funasr.train_utils.initialize import initialize
-from funasr.download.download_from_hub import download_model
+from funasr.download.download_model_from_hub import download_model
from funasr.models.lora.utils import mark_only_lora_as_trainable
from funasr.train_utils.set_all_random_seed import set_all_random_seed
from funasr.train_utils.load_pretrained_model import load_pretrained_model
@@ -81,8 +81,13 @@
deepspeed.init_distributed(dist_backend=kwargs.get("backend", "nccl"))
elif use_ddp or use_fsdp:
logging.info(f"use_ddp: {use_ddp}, use_fsdp: {use_fsdp}")
- dist.init_process_group(backend=kwargs.get("backend", "nccl"), init_method="env://")
+ dist.init_process_group(
+ backend=kwargs.get("backend", "nccl"),
+ init_method="env://",
+ )
torch.cuda.set_device(local_rank)
+
+ # rank = dist.get_rank()
logging.info("Build model, frontend, tokenizer")
device = kwargs.get("device", "cuda")
@@ -124,11 +129,12 @@
use_ddp=use_ddp,
use_fsdp=use_fsdp,
device=kwargs["device"],
+ excludes=kwargs.get("excludes", None),
output_dir=kwargs.get("output_dir", "./exp"),
**kwargs.get("train_conf"),
)
- model = trainer.warp_model(model)
+ model = trainer.warp_model(model, **kwargs)
kwargs["device"] = int(os.environ.get("LOCAL_RANK", 0))
trainer.device = int(os.environ.get("LOCAL_RANK", 0))
@@ -143,7 +149,7 @@
dataloader = dataloader_class(**kwargs)
# dataloader_tr, dataloader_val = dataloader_class(**kwargs)
- scaler = GradScaler(enabled=trainer.use_fp16) if trainer.use_fp16 else None
+ scaler = GradScaler(enabled=True) if trainer.use_fp16 else None
scaler = ShardedGradScaler(enabled=trainer.use_fp16) if trainer.use_fsdp else scaler
trainer.resume_checkpoint(
@@ -152,6 +158,10 @@
scheduler=scheduler,
scaler=scaler,
)
+
+ early_stopping_patience = kwargs.get("train_conf", {}).get("early_stopping_patience", 0)
+ best_val_loss = float("inf")
+ epochs_no_improve = 0
dataloader_tr, dataloader_val = None, None
for epoch in range(trainer.start_epoch, trainer.max_epoch):
@@ -178,19 +188,34 @@
)
trainer.start_step = 0
- torch.cuda.empty_cache()
+ device = next(model.parameters()).device
+ if device.type == "cuda":
+ with torch.cuda.device(device):
+ torch.cuda.empty_cache()
time_escaped = (time.perf_counter() - time_slice_i) / 3600.0
logging.info(
- f"rank: {local_rank}, "
+ f"\n\nrank: {local_rank}, "
f"time_escaped_epoch: {time_escaped:.3f} hours, "
- f"estimated to finish {dataloader.data_split_num} data_slices, remaining: {(dataloader.data_split_num-data_split_i)*time_escaped:.3f} hours"
- f"epoch: {((trainer.max_epoch - epoch - 1)*dataloader.data_split_num + dataloader.data_split_num-data_split_i)*time_escaped:.3f} hours\n"
+ f"estimated to finish {dataloader.data_split_num} data_slices, remaining: {dataloader.data_split_num-data_split_i} slices, {(dataloader.data_split_num-data_split_i)*time_escaped:.3f} hours, "
+ f"epoch: {trainer.max_epoch - epoch} epochs, {((trainer.max_epoch - epoch - 1)*dataloader.data_split_num + dataloader.data_split_num-data_split_i)*time_escaped:.3f} hours\n"
)
trainer.start_data_split_i = 0
trainer.validate_epoch(model=model, dataloader_val=dataloader_val, epoch=epoch + 1)
- scheduler.step()
+ current_val = trainer.val_loss_avg
+
+ if current_val < best_val_loss:
+ logging.info(f"current_val: {current_val}, best_val_loss: {best_val_loss}")
+ best_val_loss = current_val
+ epochs_no_improve = 0
+ else:
+ epochs_no_improve += 1
+ logging.info(f"No val_loss improvement for {epochs_no_improve}/{early_stopping_patience} epochs")
+ if early_stopping_patience > 0 and epochs_no_improve >= early_stopping_patience:
+ logging.info(f"Early stopping triggered at epoch {epoch+1}")
+ break
+
trainer.step_in_epoch = 0
trainer.save_checkpoint(
epoch + 1, model=model, optim=optim, scheduler=scheduler, scaler=scaler
@@ -199,7 +224,7 @@
time2 = time.perf_counter()
time_escaped = (time2 - time1) / 3600.0
logging.info(
- f"rank: {local_rank}, "
+ f"\n\nrank: {local_rank}, "
f"time_escaped_epoch: {time_escaped:.3f} hours, "
f"estimated to finish {trainer.max_epoch} "
f"epoch: {(trainer.max_epoch - epoch) * time_escaped:.3f} hours\n"
--
Gitblit v1.9.1