游雁
2024-06-07 60b3c42d6d3d90b97918b10d506efd6c471e1ba8
funasr/train_utils/trainer_ds.py
@@ -621,7 +621,6 @@
            self.train_acc_avg = train_acc_avg.detach().cpu().item() / self.world_size
    def forward_step(self, model, batch, loss_dict={}):
        dtype = torch.bfloat16
        with maybe_autocast(dtype=self.dtype, use_deepspeed=self.use_deepspeed):
            retval = model(**batch)