| | |
| | | |
| | | """Trainer module.""" |
| | | import argparse |
| | | from contextlib import contextmanager |
| | | import dataclasses |
| | | import logging |
| | | import os |
| | | import time |
| | | from contextlib import contextmanager |
| | | from dataclasses import is_dataclass |
| | | from distutils.version import LooseVersion |
| | | import logging |
| | | from io import BytesIO |
| | | from pathlib import Path |
| | | import time |
| | | from typing import Dict |
| | | from typing import Iterable |
| | | from typing import List |
| | |
| | | |
| | | import humanfriendly |
| | | import oss2 |
| | | from io import BytesIO |
| | | import os |
| | | import numpy as np |
| | | import torch |
| | | import torch.nn |
| | | import torch.optim |
| | |
| | | |
| | | from funasr.iterators.abs_iter_factory import AbsIterFactory |
| | | from funasr.main_funcs.average_nbest_models import average_nbest_models |
| | | from funasr.main_funcs.calculate_all_attentions import calculate_all_attentions |
| | | from funasr.models.base_model import FunASRModel |
| | | from funasr.schedulers.abs_scheduler import AbsBatchStepScheduler |
| | | from funasr.schedulers.abs_scheduler import AbsEpochStepScheduler |
| | | from funasr.schedulers.abs_scheduler import AbsScheduler |
| | |
| | | from funasr.torch_utils.device_funcs import to_device |
| | | from funasr.torch_utils.recursive_op import recursive_average |
| | | from funasr.torch_utils.set_all_random_seed import set_all_random_seed |
| | | from funasr.models.base_model import FunASRModel |
| | | from funasr.train.distributed_utils import DistributedOption |
| | | from funasr.train.reporter import Reporter |
| | | from funasr.train.reporter import SubReporter |
| | |
| | | wandb_model_log_interval: int |
| | | use_pai: bool |
| | | oss_bucket: Union[oss2.Bucket, None] |
| | | batch_interval: int |
| | | |
| | | |
| | | class Trainer: |
| | | """Trainer having a optimizer. |
| | | |
| | | If you'd like to use multiple optimizers, then inherit this class |
| | | and override the methods if necessary - at least "train_one_epoch()" |
| | | |
| | | >>> class TwoOptimizerTrainer(Trainer): |
| | | ... @classmethod |
| | | ... def add_arguments(cls, parser): |
| | | ... ... |
| | | ... |
| | | ... @classmethod |
| | | ... def train_one_epoch(cls, model, optimizers, ...): |
| | | ... loss1 = model.model1(...) |
| | | ... loss1.backward() |
| | | ... optimizers[0].step() |
| | | ... |
| | | ... loss2 = model.model2(...) |
| | | ... loss2.backward() |
| | | ... optimizers[1].step() |
| | | """Trainer |
| | | |
| | | """ |
| | | |
| | | def __init__(self): |
| | | raise RuntimeError("This class can't be instantiated.") |
| | | def __init__(self, |
| | | args, |
| | | model: FunASRModel, |
| | | optimizers: Sequence[torch.optim.Optimizer], |
| | | schedulers: Sequence[Optional[AbsScheduler]], |
| | | train_dataloader: AbsIterFactory, |
| | | valid_dataloader: AbsIterFactory, |
| | | trainer_options, |
| | | distributed_option: DistributedOption): |
| | | self.trainer_options = self.build_options(args) |
| | | self.model = model |
| | | self.optimizers = optimizers |
| | | self.schedulers = schedulers |
| | | self.train_dataloader = train_dataloader |
| | | self.valid_dataloader = valid_dataloader |
| | | self.trainer_options = trainer_options |
| | | self.distributed_option = distributed_option |
| | | |
| | | @classmethod |
| | | def build_options(cls, args: argparse.Namespace) -> TrainerOptions: |
| | | def build_options(self, args: argparse.Namespace) -> TrainerOptions: |
| | | """Build options consumed by train(), eval()""" |
| | | assert check_argument_types() |
| | | return build_dataclass(TrainerOptions, args) |
| | |
| | | |
| | | logging.info(f"The training was resumed using {checkpoint}") |
| | | |
| | | @classmethod |
| | | def run( |
| | | cls, |
| | | model: FunASRModel, |
| | | optimizers: Sequence[torch.optim.Optimizer], |
| | | schedulers: Sequence[Optional[AbsScheduler]], |
| | | train_iter_factory: AbsIterFactory, |
| | | valid_iter_factory: AbsIterFactory, |
| | | trainer_options, |
| | | distributed_option: DistributedOption, |
| | | ) -> None: |
| | | def run(self) -> None: |
| | | """Perform training. This method performs the main process of training.""" |
| | | assert check_argument_types() |
| | | # NOTE(kamo): Don't check the type more strictly as far trainer_options |
| | | model = self.model |
| | | optimizers = self.optimizers |
| | | schedulers = self.schedulers |
| | | train_dataloader = self.train_dataloader |
| | | valid_dataloader = self.valid_dataloader |
| | | trainer_options = self.trainer_options |
| | | distributed_option = self.distributed_option |
| | | assert is_dataclass(trainer_options), type(trainer_options) |
| | | assert len(optimizers) == len(schedulers), (len(optimizers), len(schedulers)) |
| | | |
| | |
| | | logging.warning("No keep_nbest_models is given. Change to [1]") |
| | | trainer_options.keep_nbest_models = [1] |
| | | keep_nbest_models = trainer_options.keep_nbest_models |
| | | |
| | | # assert batch_interval is set and >0 |
| | | assert trainer_options.batch_interval > 0 |
| | | |
| | | output_dir = Path(trainer_options.output_dir) |
| | | reporter = Reporter() |
| | |
| | | scaler = None |
| | | |
| | | if trainer_options.resume and (output_dir / "checkpoint.pb").exists(): |
| | | cls.resume( |
| | | self.resume( |
| | | checkpoint=output_dir / "checkpoint.pb", |
| | | model=model, |
| | | optimizers=optimizers, |
| | |
| | | ) |
| | | |
| | | if distributed_option.distributed: |
| | | if trainer_options.sharded_ddp: |
| | | dp_model = fairscale.nn.data_parallel.ShardedDataParallel( |
| | | module=model, |
| | | sharded_optimizer=optimizers, |
| | | ) |
| | | else: |
| | | dp_model = torch.nn.parallel.DistributedDataParallel( |
| | | model, find_unused_parameters=trainer_options.unused_parameters) |
| | | dp_model = torch.nn.parallel.DistributedDataParallel( |
| | | model, find_unused_parameters=trainer_options.unused_parameters) |
| | | elif distributed_option.ngpu > 1: |
| | | dp_model = torch.nn.parallel.DataParallel( |
| | | model, |
| | |
| | | reporter.set_epoch(iepoch) |
| | | # 1. Train and validation for one-epoch |
| | | with reporter.observe("train") as sub_reporter: |
| | | all_steps_are_invalid, max_update_stop = cls.train_one_epoch( |
| | | all_steps_are_invalid, max_update_stop = self.train_one_epoch( |
| | | model=dp_model, |
| | | optimizers=optimizers, |
| | | schedulers=schedulers, |
| | | iterator=train_iter_factory.build_iter(iepoch), |
| | | iterator=train_dataloader.build_iter(iepoch), |
| | | reporter=sub_reporter, |
| | | scaler=scaler, |
| | | summary_writer=train_summary_writer, |
| | |
| | | ) |
| | | |
| | | with reporter.observe("valid") as sub_reporter: |
| | | cls.validate_one_epoch( |
| | | self.validate_one_epoch( |
| | | model=dp_model, |
| | | iterator=valid_iter_factory.build_iter(iepoch), |
| | | iterator=valid_dataloader.build_iter(iepoch), |
| | | reporter=sub_reporter, |
| | | options=trainer_options, |
| | | distributed_option=distributed_option, |
| | |
| | | pai_output_dir=trainer_options.output_dir, |
| | | ) |
| | | |
| | | @classmethod |
| | | def train_one_epoch( |
| | | cls, |
| | | self, |
| | | model: torch.nn.Module, |
| | | iterator: Iterable[Tuple[List[str], Dict[str, torch.Tensor]]], |
| | | optimizers: Sequence[torch.optim.Optimizer], |
| | |
| | | # processes, send stop-flag to the other processes if iterator is finished |
| | | iterator_stop = torch.tensor(0).to("cuda" if ngpu > 0 else "cpu") |
| | | |
| | | # get the rank |
| | | rank = distributed_option.dist_rank |
| | | # get the num batch updates |
| | | num_batch_updates = 0 |
| | | # ouput dir |
| | | output_dir = Path(options.output_dir) |
| | | # batch interval |
| | | batch_interval = options.batch_interval |
| | | assert batch_interval > 0 |
| | | |
| | | start_time = time.perf_counter() |
| | | for iiter, (_, batch) in enumerate( |
| | | reporter.measure_iter_time(iterator, "iter_time"), 1 |
| | | ): |
| | | assert isinstance(batch, dict), type(batch) |
| | | |
| | | if rank == 0: |
| | | if hasattr(model, "num_updates") or (hasattr(model, "module") and hasattr(model.module, "num_updates")): |
| | | num_batch_updates = model.get_num_updates() if hasattr(model, |
| | | "num_updates") else model.module.get_num_updates() |
| | | if (num_batch_updates % batch_interval == 0) and (options.oss_bucket is not None) and options.use_pai: |
| | | buffer = BytesIO() |
| | | torch.save(model.state_dict(), buffer) |
| | | options.oss_bucket.put_object(os.path.join(output_dir, f"{num_batch_updates}batch.pth"), |
| | | buffer.getvalue()) |
| | | |
| | | if distributed: |
| | | torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM) |
| | |
| | | torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM) |
| | | return all_steps_are_invalid, max_update_stop |
| | | |
| | | @classmethod |
| | | @torch.no_grad() |
| | | def validate_one_epoch( |
| | | cls, |
| | | self, |
| | | model: torch.nn.Module, |
| | | iterator: Iterable[Dict[str, torch.Tensor]], |
| | | reporter: SubReporter, |