From b5d3df75cf6462aa3bf42fd3c86fa2aa7f1c8a15 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期五, 24 十一月 2023 00:54:44 +0800
Subject: [PATCH] setup jamo
---
funasr/build_utils/build_trainer.py | 108 +++++++++++++++++++++++++-----------------------------
1 files changed, 50 insertions(+), 58 deletions(-)
diff --git a/funasr/build_utils/build_trainer.py b/funasr/build_utils/build_trainer.py
index 060b57f..498d05d 100644
--- a/funasr/build_utils/build_trainer.py
+++ b/funasr/build_utils/build_trainer.py
@@ -25,7 +25,6 @@
import torch
import torch.nn
import torch.optim
-from typeguard import check_argument_types
from funasr.iterators.abs_iter_factory import AbsIterFactory
from funasr.main_funcs.average_nbest_models import average_nbest_models
@@ -75,14 +74,14 @@
grad_clip: float
grad_clip_type: float
log_interval: Optional[int]
- no_forward_run: bool
+ # no_forward_run: bool
use_tensorboard: bool
- use_wandb: bool
+ # use_wandb: bool
output_dir: Union[Path, str]
max_epoch: int
max_update: int
seed: int
- sharded_ddp: bool
+ # sharded_ddp: bool
patience: Optional[int]
keep_nbest_models: Union[int, List[int]]
nbest_averaging_interval: int
@@ -90,7 +89,7 @@
best_model_criterion: Sequence[Sequence[str]]
val_scheduler_criterion: Sequence[str]
unused_parameters: bool
- wandb_model_log_interval: int
+ # wandb_model_log_interval: int
use_pai: bool
oss_bucket: Union[oss2.Bucket, None]
@@ -118,7 +117,6 @@
def build_options(self, args: argparse.Namespace) -> TrainerOptions:
"""Build options consumed by train(), eval()"""
- assert check_argument_types()
return build_dataclass(TrainerOptions, args)
@classmethod
@@ -156,7 +154,6 @@
def run(self) -> None:
"""Perform training. This method performs the main process of training."""
- assert check_argument_types()
# NOTE(kamo): Don't check the type more strictly as far trainer_options
model = self.model
optimizers = self.optimizers
@@ -183,14 +180,14 @@
raise RuntimeError(
"Require torch>=1.6.0 for Automatic Mixed Precision"
)
- if trainer_options.sharded_ddp:
- if fairscale is None:
- raise RuntimeError(
- "Requiring fairscale. Do 'pip install fairscale'"
- )
- scaler = fairscale.optim.grad_scaler.ShardedGradScaler()
- else:
- scaler = GradScaler()
+ # if trainer_options.sharded_ddp:
+ # if fairscale is None:
+ # raise RuntimeError(
+ # "Requiring fairscale. Do 'pip install fairscale'"
+ # )
+ # scaler = fairscale.optim.grad_scaler.ShardedGradScaler()
+ # else:
+ scaler = GradScaler()
else:
scaler = None
@@ -249,14 +246,11 @@
for iepoch in range(start_epoch, trainer_options.max_epoch + 1):
if iepoch != start_epoch:
logging.info(
- "{}/{}epoch started. Estimated time to finish: {}".format(
+ "{}/{}epoch started. Estimated time to finish: {} hours".format(
iepoch,
trainer_options.max_epoch,
- humanfriendly.format_timespan(
- (time.perf_counter() - start_time)
- / (iepoch - start_epoch)
- * (trainer_options.max_epoch - iepoch + 1)
- ),
+ (time.perf_counter() - start_time) / 3600.0 / (iepoch - start_epoch) * (
+ trainer_options.max_epoch - iepoch + 1),
)
)
else:
@@ -295,10 +289,10 @@
)
elif isinstance(scheduler, AbsEpochStepScheduler):
scheduler.step()
- if trainer_options.sharded_ddp:
- for optimizer in optimizers:
- if isinstance(optimizer, fairscale.optim.oss.OSS):
- optimizer.consolidate_state_dict()
+ # if trainer_options.sharded_ddp:
+ # for optimizer in optimizers:
+ # if isinstance(optimizer, fairscale.optim.oss.OSS):
+ # optimizer.consolidate_state_dict()
if not distributed_option.distributed or distributed_option.dist_rank == 0:
# 3. Report the results
@@ -306,8 +300,8 @@
if train_summary_writer is not None:
reporter.tensorboard_add_scalar(train_summary_writer, key1="train")
reporter.tensorboard_add_scalar(valid_summary_writer, key1="valid")
- if trainer_options.use_wandb:
- reporter.wandb_log()
+ # if trainer_options.use_wandb:
+ # reporter.wandb_log()
# save tensorboard on oss
if trainer_options.use_pai and train_summary_writer is not None:
@@ -412,25 +406,25 @@
"The best model has been updated: " + ", ".join(_improved)
)
- log_model = (
- trainer_options.wandb_model_log_interval > 0
- and iepoch % trainer_options.wandb_model_log_interval == 0
- )
- if log_model and trainer_options.use_wandb:
- import wandb
-
- logging.info("Logging Model on this epoch :::::")
- artifact = wandb.Artifact(
- name=f"model_{wandb.run.id}",
- type="model",
- metadata={"improved": _improved},
- )
- artifact.add_file(str(output_dir / f"{iepoch}epoch.pb"))
- aliases = [
- f"epoch-{iepoch}",
- "best" if best_epoch == iepoch else "",
- ]
- wandb.log_artifact(artifact, aliases=aliases)
+ # log_model = (
+ # trainer_options.wandb_model_log_interval > 0
+ # and iepoch % trainer_options.wandb_model_log_interval == 0
+ # )
+ # if log_model and trainer_options.use_wandb:
+ # import wandb
+ #
+ # logging.info("Logging Model on this epoch :::::")
+ # artifact = wandb.Artifact(
+ # name=f"model_{wandb.run.id}",
+ # type="model",
+ # metadata={"improved": _improved},
+ # )
+ # artifact.add_file(str(output_dir / f"{iepoch}epoch.pb"))
+ # aliases = [
+ # f"epoch-{iepoch}",
+ # "best" if best_epoch == iepoch else "",
+ # ]
+ # wandb.log_artifact(artifact, aliases=aliases)
# 6. Remove the model files excluding n-best epoch and latest epoch
_removed = []
@@ -522,16 +516,15 @@
options: TrainerOptions,
distributed_option: DistributedOption,
) -> Tuple[bool, bool]:
- assert check_argument_types()
grad_noise = options.grad_noise
accum_grad = options.accum_grad
grad_clip = options.grad_clip
grad_clip_type = options.grad_clip_type
log_interval = options.log_interval
- no_forward_run = options.no_forward_run
+ # no_forward_run = options.no_forward_run
ngpu = options.ngpu
- use_wandb = options.use_wandb
+ # use_wandb = options.use_wandb
distributed = distributed_option.distributed
if log_interval is None:
@@ -559,9 +552,9 @@
break
batch = to_device(batch, "cuda" if ngpu > 0 else "cpu")
- if no_forward_run:
- all_steps_are_invalid = False
- continue
+ # if no_forward_run:
+ # all_steps_are_invalid = False
+ # continue
with autocast(scaler is not None):
with reporter.measure_time("forward_time"):
@@ -737,8 +730,8 @@
logging.info(reporter.log_message(-log_interval, num_updates=num_updates))
if summary_writer is not None:
reporter.tensorboard_add_scalar(summary_writer, -log_interval)
- if use_wandb:
- reporter.wandb_log()
+ # if use_wandb:
+ # reporter.wandb_log()
if max_update_stop:
break
@@ -758,9 +751,8 @@
options: TrainerOptions,
distributed_option: DistributedOption,
) -> None:
- assert check_argument_types()
ngpu = options.ngpu
- no_forward_run = options.no_forward_run
+ # no_forward_run = options.no_forward_run
distributed = distributed_option.distributed
model.eval()
@@ -776,8 +768,8 @@
break
batch = to_device(batch, "cuda" if ngpu > 0 else "cpu")
- if no_forward_run:
- continue
+ # if no_forward_run:
+ # continue
retval = model(**batch)
if isinstance(retval, dict):
--
Gitblit v1.9.1