From 4c3e502cb8fbbd16c7cf37e0c2564050cc55fd16 Mon Sep 17 00:00:00 2001
From: 嘉渊 <wangjiaming.wjm@alibaba-inc.com>
Date: 星期二, 25 四月 2023 01:23:14 +0800
Subject: [PATCH] update
---
funasr/build_utils/build_trainer.py | 86 +++++++++++++++++++++---------------------
1 files changed, 43 insertions(+), 43 deletions(-)
diff --git a/funasr/build_utils/build_trainer.py b/funasr/build_utils/build_trainer.py
index 71fc6df..aff99b5 100644
--- a/funasr/build_utils/build_trainer.py
+++ b/funasr/build_utils/build_trainer.py
@@ -183,14 +183,14 @@
raise RuntimeError(
"Require torch>=1.6.0 for Automatic Mixed Precision"
)
- if trainer_options.sharded_ddp:
- if fairscale is None:
- raise RuntimeError(
- "Requiring fairscale. Do 'pip install fairscale'"
- )
- scaler = fairscale.optim.grad_scaler.ShardedGradScaler()
- else:
- scaler = GradScaler()
+ # if trainer_options.sharded_ddp:
+ # if fairscale is None:
+ # raise RuntimeError(
+ # "Requiring fairscale. Do 'pip install fairscale'"
+ # )
+ # scaler = fairscale.optim.grad_scaler.ShardedGradScaler()
+ # else:
+ scaler = GradScaler()
else:
scaler = None
@@ -295,10 +295,10 @@
)
elif isinstance(scheduler, AbsEpochStepScheduler):
scheduler.step()
- if trainer_options.sharded_ddp:
- for optimizer in optimizers:
- if isinstance(optimizer, fairscale.optim.oss.OSS):
- optimizer.consolidate_state_dict()
+ # if trainer_options.sharded_ddp:
+ # for optimizer in optimizers:
+ # if isinstance(optimizer, fairscale.optim.oss.OSS):
+ # optimizer.consolidate_state_dict()
if not distributed_option.distributed or distributed_option.dist_rank == 0:
# 3. Report the results
@@ -306,8 +306,8 @@
if train_summary_writer is not None:
reporter.tensorboard_add_scalar(train_summary_writer, key1="train")
reporter.tensorboard_add_scalar(valid_summary_writer, key1="valid")
- if trainer_options.use_wandb:
- reporter.wandb_log()
+ # if trainer_options.use_wandb:
+ # reporter.wandb_log()
# save tensorboard on oss
if trainer_options.use_pai and train_summary_writer is not None:
@@ -412,25 +412,25 @@
"The best model has been updated: " + ", ".join(_improved)
)
- log_model = (
- trainer_options.wandb_model_log_interval > 0
- and iepoch % trainer_options.wandb_model_log_interval == 0
- )
- if log_model and trainer_options.use_wandb:
- import wandb
-
- logging.info("Logging Model on this epoch :::::")
- artifact = wandb.Artifact(
- name=f"model_{wandb.run.id}",
- type="model",
- metadata={"improved": _improved},
- )
- artifact.add_file(str(output_dir / f"{iepoch}epoch.pb"))
- aliases = [
- f"epoch-{iepoch}",
- "best" if best_epoch == iepoch else "",
- ]
- wandb.log_artifact(artifact, aliases=aliases)
+ # log_model = (
+ # trainer_options.wandb_model_log_interval > 0
+ # and iepoch % trainer_options.wandb_model_log_interval == 0
+ # )
+ # if log_model and trainer_options.use_wandb:
+ # import wandb
+ #
+ # logging.info("Logging Model on this epoch :::::")
+ # artifact = wandb.Artifact(
+ # name=f"model_{wandb.run.id}",
+ # type="model",
+ # metadata={"improved": _improved},
+ # )
+ # artifact.add_file(str(output_dir / f"{iepoch}epoch.pb"))
+ # aliases = [
+ # f"epoch-{iepoch}",
+ # "best" if best_epoch == iepoch else "",
+ # ]
+ # wandb.log_artifact(artifact, aliases=aliases)
# 6. Remove the model files excluding n-best epoch and latest epoch
_removed = []
@@ -529,9 +529,9 @@
grad_clip = options.grad_clip
grad_clip_type = options.grad_clip_type
log_interval = options.log_interval
- no_forward_run = options.no_forward_run
+ # no_forward_run = options.no_forward_run
ngpu = options.ngpu
- use_wandb = options.use_wandb
+ # use_wandb = options.use_wandb
distributed = distributed_option.distributed
if log_interval is None:
@@ -559,9 +559,9 @@
break
batch = to_device(batch, "cuda" if ngpu > 0 else "cpu")
- if no_forward_run:
- all_steps_are_invalid = False
- continue
+ # if no_forward_run:
+ # all_steps_are_invalid = False
+ # continue
with autocast(scaler is not None):
with reporter.measure_time("forward_time"):
@@ -737,8 +737,8 @@
logging.info(reporter.log_message(-log_interval, num_updates=num_updates))
if summary_writer is not None:
reporter.tensorboard_add_scalar(summary_writer, -log_interval)
- if use_wandb:
- reporter.wandb_log()
+ # if use_wandb:
+ # reporter.wandb_log()
if max_update_stop:
break
@@ -760,7 +760,7 @@
) -> None:
assert check_argument_types()
ngpu = options.ngpu
- no_forward_run = options.no_forward_run
+ # no_forward_run = options.no_forward_run
distributed = distributed_option.distributed
model.eval()
@@ -776,8 +776,8 @@
break
batch = to_device(batch, "cuda" if ngpu > 0 else "cpu")
- if no_forward_run:
- continue
+ # if no_forward_run:
+ # continue
retval = model(**batch)
if isinstance(retval, dict):
--
Gitblit v1.9.1