游雁
2023-12-06 1c2eb051cdcc6890af9ba64b10b9a0152288469a
funasr2
1个文件已修改
8 ■■■■■ 已修改文件
funasr/cli/trainer.py 8 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/cli/trainer.py
@@ -4,6 +4,7 @@
import logging
from tqdm import tqdm
from contextlib import nullcontext
import torch.distributed as dist
class Trainer:
    """
@@ -80,7 +81,7 @@
        }
        # Create output directory if it does not exist
        os.makedirs(self.output_dir, exist_ok=True)
        filename = os.path.join(self.output_dir, f'model.{epoch}.pb')
        filename = os.path.join(self.output_dir, f'model.e{epoch}.pb')
        torch.save(state, filename)
        print(f'Checkpoint saved to {filename}')
    
@@ -110,8 +111,9 @@
        for epoch in range(self.start_epoch, self.max_epoch + 1):
            self._train_epoch(epoch)
            # self._validate_epoch(epoch)
            self._save_checkpoint(epoch)
            self.scheduler.step()
            if dist.get_rank() == 0:
                self._save_checkpoint(epoch)
            # self.scheduler.step()
    
    def _train_epoch(self, epoch):
        """