From 0e622e694e6cb4459955f1e5942a7c53349ce640 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期二, 19 十二月 2023 21:58:14 +0800
Subject: [PATCH] funasr2
---
funasr/train_utils/trainer.py | 13 +++++++------
1 files changed, 7 insertions(+), 6 deletions(-)
diff --git a/funasr/train_utils/trainer.py b/funasr/train_utils/trainer.py
index 425b79f..ea502f7 100644
--- a/funasr/train_utils/trainer.py
+++ b/funasr/train_utils/trainer.py
@@ -55,7 +55,7 @@
self.dataloader_val = dataloader_val
self.output_dir = kwargs.get('output_dir', './')
self.resume = kwargs.get('resume', True)
- self.start_epoch = 1
+ self.start_epoch = 0
self.max_epoch = kwargs.get('max_epoch', 100)
self.local_rank = local_rank
self.use_ddp = use_ddp
@@ -123,7 +123,7 @@
for epoch in range(self.start_epoch, self.max_epoch + 1):
self._train_epoch(epoch)
# self._validate_epoch(epoch)
- if dist.get_rank() == 0:
+ if self.rank == 0:
self._save_checkpoint(epoch)
self.scheduler.step()
break
@@ -201,21 +201,22 @@
speed_stats["optim_time"] = f"{time5 - time4:0.3f}"
speed_stats["total_time"] = total_time
-
+ # import pdb;
+ # pdb.set_trace()
pbar.update(1)
if self.local_rank == 0:
description = (
f"Epoch: {epoch + 1}/{self.max_epoch}, "
f"step {batch_idx}/{len(self.dataloader_train)}, "
f"{speed_stats}, "
- f"(loss: {loss.detach().float():.3f}), "
+ f"(loss: {loss.detach().cpu().item():.3f}), "
f"{[(k, round(v.cpu().item(), 3)) for k, v in stats.items()]}"
)
pbar.set_description(description)
- if batch_idx == 2:
- break
+ # if batch_idx == 2:
+ # break
pbar.close()
def _validate_epoch(self, epoch):
--
Gitblit v1.9.1