嘉渊
2023-05-18 17eaf419c05853a4ecb8dfd3a0e8ebf26a1dfb1b
funasr/train/trainer.py
@@ -3,7 +3,6 @@
"""Trainer module."""
import argparse
from audioop import bias
from contextlib import contextmanager
import dataclasses
from dataclasses import is_dataclass
@@ -40,7 +39,7 @@
from funasr.torch_utils.device_funcs import to_device
from funasr.torch_utils.recursive_op import recursive_average
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
from funasr.train.abs_espnet_model import AbsESPnetModel
from funasr.models.base_model import FunASRModel
from funasr.train.distributed_utils import DistributedOption
from funasr.train.reporter import Reporter
from funasr.train.reporter import SubReporter
@@ -167,7 +166,7 @@
    @classmethod
    def run(
        cls,
        model: AbsESPnetModel,
        model: FunASRModel,
        optimizers: Sequence[torch.optim.Optimizer],
        schedulers: Sequence[Optional[AbsScheduler]],
        train_iter_factory: AbsIterFactory,