From 1c2eb051cdcc6890af9ba64b10b9a0152288469a Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 06 十二月 2023 19:45:49 +0800
Subject: [PATCH] funasr2
---
funasr/cli/trainer.py | 8 +++++---
1 files changed, 5 insertions(+), 3 deletions(-)
diff --git a/funasr/cli/trainer.py b/funasr/cli/trainer.py
index 74e058f..ee5af0f 100644
--- a/funasr/cli/trainer.py
+++ b/funasr/cli/trainer.py
@@ -4,6 +4,7 @@
import logging
from tqdm import tqdm
from contextlib import nullcontext
+import torch.distributed as dist
class Trainer:
"""
@@ -80,7 +81,7 @@
}
# Create output directory if it does not exist
os.makedirs(self.output_dir, exist_ok=True)
- filename = os.path.join(self.output_dir, f'model.{epoch}.pb')
+ filename = os.path.join(self.output_dir, f'model.e{epoch}.pb')
torch.save(state, filename)
print(f'Checkpoint saved to {filename}')
@@ -110,8 +111,9 @@
for epoch in range(self.start_epoch, self.max_epoch + 1):
self._train_epoch(epoch)
# self._validate_epoch(epoch)
- self._save_checkpoint(epoch)
- self.scheduler.step()
+ if dist.get_rank() == 0:
+ self._save_checkpoint(epoch)
+ # self.scheduler.step()
def _train_epoch(self, epoch):
"""
--
Gitblit v1.9.1