| | |
| | | from funasr.torch_utils.device_funcs import to_device |
| | | from funasr.torch_utils.recursive_op import recursive_average |
| | | from funasr.torch_utils.set_all_random_seed import set_all_random_seed |
| | | from funasr.train.abs_espnet_model import AbsESPnetModel |
| | | from funasr.models.base_model import FunASRModel |
| | | from funasr.train.distributed_utils import DistributedOption |
| | | from funasr.train.reporter import Reporter |
| | | from funasr.train.reporter import SubReporter |
| | |
| | | @classmethod |
| | | def run( |
| | | cls, |
| | | model: AbsESPnetModel, |
| | | model: FunASRModel, |
| | | optimizers: Sequence[torch.optim.Optimizer], |
| | | schedulers: Sequence[Optional[AbsScheduler]], |
| | | train_iter_factory: AbsIterFactory, |