From c652f6814ac62eebb5fd1a55a303ee9110c87b58 Mon Sep 17 00:00:00 2001
From: 嘉渊 <wangjiaming.wjm@alibaba-inc.com>
Date: 星期日, 23 四月 2023 17:30:38 +0800
Subject: [PATCH] update
---
funasr/bin/train.py | 8
funasr/build_utils/build_scheduler.py | 25 +
funasr/build_utils/build_trainer.py | 843 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++
funasr/build_utils/build_optimizer.py | 4
4 files changed, 869 insertions(+), 11 deletions(-)
diff --git a/funasr/bin/train.py b/funasr/bin/train.py
index e861199..c32a362 100644
--- a/funasr/bin/train.py
+++ b/funasr/bin/train.py
@@ -420,16 +420,16 @@
prepare_data(args, distributed_option)
model = build_model(args)
- optimizer = build_optimizer(args, model=model)
- scheduler = build_scheduler(args, optimizer)
+ optimizers = build_optimizer(args, model=model)
+ schedulers = build_scheduler(args, optimizers)
logging.info("world size: {}, rank: {}, local_rank: {}".format(distributed_option.dist_world_size,
distributed_option.dist_rank,
distributed_option.local_rank))
logging.info(pytorch_cudnn_version())
logging.info(model_summary(model))
- logging.info("Optimizer: {}".format(optimizer))
- logging.info("Scheduler: {}".format(scheduler))
+ logging.info("Optimizer: {}".format(optimizers))
+ logging.info("Scheduler: {}".format(schedulers))
# dump args to config.yaml
if not distributed_option.distributed or distributed_option.dist_rank == 0:
diff --git a/funasr/build_utils/build_optimizer.py b/funasr/build_utils/build_optimizer.py
index 3b27994..bd0b73d 100644
--- a/funasr/build_utils/build_optimizer.py
+++ b/funasr/build_utils/build_optimizer.py
@@ -23,4 +23,6 @@
if optim_class is None:
raise ValueError(f"must be one of {list(optim_classes)}: {args.optim}")
optimizer = optim_class(model.parameters(), **args.optim_conf)
- return optimizer
\ No newline at end of file
+
+ optimizers = [optimizer]
+ return optimizers
\ No newline at end of file
diff --git a/funasr/build_utils/build_scheduler.py b/funasr/build_utils/build_scheduler.py
index f0e6d1f..4b9990e 100644
--- a/funasr/build_utils/build_scheduler.py
+++ b/funasr/build_utils/build_scheduler.py
@@ -8,7 +8,7 @@
from funasr.schedulers.warmup_lr import WarmupLR
-def build_scheduler(args, optimizer):
+def build_scheduler(args, optimizers):
scheduler_classes = dict(
ReduceLROnPlateau=torch.optim.lr_scheduler.ReduceLROnPlateau,
lambdalr=torch.optim.lr_scheduler.LambdaLR,
@@ -24,8 +24,21 @@
CosineAnnealingWarmRestarts=torch.optim.lr_scheduler.CosineAnnealingWarmRestarts,
)
- scheduler_class = scheduler_classes.get(args.scheduler)
- if scheduler_class is None:
- raise ValueError(f"must be one of {list(scheduler_classes)}: {args.scheduler}")
- scheduler = scheduler_class(optimizer, **args.scheduler_conf)
- return scheduler
\ No newline at end of file
+ schedulers = []
+ for i, optim in enumerate(optimizers, 1):
+ suf = "" if i == 1 else str(i)
+ name = getattr(args, f"scheduler{suf}")
+ conf = getattr(args, f"scheduler{suf}_conf")
+ if name is not None:
+ cls_ = scheduler_classes.get(name)
+ if cls_ is None:
+ raise ValueError(
+ f"must be one of {list(scheduler_classes)}: {name}"
+ )
+ scheduler = cls_(optim, **conf)
+ else:
+ scheduler = None
+
+ schedulers.append(scheduler)
+
+ return schedulers
\ No newline at end of file
diff --git a/funasr/build_utils/build_trainer.py b/funasr/build_utils/build_trainer.py
new file mode 100644
index 0000000..8e4ee46
--- /dev/null
+++ b/funasr/build_utils/build_trainer.py
@@ -0,0 +1,843 @@
+# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
+# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
+
+"""Trainer module."""
+import argparse
+from contextlib import contextmanager
+import dataclasses
+from dataclasses import is_dataclass
+from distutils.version import LooseVersion
+import logging
+from pathlib import Path
+import time
+from typing import Dict
+from typing import Iterable
+from typing import List
+from typing import Optional
+from typing import Sequence
+from typing import Tuple
+from typing import Union
+
+import humanfriendly
+import oss2
+from io import BytesIO
+import os
+import numpy as np
+import torch
+import torch.nn
+import torch.optim
+from typeguard import check_argument_types
+
+from funasr.iterators.abs_iter_factory import AbsIterFactory
+from funasr.main_funcs.average_nbest_models import average_nbest_models
+from funasr.main_funcs.calculate_all_attentions import calculate_all_attentions
+from funasr.schedulers.abs_scheduler import AbsBatchStepScheduler
+from funasr.schedulers.abs_scheduler import AbsEpochStepScheduler
+from funasr.schedulers.abs_scheduler import AbsScheduler
+from funasr.schedulers.abs_scheduler import AbsValEpochStepScheduler
+from funasr.torch_utils.add_gradient_noise import add_gradient_noise
+from funasr.torch_utils.device_funcs import to_device
+from funasr.torch_utils.recursive_op import recursive_average
+from funasr.torch_utils.set_all_random_seed import set_all_random_seed
+from funasr.models.base_model import FunASRModel
+from funasr.train.distributed_utils import DistributedOption
+from funasr.train.reporter import Reporter
+from funasr.train.reporter import SubReporter
+from funasr.utils.build_dataclass import build_dataclass
+
+if torch.distributed.is_available():
+ from torch.distributed import ReduceOp
+
+if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
+ from torch.cuda.amp import autocast
+ from torch.cuda.amp import GradScaler
+else:
+ # Nothing to do if torch<1.6.0
+ @contextmanager
+ def autocast(enabled=True):
+ yield
+
+
+ GradScaler = None
+
+try:
+ import fairscale
+except ImportError:
+ fairscale = None
+
+
+@dataclasses.dataclass
+class TrainerOptions:
+ ngpu: int
+ resume: bool
+ use_amp: bool
+ train_dtype: str
+ grad_noise: bool
+ accum_grad: int
+ grad_clip: float
+ grad_clip_type: float
+ log_interval: Optional[int]
+ no_forward_run: bool
+ use_tensorboard: bool
+ use_wandb: bool
+ output_dir: Union[Path, str]
+ max_epoch: int
+ max_update: int
+ seed: int
+ sharded_ddp: bool
+ patience: Optional[int]
+ keep_nbest_models: Union[int, List[int]]
+ nbest_averaging_interval: int
+ early_stopping_criterion: Sequence[str]
+ best_model_criterion: Sequence[Sequence[str]]
+ val_scheduler_criterion: Sequence[str]
+ unused_parameters: bool
+ wandb_model_log_interval: int
+ use_pai: bool
+ oss_bucket: Union[oss2.Bucket, None]
+ batch_interval: int
+
+
+class Trainer:
+ """Trainer having a optimizer.
+
+ If you'd like to use multiple optimizers, then inherit this class
+ and override the methods if necessary - at least "train_one_epoch()"
+
+ >>> class TwoOptimizerTrainer(Trainer):
+ ... @classmethod
+ ... def add_arguments(cls, parser):
+ ... ...
+ ...
+ ... @classmethod
+ ... def train_one_epoch(cls, model, optimizers, ...):
+ ... loss1 = model.model1(...)
+ ... loss1.backward()
+ ... optimizers[0].step()
+ ...
+ ... loss2 = model.model2(...)
+ ... loss2.backward()
+ ... optimizers[1].step()
+
+ """
+
+ def __init__(self):
+ raise RuntimeError("This class can't be instantiated.")
+
+ @classmethod
+ def build_options(cls, args: argparse.Namespace) -> TrainerOptions:
+ """Build options consumed by train(), eval()"""
+ assert check_argument_types()
+ return build_dataclass(TrainerOptions, args)
+
+ @classmethod
+ def add_arguments(cls, parser: argparse.ArgumentParser):
+ """Reserved for future development of another Trainer"""
+ pass
+
+ @staticmethod
+ def resume(
+ checkpoint: Union[str, Path],
+ model: torch.nn.Module,
+ reporter: Reporter,
+ optimizers: Sequence[torch.optim.Optimizer],
+ schedulers: Sequence[Optional[AbsScheduler]],
+ scaler: Optional[GradScaler],
+ ngpu: int = 0,
+ ):
+ states = torch.load(
+ checkpoint,
+ map_location=f"cuda:{torch.cuda.current_device()}" if ngpu > 0 else "cpu",
+ )
+ model.load_state_dict(states["model"])
+ reporter.load_state_dict(states["reporter"])
+ for optimizer, state in zip(optimizers, states["optimizers"]):
+ optimizer.load_state_dict(state)
+ for scheduler, state in zip(schedulers, states["schedulers"]):
+ if scheduler is not None:
+ scheduler.load_state_dict(state)
+ if scaler is not None:
+ if states["scaler"] is None:
+ logging.warning("scaler state is not found")
+ else:
+ scaler.load_state_dict(states["scaler"])
+
+ logging.info(f"The training was resumed using {checkpoint}")
+
+ @classmethod
+ def run(
+ cls,
+ model: FunASRModel,
+ optimizers: Sequence[torch.optim.Optimizer],
+ schedulers: Sequence[Optional[AbsScheduler]],
+ train_iter_factory: AbsIterFactory,
+ valid_iter_factory: AbsIterFactory,
+ trainer_options,
+ distributed_option: DistributedOption,
+ ) -> None:
+ """Perform training. This method performs the main process of training."""
+ assert check_argument_types()
+ # NOTE(kamo): Don't check the type more strictly as far trainer_options
+ assert is_dataclass(trainer_options), type(trainer_options)
+ assert len(optimizers) == len(schedulers), (len(optimizers), len(schedulers))
+
+ if isinstance(trainer_options.keep_nbest_models, int):
+ keep_nbest_models = [trainer_options.keep_nbest_models]
+ else:
+ if len(trainer_options.keep_nbest_models) == 0:
+ logging.warning("No keep_nbest_models is given. Change to [1]")
+ trainer_options.keep_nbest_models = [1]
+ keep_nbest_models = trainer_options.keep_nbest_models
+
+ # assert batch_interval is set and >0
+ assert trainer_options.batch_interval > 0
+
+ output_dir = Path(trainer_options.output_dir)
+ reporter = Reporter()
+ if trainer_options.use_amp:
+ if LooseVersion(torch.__version__) < LooseVersion("1.6.0"):
+ raise RuntimeError(
+ "Require torch>=1.6.0 for Automatic Mixed Precision"
+ )
+ if trainer_options.sharded_ddp:
+ if fairscale is None:
+ raise RuntimeError(
+ "Requiring fairscale. Do 'pip install fairscale'"
+ )
+ scaler = fairscale.optim.grad_scaler.ShardedGradScaler()
+ else:
+ scaler = GradScaler()
+ else:
+ scaler = None
+
+ if trainer_options.resume and (output_dir / "checkpoint.pb").exists():
+ cls.resume(
+ checkpoint=output_dir / "checkpoint.pb",
+ model=model,
+ optimizers=optimizers,
+ schedulers=schedulers,
+ reporter=reporter,
+ scaler=scaler,
+ ngpu=trainer_options.ngpu,
+ )
+
+ start_epoch = reporter.get_epoch() + 1
+ if start_epoch == trainer_options.max_epoch + 1:
+ logging.warning(
+ f"The training has already reached at max_epoch: {start_epoch}"
+ )
+
+ if distributed_option.distributed:
+ if trainer_options.sharded_ddp:
+ dp_model = fairscale.nn.data_parallel.ShardedDataParallel(
+ module=model,
+ sharded_optimizer=optimizers,
+ )
+ else:
+ dp_model = torch.nn.parallel.DistributedDataParallel(
+ model, find_unused_parameters=trainer_options.unused_parameters)
+ elif distributed_option.ngpu > 1:
+ dp_model = torch.nn.parallel.DataParallel(
+ model,
+ device_ids=list(range(distributed_option.ngpu)),
+ )
+ else:
+ # NOTE(kamo): DataParallel also should work with ngpu=1,
+ # but for debuggability it's better to keep this block.
+ dp_model = model
+
+ if trainer_options.use_tensorboard and (
+ not distributed_option.distributed or distributed_option.dist_rank == 0
+ ):
+ from torch.utils.tensorboard import SummaryWriter
+ if trainer_options.use_pai:
+ train_summary_writer = SummaryWriter(
+ os.path.join(trainer_options.output_dir, "tensorboard/train")
+ )
+ valid_summary_writer = SummaryWriter(
+ os.path.join(trainer_options.output_dir, "tensorboard/valid")
+ )
+ else:
+ train_summary_writer = SummaryWriter(
+ str(output_dir / "tensorboard" / "train")
+ )
+ valid_summary_writer = SummaryWriter(
+ str(output_dir / "tensorboard" / "valid")
+ )
+ else:
+ train_summary_writer = None
+
+ start_time = time.perf_counter()
+ for iepoch in range(start_epoch, trainer_options.max_epoch + 1):
+ if iepoch != start_epoch:
+ logging.info(
+ "{}/{}epoch started. Estimated time to finish: {}".format(
+ iepoch,
+ trainer_options.max_epoch,
+ humanfriendly.format_timespan(
+ (time.perf_counter() - start_time)
+ / (iepoch - start_epoch)
+ * (trainer_options.max_epoch - iepoch + 1)
+ ),
+ )
+ )
+ else:
+ logging.info(f"{iepoch}/{trainer_options.max_epoch}epoch started")
+ set_all_random_seed(trainer_options.seed + iepoch)
+
+ reporter.set_epoch(iepoch)
+ # 1. Train and validation for one-epoch
+ with reporter.observe("train") as sub_reporter:
+ all_steps_are_invalid, max_update_stop = cls.train_one_epoch(
+ model=dp_model,
+ optimizers=optimizers,
+ schedulers=schedulers,
+ iterator=train_iter_factory.build_iter(iepoch),
+ reporter=sub_reporter,
+ scaler=scaler,
+ summary_writer=train_summary_writer,
+ options=trainer_options,
+ distributed_option=distributed_option,
+ )
+
+ with reporter.observe("valid") as sub_reporter:
+ cls.validate_one_epoch(
+ model=dp_model,
+ iterator=valid_iter_factory.build_iter(iepoch),
+ reporter=sub_reporter,
+ options=trainer_options,
+ distributed_option=distributed_option,
+ )
+
+ # 2. LR Scheduler step
+ for scheduler in schedulers:
+ if isinstance(scheduler, AbsValEpochStepScheduler):
+ scheduler.step(
+ reporter.get_value(*trainer_options.val_scheduler_criterion)
+ )
+ elif isinstance(scheduler, AbsEpochStepScheduler):
+ scheduler.step()
+ if trainer_options.sharded_ddp:
+ for optimizer in optimizers:
+ if isinstance(optimizer, fairscale.optim.oss.OSS):
+ optimizer.consolidate_state_dict()
+
+ if not distributed_option.distributed or distributed_option.dist_rank == 0:
+ # 3. Report the results
+ logging.info(reporter.log_message())
+ if train_summary_writer is not None:
+ reporter.tensorboard_add_scalar(train_summary_writer, key1="train")
+ reporter.tensorboard_add_scalar(valid_summary_writer, key1="valid")
+ if trainer_options.use_wandb:
+ reporter.wandb_log()
+
+ # save tensorboard on oss
+ if trainer_options.use_pai and train_summary_writer is not None:
+ def write_tensorboard_summary(summary_writer_path, oss_bucket):
+ file_list = []
+ for root, dirs, files in os.walk(summary_writer_path, topdown=False):
+ for name in files:
+ file_full_path = os.path.join(root, name)
+ file_list.append(file_full_path)
+
+ for file_full_path in file_list:
+ with open(file_full_path, "rb") as f:
+ oss_bucket.put_object(file_full_path, f)
+
+ write_tensorboard_summary(os.path.join(trainer_options.output_dir, "tensorboard/train"),
+ trainer_options.oss_bucket)
+ write_tensorboard_summary(os.path.join(trainer_options.output_dir, "tensorboard/valid"),
+ trainer_options.oss_bucket)
+
+ # 4. Save/Update the checkpoint
+ if trainer_options.use_pai:
+ buffer = BytesIO()
+ torch.save(
+ {
+ "model": model.state_dict(),
+ "reporter": reporter.state_dict(),
+ "optimizers": [o.state_dict() for o in optimizers],
+ "schedulers": [
+ s.state_dict() if s is not None else None
+ for s in schedulers
+ ],
+ "scaler": scaler.state_dict() if scaler is not None else None,
+ "ema_model": model.encoder.ema.model.state_dict()
+ if hasattr(model.encoder, "ema") and model.encoder.ema is not None else None,
+ },
+ buffer,
+ )
+ trainer_options.oss_bucket.put_object(os.path.join(trainer_options.output_dir, "checkpoint.pb"),
+ buffer.getvalue())
+ else:
+ torch.save(
+ {
+ "model": model.state_dict(),
+ "reporter": reporter.state_dict(),
+ "optimizers": [o.state_dict() for o in optimizers],
+ "schedulers": [
+ s.state_dict() if s is not None else None
+ for s in schedulers
+ ],
+ "scaler": scaler.state_dict() if scaler is not None else None,
+ },
+ output_dir / "checkpoint.pb",
+ )
+
+ # 5. Save and log the model and update the link to the best model
+ if trainer_options.use_pai:
+ buffer = BytesIO()
+ torch.save(model.state_dict(), buffer)
+ trainer_options.oss_bucket.put_object(os.path.join(trainer_options.output_dir,
+ f"{iepoch}epoch.pb"), buffer.getvalue())
+ else:
+ torch.save(model.state_dict(), output_dir / f"{iepoch}epoch.pb")
+
+ # Creates a sym link latest.pb -> {iepoch}epoch.pb
+ if trainer_options.use_pai:
+ p = os.path.join(trainer_options.output_dir, "latest.pb")
+ if trainer_options.oss_bucket.object_exists(p):
+ trainer_options.oss_bucket.delete_object(p)
+ trainer_options.oss_bucket.copy_object(trainer_options.oss_bucket.bucket_name,
+ os.path.join(trainer_options.output_dir,
+ f"{iepoch}epoch.pb"), p)
+ else:
+ p = output_dir / "latest.pb"
+ if p.is_symlink() or p.exists():
+ p.unlink()
+ p.symlink_to(f"{iepoch}epoch.pb")
+
+ _improved = []
+ for _phase, k, _mode in trainer_options.best_model_criterion:
+ # e.g. _phase, k, _mode = "train", "loss", "min"
+ if reporter.has(_phase, k):
+ best_epoch = reporter.get_best_epoch(_phase, k, _mode)
+ # Creates sym links if it's the best result
+ if best_epoch == iepoch:
+ if trainer_options.use_pai:
+ p = os.path.join(trainer_options.output_dir, f"{_phase}.{k}.best.pb")
+ if trainer_options.oss_bucket.object_exists(p):
+ trainer_options.oss_bucket.delete_object(p)
+ trainer_options.oss_bucket.copy_object(trainer_options.oss_bucket.bucket_name,
+ os.path.join(trainer_options.output_dir,
+ f"{iepoch}epoch.pb"), p)
+ else:
+ p = output_dir / f"{_phase}.{k}.best.pb"
+ if p.is_symlink() or p.exists():
+ p.unlink()
+ p.symlink_to(f"{iepoch}epoch.pb")
+ _improved.append(f"{_phase}.{k}")
+ if len(_improved) == 0:
+ logging.info("There are no improvements in this epoch")
+ else:
+ logging.info(
+ "The best model has been updated: " + ", ".join(_improved)
+ )
+
+ log_model = (
+ trainer_options.wandb_model_log_interval > 0
+ and iepoch % trainer_options.wandb_model_log_interval == 0
+ )
+ if log_model and trainer_options.use_wandb:
+ import wandb
+
+ logging.info("Logging Model on this epoch :::::")
+ artifact = wandb.Artifact(
+ name=f"model_{wandb.run.id}",
+ type="model",
+ metadata={"improved": _improved},
+ )
+ artifact.add_file(str(output_dir / f"{iepoch}epoch.pb"))
+ aliases = [
+ f"epoch-{iepoch}",
+ "best" if best_epoch == iepoch else "",
+ ]
+ wandb.log_artifact(artifact, aliases=aliases)
+
+ # 6. Remove the model files excluding n-best epoch and latest epoch
+ _removed = []
+ # Get the union set of the n-best among multiple criterion
+ nbests = set().union(
+ *[
+ set(reporter.sort_epochs(ph, k, m)[: max(keep_nbest_models)])
+ for ph, k, m in trainer_options.best_model_criterion
+ if reporter.has(ph, k)
+ ]
+ )
+
+ # Generated n-best averaged model
+ if (
+ trainer_options.nbest_averaging_interval > 0
+ and iepoch % trainer_options.nbest_averaging_interval == 0
+ ):
+ average_nbest_models(
+ reporter=reporter,
+ output_dir=output_dir,
+ best_model_criterion=trainer_options.best_model_criterion,
+ nbest=keep_nbest_models,
+ suffix=f"till{iepoch}epoch",
+ oss_bucket=trainer_options.oss_bucket,
+ pai_output_dir=trainer_options.output_dir,
+ )
+
+ for e in range(1, iepoch):
+ if trainer_options.use_pai:
+ p = os.path.join(trainer_options.output_dir, f"{e}epoch.pb")
+ if trainer_options.oss_bucket.object_exists(p) and e not in nbests:
+ trainer_options.oss_bucket.delete_object(p)
+ _removed.append(str(p))
+ else:
+ p = output_dir / f"{e}epoch.pb"
+ if p.exists() and e not in nbests:
+ p.unlink()
+ _removed.append(str(p))
+ if len(_removed) != 0:
+ logging.info("The model files were removed: " + ", ".join(_removed))
+
+ # 7. If any updating haven't happened, stops the training
+ if all_steps_are_invalid:
+ logging.warning(
+ f"The gradients at all steps are invalid in this epoch. "
+ f"Something seems wrong. This training was stopped at {iepoch}epoch"
+ )
+ break
+
+ if max_update_stop:
+ logging.info(
+ f"Stopping training due to "
+ f"num_updates: {trainer_options.num_updates} >= max_update: {trainer_options.max_update}"
+ )
+ break
+
+ # 8. Check early stopping
+ if trainer_options.patience is not None:
+ if reporter.check_early_stopping(
+ trainer_options.patience, *trainer_options.early_stopping_criterion
+ ):
+ break
+
+ else:
+ logging.info(
+ f"The training was finished at {trainer_options.max_epoch} epochs "
+ )
+
+ # Generated n-best averaged model
+ if not distributed_option.distributed or distributed_option.dist_rank == 0:
+ average_nbest_models(
+ reporter=reporter,
+ output_dir=output_dir,
+ best_model_criterion=trainer_options.best_model_criterion,
+ nbest=keep_nbest_models,
+ oss_bucket=trainer_options.oss_bucket,
+ pai_output_dir=trainer_options.output_dir,
+ )
+
+ @classmethod
+ def train_one_epoch(
+ cls,
+ model: torch.nn.Module,
+ iterator: Iterable[Tuple[List[str], Dict[str, torch.Tensor]]],
+ optimizers: Sequence[torch.optim.Optimizer],
+ schedulers: Sequence[Optional[AbsScheduler]],
+ scaler: Optional[GradScaler],
+ reporter: SubReporter,
+ summary_writer,
+ options: TrainerOptions,
+ distributed_option: DistributedOption,
+ ) -> Tuple[bool, bool]:
+ assert check_argument_types()
+
+ grad_noise = options.grad_noise
+ accum_grad = options.accum_grad
+ grad_clip = options.grad_clip
+ grad_clip_type = options.grad_clip_type
+ log_interval = options.log_interval
+ no_forward_run = options.no_forward_run
+ ngpu = options.ngpu
+ use_wandb = options.use_wandb
+ distributed = distributed_option.distributed
+
+ if log_interval is None:
+ try:
+ log_interval = max(len(iterator) // 20, 10)
+ except TypeError:
+ log_interval = 100
+
+ model.train()
+ all_steps_are_invalid = True
+ max_update_stop = False
+ # [For distributed] Because iteration counts are not always equals between
+ # processes, send stop-flag to the other processes if iterator is finished
+ iterator_stop = torch.tensor(0).to("cuda" if ngpu > 0 else "cpu")
+
+ # get the rank
+ rank = distributed_option.dist_rank
+ # get the num batch updates
+ num_batch_updates = 0
+ # ouput dir
+ output_dir = Path(options.output_dir)
+ # batch interval
+ batch_interval = options.batch_interval
+ assert batch_interval > 0
+
+ start_time = time.perf_counter()
+ for iiter, (_, batch) in enumerate(
+ reporter.measure_iter_time(iterator, "iter_time"), 1
+ ):
+ assert isinstance(batch, dict), type(batch)
+
+ if rank == 0:
+ if hasattr(model, "num_updates") or (hasattr(model, "module") and hasattr(model.module, "num_updates")):
+ num_batch_updates = model.get_num_updates() if hasattr(model,
+ "num_updates") else model.module.get_num_updates()
+ if (num_batch_updates % batch_interval == 0) and (options.oss_bucket is not None) and options.use_pai:
+ buffer = BytesIO()
+ torch.save(model.state_dict(), buffer)
+ options.oss_bucket.put_object(os.path.join(output_dir, f"{num_batch_updates}batch.pth"),
+ buffer.getvalue())
+
+ if distributed:
+ torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM)
+ if iterator_stop > 0:
+ break
+
+ batch = to_device(batch, "cuda" if ngpu > 0 else "cpu")
+ if no_forward_run:
+ all_steps_are_invalid = False
+ continue
+
+ with autocast(scaler is not None):
+ with reporter.measure_time("forward_time"):
+ retval = model(**batch)
+
+ # Note(kamo):
+ # Supporting two patterns for the returned value from the model
+ # a. dict type
+ if isinstance(retval, dict):
+ loss = retval["loss"]
+ stats = retval["stats"]
+ weight = retval["weight"]
+ optim_idx = retval.get("optim_idx")
+ if optim_idx is not None and not isinstance(optim_idx, int):
+ if not isinstance(optim_idx, torch.Tensor):
+ raise RuntimeError(
+ "optim_idx must be int or 1dim torch.Tensor, "
+ f"but got {type(optim_idx)}"
+ )
+ if optim_idx.dim() >= 2:
+ raise RuntimeError(
+ "optim_idx must be int or 1dim torch.Tensor, "
+ f"but got {optim_idx.dim()}dim tensor"
+ )
+ if optim_idx.dim() == 1:
+ for v in optim_idx:
+ if v != optim_idx[0]:
+ raise RuntimeError(
+ "optim_idx must be 1dim tensor "
+ "having same values for all entries"
+ )
+ optim_idx = optim_idx[0].item()
+ else:
+ optim_idx = optim_idx.item()
+
+ # b. tuple or list type
+ else:
+ loss, stats, weight = retval
+ optim_idx = None
+
+ stats = {k: v for k, v in stats.items() if v is not None}
+ if ngpu > 1 or distributed:
+ # Apply weighted averaging for loss and stats
+ loss = (loss * weight.type(loss.dtype)).sum()
+
+ # if distributed, this method can also apply all_reduce()
+ stats, weight = recursive_average(stats, weight, distributed)
+
+ # Now weight is summation over all workers
+ loss /= weight
+ if distributed:
+ # NOTE(kamo): Multiply world_size because DistributedDataParallel
+ # automatically normalizes the gradient by world_size.
+ loss *= torch.distributed.get_world_size()
+
+ loss /= accum_grad
+
+ reporter.register(stats, weight)
+
+ with reporter.measure_time("backward_time"):
+ if scaler is not None:
+ # Scales loss. Calls backward() on scaled loss
+ # to create scaled gradients.
+ # Backward passes under autocast are not recommended.
+ # Backward ops run in the same dtype autocast chose
+ # for corresponding forward ops.
+ scaler.scale(loss).backward()
+ else:
+ loss.backward()
+
+ if iiter % accum_grad == 0:
+ if scaler is not None:
+ # Unscales the gradients of optimizer's assigned params in-place
+ for iopt, optimizer in enumerate(optimizers):
+ if optim_idx is not None and iopt != optim_idx:
+ continue
+ scaler.unscale_(optimizer)
+
+ # gradient noise injection
+ if grad_noise:
+ add_gradient_noise(
+ model,
+ reporter.get_total_count(),
+ duration=100,
+ eta=1.0,
+ scale_factor=0.55,
+ )
+
+ # compute the gradient norm to check if it is normal or not
+ grad_norm = torch.nn.utils.clip_grad_norm_(
+ model.parameters(),
+ max_norm=grad_clip,
+ norm_type=grad_clip_type,
+ )
+ # PyTorch<=1.4, clip_grad_norm_ returns float value
+ if not isinstance(grad_norm, torch.Tensor):
+ grad_norm = torch.tensor(grad_norm)
+
+ if not torch.isfinite(grad_norm):
+ logging.warning(
+ f"The grad norm is {grad_norm}. Skipping updating the model."
+ )
+
+ # Must invoke scaler.update() if unscale_() is used in the iteration
+ # to avoid the following error:
+ # RuntimeError: unscale_() has already been called
+ # on this optimizer since the last update().
+ # Note that if the gradient has inf/nan values,
+ # scaler.step skips optimizer.step().
+ if scaler is not None:
+ for iopt, optimizer in enumerate(optimizers):
+ if optim_idx is not None and iopt != optim_idx:
+ continue
+ scaler.step(optimizer)
+ scaler.update()
+
+ else:
+ all_steps_are_invalid = False
+ with reporter.measure_time("optim_step_time"):
+ for iopt, (optimizer, scheduler) in enumerate(
+ zip(optimizers, schedulers)
+ ):
+ if optim_idx is not None and iopt != optim_idx:
+ continue
+ if scaler is not None:
+ # scaler.step() first unscales the gradients of
+ # the optimizer's assigned params.
+ scaler.step(optimizer)
+ # Updates the scale for next iteration.
+ scaler.update()
+ else:
+ optimizer.step()
+ if isinstance(scheduler, AbsBatchStepScheduler):
+ scheduler.step()
+ for iopt, optimizer in enumerate(optimizers):
+ if optim_idx is not None and iopt != optim_idx:
+ continue
+ optimizer.zero_grad()
+
+ # Register lr and train/load time[sec/step],
+ # where step refers to accum_grad * mini-batch
+ reporter.register(
+ dict(
+ {
+ f"optim{i}_lr{j}": pg["lr"]
+ for i, optimizer in enumerate(optimizers)
+ for j, pg in enumerate(optimizer.param_groups)
+ if "lr" in pg
+ },
+ train_time=time.perf_counter() - start_time,
+ ),
+ )
+ start_time = time.perf_counter()
+
+ # update num_updates
+ if distributed:
+ if hasattr(model.module, "num_updates"):
+ model.module.set_num_updates(model.module.get_num_updates() + 1)
+ options.num_updates = model.module.get_num_updates()
+ if model.module.get_num_updates() >= options.max_update:
+ max_update_stop = True
+ else:
+ if hasattr(model, "num_updates"):
+ model.set_num_updates(model.get_num_updates() + 1)
+ options.num_updates = model.get_num_updates()
+ if model.get_num_updates() >= options.max_update:
+ max_update_stop = True
+
+ # NOTE(kamo): Call log_message() after next()
+ reporter.next()
+ if iiter % log_interval == 0:
+ num_updates = options.num_updates if hasattr(options, "num_updates") else None
+ logging.info(reporter.log_message(-log_interval, num_updates=num_updates))
+ if summary_writer is not None:
+ reporter.tensorboard_add_scalar(summary_writer, -log_interval)
+ if use_wandb:
+ reporter.wandb_log()
+
+ if max_update_stop:
+ break
+
+ else:
+ if distributed:
+ iterator_stop.fill_(1)
+ torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM)
+ return all_steps_are_invalid, max_update_stop
+
+ @classmethod
+ @torch.no_grad()
+ def validate_one_epoch(
+ cls,
+ model: torch.nn.Module,
+ iterator: Iterable[Dict[str, torch.Tensor]],
+ reporter: SubReporter,
+ options: TrainerOptions,
+ distributed_option: DistributedOption,
+ ) -> None:
+ assert check_argument_types()
+ ngpu = options.ngpu
+ no_forward_run = options.no_forward_run
+ distributed = distributed_option.distributed
+
+ model.eval()
+
+ # [For distributed] Because iteration counts are not always equals between
+ # processes, send stop-flag to the other processes if iterator is finished
+ iterator_stop = torch.tensor(0).to("cuda" if ngpu > 0 else "cpu")
+ for (_, batch) in iterator:
+ assert isinstance(batch, dict), type(batch)
+ if distributed:
+ torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM)
+ if iterator_stop > 0:
+ break
+
+ batch = to_device(batch, "cuda" if ngpu > 0 else "cpu")
+ if no_forward_run:
+ continue
+
+ retval = model(**batch)
+ if isinstance(retval, dict):
+ stats = retval["stats"]
+ weight = retval["weight"]
+ else:
+ _, stats, weight = retval
+ if ngpu > 1 or distributed:
+ # Apply weighted averaging for stats.
+ # if distributed, this method can also apply all_reduce()
+ stats, weight = recursive_average(stats, weight, distributed)
+
+ reporter.register(stats, weight)
+ reporter.next()
+
+ else:
+ if distributed:
+ iterator_stop.fill_(1)
+ torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM)
--
Gitblit v1.9.1