游雁
2024-06-24 1596f6f414f6f41da66506debb1dff19fffeb3ec
funasr/models/sense_voice/whisper_lib/triton_ops.py
@@ -11,9 +11,7 @@
@triton.jit
def dtw_kernel(
    cost, trace, x, x_stride, cost_stride, trace_stride, N, M, BLOCK_SIZE: tl.constexpr
):
def dtw_kernel(cost, trace, x, x_stride, cost_stride, trace_stride, N, M, BLOCK_SIZE: tl.constexpr):
    offsets = tl.arange(0, BLOCK_SIZE)
    mask = offsets < M
@@ -43,9 +41,7 @@
@lru_cache(maxsize=None)
def median_kernel(filter_width: int):
    @triton.jit
    def kernel(
        y, x, x_stride, y_stride, BLOCK_SIZE: tl.constexpr
    ):  # x.shape[-1] == filter_width
    def kernel(y, x, x_stride, y_stride, BLOCK_SIZE: tl.constexpr):  # x.shape[-1] == filter_width
        row_idx = tl.program_id(0)
        offsets = tl.arange(0, BLOCK_SIZE)
        mask = offsets < y_stride
@@ -63,10 +59,7 @@
    kernel.src = kernel.src.replace(
        "    LOAD_ALL_ROWS_HERE",
        "\n".join(
            [
                f"    row{i} = tl.load(x_ptr + offsets + {i}, mask=mask)"
                for i in range(filter_width)
            ]
            [f"    row{i} = tl.load(x_ptr + offsets + {i}, mask=mask)" for i in range(filter_width)]
        ),
    )
    kernel.src = kernel.src.replace(