From 4c3e502cb8fbbd16c7cf37e0c2564050cc55fd16 Mon Sep 17 00:00:00 2001
From: 嘉渊 <wangjiaming.wjm@alibaba-inc.com>
Date: 星期二, 25 四月 2023 01:23:14 +0800
Subject: [PATCH] update

---
 funasr/build_utils/build_trainer.py |   86 +++++++++++++++++++++---------------------
 1 files changed, 43 insertions(+), 43 deletions(-)

diff --git a/funasr/build_utils/build_trainer.py b/funasr/build_utils/build_trainer.py
index 71fc6df..aff99b5 100644
--- a/funasr/build_utils/build_trainer.py
+++ b/funasr/build_utils/build_trainer.py
@@ -183,14 +183,14 @@
                 raise RuntimeError(
                     "Require torch>=1.6.0 for  Automatic Mixed Precision"
                 )
-            if trainer_options.sharded_ddp:
-                if fairscale is None:
-                    raise RuntimeError(
-                        "Requiring fairscale. Do 'pip install fairscale'"
-                    )
-                scaler = fairscale.optim.grad_scaler.ShardedGradScaler()
-            else:
-                scaler = GradScaler()
+            # if trainer_options.sharded_ddp:
+            #     if fairscale is None:
+            #         raise RuntimeError(
+            #             "Requiring fairscale. Do 'pip install fairscale'"
+            #         )
+            #     scaler = fairscale.optim.grad_scaler.ShardedGradScaler()
+            # else:
+            scaler = GradScaler()
         else:
             scaler = None
 
@@ -295,10 +295,10 @@
                     )
                 elif isinstance(scheduler, AbsEpochStepScheduler):
                     scheduler.step()
-            if trainer_options.sharded_ddp:
-                for optimizer in optimizers:
-                    if isinstance(optimizer, fairscale.optim.oss.OSS):
-                        optimizer.consolidate_state_dict()
+            # if trainer_options.sharded_ddp:
+            #     for optimizer in optimizers:
+            #         if isinstance(optimizer, fairscale.optim.oss.OSS):
+            #             optimizer.consolidate_state_dict()
 
             if not distributed_option.distributed or distributed_option.dist_rank == 0:
                 # 3. Report the results
@@ -306,8 +306,8 @@
                 if train_summary_writer is not None:
                     reporter.tensorboard_add_scalar(train_summary_writer, key1="train")
                     reporter.tensorboard_add_scalar(valid_summary_writer, key1="valid")
-                if trainer_options.use_wandb:
-                    reporter.wandb_log()
+                # if trainer_options.use_wandb:
+                #     reporter.wandb_log()
 
                 # save tensorboard on oss
                 if trainer_options.use_pai and train_summary_writer is not None:
@@ -412,25 +412,25 @@
                         "The best model has been updated: " + ", ".join(_improved)
                     )
 
-                log_model = (
-                        trainer_options.wandb_model_log_interval > 0
-                        and iepoch % trainer_options.wandb_model_log_interval == 0
-                )
-                if log_model and trainer_options.use_wandb:
-                    import wandb
-
-                    logging.info("Logging Model on this epoch :::::")
-                    artifact = wandb.Artifact(
-                        name=f"model_{wandb.run.id}",
-                        type="model",
-                        metadata={"improved": _improved},
-                    )
-                    artifact.add_file(str(output_dir / f"{iepoch}epoch.pb"))
-                    aliases = [
-                        f"epoch-{iepoch}",
-                        "best" if best_epoch == iepoch else "",
-                    ]
-                    wandb.log_artifact(artifact, aliases=aliases)
+                # log_model = (
+                #         trainer_options.wandb_model_log_interval > 0
+                #         and iepoch % trainer_options.wandb_model_log_interval == 0
+                # )
+                # if log_model and trainer_options.use_wandb:
+                #     import wandb
+                #
+                #     logging.info("Logging Model on this epoch :::::")
+                #     artifact = wandb.Artifact(
+                #         name=f"model_{wandb.run.id}",
+                #         type="model",
+                #         metadata={"improved": _improved},
+                #     )
+                #     artifact.add_file(str(output_dir / f"{iepoch}epoch.pb"))
+                #     aliases = [
+                #         f"epoch-{iepoch}",
+                #         "best" if best_epoch == iepoch else "",
+                #     ]
+                #     wandb.log_artifact(artifact, aliases=aliases)
 
                 # 6. Remove the model files excluding n-best epoch and latest epoch
                 _removed = []
@@ -529,9 +529,9 @@
         grad_clip = options.grad_clip
         grad_clip_type = options.grad_clip_type
         log_interval = options.log_interval
-        no_forward_run = options.no_forward_run
+        # no_forward_run = options.no_forward_run
         ngpu = options.ngpu
-        use_wandb = options.use_wandb
+        # use_wandb = options.use_wandb
         distributed = distributed_option.distributed
 
         if log_interval is None:
@@ -559,9 +559,9 @@
                     break
 
             batch = to_device(batch, "cuda" if ngpu > 0 else "cpu")
-            if no_forward_run:
-                all_steps_are_invalid = False
-                continue
+            # if no_forward_run:
+            #     all_steps_are_invalid = False
+            #     continue
 
             with autocast(scaler is not None):
                 with reporter.measure_time("forward_time"):
@@ -737,8 +737,8 @@
                 logging.info(reporter.log_message(-log_interval, num_updates=num_updates))
                 if summary_writer is not None:
                     reporter.tensorboard_add_scalar(summary_writer, -log_interval)
-                if use_wandb:
-                    reporter.wandb_log()
+                # if use_wandb:
+                #     reporter.wandb_log()
 
             if max_update_stop:
                 break
@@ -760,7 +760,7 @@
     ) -> None:
         assert check_argument_types()
         ngpu = options.ngpu
-        no_forward_run = options.no_forward_run
+        # no_forward_run = options.no_forward_run
         distributed = distributed_option.distributed
 
         model.eval()
@@ -776,8 +776,8 @@
                     break
 
             batch = to_device(batch, "cuda" if ngpu > 0 else "cpu")
-            if no_forward_run:
-                continue
+            # if no_forward_run:
+            #     continue
 
             retval = model(**batch)
             if isinstance(retval, dict):

--
Gitblit v1.9.1