| | |
| | | from funasr.torch_utils.device_funcs import to_device |
| | | from funasr.torch_utils.recursive_op import recursive_average |
| | | from funasr.torch_utils.set_all_random_seed import set_all_random_seed |
| | | from funasr.train.abs_espnet_model import AbsESPnetModel |
| | | from funasr.models.base_model import FunASRModel |
| | | from funasr.train.distributed_utils import DistributedOption |
| | | from funasr.train.reporter import Reporter |
| | | from funasr.train.reporter import SubReporter |
| | |
| | | @classmethod |
| | | def run( |
| | | cls, |
| | | model: AbsESPnetModel, |
| | | model: FunASRModel, |
| | | optimizers: Sequence[torch.optim.Optimizer], |
| | | schedulers: Sequence[Optional[AbsScheduler]], |
| | | train_iter_factory: AbsIterFactory, |
| | |
| | | reporter.measure_iter_time(iterator, "iter_time"), 1 |
| | | ): |
| | | assert isinstance(batch, dict), type(batch) |
| | | |
| | | if rank == 0 and hasattr(model.module, "num_updates"): |
| | | num_batch_updates = model.module.get_num_updates() |
| | | |
| | | if rank == 0: |
| | | if hasattr(model, "num_updates") or (hasattr(model, "module") and hasattr(model.module, "num_updates")): |
| | | num_batch_updates = model.get_num_updates() if hasattr(model,"num_updates") else model.module.get_num_updates() |
| | | if (num_batch_updates%batch_interval == 0) and (options.oss_bucket is not None) and options.use_pai: |
| | | buffer = BytesIO() |
| | | torch.save(model.state_dict(), buffer) |