From 3d9f094e9652d4b84894c6fd4eae39a4a753b0f0 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期二, 16 五月 2023 23:48:00 +0800
Subject: [PATCH] train

---
 funasr/train/trainer.py |   40 ++++++++++++++++++++++++++++------------
 1 files changed, 28 insertions(+), 12 deletions(-)

diff --git a/funasr/train/trainer.py b/funasr/train/trainer.py
index 2260f00..4052448 100644
--- a/funasr/train/trainer.py
+++ b/funasr/train/trainer.py
@@ -39,7 +39,7 @@
 from funasr.torch_utils.device_funcs import to_device
 from funasr.torch_utils.recursive_op import recursive_average
 from funasr.torch_utils.set_all_random_seed import set_all_random_seed
-from funasr.train.abs_espnet_model import AbsESPnetModel
+from funasr.models.base_model import FunASRModel
 from funasr.train.distributed_utils import DistributedOption
 from funasr.train.reporter import Reporter
 from funasr.train.reporter import SubReporter
@@ -95,6 +95,7 @@
     use_pai: bool
     oss_bucket: Union[oss2.Bucket, None]
     batch_interval: int
+    bias_grad_times: float
 
 class Trainer:
     """Trainer having a optimizer.
@@ -165,7 +166,7 @@
     @classmethod
     def run(
         cls,
-        model: AbsESPnetModel,
+        model: FunASRModel,
         optimizers: Sequence[torch.optim.Optimizer],
         schedulers: Sequence[Optional[AbsScheduler]],
         train_iter_factory: AbsIterFactory,
@@ -186,9 +187,6 @@
                 logging.warning("No keep_nbest_models is given. Change to [1]")
                 trainer_options.keep_nbest_models = [1]
             keep_nbest_models = trainer_options.keep_nbest_models
-     
-        #assert batch_interval is set and >0
-        assert trainer_options.batch_interval > 0
  
         output_dir = Path(trainer_options.output_dir)
         reporter = Reporter()
@@ -549,8 +547,11 @@
         no_forward_run = options.no_forward_run
         ngpu = options.ngpu
         use_wandb = options.use_wandb
+        bias_grad_times = options.bias_grad_times
         distributed = distributed_option.distributed
 
+        if bias_grad_times != 1.0:
+            logging.warning("Using bias_grad_times: {} for gradient scaling".format(bias_grad_times))
         if log_interval is None:
             try:
                 log_interval = max(len(iterator) // 20, 10)
@@ -571,8 +572,7 @@
         #ouput dir
         output_dir = Path(options.output_dir)
         #batch interval
-        batch_interval = options.batch_interval       
-        assert batch_interval > 0
+        batch_interval = options.batch_interval
  
         start_time = time.perf_counter()
         for iiter, (_, batch) in enumerate(
@@ -580,16 +580,22 @@
         ):
             assert isinstance(batch, dict), type(batch)
 
-            if rank == 0:
+            if batch_interval > 0 and (not distributed_option.distributed or rank == 0):
                 if hasattr(model, "num_updates") or (hasattr(model, "module") and hasattr(model.module, "num_updates")):
                     num_batch_updates = model.get_num_updates() if hasattr(model,"num_updates") else model.module.get_num_updates()
-                if (num_batch_updates%batch_interval == 0) and (options.oss_bucket is not None):
-                    if options.use_pai:
+                if num_batch_updates % batch_interval == 0:
+                    if options.use_pai and options.oss_bucket is not None:
                         buffer = BytesIO()
-                        torch.save(model.state_dict(), buffer)
+                        if hasattr(model, "module"):
+                            torch.save(model.module.state_dict(), buffer)
+                        else:
+                            torch.save(model.state_dict(), buffer)
                         options.oss_bucket.put_object(os.path.join(output_dir, f"{num_batch_updates}step.pb"), buffer.getvalue())
                     else:
-                        torch.save(model.state_dict(), os.path.join(output_dir, f"{num_batch_updates}step.pb"))
+                        if hasattr(model, "module"):
+                            torch.save(model.module.state_dict(), os.path.join(output_dir, f"{num_batch_updates}step.pb"))
+                        else:
+                            torch.save(model.state_dict(), os.path.join(output_dir, f"{num_batch_updates}step.pb"))
 
             if distributed:
                 torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM)
@@ -688,6 +694,16 @@
                         scale_factor=0.55,
                     )
 
+                # for contextual training
+                if bias_grad_times != 1.0:
+                    # contextual related parameter names
+                    cr_pnames = ["bias_encoder", "bias_embed", "decoder.bias_decoder", "decoder.bias_output"]
+                    for name, param in model.named_parameters():
+                        for cr_pname in cr_pnames:
+                            if cr_pname in name:
+                                param.grad *= bias_grad_times
+                                continue
+
                 # compute the gradient norm to check if it is normal or not
                 grad_norm = torch.nn.utils.clip_grad_norm_(
                     model.parameters(),

--
Gitblit v1.9.1