From d3ff05837bbc14749d09f44947633b87e8f2db0e Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期五, 17 五月 2024 11:47:27 +0800
Subject: [PATCH] deepspeed
---
funasr/models/sense_voice/decoder.py | 4 +
funasr/bin/train_ds.py | 2
funasr/train_utils/trainer_ds.py | 52 +++++++++----------------
funasr/models/sense_voice/whisper_lib/model.py | 4 +
funasr/models/sanm/attention.py | 16 ++++++--
5 files changed, 38 insertions(+), 40 deletions(-)
diff --git a/funasr/bin/train_ds.py b/funasr/bin/train_ds.py
index 032a0cf..b0931b0 100644
--- a/funasr/bin/train_ds.py
+++ b/funasr/bin/train_ds.py
@@ -133,7 +133,7 @@
kwargs["device"] = next(model.parameters()).device
trainer.device = kwargs["device"]
- model, optim, scheduler = trainer.warp_optim_scheduler(model, kwargs)
+ model, optim, scheduler = trainer.warp_optim_scheduler(model, **kwargs)
# dataset
logging.info("Build dataloader")
diff --git a/funasr/models/sanm/attention.py b/funasr/models/sanm/attention.py
index da8850f..08f7dc7 100644
--- a/funasr/models/sanm/attention.py
+++ b/funasr/models/sanm/attention.py
@@ -100,7 +100,9 @@
n_batch = value.size(0)
if mask is not None:
mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
- min_value = float(numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min)
+ min_value = -float(
+ "inf"
+ ) # float(numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min)
scores = scores.masked_fill(mask, min_value)
self.attn = torch.softmax(scores, dim=-1).masked_fill(
mask, 0.0
@@ -269,7 +271,9 @@
mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
- min_value = float(numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min)
+ min_value = -float(
+ "inf"
+ ) # float(numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min)
scores = scores.masked_fill(mask, min_value)
self.attn = torch.softmax(scores, dim=-1).masked_fill(
mask, 0.0
@@ -673,7 +677,9 @@
n_batch = value.size(0)
if mask is not None:
mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
- min_value = float(numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min)
+ min_value = -float(
+ "inf"
+ ) # float(numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min)
# logging.info(
# "scores: {}, mask_size: {}".format(scores.size(), mask.size()))
scores = scores.masked_fill(mask, min_value)
@@ -858,7 +864,9 @@
mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
- min_value = float(numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min)
+ min_value = -float(
+ "inf"
+ ) # float(numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min)
scores = scores.masked_fill(mask, min_value)
self.attn = torch.softmax(scores, dim=-1).masked_fill(
mask, 0.0
diff --git a/funasr/models/sense_voice/decoder.py b/funasr/models/sense_voice/decoder.py
index 03b7532..60af29a 100644
--- a/funasr/models/sense_voice/decoder.py
+++ b/funasr/models/sense_voice/decoder.py
@@ -146,7 +146,9 @@
qk = qk + mask[:n_ctx, :n_ctx]
else:
mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
- min_value = float(np.finfo(torch.tensor(0, dtype=qk.dtype).numpy().dtype).min)
+ min_value = -float(
+ "inf"
+ ) # min_value = float(np.finfo(torch.tensor(0, dtype=qk.dtype).numpy().dtype).min)
qk = qk.masked_fill(mask, min_value)
qk = qk.float()
diff --git a/funasr/models/sense_voice/whisper_lib/model.py b/funasr/models/sense_voice/whisper_lib/model.py
index 40939df..8b3d3ab 100644
--- a/funasr/models/sense_voice/whisper_lib/model.py
+++ b/funasr/models/sense_voice/whisper_lib/model.py
@@ -112,7 +112,9 @@
qk = qk + mask[:n_ctx, :n_ctx]
else:
mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
- min_value = float(np.finfo(torch.tensor(0, dtype=qk.dtype).numpy().dtype).min)
+ min_value = -float(
+ "inf"
+ ) # min_value = float(np.finfo(torch.tensor(0, dtype=qk.dtype).numpy().dtype).min)
qk = qk.masked_fill(mask, min_value)
qk = qk.float()
diff --git a/funasr/train_utils/trainer_ds.py b/funasr/train_utils/trainer_ds.py
index f3a0309..aa3c5af 100644
--- a/funasr/train_utils/trainer_ds.py
+++ b/funasr/train_utils/trainer_ds.py
@@ -78,7 +78,7 @@
self.world_size = world_size
self.use_ddp = use_ddp
self.use_fsdp = use_fsdp
- self.use_deepspeed = use_deepspeed
+
self.device = kwargs.get("device", "cuda")
self.output_dir = output_dir
@@ -136,6 +136,9 @@
self.writer = SummaryWriter(tensorboard_dir) # if trainer.rank == 0 else None
except:
self.writer = None
+
+ self.use_deepspeed = use_deepspeed
+ self.deepspeed_config = kwargs.get("deepspeed_config", "")
def save_checkpoint(
self,
@@ -443,7 +446,8 @@
iterator_stop = torch.tensor(0).to(self.device)
def forward_step(self, model, batch, loss_dict={}):
- with maybe_autocast(self.use_fp16):
+ dtype = torch.bfloat16
+ with torch.cuda.amp.autocast(enabled=True, dtype=dtype, cache_enabled=False):
retval = model(**batch)
loss, stats, weight = retval
@@ -465,7 +469,7 @@
else:
loss.backward()
- def update_step(self, model, optim, scheduler, scaler, batch_idx=0, loss_dict=loss_dict):
+ def update_step(self, model, optim, scheduler, scaler, batch_idx=0, loss_dict=None):
if self.use_deepspeed:
model.step()
@@ -613,7 +617,7 @@
loss = loss_dict["loss"].detach().cpu().item()
epoch = loss_dict["epoch"]
batch_idx = loss_dict["batch_idx"]
- step_in_epoch = loss_dict["step_in_epoch"]
+ step_in_epoch = self.step_in_epoch
batch_num_epoch = loss_dict["batch_num_epoch"]
lr = loss_dict["lr"]
@@ -732,36 +736,18 @@
"find_unused_parameters", False
),
)
- # elif self.use_fsdp:
- # # model = FSDP(model).cuda(local_rank)
- #
- # def custom_auto_wrap_policy(
- # module: nn.Module,
- # recurse: bool,
- # nonwrapped_numel: int,
- # # Additional custom arguments
- # min_num_params: int = int(1e8),
- # ) -> bool:
- # # 鏍规嵁鑷畾涔夐�昏緫鍐冲畾鏄惁鍖呰妯″潡
- # is_large = unwrapped_params >= min_num_params
- # requires_grad_uniform = len({p.requires_grad for p in module.parameters()}) == 1
- # return is_large and requires_grad_uniform
- #
- # # Configure a custom `min_num_params`
- # my_auto_wrap_policy = functools.partial(custom_auto_wrap_policy, min_num_params=int(1e5))
- # torch.cuda.set_device(local_rank)
- # model = FSDP(
- # model,
- # auto_wrap_policy=custom_auto_wrap_policy,
- # mixed_precision=None,
- # device_id=torch.cuda.current_device(),
- # )
+
else:
model = model.to(device=kwargs.get("device", "cuda"))
return model
def warp_optim_scheduler(self, model, **kwargs):
+ from funasr.optimizers import optim_classes
+ from funasr.schedulers import scheduler_classes
+ from omegaconf import OmegaConf, DictConfig
+ import json
+ import deepspeed
# optim
logging.info("Build optim")
@@ -777,15 +763,16 @@
scheduler_class = scheduler_classes.get(scheduler)
scheduler = scheduler_class(optim, **kwargs.get("scheduler_conf"))
- if use_deepspeed:
- deepspeed_config = kwargs.get("deepspeed_config", "")
- with open(deepspeed_config, "r") as fin:
+ if self.use_deepspeed:
+
+ args = OmegaConf.create({"deepspeed_config": self.deepspeed_config})
+ with open(self.deepspeed_config, "r") as fin:
ds_configs = json.load(fin)
if "optimizer" in ds_configs:
# NOTE(xcsong): Disable custom optimizer if it is set in ds_config,
# extremely useful when enable cpu_offload, DeepspeedCpuAdam
# could be 4~5x faster than torch native adam
- deepspeed_config = None
+ optim = None
if "scheduler" in ds_configs:
scheduler = None
else:
@@ -793,7 +780,6 @@
def scheduler(opt):
return scheduler_class(opt, **kwargs.get("scheduler_conf"))
- args = OmegaConf.create({"deepspeed_config": deepspeed_config})
model, optimizer, _, scheduler = deepspeed.initialize(
args=args,
model=model,
--
Gitblit v1.9.1