funasr/train_utils/trainer_ds.py @@ -30,8 +30,9 @@ yield else: if dtype == torch.float16 or dtype == torch.bfloat16: with autocast(enabled=True, dtype=dtype): yield yield # with autocast(enabled=True, dtype=dtype): # yield else: yield