From d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99 Mon Sep 17 00:00:00 2001
From: liugz18 <57401541+liugz18@users.noreply.github.com>
Date: 星期四, 18 七月 2024 21:34:55 +0800
Subject: [PATCH] Rename 'res' in line 514 to avoid with naming conflict with line 365

---
 funasr/train_utils/trainer_ds.py |   19 ++++++++++++-------
 1 files changed, 12 insertions(+), 7 deletions(-)

diff --git a/funasr/train_utils/trainer_ds.py b/funasr/train_utils/trainer_ds.py
index 0fbac96..85513a5 100644
--- a/funasr/train_utils/trainer_ds.py
+++ b/funasr/train_utils/trainer_ds.py
@@ -29,9 +29,10 @@
         with torch.cuda.amp.autocast(enabled=True, dtype=dtype, cache_enabled=False):
             yield
     else:
-        if dtype == torch.float16:
-            with autocast(enabled=True):
-                yield
+        if dtype == torch.float16 or dtype == torch.bfloat16:
+            yield
+            # with autocast(enabled=True, dtype=dtype):
+            #     yield
         else:
             yield
 
@@ -60,6 +61,7 @@
         use_ddp: bool = False,
         use_fsdp: bool = False,
         use_fp16: bool = False,
+        use_bf16: bool = False,
         use_deepspeed: bool = False,
         output_dir: str = "./",
         **kwargs,
@@ -78,7 +80,7 @@
                       output_dir (str): The directory where model checkpoints will be saved. Default is './'.
                       resume (str, optional): The file path to a checkpoint to resume training from.
         """
-        self.rank = kwargs.get("rank", 0)
+        self.rank = rank
         self.local_rank = local_rank
         self.world_size = world_size
         self.use_ddp = use_ddp
@@ -98,8 +100,11 @@
         self.batch_total = 0
         self.dtype = torch.float32
         self.use_fp16 = use_fp16
+        self.use_bf16 = use_bf16
         if self.use_fp16:
             self.dtype = torch.float16
+        if self.use_bf16:
+            self.dtype = torch.bfloat16
         self.save_checkpoint_interval = kwargs.get("save_checkpoint_interval", 5000)
         self.validate_interval = kwargs.get("validate_interval", 5000)
         self.keep_nbest_models = kwargs.get("keep_nbest_models", 500)
@@ -473,7 +478,7 @@
                             for k_ex in self.excludes:
                                 k_tmp = k.replace("module.", "")
                                 if k_tmp.startswith(k_ex):
-                                    logging.info(f"key: {{k}} matching: {k_ex}, excluded")
+                                    logging.info(f"key: {k} matching: {k_ex}, excluded")
                                     excludes_flag = True
                                     break
                         if excludes_flag:
@@ -678,7 +683,7 @@
             scaled_loss = model.backward(loss)
         else:
             loss = loss / self.accum_grad
-            if self.use_fp16:
+            if self.use_fp16 or self.use_bf16:
                 scaler.scale(loss).backward()
             else:
                 loss.backward()
@@ -706,7 +711,7 @@
                 # Execute an optimization step (update model parameters)
                 if self.use_ddp or self.use_fsdp:
                     dist.barrier()
-                if self.use_fp16:
+                if self.use_fp16 or self.use_bf16:
                     scaler.step(optim)
                     scaler.update()
                 else:

--
Gitblit v1.9.1