Shi Xian
2024-01-15 ddbc8b5eded1fff6084001d160d46b532020ecb7
funasr/train_utils/trainer.py
@@ -1,13 +1,15 @@
import torch
import os
from funasr.train_utils.device_funcs import to_device
import logging
import time
import torch
import logging
from tqdm import tqdm
from contextlib import nullcontext
import torch.distributed as dist
from contextlib import nullcontext
from funasr.train_utils.device_funcs import to_device
from funasr.train_utils.recursive_op import recursive_average
class Trainer:
   """
   A simple trainer class for training a PyTorch model, saving checkpoints at the end of each epoch,