funasr/bin/train.py @@ -128,7 +128,8 @@ else: model = model.to(device=kwargs.get("device", "cuda")) logging.info(f"{model}") if local_rank == 0: logging.info(f"{model}") kwargs["device"] = next(model.parameters()).device # optim