zhifu gao
2024-06-06 32e783664534bbb8d3b8ba64c2c2ecb42398eb00
funasr/train_utils/trainer_ds.py
@@ -621,7 +621,6 @@
            self.train_acc_avg = train_acc_avg.detach().cpu().item() / self.world_size
    def forward_step(self, model, batch, loss_dict={}):
        dtype = torch.bfloat16
        with maybe_autocast(dtype=self.dtype, use_deepspeed=self.use_deepspeed):
            retval = model(**batch)