From 580b11b57ac4b62f7e2acda73813a4e10e8e4cd3 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期二, 10 十月 2023 17:17:29 +0800
Subject: [PATCH] v0.8.0

---
 funasr/build_utils/build_trainer.py |   99 +++++++++++++++++++++++--------------------------
 1 files changed, 47 insertions(+), 52 deletions(-)

diff --git a/funasr/build_utils/build_trainer.py b/funasr/build_utils/build_trainer.py
index 060b57f..03aa780 100644
--- a/funasr/build_utils/build_trainer.py
+++ b/funasr/build_utils/build_trainer.py
@@ -25,7 +25,6 @@
 import torch
 import torch.nn
 import torch.optim
-from typeguard import check_argument_types
 
 from funasr.iterators.abs_iter_factory import AbsIterFactory
 from funasr.main_funcs.average_nbest_models import average_nbest_models
@@ -75,14 +74,14 @@
     grad_clip: float
     grad_clip_type: float
     log_interval: Optional[int]
-    no_forward_run: bool
+    # no_forward_run: bool
     use_tensorboard: bool
-    use_wandb: bool
+    # use_wandb: bool
     output_dir: Union[Path, str]
     max_epoch: int
     max_update: int
     seed: int
-    sharded_ddp: bool
+    # sharded_ddp: bool
     patience: Optional[int]
     keep_nbest_models: Union[int, List[int]]
     nbest_averaging_interval: int
@@ -90,7 +89,7 @@
     best_model_criterion: Sequence[Sequence[str]]
     val_scheduler_criterion: Sequence[str]
     unused_parameters: bool
-    wandb_model_log_interval: int
+    # wandb_model_log_interval: int
     use_pai: bool
     oss_bucket: Union[oss2.Bucket, None]
 
@@ -118,7 +117,6 @@
 
     def build_options(self, args: argparse.Namespace) -> TrainerOptions:
         """Build options consumed by train(), eval()"""
-        assert check_argument_types()
         return build_dataclass(TrainerOptions, args)
 
     @classmethod
@@ -156,7 +154,6 @@
 
     def run(self) -> None:
         """Perform training. This method performs the main process of training."""
-        assert check_argument_types()
         # NOTE(kamo): Don't check the type more strictly as far trainer_options
         model = self.model
         optimizers = self.optimizers
@@ -183,14 +180,14 @@
                 raise RuntimeError(
                     "Require torch>=1.6.0 for  Automatic Mixed Precision"
                 )
-            if trainer_options.sharded_ddp:
-                if fairscale is None:
-                    raise RuntimeError(
-                        "Requiring fairscale. Do 'pip install fairscale'"
-                    )
-                scaler = fairscale.optim.grad_scaler.ShardedGradScaler()
-            else:
-                scaler = GradScaler()
+            # if trainer_options.sharded_ddp:
+            #     if fairscale is None:
+            #         raise RuntimeError(
+            #             "Requiring fairscale. Do 'pip install fairscale'"
+            #         )
+            #     scaler = fairscale.optim.grad_scaler.ShardedGradScaler()
+            # else:
+            scaler = GradScaler()
         else:
             scaler = None
 
@@ -295,10 +292,10 @@
                     )
                 elif isinstance(scheduler, AbsEpochStepScheduler):
                     scheduler.step()
-            if trainer_options.sharded_ddp:
-                for optimizer in optimizers:
-                    if isinstance(optimizer, fairscale.optim.oss.OSS):
-                        optimizer.consolidate_state_dict()
+            # if trainer_options.sharded_ddp:
+            #     for optimizer in optimizers:
+            #         if isinstance(optimizer, fairscale.optim.oss.OSS):
+            #             optimizer.consolidate_state_dict()
 
             if not distributed_option.distributed or distributed_option.dist_rank == 0:
                 # 3. Report the results
@@ -306,8 +303,8 @@
                 if train_summary_writer is not None:
                     reporter.tensorboard_add_scalar(train_summary_writer, key1="train")
                     reporter.tensorboard_add_scalar(valid_summary_writer, key1="valid")
