From d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99 Mon Sep 17 00:00:00 2001
From: liugz18 <57401541+liugz18@users.noreply.github.com>
Date: 星期四, 18 七月 2024 21:34:55 +0800
Subject: [PATCH] Rename 'res' in line 514 to avoid with naming conflict with line 365
---
funasr/train_utils/trainer_ds.py | 19 ++++++++++++-------
1 files changed, 12 insertions(+), 7 deletions(-)
diff --git a/funasr/train_utils/trainer_ds.py b/funasr/train_utils/trainer_ds.py
index 0fbac96..85513a5 100644
--- a/funasr/train_utils/trainer_ds.py
+++ b/funasr/train_utils/trainer_ds.py
@@ -29,9 +29,10 @@
with torch.cuda.amp.autocast(enabled=True, dtype=dtype, cache_enabled=False):
yield
else:
- if dtype == torch.float16:
- with autocast(enabled=True):
- yield
+ if dtype == torch.float16 or dtype == torch.bfloat16:
+ yield
+ # with autocast(enabled=True, dtype=dtype):
+ # yield
else:
yield
@@ -60,6 +61,7 @@
use_ddp: bool = False,
use_fsdp: bool = False,
use_fp16: bool = False,
+ use_bf16: bool = False,
use_deepspeed: bool = False,
output_dir: str = "./",
**kwargs,
@@ -78,7 +80,7 @@
output_dir (str): The directory where model checkpoints will be saved. Default is './'.
resume (str, optional): The file path to a checkpoint to resume training from.
"""
- self.rank = kwargs.get("rank", 0)
+ self.rank = rank
self.local_rank = local_rank
self.world_size = world_size
self.use_ddp = use_ddp
@@ -98,8 +100,11 @@
self.batch_total = 0
self.dtype = torch.float32
self.use_fp16 = use_fp16
+ self.use_bf16 = use_bf16
if self.use_fp16:
self.dtype = torch.float16
+ if self.use_bf16:
+ self.dtype = torch.bfloat16
self.save_checkpoint_interval = kwargs.get("save_checkpoint_interval", 5000)
self.validate_interval = kwargs.get("validate_interval", 5000)
self.keep_nbest_models = kwargs.get("keep_nbest_models", 500)
@@ -473,7 +478,7 @@
for k_ex in self.excludes:
k_tmp = k.replace("module.", "")
if k_tmp.startswith(k_ex):
- logging.info(f"key: {{k}} matching: {k_ex}, excluded")
+ logging.info(f"key: {k} matching: {k_ex}, excluded")
excludes_flag = True
break
if excludes_flag:
@@ -678,7 +683,7 @@
scaled_loss = model.backward(loss)
else:
loss = loss / self.accum_grad
- if self.use_fp16:
+ if self.use_fp16 or self.use_bf16:
scaler.scale(loss).backward()
else:
loss.backward()
@@ -706,7 +711,7 @@
# Execute an optimization step (update model parameters)
if self.use_ddp or self.use_fsdp:
dist.barrier()
- if self.use_fp16:
+ if self.use_fp16 or self.use_bf16:
scaler.step(optim)
scaler.update()
else:
--
Gitblit v1.9.1