| | |
| | | model=model, |
| | | dataloader_val=dataloader_val, |
| | | epoch=epoch, |
| | | writer=writer |
| | | writer=writer, |
| | | step=batch_idx+1, |
| | | ) |
| | | |
| | | if (batch_idx+1) % self.save_checkpoint_interval == 0: |
| | |
| | | f"step: {batch_idx + 1}/{batch_num_epoch}, total step: {self.batch_total}, " |
| | | f"(loss_avg_rank: {loss:.3f}), " |
| | | f"(loss_avg_epoch: {loss_avg_epoch:.3f}), " |
| | | f"(ppl_avg_epoch: {math.exp(loss_avg_epoch):.3f}), " |
| | | f"(ppl_avg_epoch: {math.exp(loss_avg_epoch):.3e}), " |
| | | f"(acc_avg_epoch: {acc_avg_epoch:.3f}), " |
| | | f"(lr: {lr:.3e}), " |
| | | f"{[(k, round(v.detach().cpu().item(), 3)) for k, v in stats.items()]}, " |