From 1596f6f414f6f41da66506debb1dff19fffeb3ec Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 24 六月 2024 11:55:17 +0800
Subject: [PATCH] fixbug hotwords
---
funasr/models/sense_voice/whisper_lib/triton_ops.py | 13 +++----------
1 files changed, 3 insertions(+), 10 deletions(-)
diff --git a/funasr/models/sense_voice/whisper_lib/triton_ops.py b/funasr/models/sense_voice/whisper_lib/triton_ops.py
index edd4564..9919595 100644
--- a/funasr/models/sense_voice/whisper_lib/triton_ops.py
+++ b/funasr/models/sense_voice/whisper_lib/triton_ops.py
@@ -11,9 +11,7 @@
@triton.jit
-def dtw_kernel(
- cost, trace, x, x_stride, cost_stride, trace_stride, N, M, BLOCK_SIZE: tl.constexpr
-):
+def dtw_kernel(cost, trace, x, x_stride, cost_stride, trace_stride, N, M, BLOCK_SIZE: tl.constexpr):
offsets = tl.arange(0, BLOCK_SIZE)
mask = offsets < M
@@ -43,9 +41,7 @@
@lru_cache(maxsize=None)
def median_kernel(filter_width: int):
@triton.jit
- def kernel(
- y, x, x_stride, y_stride, BLOCK_SIZE: tl.constexpr
- ): # x.shape[-1] == filter_width
+ def kernel(y, x, x_stride, y_stride, BLOCK_SIZE: tl.constexpr): # x.shape[-1] == filter_width
row_idx = tl.program_id(0)
offsets = tl.arange(0, BLOCK_SIZE)
mask = offsets < y_stride
@@ -63,10 +59,7 @@
kernel.src = kernel.src.replace(
" LOAD_ALL_ROWS_HERE",
"\n".join(
- [
- f" row{i} = tl.load(x_ptr + offsets + {i}, mask=mask)"
- for i in range(filter_width)
- ]
+ [f" row{i} = tl.load(x_ptr + offsets + {i}, mask=mask)" for i in range(filter_width)]
),
)
kernel.src = kernel.src.replace(
--
Gitblit v1.9.1