diff --git a/csrc/includes/ds_kernel_utils.h b/csrc/includes/ds_kernel_utils.h index 8e4888109fcd..f8b16ee6a315 100644 --- a/csrc/includes/ds_kernel_utils.h +++ b/csrc/includes/ds_kernel_utils.h @@ -23,7 +23,7 @@ used throughout the codebase. #ifdef __HIP_PLATFORM_AMD__ // constexpr variant of warpSize for templating -constexpr int hw_warp_size = 64; +constexpr int hw_warp_size = ROCM_WAVEFRONT_SIZE; #define HALF_PRECISION_AVAILABLE = 1 #include #include diff --git a/csrc/random_ltd/token_sort.cu b/csrc/random_ltd/token_sort.cu index 3049471cfe34..3c1dff49429f 100644 --- a/csrc/random_ltd/token_sort.cu +++ b/csrc/random_ltd/token_sort.cu @@ -16,7 +16,7 @@ constexpr int mem_vals = granularity / sizeof(int32_t); constexpr int max_buffer_size = (threads + 1) * mem_vals; #ifdef __HIP_PLATFORM_AMD__ -constexpr int warp_size = 64; +constexpr int warp_size = ROCM_WAVEFRONT_SIZE; #else constexpr int warp_size = 32; #endif diff --git a/csrc/transformer/inference/csrc/apply_rotary_pos_emb.cu b/csrc/transformer/inference/csrc/apply_rotary_pos_emb.cu index a06dbb48fd33..25a494111c54 100644 --- a/csrc/transformer/inference/csrc/apply_rotary_pos_emb.cu +++ b/csrc/transformer/inference/csrc/apply_rotary_pos_emb.cu @@ -99,17 +99,9 @@ __global__ void apply_rotary_pos_half(T* mixed_query, rope_theta, \ max_out_tokens); -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) and ROCM_WAVEFRONT_SIZE == 64 #define LAUNCH_FOR_ALIGNMENT(ALIGNMENT) \ - if (threads_per_head == 4) { \ - LAUNCH_ROT_POS_EMB_HALF(4, ALIGNMENT); \ - } else if (threads_per_head == 8) { \ - LAUNCH_ROT_POS_EMB_HALF(8, ALIGNMENT); \ - } else if (threads_per_head == 16) { \ - LAUNCH_ROT_POS_EMB_HALF(16, ALIGNMENT); \ - } else if (threads_per_head == 32) { \ - LAUNCH_ROT_POS_EMB_HALF(32, ALIGNMENT); \ - } else if (threads_per_head == 64) { \ + if (threads_per_head == 64) { \ LAUNCH_ROT_POS_EMB_HALF(64, ALIGNMENT); \ } else { \ assert(false); \ diff --git a/deepspeed/inference/v2/kernels/includes/ds_kernel_utils.h b/deepspeed/inference/v2/kernels/includes/ds_kernel_utils.h index 8e4888109fcd..f8b16ee6a315 100644 --- a/deepspeed/inference/v2/kernels/includes/ds_kernel_utils.h +++ b/deepspeed/inference/v2/kernels/includes/ds_kernel_utils.h @@ -23,7 +23,7 @@ used throughout the codebase. #ifdef __HIP_PLATFORM_AMD__ // constexpr variant of warpSize for templating -constexpr int hw_warp_size = 64; +constexpr int hw_warp_size = ROCM_WAVEFRONT_SIZE; #define HALF_PRECISION_AVAILABLE = 1 #include #include diff --git a/op_builder/builder.py b/op_builder/builder.py index 8dc825c7926d..18c130221b0e 100644 --- a/op_builder/builder.py +++ b/op_builder/builder.py @@ -107,6 +107,8 @@ def assert_no_cuda_mismatch(name=""): class OpBuilder(ABC): _rocm_version = None + _rocm_gpu_arch = None + _rocm_wavefront_size = None _is_rocm_pytorch = None _is_sycl_enabled = None _loaded_ops = {} @@ -229,6 +231,32 @@ def installed_rocm_version(): OpBuilder._rocm_version = (int(ROCM_MAJOR), int(ROCM_MINOR)) return OpBuilder._rocm_version + @staticmethod + def get_rocm_gpu_arch(): + if OpBuilder._rocm_gpu_arch: + return OpBuilder._rocm_gpu_arch + rocm_gpu_arch_cmd = "/opt/rocm/bin/rocminfo | grep -o -m 1 'gfx.*'" + try: + result = subprocess.check_output(rocm_gpu_arch_cmd, shell=True) + rocm_gpu_arch = result.decode('utf-8').strip() + except subprocess.CalledProcessError: + rocm_gpu_arch = "" + OpBuilder._rocm_gpu_arch = rocm_gpu_arch + return OpBuilder._rocm_gpu_arch + + @staticmethod + def get_rocm_wavefront_size(): + if OpBuilder._rocm_wavefront_size: + return OpBuilder._rocm_wavefront_size + rocm_wavefront_size_cmd = "/opt/rocm/bin/rocminfo | grep -Eo -m1 'Wavefront Size:[[:space:]]+[0-9]+' | grep -Eo '[0-9]+'" + try: + result = subprocess.check_output(rocm_wavefront_size_cmd, shell=True) + rocm_wavefront_size = result.decode('utf-8').strip() + except subprocess.CalledProcessError: + rocm_wavefront_size = "32" + OpBuilder._rocm_wavefront_size = rocm_wavefront_size + return OpBuilder._rocm_wavefront_size + def include_paths(self): ''' Returns list of include paths, relative to root of deepspeed package (i.e., DeepSpeed/deepspeed) @@ -520,6 +548,8 @@ def jit_load(self, verbose=True): if self.is_rocm_pytorch(): cxx_args.append("-D__HIP_PLATFORM_AMD__=1") + os.environ["PYTORCH_ROCM_ARCH"] = self.get_rocm_gpu_arch() + cxx_args.append('-DROCM_WAVEFRONT_SIZE=%s' % self.get_rocm_wavefront_size()) op_module = load(name=self.name, sources=self.strip_empty_entries(sources), @@ -650,6 +680,12 @@ def builder(self): if self.is_rocm_pytorch(): compile_args['cxx'].append("-D__HIP_PLATFORM_AMD__=1") + #cxx compiler args are required to compile cpp files + compile_args['cxx'].append('-DROCM_WAVEFRONT_SIZE=%s' % self.get_rocm_wavefront_size()) + #nvcc compiler args are required to compile hip files + compile_args['nvcc'].append('-DROCM_WAVEFRONT_SIZE=%s' % self.get_rocm_wavefront_size()) + if self.get_rocm_gpu_arch(): + os.environ["PYTORCH_ROCM_ARCH"] = self.get_rocm_gpu_arch() cuda_ext = ExtensionBuilder(name=self.absolute_name(), sources=self.strip_empty_entries(self.sources()),