-                if trainer_options.use_wandb:
-                    reporter.wandb_log()
+                # if trainer_options.use_wandb:
+                #     reporter.wandb_log()
 
                 # save tensorboard on oss
                 if trainer_options.use_pai and train_summary_writer is not None:
@@ -412,25 +409,25 @@
                         "The best model has been updated: " + ", ".join(_improved)
                     )
 
-                log_model = (
-                        trainer_options.wandb_model_log_interval > 0
-                        and iepoch % trainer_options.wandb_model_log_interval == 0
-                )
-                if log_model and trainer_options.use_wandb:
-                    import wandb
-
-                    logging.info("Logging Model on this epoch :::::")
-                    artifact = wandb.Artifact(
-                        name=f"model_{wandb.run.id}",
-                        type="model",
-                        metadata={"improved": _improved},
-                    )
-                    artifact.add_file(str(output_dir / f"{iepoch}epoch.pb"))
-                    aliases = [
-                        f"epoch-{iepoch}",
-                        "best" if best_epoch == iepoch else "",
-                    ]
-                    wandb.log_artifact(artifact, aliases=aliases)
+                # log_model = (
+                #         trainer_options.wandb_model_log_interval > 0
+                #         and iepoch % trainer_options.wandb_model_log_interval == 0
+                # )
+                # if log_model and trainer_options.use_wandb:
+                #     import wandb
+                #
+                #     logging.info("Logging Model on this epoch :::::")
+                #     artifact = wandb.Artifact(
+                #         name=f"model_{wandb.run.id}",
+                #         type="model",
+                #         metadata={"improved": _improved},
+                #     )
+                #     artifact.add_file(str(output_dir / f"{iepoch}epoch.pb"))
+                #     aliases = [
+                #         f"epoch-{iepoch}",
+                #         "best" if best_epoch == iepoch else "",
+                #     ]
+                #     wandb.log_artifact(artifact, aliases=aliases)
 
                 # 6. Remove the model files excluding n-best epoch and latest epoch
                 _removed = []
@@ -522,16 +519,15 @@
             options: TrainerOptions,
             distributed_option: DistributedOption,
     ) -> Tuple[bool, bool]:
-        assert check_argument_types()
 
         grad_noise = options.grad_noise
         accum_grad = options.accum_grad
         grad_clip = options.grad_clip
         grad_clip_type = options.grad_clip_type
         log_interval = options.log_interval
-        no_forward_run = options.no_forward_run
+        # no_forward_run = options.no_forward_run
         ngpu = options.ngpu
-        use_wandb = options.use_wandb
+        # use_wandb = options.use_wandb
         distributed = distributed_option.distributed
 
         if log_interval is None:
@@ -559,9 +555,9 @@
                     break
 
             batch = to_device(batch, "cuda" if ngpu > 0 else "cpu")
-            if no_forward_run:
-                all_steps_are_invalid = False
-                continue
+            # if no_forward_run:
+            #     all_steps_are_invalid = False
+            #     continue
 
             with autocast(scaler is not None):
                 with reporter.measure_time("forward_time"):
@@ -737,8 +733,8 @@
                 logging.info(reporter.log_message(-log_interval, num_updates=num_updates))
                 if summary_writer is not None:
                     reporter.tensorboard_add_scalar(summary_writer, -log_interval)
-                if use_wandb:
-                    reporter.wandb_log()
+                # if use_wandb:
+                #     reporter.wandb_log()
 
             if max_update_stop:
                 break
@@ -758,9 +754,8 @@
             options: TrainerOptions,
             distributed_option: DistributedOption,
     ) -> None:
-        assert check_argument_types()
         ngpu = options.ngpu
-        no_forward_run = options.no_forward_run
+        # no_forward_run = options.no_forward_run
         distributed = distributed_option.distributed
 
         model.eval()
@@ -776,8 +771,8 @@
                     break
 
             batch = to_device(batch, "cuda" if ngpu > 0 else "cpu")
-            if no_forward_run:
-                continue
+            # if no_forward_run:
+            #     continue
 
             retval = model(**batch)
             if isinstance(retval, dict):

--
Gitblit v1.9.1