From 0f3d2d1266fec1b1666be1a4e0a2d234eb152567 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 11 十二月 2024 14:21:57 +0800
Subject: [PATCH] v1.1.17

---
 funasr/bin/train_ds.py |   14 +++++++++-----
 1 files changed, 9 insertions(+), 5 deletions(-)

diff --git a/funasr/bin/train_ds.py b/funasr/bin/train_ds.py
index da99adc..dc7fb42 100644
--- a/funasr/bin/train_ds.py
+++ b/funasr/bin/train_ds.py
@@ -27,7 +27,7 @@
 from funasr.train_utils.trainer_ds import Trainer
 from funasr.schedulers import scheduler_classes
 from funasr.train_utils.initialize import initialize
-from funasr.download.download_from_hub import download_model
+from funasr.download.download_model_from_hub import download_model
 from funasr.models.lora.utils import mark_only_lora_as_trainable
 from funasr.train_utils.set_all_random_seed import set_all_random_seed
 from funasr.train_utils.load_pretrained_model import load_pretrained_model
@@ -66,6 +66,7 @@
     # open tf32
     torch.backends.cuda.matmul.allow_tf32 = kwargs.get("enable_tf32", True)
 
+    rank = int(os.environ.get("RANK", 0))
     local_rank = int(os.environ.get("LOCAL_RANK", 0))
     world_size = int(os.environ.get("WORLD_SIZE", 1))
 
@@ -80,10 +81,13 @@
         deepspeed.init_distributed(dist_backend=kwargs.get("backend", "nccl"))
     elif use_ddp or use_fsdp:
         logging.info(f"use_ddp: {use_ddp}, use_fsdp: {use_fsdp}")
-        dist.init_process_group(backend=kwargs.get("backend", "nccl"), init_method="env://")
+        dist.init_process_group(
+            backend=kwargs.get("backend", "nccl"),
+            init_method="env://",
+        )
         torch.cuda.set_device(local_rank)
 
-    rank = dist.get_rank()
+    # rank = dist.get_rank()
 
     logging.info("Build model, frontend, tokenizer")
     device = kwargs.get("device", "cuda")
@@ -130,7 +134,7 @@
         **kwargs.get("train_conf"),
     )
 
-    model = trainer.warp_model(model)
+    model = trainer.warp_model(model, **kwargs)
 
     kwargs["device"] = int(os.environ.get("LOCAL_RANK", 0))
     trainer.device = int(os.environ.get("LOCAL_RANK", 0))
@@ -145,7 +149,7 @@
     dataloader = dataloader_class(**kwargs)
     # dataloader_tr, dataloader_val = dataloader_class(**kwargs)
 
-    scaler = GradScaler(enabled=trainer.use_fp16) if trainer.use_fp16 else None
+    scaler = GradScaler(enabled=True) if trainer.use_fp16 or trainer.use_bf16 else None
     scaler = ShardedGradScaler(enabled=trainer.use_fp16) if trainer.use_fsdp else scaler
 
     trainer.resume_checkpoint(

--
Gitblit v1.9.1