From d43d0853dcf3a1db04302c7b527e92ace3ccfb55 Mon Sep 17 00:00:00 2001
From: AldarisX <aldaris@axnet.icu>
Date: 星期一, 07 四月 2025 21:20:31 +0800
Subject: [PATCH] add intel xpu support (#2468)
---
funasr/frontends/fused.py | 16 ++++++----------
1 files changed, 6 insertions(+), 10 deletions(-)
diff --git a/funasr/frontends/fused.py b/funasr/frontends/fused.py
index 24f73f4..0935910 100644
--- a/funasr/frontends/fused.py
+++ b/funasr/frontends/fused.py
@@ -7,14 +7,10 @@
class FusedFrontends(nn.Module):
- def __init__(
- self, frontends=None, align_method="linear_projection", proj_dim=100, fs=16000
- ):
+ def __init__(self, frontends=None, align_method="linear_projection", proj_dim=100, fs=16000):
super().__init__()
- self.align_method = (
- align_method # fusing method : linear_projection only for now
- )
+ self.align_method = align_method # fusing method : linear_projection only for now
self.proj_dim = proj_dim # dim of the projection done on each frontend
self.frontends = [] # list of the frontends to combine
@@ -82,6 +78,8 @@
self.factors = [frontend.hop_length // self.gcd for frontend in self.frontends]
if torch.cuda.is_available():
dev = "cuda"
+ elif torch.xpu.is_available():
+ dev = "xpu"
else:
dev = "cpu"
if self.align_method == "linear_projection":
@@ -109,9 +107,7 @@
input_feats, feats_lens = frontend.forward(input, input_lengths)
self.feats.append([input_feats, feats_lens])
- if (
- self.align_method == "linear_projection"
- ): # TODO(Dan): to add other align methods
+ if self.align_method == "linear_projection": # TODO(Dan): to add other align methods
# first step : projections
self.feats_proj = []
@@ -141,4 +137,4 @@
else:
raise NotImplementedError
- return input_feats, feats_lens
\ No newline at end of file
+ return input_feats, feats_lens
--
Gitblit v1.9.1