From d43d0853dcf3a1db04302c7b527e92ace3ccfb55 Mon Sep 17 00:00:00 2001
From: AldarisX <aldaris@axnet.icu>
Date: 星期一, 07 四月 2025 21:20:31 +0800
Subject: [PATCH] add intel xpu support (#2468)
---
funasr/frontends/fused.py | 2 ++
funasr/utils/export_utils.py | 6 +++---
funasr/auto/auto_model.py | 4 +++-
3 files changed, 8 insertions(+), 4 deletions(-)
diff --git a/funasr/auto/auto_model.py b/funasr/auto/auto_model.py
index 96c642e..d274fb9 100644
--- a/funasr/auto/auto_model.py
+++ b/funasr/auto/auto_model.py
@@ -182,7 +182,9 @@
set_all_random_seed(kwargs.get("seed", 0))
device = kwargs.get("device", "cuda")
- if not torch.cuda.is_available() or kwargs.get("ngpu", 1) == 0:
+ if ((device =="cuda" and not torch.cuda.is_available())
+ or (device == "xpu" and not torch.xpu.is_available())
+ or kwargs.get("ngpu", 1) == 0):
device = "cpu"
kwargs["batch_size"] = 1
kwargs["device"] = device
diff --git a/funasr/frontends/fused.py b/funasr/frontends/fused.py
index 0fa7639..0935910 100644
--- a/funasr/frontends/fused.py
+++ b/funasr/frontends/fused.py
@@ -78,6 +78,8 @@
self.factors = [frontend.hop_length // self.gcd for frontend in self.frontends]
if torch.cuda.is_available():
dev = "cuda"
+ elif torch.xpu.is_available():
+ dev = "xpu"
else:
dev = "cpu"
if self.align_method == "linear_projection":
diff --git a/funasr/utils/export_utils.py b/funasr/utils/export_utils.py
index ca04d75..c89dd77 100644
--- a/funasr/utils/export_utils.py
+++ b/funasr/utils/export_utils.py
@@ -28,12 +28,12 @@
**kwargs,
)
elif type == "torchscript":
- device = "cuda" if torch.cuda.is_available() else "cpu"
+ device = "cuda" if torch.cuda.is_available() else "xpu" if torch.xpu.is_available() else "cpu"
print("Exporting torchscripts on device {}".format(device))
_torchscripts(m, path=export_dir, device=device)
elif type == "bladedisc":
assert (
- torch.cuda.is_available()
+ torch.cuda.is_available() or torch.xpu.is_available()
), "Currently bladedisc optimization for FunASR only supports GPU"
# bladedisc only optimizes encoder/decoder modules
if hasattr(m, "encoder") and hasattr(m, "decoder"):
@@ -44,7 +44,7 @@
elif type == "onnx_fp16":
assert (
- torch.cuda.is_available()
+ torch.cuda.is_available() or torch.xpu.is_available()
), "Currently onnx_fp16 optimization for FunASR only supports GPU"
if hasattr(m, "encoder") and hasattr(m, "decoder"):
--
Gitblit v1.9.1