From 21efa9b9b785b5aecd0ed7f2c428848c37fccd45 Mon Sep 17 00:00:00 2001 From: rraminen Date: Thu, 11 Apr 2024 21:47:53 +0000 Subject: [PATCH] Support on all AMD GPUs --- .../inference/csrc/apply_rotary_pos_emb.cu | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/csrc/transformer/inference/csrc/apply_rotary_pos_emb.cu b/csrc/transformer/inference/csrc/apply_rotary_pos_emb.cu index a06dbb48fd332..25a494111c54b 100644 --- a/csrc/transformer/inference/csrc/apply_rotary_pos_emb.cu +++ b/csrc/transformer/inference/csrc/apply_rotary_pos_emb.cu @@ -99,17 +99,9 @@ __global__ void apply_rotary_pos_half(T* mixed_query, rope_theta, \ max_out_tokens); -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) and ROCM_WAVEFRONT_SIZE == 64 #define LAUNCH_FOR_ALIGNMENT(ALIGNMENT) \ - if (threads_per_head == 4) { \ - LAUNCH_ROT_POS_EMB_HALF(4, ALIGNMENT); \ - } else if (threads_per_head == 8) { \ - LAUNCH_ROT_POS_EMB_HALF(8, ALIGNMENT); \ - } else if (threads_per_head == 16) { \ - LAUNCH_ROT_POS_EMB_HALF(16, ALIGNMENT); \ - } else if (threads_per_head == 32) { \ - LAUNCH_ROT_POS_EMB_HALF(32, ALIGNMENT); \ - } else if (threads_per_head == 64) { \ + if (threads_per_head == 64) { \ LAUNCH_ROT_POS_EMB_HALF(64, ALIGNMENT); \ } else { \ assert(false); \