| | |
| | | time4 = time.perf_counter() |
| | | loss_dict["speed_stats"]["backward_time"] = f"{time4 - time3:0.3f}" |
| | | |
| | | self.update_step(model, optim, scheduler, scaler, loss_dict) |
| | | self.update_step(model, optim, scheduler, scaler, loss_dict=loss_dict) |
| | | total_time = f"{(time.perf_counter() - time5) / accum_grad:0.3f}" |
| | | time5 = time.perf_counter() |
| | | |
| | |
| | | model=model, |
| | | dataloader_val=dataloader_val, |
| | | epoch=epoch, |
| | | writer=writer, |
| | | writer=self.writer, |
| | | step=batch_idx + 1, |
| | | step_in_epoch=self.step_in_epoch, |
| | | ) |
| | |
| | | else: |
| | | loss.backward() |
| | | |
| | | def update_step(self, model, optim, scheduler, scaler, batch_idx=0, loss_dict=None): |
| | | |
| | | def update_step(self, model, optim, scheduler, scaler, loss_dict=None): |
| | | batch_idx = loss_dict["batch_idx"] |
| | | if self.use_deepspeed: |
| | | model.step() |
| | | else: |
| | |
| | | from funasr.schedulers import scheduler_classes |
| | | from omegaconf import OmegaConf, DictConfig |
| | | import json |
| | | import deepspeed |
| | | |
| | | # optim |
| | | logging.info("Build optim") |
| | |
| | | scheduler = scheduler_class(optim, **kwargs.get("scheduler_conf")) |
| | | |
| | | if self.use_deepspeed: |
| | | import deepspeed |
| | | |
| | | args = OmegaConf.create({"deepspeed_config": self.deepspeed_config}) |
| | | with open(self.deepspeed_config, "r") as fin: |