From 54b6ff57647e28bbe88d8df81f2b112f127660e2 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 22 二月 2024 23:52:22 +0800
Subject: [PATCH] fp16
---
funasr/train_utils/trainer.py | 43 ++++++++++++++++++++++++++++++++++---------
1 files changed, 34 insertions(+), 9 deletions(-)
diff --git a/funasr/train_utils/trainer.py b/funasr/train_utils/trainer.py
index d175fbe..5b280bf 100644
--- a/funasr/train_utils/trainer.py
+++ b/funasr/train_utils/trainer.py
@@ -5,7 +5,8 @@
from tqdm import tqdm
from datetime import datetime
import torch.distributed as dist
-from contextlib import nullcontext
+from torch.cuda.amp import autocast, GradScaler
+from contextlib import nullcontext, contextmanager
# from torch.utils.tensorboard import SummaryWriter
from tensorboardX import SummaryWriter
from pathlib import Path
@@ -13,6 +14,14 @@
from funasr.train_utils.device_funcs import to_device
from funasr.train_utils.recursive_op import recursive_average
from funasr.train_utils.average_nbest_models import average_checkpoints
+
+@contextmanager
+def maybe_autocast(enabled):
+ if enabled:
+ with autocast():
+ yield
+ else:
+ yield
class Trainer:
"""
@@ -36,8 +45,9 @@
dataloader_train,
dataloader_val,
local_rank,
- use_ddp=False,
- use_fsdp=False,
+ use_ddp: bool = False,
+ use_fsdp: bool = False,
+ use_fp16: bool = False,
output_dir: str="./",
**kwargs):
"""
@@ -72,6 +82,9 @@
self.kwargs = kwargs
self.log_interval = kwargs.get("log_interval", 50)
self.batch_total = 0
+ self.use_fp16 = use_fp16
+ self.disable_gpu_cache = kwargs.get("disable_gpu_cache", True)
+ self.scaler = GradScaler(enabled=use_fp16) if use_fp16 else None
try:
@@ -103,6 +116,8 @@
'optimizer': self.optim.state_dict(),
'scheduler': self.scheduler.state_dict(),
}
+ if self.scaler:
+ state["scaler_state"] = self.scaler.state_dict()
# Create output directory if it does not exist
os.makedirs(self.output_dir, exist_ok=True)
filename = os.path.join(self.output_dir, f'model.pt.ep{epoch}')
@@ -141,6 +156,8 @@
self.model.load_state_dict(dst_state)
self.optim.load_state_dict(checkpoint['optimizer'])
self.scheduler.load_state_dict(checkpoint['scheduler'])
+ if self.scaler and 'scaler_state' in checkpoint:
+ self.scaler.load_state_dict(checkpoint['scaler_state'])
print(f"Checkpoint loaded successfully from '{ckpt}'")
else:
print(f"No checkpoint found at '{ckpt}', starting from scratch")
@@ -221,9 +238,10 @@
my_context = self.model.no_sync if batch_idx % accum_grad != 0 else nullcontext
with my_context():
time2 = time.perf_counter()
-
- retval = self.model(**batch)
- torch.cuda.empty_cache()
+ with maybe_autocast(self.use_fp16):
+ retval = self.model(**batch)
+
+ if self.disable_gpu_cache: torch.cuda.empty_cache()
time3 = time.perf_counter()
speed_stats["forward_time"] = f"{time3 - time2:0.3f}"
@@ -241,7 +259,10 @@
loss *= self.world_size
# Scale the loss since we're not updating for every mini-batch
loss = loss / accum_grad
- loss.backward()
+ if self.use_fp16:
+ self.scaler.scale(loss).backward()
+ else:
+ loss.backward()
time4 = time.perf_counter()
speed_stats["backward_time"] = f"{time4 - time3:0.3f}"
@@ -264,10 +285,14 @@
# Execute an optimization step (update model parameters)
if self.use_ddp or self.use_fsdp:
dist.barrier()
- self.optim.step()
+ if self.use_fp16:
+ self.scaler.step(self.optim)
+ self.scaler.update()
+ else:
+ self.optim.step()
self.scheduler.step()
# Clear gradients for the next accumulation stage
- self.optim.zero_grad()
+ self.optim.zero_grad(set_to_none=True)
total_time = f"{time.perf_counter() - time5:0.3f}"
time5 = time.perf_counter()
speed_stats["optim_time"] = f"{time5 - time4:0.3f}"
--
Gitblit v1.9.1