funasr/bin/train.py @@ -128,6 +128,7 @@ else: model = model.to(device=kwargs.get("device", "cuda")) if local_rank == 0: logging.info(f"{model}") kwargs["device"] = next(model.parameters()).device