From ccd4c4d240af6414c86af606e6ad9a01ac52e991 Mon Sep 17 00:00:00 2001
From: 嘉渊 <wangjiaming.wjm@alibaba-inc.com>
Date: 星期日, 23 四月 2023 17:47:12 +0800
Subject: [PATCH] update

---
 funasr/build_utils/build_trainer.py |   42 ++++++++++++++++++++++++++++++++----------
 1 files changed, 32 insertions(+), 10 deletions(-)

diff --git a/funasr/build_utils/build_trainer.py b/funasr/build_utils/build_trainer.py
index 55bc89c..dd592e4 100644
--- a/funasr/build_utils/build_trainer.py
+++ b/funasr/build_utils/build_trainer.py
@@ -128,16 +128,15 @@
         """Reserved for future development of another Trainer"""
         pass
 
-    @staticmethod
-    def resume(
-            checkpoint: Union[str, Path],
-            model: torch.nn.Module,
-            reporter: Reporter,
-            optimizers: Sequence[torch.optim.Optimizer],
-            schedulers: Sequence[Optional[AbsScheduler]],
-            scaler: Optional[GradScaler],
-            ngpu: int = 0,
-    ):
+    def resume(self,
+               checkpoint: Union[str, Path],
+               model: torch.nn.Module,
+               reporter: Reporter,
+               optimizers: Sequence[torch.optim.Optimizer],
+               schedulers: Sequence[Optional[AbsScheduler]],
+               scaler: Optional[GradScaler],
+               ngpu: int = 0,
+               ):
         states = torch.load(
             checkpoint,
             map_location=f"cuda:{torch.cuda.current_device()}" if ngpu > 0 else "cpu",
@@ -800,3 +799,26 @@
             if distributed:
                 iterator_stop.fill_(1)
                 torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM)
+
+
+def build_trainer(
+        args,
+        model: FunASRModel,
+        optimizers: Sequence[torch.optim.Optimizer],
+        schedulers: Sequence[Optional[AbsScheduler]],
+        train_dataloader: AbsIterFactory,
+        valid_dataloader: AbsIterFactory,
+        trainer_options,
+        distributed_option: DistributedOption
+):
+    trainer = Trainer(
+        args=args,
+        model=model,
+        optimizers=optimizers,
+        schedulers=schedulers,
+        train_dataloader=train_dataloader,
+        valid_dataloader=valid_dataloader,
+        trainer_options=trainer_options,
+        distributed_option=distributed_option
+    )
+    return trainer

--
Gitblit v1.9.1