From 44ca8f9f9beec642d72a394c3a347da505e567d8 Mon Sep 17 00:00:00 2001 From: Yucheng Li Date: Tue, 3 Dec 2024 03:30:25 +0000 Subject: [PATCH] boundary check in indexing --- csrc/vertical_slash_index.cu | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/csrc/vertical_slash_index.cu b/csrc/vertical_slash_index.cu index 45af042..7839343 100644 --- a/csrc/vertical_slash_index.cu +++ b/csrc/vertical_slash_index.cu @@ -60,12 +60,20 @@ __global__ void convert_vertical_slash_indexes_kernel( int tmp_col_cnt = 0, tmp_blk_cnt = 0; int s = 0, v = 0; - int v_idx = vertical_indexes[v++]; - int s_idx = slash_indexes[s++]; - while (s_idx >= end_m) { + + // in case of vs are empty + int v_idx = (v < NNZ_V) ? vertical_indexes[v++] : (end_m + BLOCK_SIZE_M); + int s_idx = (s < NNZ_S) ? slash_indexes[s++] : -1; + + // make sure s_idx is valid + while (s_idx >= end_m && s < NNZ_S) { s_idx = slash_indexes[s++]; } - s_idx = max(end_m - s_idx, BLOCK_SIZE_M); + if (s_idx >= end_m) { + s_idx = end_m + BLOCK_SIZE_M; + } else { + s_idx = max(end_m - s_idx, BLOCK_SIZE_M); + } int range_start = s_idx - BLOCK_SIZE_M, range_end = s_idx; while (1) { if (v_idx < range_end) {