From 2ed3f46f40cd5da19cad76a97b52c46b2869d5ed Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期二, 16 一月 2024 18:42:37 +0800
Subject: [PATCH] funasr1.0 finetune
---
funasr/train_utils/trainer.py | 22 ++++++++++++++++++++--
1 files changed, 20 insertions(+), 2 deletions(-)
diff --git a/funasr/train_utils/trainer.py b/funasr/train_utils/trainer.py
index 0f0acc2..da346c3 100644
--- a/funasr/train_utils/trainer.py
+++ b/funasr/train_utils/trainer.py
@@ -5,6 +5,8 @@
from tqdm import tqdm
import torch.distributed as dist
from contextlib import nullcontext
+# from torch.utils.tensorboard import SummaryWriter
+from tensorboardX import SummaryWriter
from funasr.train_utils.device_funcs import to_device
from funasr.train_utils.recursive_op import recursive_average
@@ -34,6 +36,7 @@
local_rank,
use_ddp=False,
use_fsdp=False,
+ output_dir: str="./",
**kwargs):
"""
Initializes the Trainer class with the model, optimizer, scheduler, dataloaders, and other settings.
@@ -55,7 +58,7 @@
self.scheduler = scheduler
self.dataloader_train = dataloader_train
self.dataloader_val = dataloader_val
- self.output_dir = kwargs.get('output_dir', './')
+ self.output_dir = output_dir
self.resume = kwargs.get('resume', True)
self.start_epoch = 0
self.max_epoch = kwargs.get('max_epoch', 100)
@@ -77,6 +80,10 @@
logging.warning("distributed is not initialized, only single shard")
self.rank = rank
self.world_size = world_size
+
+ os.makedirs(os.path.join(self.output_dir, "tensorboard"), exist_ok=True)
+ self.writer = SummaryWriter(os.path.join(self.output_dir, "tensorboard")) if rank == 0 else None
+
def _save_checkpoint(self, epoch):
"""
@@ -128,6 +135,8 @@
if self.rank == 0:
self._save_checkpoint(epoch)
self.scheduler.step()
+
+ self.writer.close()
def _train_epoch(self, epoch):
"""
@@ -215,7 +224,16 @@
f"{[(k, round(v.cpu().item(), 3)) for k, v in stats.items()]}"
)
pbar.set_description(description)
-
+ if self.writer:
+ self.writer.add_scalar('Loss/train', loss.item(),
+ epoch*len(self.dataloader_train) + batch_idx)
+ for key, var in stats.items():
+ self.writer.add_scalar(f'{key}/train', var.item(),
+ epoch * len(self.dataloader_train) + batch_idx)
+ for key, var in speed_stats.items():
+ self.writer.add_scalar(f'{key}/train', eval(var),
+ epoch * len(self.dataloader_train) + batch_idx)
+
# if batch_idx == 2:
# break
pbar.close()
--
Gitblit v1.9.1