ds
游雁
2024-05-20 b3b10158097b10aa26ee3469c5ba8fd20c745de3
ds
1个文件已修改
20 ■■■■ 已修改文件
funasr/train_utils/trainer_ds.py 20 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/train_utils/trainer_ds.py
@@ -574,12 +574,12 @@
            loss_dict["lr"] = scheduler.get_last_lr()[0]
            loss_dict["batch_num_epoch"] = len(dataloader_train)
            self.val_loss_avg = (
                self.val_loss_avg * batch_idx + loss_dict["loss"].detach().cpu().item()
            self.train_loss_avg = (
                self.train_loss_avg * batch_idx + loss_dict["loss"].detach().cpu().item()
            ) / (batch_idx + 1)
            if "acc" in loss_dict["stats"]:
                self.val_acc_avg = (
                    self.val_acc_avg * batch_idx + loss_dict["stats"]["acc"].detach().cpu().item()
                self.train_acc_avg = (
                    self.train_acc_avg * batch_idx + loss_dict["stats"]["acc"].detach().cpu().item()
                ) / (batch_idx + 1)
            self.log(loss_dict, tag="train")
@@ -612,12 +612,12 @@
            time_beg = time.perf_counter()
        if self.use_ddp or self.use_fsdp or self.use_deepspeed:
            val_loss_avg = torch.tensor(self.val_loss_avg, dtype=torch.float32).to(self.device)
            val_acc_avg = torch.tensor(self.val_acc_avg, dtype=torch.float32).to(self.device)
            dist.all_reduce(val_loss_avg, op=dist.ReduceOp.SUM)
            dist.all_reduce(val_acc_avg, op=dist.ReduceOp.SUM)
            self.val_loss_avg = val_loss_avg.detach().cpu().item() / self.world_size
            self.val_acc_avg = val_acc_avg.detach().cpu().item() / self.world_size
            train_loss_avg = torch.tensor(self.train_loss_avg, dtype=torch.float32).to(self.device)
            train_acc_avg = torch.tensor(self.train_acc_avg, dtype=torch.float32).to(self.device)
            dist.all_reduce(train_loss_avg, op=dist.ReduceOp.SUM)
            dist.all_reduce(train_acc_avg, op=dist.ReduceOp.SUM)
            self.train_loss_avg = train_loss_avg.detach().cpu().item() / self.world_size
            self.train_acc_avg = train_acc_avg.detach().cpu().item() / self.world_size
    def forward_step(self, model, batch, loss_dict={}):
        dtype = torch.bfloat16