funasr/train_utils/trainer.py
@@ -1,13 +1,15 @@ import torch import os from funasr.train_utils.device_funcs import to_device import logging import time import torch import logging from tqdm import tqdm from contextlib import nullcontext import torch.distributed as dist from contextlib import nullcontext from funasr.train_utils.device_funcs import to_device from funasr.train_utils.recursive_op import recursive_average class Trainer: """ A simple trainer class for training a PyTorch model, saving checkpoints at the end of each epoch,