| | |
| | | |
| | | """Trainer module.""" |
| | | import argparse |
| | | from audioop import bias |
| | | from contextlib import contextmanager |
| | | import dataclasses |
| | | from dataclasses import is_dataclass |
| | |
| | | from funasr.torch_utils.device_funcs import to_device |
| | | from funasr.torch_utils.recursive_op import recursive_average |
| | | from funasr.torch_utils.set_all_random_seed import set_all_random_seed |
| | | from funasr.train.abs_espnet_model import AbsESPnetModel |
| | | from funasr.models.base_model import FunASRModel |
| | | from funasr.train.distributed_utils import DistributedOption |
| | | from funasr.train.reporter import Reporter |
| | | from funasr.train.reporter import SubReporter |
| | | from funasr.utils.build_dataclass import build_dataclass |
| | | from funasr.utils.kwargs2args import kwargs2args |
| | | |
| | | if torch.distributed.is_available(): |
| | | from torch.distributed import ReduceOp |
| | |
| | | schedulers: Sequence[Optional[AbsScheduler]], |
| | | scaler: Optional[GradScaler], |
| | | ngpu: int = 0, |
| | | oss_bucket=None, |
| | | ): |
| | | states = torch.load( |
| | | checkpoint, |
| | | map_location=f"cuda:{torch.cuda.current_device()}" if ngpu > 0 else "cpu", |
| | | ) |
| | | if oss_bucket is None: |
| | | if os.path.exists(checkpoint): |
| | | states = torch.load( |
| | | checkpoint, |
| | | map_location=f"cuda:{torch.cuda.current_device()}" if ngpu > 0 else "cpu", |
| | | ) |
| | | |
| | | else: |
| | | return 0 |
| | | else: |
| | | if oss_bucket.object_exists(checkpoint): |
| | | buffer = BytesIO(oss_bucket.get_object(checkpoint).read()) |
| | | states = torch.load(buffer, map_location=f"cuda:{torch.cuda.current_device()}" if ngpu > 0 else "cpu",) |
| | | else: |
| | | return 0 |
| | | model.load_state_dict(states["model"]) |
| | | reporter.load_state_dict(states["reporter"]) |
| | | for optimizer, state in zip(optimizers, states["optimizers"]): |
| | |
| | | @classmethod |
| | | def run( |
| | | cls, |
| | | model: AbsESPnetModel, |
| | | model: FunASRModel, |
| | | optimizers: Sequence[torch.optim.Optimizer], |
| | | schedulers: Sequence[Optional[AbsScheduler]], |
| | | train_iter_factory: AbsIterFactory, |
| | |
| | | else: |
| | | scaler = None |
| | | |
| | | if trainer_options.resume and (output_dir / "checkpoint.pb").exists(): |
| | | if trainer_options.resume: |
| | | cls.resume( |
| | | checkpoint=output_dir / "checkpoint.pb", |
| | | checkpoint=os.path.join(trainer_options.output_dir, "checkpoint.pb") if trainer_options.use_pai else output_dir / "checkpoint.pb", |
| | | model=model, |
| | | optimizers=optimizers, |
| | | schedulers=schedulers, |
| | | reporter=reporter, |
| | | scaler=scaler, |
| | | ngpu=trainer_options.ngpu, |
| | | oss_bucket=trainer_options.oss_bucket if trainer_options.use_pai else None, |
| | | ) |
| | | |
| | | start_epoch = reporter.get_epoch() + 1 |
| | |
| | | all_steps_are_invalid = False |
| | | continue |
| | | |
| | | if iiter == 1 and summary_writer is not None: |
| | | try: |
| | | args = kwargs2args(model.forward, batch) |
| | | except (ValueError, TypeError): |
| | | logging.warning( |
| | | "inpect.signature() is failed for the model. " |
| | | "The graph can't be added for tensorboard." |
| | | ) |
| | | else: |
| | | try: |
| | | summary_writer.add_graph(model, args, use_strict_trace=False) |
| | | except Exception: |
| | | logging.warning( |
| | | "summary_writer.add_graph() is failed for the model. " |
| | | "The graph can't be added for tensorboard." |
| | | ) |
| | | del args |
| | | |
| | | with autocast(scaler is not None): |
| | | with reporter.measure_time("forward_time"): |
| | | retval = model(**batch) |