From 1596f6f414f6f41da66506debb1dff19fffeb3ec Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 24 六月 2024 11:55:17 +0800
Subject: [PATCH] fixbug hotwords

---
 funasr/models/sense_voice/whisper_lib/triton_ops.py |   13 +++----------
 1 files changed, 3 insertions(+), 10 deletions(-)

diff --git a/funasr/models/sense_voice/whisper_lib/triton_ops.py b/funasr/models/sense_voice/whisper_lib/triton_ops.py
index edd4564..9919595 100644
--- a/funasr/models/sense_voice/whisper_lib/triton_ops.py
+++ b/funasr/models/sense_voice/whisper_lib/triton_ops.py
@@ -11,9 +11,7 @@
 
 
 @triton.jit
-def dtw_kernel(
-    cost, trace, x, x_stride, cost_stride, trace_stride, N, M, BLOCK_SIZE: tl.constexpr
-):
+def dtw_kernel(cost, trace, x, x_stride, cost_stride, trace_stride, N, M, BLOCK_SIZE: tl.constexpr):
     offsets = tl.arange(0, BLOCK_SIZE)
     mask = offsets < M
 
@@ -43,9 +41,7 @@
 @lru_cache(maxsize=None)
 def median_kernel(filter_width: int):
     @triton.jit
-    def kernel(
-        y, x, x_stride, y_stride, BLOCK_SIZE: tl.constexpr
-    ):  # x.shape[-1] == filter_width
+    def kernel(y, x, x_stride, y_stride, BLOCK_SIZE: tl.constexpr):  # x.shape[-1] == filter_width
         row_idx = tl.program_id(0)
         offsets = tl.arange(0, BLOCK_SIZE)
         mask = offsets < y_stride
@@ -63,10 +59,7 @@
     kernel.src = kernel.src.replace(
         "    LOAD_ALL_ROWS_HERE",
         "\n".join(
-            [
-                f"    row{i} = tl.load(x_ptr + offsets + {i}, mask=mask)"
-                for i in range(filter_width)
-            ]
+            [f"    row{i} = tl.load(x_ptr + offsets + {i}, mask=mask)" for i in range(filter_width)]
         ),
     )
     kernel.src = kernel.src.replace(

--
Gitblit v1.9.1