From 580b11b57ac4b62f7e2acda73813a4e10e8e4cd3 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期二, 10 十月 2023 17:17:29 +0800
Subject: [PATCH] v0.8.0
---
funasr/build_utils/build_trainer.py | 99 +++++++++++++++++++++++--------------------------
1 files changed, 47 insertions(+), 52 deletions(-)
diff --git a/funasr/build_utils/build_trainer.py b/funasr/build_utils/build_trainer.py
index 060b57f..03aa780 100644
--- a/funasr/build_utils/build_trainer.py
+++ b/funasr/build_utils/build_trainer.py
@@ -25,7 +25,6 @@
import torch
import torch.nn
import torch.optim
-from typeguard import check_argument_types
from funasr.iterators.abs_iter_factory import AbsIterFactory
from funasr.main_funcs.average_nbest_models import average_nbest_models
@@ -75,14 +74,14 @@
grad_clip: float
grad_clip_type: float
log_interval: Optional[int]
- no_forward_run: bool
+ # no_forward_run: bool
use_tensorboard: bool
- use_wandb: bool
+ # use_wandb: bool
output_dir: Union[Path, str]
max_epoch: int
max_update: int
seed: int
- sharded_ddp: bool
+ # sharded_ddp: bool
patience: Optional[int]
keep_nbest_models: Union[int, List[int]]
nbest_averaging_interval: int
@@ -90,7 +89,7 @@
best_model_criterion: Sequence[Sequence[str]]
val_scheduler_criterion: Sequence[str]
unused_parameters: bool
- wandb_model_log_interval: int
+ # wandb_model_log_interval: int
use_pai: bool
oss_bucket: Union[oss2.Bucket, None]
@@ -118,7 +117,6 @@
def build_options(self, args: argparse.Namespace) -> TrainerOptions:
"""Build options consumed by train(), eval()"""
- assert check_argument_types()
return build_dataclass(TrainerOptions, args)
@classmethod
@@ -156,7 +154,6 @@
def run(self) -> None:
"""Perform training. This method performs the main process of training."""
- assert check_argument_types()
# NOTE(kamo): Don't check the type more strictly as far trainer_options
model = self.model
optimizers = self.optimizers
@@ -183,14 +180,14 @@
raise RuntimeError(
"Require torch>=1.6.0 for Automatic Mixed Precision"
)
- if trainer_options.sharded_ddp:
- if fairscale is None:
- raise RuntimeError(
- "Requiring fairscale. Do 'pip install fairscale'"
- )
- scaler = fairscale.optim.grad_scaler.ShardedGradScaler()
- else:
- scaler = GradScaler()
+ # if trainer_options.sharded_ddp:
+ # if fairscale is None:
+ # raise RuntimeError(
+ # "Requiring fairscale. Do 'pip install fairscale'"
+ # )
+ # scaler = fairscale.optim.grad_scaler.ShardedGradScaler()
+ # else:
+ scaler = GradScaler()
else:
scaler = None
@@ -295,10 +292,10 @@
)
elif isinstance(scheduler, AbsEpochStepScheduler):
scheduler.step()
- if trainer_options.sharded_ddp:
- for optimizer in optimizers:
- if isinstance(optimizer, fairscale.optim.oss.OSS):
- optimizer.consolidate_state_dict()
+ # if trainer_options.sharded_ddp:
+ # for optimizer in optimizers:
+ # if isinstance(optimizer, fairscale.optim.oss.OSS):
+ # optimizer.consolidate_state_dict()
if not distributed_option.distributed or distributed_option.dist_rank == 0:
# 3. Report the results
@@ -306,8 +303,8 @@
if train_summary_writer is not None:
reporter.tensorboard_add_scalar(train_summary_writer, key1="train")
reporter.tensorboard_add_scalar(valid_summary_writer, key1="valid")
- if trainer_options.use_wandb:
- reporter.wandb_log()
+ # if trainer_options.use_wandb:
+ # reporter.wandb_log()
# save tensorboard on oss
if trainer_options.use_pai and train_summary_writer is not None:
@@ -412,25 +409,25 @@
"The best model has been updated: " + ", ".join(_improved)
)
- log_model = (
- trainer_options.wandb_model_log_interval > 0
- and iepoch % trainer_options.wandb_model_log_interval == 0
- )
- if log_model and trainer_options.use_wandb:
- import wandb
-
- logging.info("Logging Model on this epoch :::::")
- artifact = wandb.Artifact(
- name=f"model_{wandb.run.id}",
- type="model",
- metadata={"improved": _improved},
- )
- artifact.add_file(str(output_dir / f"{iepoch}epoch.pb"))
- aliases = [
- f"epoch-{iepoch}",
- "best" if best_epoch == iepoch else "",
- ]
- wandb.log_artifact(artifact, aliases=aliases)
+ # log_model = (
+ # trainer_options.wandb_model_log_interval > 0
+ # and iepoch % trainer_options.wandb_model_log_interval == 0
+ # )
+ # if log_model and trainer_options.use_wandb:
+ # import wandb
+ #
+ # logging.info("Logging Model on this epoch :::::")
+ # artifact = wandb.Artifact(
+ # name=f"model_{wandb.run.id}",
+ # type="model",
+ # metadata={"improved": _improved},
+ # )
+ # artifact.add_file(str(output_dir / f"{iepoch}epoch.pb"))
+ # aliases = [
+ # f"epoch-{iepoch}",
+ # "best" if best_epoch == iepoch else "",
+ # ]
+ # wandb.log_artifact(artifact, aliases=aliases)
# 6. Remove the model files excluding n-best epoch and latest epoch
_removed = []
@@ -522,16 +519,15 @@
options: TrainerOptions,
distributed_option: DistributedOption,
) -> Tuple[bool, bool]:
- assert check_argument_types()
grad_noise = options.grad_noise
accum_grad = options.accum_grad
grad_clip = options.grad_clip
grad_clip_type = options.grad_clip_type
log_interval = options.log_interval
- no_forward_run = options.no_forward_run
+ # no_forward_run = options.no_forward_run
ngpu = options.ngpu
- use_wandb = options.use_wandb
+ # use_wandb = options.use_wandb
distributed = distributed_option.distributed
if log_interval is None:
@@ -559,9 +555,9 @@
break
batch = to_device(batch, "cuda" if ngpu > 0 else "cpu")
- if no_forward_run:
- all_steps_are_invalid = False
- continue
+ # if no_forward_run:
+ # all_steps_are_invalid = False
+ # continue
with autocast(scaler is not None):
with reporter.measure_time("forward_time"):
@@ -737,8 +733,8 @@
logging.info(reporter.log_message(-log_interval, num_updates=num_updates))
if summary_writer is not None:
reporter.tensorboard_add_scalar(summary_writer, -log_interval)
- if use_wandb:
- reporter.wandb_log()
+ # if use_wandb:
+ # reporter.wandb_log()
if max_update_stop:
break
@@ -758,9 +754,8 @@
options: TrainerOptions,
distributed_option: DistributedOption,
) -> None:
- assert check_argument_types()
ngpu = options.ngpu
- no_forward_run = options.no_forward_run
+ # no_forward_run = options.no_forward_run
distributed = distributed_option.distributed
model.eval()
@@ -776,8 +771,8 @@
break
batch = to_device(batch, "cuda" if ngpu > 0 else "cpu")
- if no_forward_run:
- continue
+ # if no_forward_run:
+ # continue
retval = model(**batch)
if isinstance(retval, dict):
--
Gitblit v1.9.1