From c553a8db1712c2a5deeef5bbb68bd1fdf8d61ab7 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 13 六月 2024 17:38:01 +0800
Subject: [PATCH] decoding

---
 funasr/train_utils/trainer_ds.py |   12 ++++++++----
 1 files changed, 8 insertions(+), 4 deletions(-)

diff --git a/funasr/train_utils/trainer_ds.py b/funasr/train_utils/trainer_ds.py
index e6b6284..8a0679c 100644
--- a/funasr/train_utils/trainer_ds.py
+++ b/funasr/train_utils/trainer_ds.py
@@ -29,8 +29,8 @@
         with torch.cuda.amp.autocast(enabled=True, dtype=dtype, cache_enabled=False):
             yield
     else:
-        if dtype == torch.float16:
-            with autocast(enabled=True):
+        if dtype == torch.float16 or dtype == torch.bfloat16:
+            with autocast(enabled=True, dtype=dtype):
                 yield
         else:
             yield
@@ -60,6 +60,7 @@
         use_ddp: bool = False,
         use_fsdp: bool = False,
         use_fp16: bool = False,
+        use_bf16: bool = False,
         use_deepspeed: bool = False,
         output_dir: str = "./",
         **kwargs,
@@ -98,8 +99,11 @@
         self.batch_total = 0
         self.dtype = torch.float32
         self.use_fp16 = use_fp16
+        self.use_bf16 = use_bf16
         if self.use_fp16:
             self.dtype = torch.float16
+        if self.use_bf16:
+            self.dtype = torch.bfloat16
         self.save_checkpoint_interval = kwargs.get("save_checkpoint_interval", 5000)
         self.validate_interval = kwargs.get("validate_interval", 5000)
         self.keep_nbest_models = kwargs.get("keep_nbest_models", 500)
@@ -678,7 +682,7 @@
             scaled_loss = model.backward(loss)
         else:
             loss = loss / self.accum_grad
-            if self.use_fp16:
+            if self.use_fp16 or self.use_bf16:
                 scaler.scale(loss).backward()
             else:
                 loss.backward()
@@ -706,7 +710,7 @@
                 # Execute an optimization step (update model parameters)
                 if self.use_ddp or self.use_fsdp:
                     dist.barrier()
-                if self.use_fp16:
+                if self.use_fp16 or self.use_bf16:
                     scaler.step(optim)
                     scaler.update()
                 else:

--
Gitblit v1.9.1