From 6e69d784e4814c3dbe35e8f70c6cf4b920c8b20b Mon Sep 17 00:00:00 2001
From: 天地 <tiandiweizun@gmail.com>
Date: 星期三, 19 三月 2025 23:10:13 +0800
Subject: [PATCH] 1. bug fix:list(mean)和list(var),由于mean和var是numpy,导致写入到文件的格式错误,参考上面的话,大概率是list(mean.tolist()),其实外层list没有必要 (#2437)

---
 funasr/bin/train_ds.py |   16 +++++++++++-----
 1 files changed, 11 insertions(+), 5 deletions(-)

diff --git a/funasr/bin/train_ds.py b/funasr/bin/train_ds.py
index c55f950..10a5d08 100644
--- a/funasr/bin/train_ds.py
+++ b/funasr/bin/train_ds.py
@@ -27,7 +27,7 @@
 from funasr.train_utils.trainer_ds import Trainer
 from funasr.schedulers import scheduler_classes
 from funasr.train_utils.initialize import initialize
-from funasr.download.download_from_hub import download_model
+from funasr.download.download_model_from_hub import download_model
 from funasr.models.lora.utils import mark_only_lora_as_trainable
 from funasr.train_utils.set_all_random_seed import set_all_random_seed
 from funasr.train_utils.load_pretrained_model import load_pretrained_model
@@ -81,7 +81,10 @@
         deepspeed.init_distributed(dist_backend=kwargs.get("backend", "nccl"))
     elif use_ddp or use_fsdp:
         logging.info(f"use_ddp: {use_ddp}, use_fsdp: {use_fsdp}")
-        dist.init_process_group(backend=kwargs.get("backend", "nccl"), init_method="env://")
+        dist.init_process_group(
+            backend=kwargs.get("backend", "nccl"),
+            init_method="env://",
+        )
         torch.cuda.set_device(local_rank)
 
     # rank = dist.get_rank()
@@ -131,7 +134,7 @@
         **kwargs.get("train_conf"),
     )
 
-    model = trainer.warp_model(model)
+    model = trainer.warp_model(model, **kwargs)
 
     kwargs["device"] = int(os.environ.get("LOCAL_RANK", 0))
     trainer.device = int(os.environ.get("LOCAL_RANK", 0))
@@ -146,7 +149,7 @@
     dataloader = dataloader_class(**kwargs)
     # dataloader_tr, dataloader_val = dataloader_class(**kwargs)
 
-    scaler = GradScaler(enabled=trainer.use_fp16) if trainer.use_fp16 else None
+    scaler = GradScaler(enabled=True) if trainer.use_fp16 or trainer.use_bf16 else None
     scaler = ShardedGradScaler(enabled=trainer.use_fp16) if trainer.use_fsdp else scaler
 
     trainer.resume_checkpoint(
@@ -181,7 +184,10 @@
             )
             trainer.start_step = 0
 
-            torch.cuda.empty_cache()
+            device = next(model.parameters()).device
+            if device.type == "cuda":
+                with torch.cuda.device(device):
+                    torch.cuda.empty_cache()
 
             time_escaped = (time.perf_counter() - time_slice_i) / 3600.0
             logging.info(

--
Gitblit v1.9.1