From 33d3d2084403fd34b79c835d2f2fe04f6cd8f738 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 13 九月 2023 09:33:54 +0800
Subject: [PATCH] Merge branch 'main' of github.com:alibaba-damo-academy/FunASR add
---
funasr/build_utils/build_trainer.py | 262 +++++++++++++++++++++++-----------------------------
1 files changed, 117 insertions(+), 145 deletions(-)
diff --git a/funasr/build_utils/build_trainer.py b/funasr/build_utils/build_trainer.py
index 8e4ee46..03aa780 100644
--- a/funasr/build_utils/build_trainer.py
+++ b/funasr/build_utils/build_trainer.py
@@ -3,13 +3,15 @@
"""Trainer module."""
import argparse
-from contextlib import contextmanager
import dataclasses
+import logging
+import os
+import time
+from contextlib import contextmanager
from dataclasses import is_dataclass
from distutils.version import LooseVersion
-import logging
+from io import BytesIO
from pathlib import Path
-import time
from typing import Dict
from typing import Iterable
from typing import List
@@ -20,17 +22,13 @@
import humanfriendly
import oss2
-from io import BytesIO
-import os
-import numpy as np
import torch
import torch.nn
import torch.optim
-from typeguard import check_argument_types
from funasr.iterators.abs_iter_factory import AbsIterFactory
from funasr.main_funcs.average_nbest_models import average_nbest_models
-from funasr.main_funcs.calculate_all_attentions import calculate_all_attentions
+from funasr.models.base_model import FunASRModel
from funasr.schedulers.abs_scheduler import AbsBatchStepScheduler
from funasr.schedulers.abs_scheduler import AbsEpochStepScheduler
from funasr.schedulers.abs_scheduler import AbsScheduler
@@ -39,7 +37,6 @@
from funasr.torch_utils.device_funcs import to_device
from funasr.torch_utils.recursive_op import recursive_average
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
-from funasr.models.base_model import FunASRModel
from funasr.train.distributed_utils import DistributedOption
from funasr.train.reporter import Reporter
from funasr.train.reporter import SubReporter
@@ -77,14 +74,14 @@
grad_clip: float
grad_clip_type: float
log_interval: Optional[int]
- no_forward_run: bool
+ # no_forward_run: bool
use_tensorboard: bool
- use_wandb: bool
+ # use_wandb: bool
output_dir: Union[Path, str]
max_epoch: int
max_update: int
seed: int
- sharded_ddp: bool
+ # sharded_ddp: bool
patience: Optional[int]
keep_nbest_models: Union[int, List[int]]
nbest_averaging_interval: int
@@ -92,42 +89,34 @@
best_model_criterion: Sequence[Sequence[str]]
val_scheduler_criterion: Sequence[str]
unused_parameters: bool
- wandb_model_log_interval: int
+ # wandb_model_log_interval: int
use_pai: bool
oss_bucket: Union[oss2.Bucket, None]
- batch_interval: int
class Trainer:
- """Trainer having a optimizer.
-
- If you'd like to use multiple optimizers, then inherit this class
- and override the methods if necessary - at least "train_one_epoch()"
-
- >>> class TwoOptimizerTrainer(Trainer):
- ... @classmethod
- ... def add_arguments(cls, parser):
- ... ...
- ...
- ... @classmethod
- ... def train_one_epoch(cls, model, optimizers, ...):
- ... loss1 = model.model1(...)
- ... loss1.backward()
- ... optimizers[0].step()
- ...
- ... loss2 = model.model2(...)
- ... loss2.backward()
- ... optimizers[1].step()
+ """Trainer
"""
- def __init__(self):
- raise RuntimeError("This class can't be instantiated.")
+ def __init__(self,
+ args,
+ model: FunASRModel,
+ optimizers: Sequence[torch.optim.Optimizer],
+ schedulers: Sequence[Optional[AbsScheduler]],
+ train_dataloader: AbsIterFactory,
+ valid_dataloader: AbsIterFactory,
+ distributed_option: DistributedOption):
+ self.trainer_options = self.build_options(args)
+ self.model = model
+ self.optimizers = optimizers
+ self.schedulers = schedulers
+ self.train_dataloader = train_dataloader
+ self.valid_dataloader = valid_dataloader
+ self.distributed_option = distributed_option
- @classmethod
- def build_options(cls, args: argparse.Namespace) -> TrainerOptions:
+ def build_options(self, args: argparse.Namespace) -> TrainerOptions:
"""Build options consumed by train(), eval()"""
- assert check_argument_types()
return build_dataclass(TrainerOptions, args)
@classmethod
@@ -135,16 +124,15 @@
"""Reserved for future development of another Trainer"""
pass
- @staticmethod
- def resume(
- checkpoint: Union[str, Path],
- model: torch.nn.Module,
- reporter: Reporter,
- optimizers: Sequence[torch.optim.Optimizer],
- schedulers: Sequence[Optional[AbsScheduler]],
- scaler: Optional[GradScaler],
- ngpu: int = 0,
- ):
+ def resume(self,
+ checkpoint: Union[str, Path],
+ model: torch.nn.Module,
+ reporter: Reporter,
+ optimizers: Sequence[torch.optim.Optimizer],
+ schedulers: Sequence[Optional[AbsScheduler]],
+ scaler: Optional[GradScaler],
+ ngpu: int = 0,
+ ):
states = torch.load(
checkpoint,
map_location=f"cuda:{torch.cuda.current_device()}" if ngpu > 0 else "cpu",
@@ -164,20 +152,16 @@
logging.info(f"The training was resumed using {checkpoint}")
- @classmethod
- def run(
- cls,
- model: FunASRModel,
- optimizers: Sequence[torch.optim.Optimizer],
- schedulers: Sequence[Optional[AbsScheduler]],
- train_iter_factory: AbsIterFactory,
- valid_iter_factory: AbsIterFactory,
- trainer_options,
- distributed_option: DistributedOption,
- ) -> None:
+ def run(self) -> None:
"""Perform training. This method performs the main process of training."""
- assert check_argument_types()
# NOTE(kamo): Don't check the type more strictly as far trainer_options
+ model = self.model
+ optimizers = self.optimizers
+ schedulers = self.schedulers
+ train_dataloader = self.train_dataloader
+ valid_dataloader = self.valid_dataloader
+ trainer_options = self.trainer_options
+ distributed_option = self.distributed_option
assert is_dataclass(trainer_options), type(trainer_options)
assert len(optimizers) == len(schedulers), (len(optimizers), len(schedulers))
@@ -189,9 +173,6 @@
trainer_options.keep_nbest_models = [1]
keep_nbest_models = trainer_options.keep_nbest_models
- # assert batch_interval is set and >0
- assert trainer_options.batch_interval > 0
-
output_dir = Path(trainer_options.output_dir)
reporter = Reporter()
if trainer_options.use_amp:
@@ -199,19 +180,19 @@
raise RuntimeError(
"Require torch>=1.6.0 for Automatic Mixed Precision"
)
- if trainer_options.sharded_ddp:
- if fairscale is None:
- raise RuntimeError(
- "Requiring fairscale. Do 'pip install fairscale'"
- )
- scaler = fairscale.optim.grad_scaler.ShardedGradScaler()
- else:
- scaler = GradScaler()
+ # if trainer_options.sharded_ddp:
+ # if fairscale is None:
+ # raise RuntimeError(
+ # "Requiring fairscale. Do 'pip install fairscale'"
+ # )
+ # scaler = fairscale.optim.grad_scaler.ShardedGradScaler()
+ # else:
+ scaler = GradScaler()
else:
scaler = None
if trainer_options.resume and (output_dir / "checkpoint.pb").exists():
- cls.resume(
+ self.resume(
checkpoint=output_dir / "checkpoint.pb",
model=model,
optimizers=optimizers,
@@ -228,14 +209,8 @@
)
if distributed_option.distributed:
- if trainer_options.sharded_ddp:
- dp_model = fairscale.nn.data_parallel.ShardedDataParallel(
- module=model,
- sharded_optimizer=optimizers,
- )
- else:
- dp_model = torch.nn.parallel.DistributedDataParallel(
- model, find_unused_parameters=trainer_options.unused_parameters)
+ dp_model = torch.nn.parallel.DistributedDataParallel(
+ model, find_unused_parameters=trainer_options.unused_parameters)
elif distributed_option.ngpu > 1:
dp_model = torch.nn.parallel.DataParallel(
model,
@@ -288,11 +263,11 @@
reporter.set_epoch(iepoch)
# 1. Train and validation for one-epoch
with reporter.observe("train") as sub_reporter:
- all_steps_are_invalid, max_update_stop = cls.train_one_epoch(
+ all_steps_are_invalid, max_update_stop = self.train_one_epoch(
model=dp_model,
optimizers=optimizers,
schedulers=schedulers,
- iterator=train_iter_factory.build_iter(iepoch),
+ iterator=train_dataloader.build_iter(iepoch),
reporter=sub_reporter,
scaler=scaler,
summary_writer=train_summary_writer,
@@ -301,9 +276,9 @@
)
with reporter.observe("valid") as sub_reporter:
- cls.validate_one_epoch(
+ self.validate_one_epoch(
model=dp_model,
- iterator=valid_iter_factory.build_iter(iepoch),
+ iterator=valid_dataloader.build_iter(iepoch),
reporter=sub_reporter,
options=trainer_options,
distributed_option=distributed_option,
@@ -317,10 +292,10 @@
)
elif isinstance(scheduler, AbsEpochStepScheduler):
scheduler.step()
- if trainer_options.sharded_ddp:
- for optimizer in optimizers:
- if isinstance(optimizer, fairscale.optim.oss.OSS):
- optimizer.consolidate_state_dict()
+ # if trainer_options.sharded_ddp:
+ # for optimizer in optimizers:
+ # if isinstance(optimizer, fairscale.optim.oss.OSS):
+ # optimizer.consolidate_state_dict()
if not distributed_option.distributed or distributed_option.dist_rank == 0:
# 3. Report the results
@@ -328,8 +303,8 @@
if train_summary_writer is not None:
reporter.tensorboard_add_scalar(train_summary_writer, key1="train")
reporter.tensorboard_add_scalar(valid_summary_writer, key1="valid")
- if trainer_options.use_wandb:
- reporter.wandb_log()
+ # if trainer_options.use_wandb:
+ # reporter.wandb_log()
# save tensorboard on oss
if trainer_options.use_pai and train_summary_writer is not None:
@@ -434,25 +409,25 @@
"The best model has been updated: " + ", ".join(_improved)
)
- log_model = (
- trainer_options.wandb_model_log_interval > 0
- and iepoch % trainer_options.wandb_model_log_interval == 0
- )
- if log_model and trainer_options.use_wandb:
- import wandb
-
- logging.info("Logging Model on this epoch :::::")
- artifact = wandb.Artifact(
- name=f"model_{wandb.run.id}",
- type="model",
- metadata={"improved": _improved},
- )
- artifact.add_file(str(output_dir / f"{iepoch}epoch.pb"))
- aliases = [
- f"epoch-{iepoch}",
- "best" if best_epoch == iepoch else "",
- ]
- wandb.log_artifact(artifact, aliases=aliases)
+ # log_model = (
+ # trainer_options.wandb_model_log_interval > 0
+ # and iepoch % trainer_options.wandb_model_log_interval == 0
+ # )
+ # if log_model and trainer_options.use_wandb:
+ # import wandb
+ #
+ # logging.info("Logging Model on this epoch :::::")
+ # artifact = wandb.Artifact(
+ # name=f"model_{wandb.run.id}",
+ # type="model",
+ # metadata={"improved": _improved},
+ # )
+ # artifact.add_file(str(output_dir / f"{iepoch}epoch.pb"))
+ # aliases = [
+ # f"epoch-{iepoch}",
+ # "best" if best_epoch == iepoch else "",
+ # ]
+ # wandb.log_artifact(artifact, aliases=aliases)
# 6. Remove the model files excluding n-best epoch and latest epoch
_removed = []
@@ -532,9 +507,8 @@
pai_output_dir=trainer_options.output_dir,
)
- @classmethod
def train_one_epoch(
- cls,
+ self,
model: torch.nn.Module,
iterator: Iterable[Tuple[List[str], Dict[str, torch.Tensor]]],
optimizers: Sequence[torch.optim.Optimizer],
@@ -545,16 +519,15 @@
options: TrainerOptions,
distributed_option: DistributedOption,
) -> Tuple[bool, bool]:
- assert check_argument_types()
grad_noise = options.grad_noise
accum_grad = options.accum_grad
grad_clip = options.grad_clip
grad_clip_type = options.grad_clip_type
log_interval = options.log_interval
- no_forward_run = options.no_forward_run
+ # no_forward_run = options.no_forward_run
ngpu = options.ngpu
- use_wandb = options.use_wandb
+ # use_wandb = options.use_wandb
distributed = distributed_option.distributed
if log_interval is None:
@@ -570,31 +543,11 @@
# processes, send stop-flag to the other processes if iterator is finished
iterator_stop = torch.tensor(0).to("cuda" if ngpu > 0 else "cpu")
- # get the rank
- rank = distributed_option.dist_rank
- # get the num batch updates
- num_batch_updates = 0
- # ouput dir
- output_dir = Path(options.output_dir)
- # batch interval
- batch_interval = options.batch_interval
- assert batch_interval > 0
-
start_time = time.perf_counter()
for iiter, (_, batch) in enumerate(
reporter.measure_iter_time(iterator, "iter_time"), 1
):
assert isinstance(batch, dict), type(batch)
-
- if rank == 0:
- if hasattr(model, "num_updates") or (hasattr(model, "module") and hasattr(model.module, "num_updates")):
- num_batch_updates = model.get_num_updates() if hasattr(model,
- "num_updates") else model.module.get_num_updates()
- if (num_batch_updates % batch_interval == 0) and (options.oss_bucket is not None) and options.use_pai:
- buffer = BytesIO()
- torch.save(model.state_dict(), buffer)
- options.oss_bucket.put_object(os.path.join(output_dir, f"{num_batch_updates}batch.pth"),
- buffer.getvalue())
if distributed:
torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM)
@@ -602,9 +555,9 @@
break
batch = to_device(batch, "cuda" if ngpu > 0 else "cpu")
- if no_forward_run:
- all_steps_are_invalid = False
- continue
+ # if no_forward_run:
+ # all_steps_are_invalid = False
+ # continue
with autocast(scaler is not None):
with reporter.measure_time("forward_time"):
@@ -780,8 +733,8 @@
logging.info(reporter.log_message(-log_interval, num_updates=num_updates))
if summary_writer is not None:
reporter.tensorboard_add_scalar(summary_writer, -log_interval)
- if use_wandb:
- reporter.wandb_log()
+ # if use_wandb:
+ # reporter.wandb_log()
if max_update_stop:
break
@@ -792,19 +745,17 @@
torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM)
return all_steps_are_invalid, max_update_stop
- @classmethod
@torch.no_grad()
def validate_one_epoch(
- cls,
+ self,
model: torch.nn.Module,
iterator: Iterable[Dict[str, torch.Tensor]],
reporter: SubReporter,
options: TrainerOptions,
distributed_option: DistributedOption,
) -> None:
- assert check_argument_types()
ngpu = options.ngpu
- no_forward_run = options.no_forward_run
+ # no_forward_run = options.no_forward_run
distributed = distributed_option.distributed
model.eval()
@@ -820,8 +771,8 @@
break
batch = to_device(batch, "cuda" if ngpu > 0 else "cpu")
- if no_forward_run:
- continue
+ # if no_forward_run:
+ # continue
retval = model(**batch)
if isinstance(retval, dict):
@@ -841,3 +792,24 @@
if distributed:
iterator_stop.fill_(1)
torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM)
+
+
+def build_trainer(
+ args,
+ model: FunASRModel,
+ optimizers: Sequence[torch.optim.Optimizer],
+ schedulers: Sequence[Optional[AbsScheduler]],
+ train_dataloader: AbsIterFactory,
+ valid_dataloader: AbsIterFactory,
+ distributed_option: DistributedOption
+):
+ trainer = Trainer(
+ args=args,
+ model=model,
+ optimizers=optimizers,
+ schedulers=schedulers,
+ train_dataloader=train_dataloader,
+ valid_dataloader=valid_dataloader,
+ distributed_option=distributed_option
+ )
+ return trainer
--
Gitblit v1.9.1