| | |
| | | |
| | | |
| | | @triton.jit |
| | | def dtw_kernel( |
| | | cost, trace, x, x_stride, cost_stride, trace_stride, N, M, BLOCK_SIZE: tl.constexpr |
| | | ): |
| | | def dtw_kernel(cost, trace, x, x_stride, cost_stride, trace_stride, N, M, BLOCK_SIZE: tl.constexpr): |
| | | offsets = tl.arange(0, BLOCK_SIZE) |
| | | mask = offsets < M |
| | | |
| | |
| | | @lru_cache(maxsize=None) |
| | | def median_kernel(filter_width: int): |
| | | @triton.jit |
| | | def kernel( |
| | | y, x, x_stride, y_stride, BLOCK_SIZE: tl.constexpr |
| | | ): # x.shape[-1] == filter_width |
| | | def kernel(y, x, x_stride, y_stride, BLOCK_SIZE: tl.constexpr): # x.shape[-1] == filter_width |
| | | row_idx = tl.program_id(0) |
| | | offsets = tl.arange(0, BLOCK_SIZE) |
| | | mask = offsets < y_stride |
| | |
| | | kernel.src = kernel.src.replace( |
| | | " LOAD_ALL_ROWS_HERE", |
| | | "\n".join( |
| | | [ |
| | | f" row{i} = tl.load(x_ptr + offsets + {i}, mask=mask)" |
| | | for i in range(filter_width) |
| | | ] |
| | | [f" row{i} = tl.load(x_ptr + offsets + {i}, mask=mask)" for i in range(filter_width)] |
| | | ), |
| | | ) |
| | | kernel.src = kernel.src.replace( |