游雁
2024-05-17 1ca314955fbe150db9a3f40193ca10736a9a4260
deepspeed
1个文件已修改
10 ■■■■ 已修改文件
funasr/train_utils/trainer_ds.py 10 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/train_utils/trainer_ds.py
@@ -397,7 +397,7 @@
                time4 = time.perf_counter()
                loss_dict["speed_stats"]["backward_time"] = f"{time4 - time3:0.3f}"
            self.update_step(model, optim, scheduler, scaler, loss_dict)
            self.update_step(model, optim, scheduler, scaler, loss_dict=loss_dict)
            total_time = f"{(time.perf_counter() - time5) / accum_grad:0.3f}"
            time5 = time.perf_counter()
@@ -415,7 +415,7 @@
                    model=model,
                    dataloader_val=dataloader_val,
                    epoch=epoch,
                    writer=writer,
                    writer=self.writer,
                    step=batch_idx + 1,
                    step_in_epoch=self.step_in_epoch,
                )
@@ -469,8 +469,8 @@
            else:
                loss.backward()
    def update_step(self, model, optim, scheduler, scaler, batch_idx=0, loss_dict=None):
    def update_step(self, model, optim, scheduler, scaler, loss_dict=None):
        batch_idx = loss_dict["batch_idx"]
        if self.use_deepspeed:
            model.step()
        else:
@@ -747,7 +747,6 @@
        from funasr.schedulers import scheduler_classes
        from omegaconf import OmegaConf, DictConfig
        import json
        import deepspeed
        # optim
        logging.info("Build optim")
@@ -764,6 +763,7 @@
        scheduler = scheduler_class(optim, **kwargs.get("scheduler_conf"))
        if self.use_deepspeed:
            import deepspeed
            args = OmegaConf.create({"deepspeed_config": self.deepspeed_config})
            with open(self.deepspeed_config, "r") as fin: