From df5f263e5fe3d7961b1aeb3589012400a9905a8f Mon Sep 17 00:00:00 2001
From: 嘉渊 <wangjiaming.wjm@alibaba-inc.com>
Date: 星期一, 24 四月 2023 16:17:41 +0800
Subject: [PATCH] update

---
 funasr/build_utils/build_trainer.py |   42 ++++++++++++++++++++++++++++++------------
 1 files changed, 30 insertions(+), 12 deletions(-)

diff --git a/funasr/build_utils/build_trainer.py b/funasr/build_utils/build_trainer.py
index 55bc89c..060b57f 100644
--- a/funasr/build_utils/build_trainer.py
+++ b/funasr/build_utils/build_trainer.py
@@ -107,7 +107,6 @@
                  schedulers: Sequence[Optional[AbsScheduler]],
                  train_dataloader: AbsIterFactory,
                  valid_dataloader: AbsIterFactory,
-                 trainer_options,
                  distributed_option: DistributedOption):
         self.trainer_options = self.build_options(args)
         self.model = model
@@ -115,7 +114,6 @@
         self.schedulers = schedulers
         self.train_dataloader = train_dataloader
         self.valid_dataloader = valid_dataloader
-        self.trainer_options = trainer_options
         self.distributed_option = distributed_option
 
     def build_options(self, args: argparse.Namespace) -> TrainerOptions:
@@ -128,16 +126,15 @@
         """Reserved for future development of another Trainer"""
         pass
 
-    @staticmethod
-    def resume(
-            checkpoint: Union[str, Path],
-            model: torch.nn.Module,
-            reporter: Reporter,
-            optimizers: Sequence[torch.optim.Optimizer],
-            schedulers: Sequence[Optional[AbsScheduler]],
-            scaler: Optional[GradScaler],
-            ngpu: int = 0,
-    ):
+    def resume(self,
+               checkpoint: Union[str, Path],
+               model: torch.nn.Module,
+               reporter: Reporter,
+               optimizers: Sequence[torch.optim.Optimizer],
+               schedulers: Sequence[Optional[AbsScheduler]],
+               scaler: Optional[GradScaler],
+               ngpu: int = 0,
+               ):
         states = torch.load(
             checkpoint,
             map_location=f"cuda:{torch.cuda.current_device()}" if ngpu > 0 else "cpu",
@@ -800,3 +797,24 @@
             if distributed:
                 iterator_stop.fill_(1)
                 torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM)
+
+
+def build_trainer(
+        args,
+        model: FunASRModel,
+        optimizers: Sequence[torch.optim.Optimizer],
+        schedulers: Sequence[Optional[AbsScheduler]],
+        train_dataloader: AbsIterFactory,
+        valid_dataloader: AbsIterFactory,
+        distributed_option: DistributedOption
+):
+    trainer = Trainer(
+        args=args,
+        model=model,
+        optimizers=optimizers,
+        schedulers=schedulers,
+        train_dataloader=train_dataloader,
+        valid_dataloader=valid_dataloader,
+        distributed_option=distributed_option
+    )
+    return trainer

--
Gitblit v1.9.1