From 3a4281f4959534b1bf5d01acf0085f4f8e6f2ec8 Mon Sep 17 00:00:00 2001
From: wuhongsheng <664116298@qq.com>
Date: 星期五, 05 七月 2024 00:55:32 +0800
Subject: [PATCH] 优化speakid和语句匹配逻辑,部分解决speakid不从0递增问题 (#1870)
---
funasr/train_utils/trainer_ds.py | 23 +++++++++++++++--------
1 files changed, 15 insertions(+), 8 deletions(-)
diff --git a/funasr/train_utils/trainer_ds.py b/funasr/train_utils/trainer_ds.py
index ba8dd16..85513a5 100644
--- a/funasr/train_utils/trainer_ds.py
+++ b/funasr/train_utils/trainer_ds.py
@@ -29,9 +29,10 @@
with torch.cuda.amp.autocast(enabled=True, dtype=dtype, cache_enabled=False):
yield
else:
- if dtype == torch.float16:
- with autocast(enabled=True):
- yield
+ if dtype == torch.float16 or dtype == torch.bfloat16:
+ yield
+ # with autocast(enabled=True, dtype=dtype):
+ # yield
else:
yield
@@ -60,6 +61,7 @@
use_ddp: bool = False,
use_fsdp: bool = False,
use_fp16: bool = False,
+ use_bf16: bool = False,
use_deepspeed: bool = False,
output_dir: str = "./",
**kwargs,
@@ -78,7 +80,7 @@
output_dir (str): The directory where model checkpoints will be saved. Default is './'.
resume (str, optional): The file path to a checkpoint to resume training from.
"""
- self.rank = kwargs.get("rank", 0)
+ self.rank = rank
self.local_rank = local_rank
self.world_size = world_size
self.use_ddp = use_ddp
@@ -98,8 +100,11 @@
self.batch_total = 0
self.dtype = torch.float32
self.use_fp16 = use_fp16
+ self.use_bf16 = use_bf16
if self.use_fp16:
self.dtype = torch.float16
+ if self.use_bf16:
+ self.dtype = torch.bfloat16
self.save_checkpoint_interval = kwargs.get("save_checkpoint_interval", 5000)
self.validate_interval = kwargs.get("validate_interval", 5000)
self.keep_nbest_models = kwargs.get("keep_nbest_models", 500)
@@ -287,7 +292,9 @@
elif self.use_fsdp:
pass
elif self.rank == 0:
- logging.info(f"Save checkpoint: {epoch}, rank: {self.local_rank}\n")
+ logging.info(
+ f"Save checkpoint: {epoch}, rank: {self.rank}, local_rank: {self.local_rank}\n"
+ )
# self.step_or_epoch += 1
state = {
"epoch": epoch,
@@ -471,7 +478,7 @@
for k_ex in self.excludes:
k_tmp = k.replace("module.", "")
if k_tmp.startswith(k_ex):
- logging.info(f"key: {{k}} matching: {k_ex}, excluded")
+ logging.info(f"key: {k} matching: {k_ex}, excluded")
excludes_flag = True
break
if excludes_flag:
@@ -676,7 +683,7 @@
scaled_loss = model.backward(loss)
else:
loss = loss / self.accum_grad
- if self.use_fp16:
+ if self.use_fp16 or self.use_bf16:
scaler.scale(loss).backward()
else:
loss.backward()
@@ -704,7 +711,7 @@
# Execute an optimization step (update model parameters)
if self.use_ddp or self.use_fsdp:
dist.barrier()
- if self.use_fp16:
+ if self.use_fp16 or self.use_bf16:
scaler.step(optim)
scaler.update()
else:
--
Gitblit v1.9.1