From e7237d8cb4f2430190817c260ea747e594d6ac35 Mon Sep 17 00:00:00 2001
From: xmx0632 <xmx0632@foxmail.com>
Date: 星期一, 14 四月 2025 13:40:12 +0800
Subject: [PATCH] add mac m1 mps support (#2477)

---
 funasr/frontends/fused.py    |    2 ++
 funasr/utils/export_utils.py |    6 +++---
 funasr/auto/auto_model.py    |    1 +
 3 files changed, 6 insertions(+), 3 deletions(-)

diff --git a/funasr/auto/auto_model.py b/funasr/auto/auto_model.py
index d274fb9..10d2ef6 100644
--- a/funasr/auto/auto_model.py
+++ b/funasr/auto/auto_model.py
@@ -184,6 +184,7 @@
         device = kwargs.get("device", "cuda")
         if ((device =="cuda" and not torch.cuda.is_available())
             or (device == "xpu" and not torch.xpu.is_available())
+            or (device == "mps" and not torch.backends.mps.is_available())
             or kwargs.get("ngpu", 1) == 0):
             device = "cpu"
             kwargs["batch_size"] = 1
diff --git a/funasr/frontends/fused.py b/funasr/frontends/fused.py
index 0935910..1da1e9f 100644
--- a/funasr/frontends/fused.py
+++ b/funasr/frontends/fused.py
@@ -80,6 +80,8 @@
             dev = "cuda"
         elif torch.xpu.is_available():
             dev = "xpu"
+        elif torch.backends.mps.is_available():
+            dev = "mps"
         else:
             dev = "cpu"
         if self.align_method == "linear_projection":
diff --git a/funasr/utils/export_utils.py b/funasr/utils/export_utils.py
index c89dd77..b03b052 100644
--- a/funasr/utils/export_utils.py
+++ b/funasr/utils/export_utils.py
@@ -28,12 +28,12 @@
                 **kwargs,
             )
         elif type == "torchscript":
-            device = "cuda" if torch.cuda.is_available() else "xpu" if torch.xpu.is_available() else "cpu"
+            device = "cuda" if torch.cuda.is_available() else "xpu" if torch.xpu.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
             print("Exporting torchscripts on device {}".format(device))
             _torchscripts(m, path=export_dir, device=device)
         elif type == "bladedisc":
             assert (
-                torch.cuda.is_available() or torch.xpu.is_available()
+                torch.cuda.is_available() or torch.xpu.is_available() or torch.backends.mps.is_available()
             ), "Currently bladedisc optimization for FunASR only supports GPU"
             # bladedisc only optimizes encoder/decoder modules
             if hasattr(m, "encoder") and hasattr(m, "decoder"):
@@ -44,7 +44,7 @@
 
         elif type == "onnx_fp16":
             assert (
-                torch.cuda.is_available() or torch.xpu.is_available()
+                torch.cuda.is_available() or torch.xpu.is_available() or torch.backends.mps.is_available()
             ), "Currently onnx_fp16 optimization for FunASR only supports GPU"
 
             if hasattr(m, "encoder") and hasattr(m, "decoder"):

--
Gitblit v1.9.1