From c553a8db1712c2a5deeef5bbb68bd1fdf8d61ab7 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 13 六月 2024 17:38:01 +0800
Subject: [PATCH] decoding
---
funasr/train_utils/trainer_ds.py | 12 ++++++++----
1 files changed, 8 insertions(+), 4 deletions(-)
diff --git a/funasr/train_utils/trainer_ds.py b/funasr/train_utils/trainer_ds.py
index e6b6284..8a0679c 100644
--- a/funasr/train_utils/trainer_ds.py
+++ b/funasr/train_utils/trainer_ds.py
@@ -29,8 +29,8 @@
with torch.cuda.amp.autocast(enabled=True, dtype=dtype, cache_enabled=False):
yield
else:
- if dtype == torch.float16:
- with autocast(enabled=True):
+ if dtype == torch.float16 or dtype == torch.bfloat16:
+ with autocast(enabled=True, dtype=dtype):
yield
else:
yield
@@ -60,6 +60,7 @@
use_ddp: bool = False,
use_fsdp: bool = False,
use_fp16: bool = False,
+ use_bf16: bool = False,
use_deepspeed: bool = False,
output_dir: str = "./",
**kwargs,
@@ -98,8 +99,11 @@
self.batch_total = 0
self.dtype = torch.float32
self.use_fp16 = use_fp16
+ self.use_bf16 = use_bf16
if self.use_fp16:
self.dtype = torch.float16
+ if self.use_bf16:
+ self.dtype = torch.bfloat16
self.save_checkpoint_interval = kwargs.get("save_checkpoint_interval", 5000)
self.validate_interval = kwargs.get("validate_interval", 5000)
self.keep_nbest_models = kwargs.get("keep_nbest_models", 500)
@@ -678,7 +682,7 @@
scaled_loss = model.backward(loss)
else:
loss = loss / self.accum_grad
- if self.use_fp16:
+ if self.use_fp16 or self.use_bf16:
scaler.scale(loss).backward()
else:
loss.backward()
@@ -706,7 +710,7 @@
# Execute an optimization step (update model parameters)
if self.use_ddp or self.use_fsdp:
dist.barrier()
- if self.use_fp16:
+ if self.use_fp16 or self.use_bf16:
scaler.step(optim)
scaler.update()
else:
--
Gitblit v1.9.1