diff --git a/csrc/transformer/inference/csrc/apply_rotary_pos_emb.cu b/csrc/transformer/inference/csrc/apply_rotary_pos_emb.cu index a06dbb48fd332..25a494111c54b 100644 --- a/csrc/transformer/inference/csrc/apply_rotary_pos_emb.cu +++ b/csrc/transformer/inference/csrc/apply_rotary_pos_emb.cu @@ -99,17 +99,9 @@ __global__ void apply_rotary_pos_half(T* mixed_query, rope_theta, \ max_out_tokens); -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) and ROCM_WAVEFRONT_SIZE == 64 #define LAUNCH_FOR_ALIGNMENT(ALIGNMENT) \ - if (threads_per_head == 4) { \ - LAUNCH_ROT_POS_EMB_HALF(4, ALIGNMENT); \ - } else if (threads_per_head == 8) { \ - LAUNCH_ROT_POS_EMB_HALF(8, ALIGNMENT); \ - } else if (threads_per_head == 16) { \ - LAUNCH_ROT_POS_EMB_HALF(16, ALIGNMENT); \ - } else if (threads_per_head == 32) { \ - LAUNCH_ROT_POS_EMB_HALF(32, ALIGNMENT); \ - } else if (threads_per_head == 64) { \ + if (threads_per_head == 64) { \ LAUNCH_ROT_POS_EMB_HALF(64, ALIGNMENT); \ } else { \ assert(false); \