Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cuda.parallel: Support structured types as algorithm inputs #3218

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

shwina
Copy link
Contributor

@shwina shwina commented Dec 25, 2024

Description

Closes #3135.

This PR enables using structured data types with cuda.parallel algorithms.

The numba CUDA target doesn't directly support using structured data types as inputs to device functions. Thus, the implementation works by defining a custom numba data type corresponding to the structured type, and compiling the user-provided reduction function for that custom data type.

⚠️ This PR supports "flat" (non-nested) struct dtypes. I think it should be relatively straightforward to add support for nested structs, but would prefer to do that in a future PR.

Additional Context: Numba support for struct types

This PR involves wrapping a struct type in a custom numba ("wrapper") type. Ostensibly, numba supports using struct types directly, but for the CUDA target we get cudaErrorIllegalAddress when numba kernels are invoked on inputs of struct type.

I don't fully understand the reasons for this, but it may be more apparent to someone else who is more familiar with reading the PTX. I suspect there are alignment issues somewhere when using struct types, as they translate to pointer-to-struct arguments to the generated device function.

Consider the following device function:

def max_g_value(x, y):
    return x if x['g'] > y['g'] else y

Below are the PTX generated by numba after compiling the function for (1) inputs as raw struct dtypes, (2) inputs as "wrapper" types. If I use the code generated for raw struct inputs, I get cupy_backends.cuda.api.runtime.CUDARuntimeError: cudaErrorIllegalAddress: an illegal memory access was encountered.

Raw struct input PTX
//
// Generated by NVIDIA NVVM Compiler
//
// Compiler Build ID: CL-35059454
// Cuda compilation tools, release 12.6, V12.6.85
// Based on NVVM 7.0.1
//

.version 8.5
.target sm_50
.address_size 64

        // .globl       max_g_value
.common .global .align 8 .u64 _ZN08NumbaEnv8__main__11max_g_valueB2v1B92cw51cXTLSUwv1sDUaKthqqNgoKmjgOR3W3CwAkMXLaJtQYkOIgxJU0gCqOkEJoHkbttqdVhoqlspQGNFHSgJ5BnXagIAE6RecordILi598EE6RecordILi598EE;
.visible .func  (.param .b64 func_retval0) max_g_value(
        .param .b64 max_g_value_param_0,
        .param .b64 max_g_value_param_1
)
{
        .reg .pred      %p<2>;
        .reg .b32       %r<3>;
        .reg .b64       %rd<4>;


        ld.param.u64    %rd1, [max_g_value_param_0];
        ld.param.u64    %rd2, [max_g_value_param_1];
        ld.u32  %r1, [%rd1+4];
        ld.u32  %r2, [%rd2+4];
        setp.gt.s32     %p1, %r1, %r2;
        selp.b64        %rd3, %rd1, %rd2, %p1;
        st.param.b64    [func_retval0+0], %rd3;
        ret;

}
Wrapper type input
//
// Generated by NVIDIA NVVM Compiler
//
// Compiler Build ID: CL-35059454
// Cuda compilation tools, release 12.6, V12.6.85
// Based on NVVM 7.0.1
//

.version 8.5
.target sm_50
.address_size 64

        // .globl       max_g_value
.common .global .align 8 .u64 _ZN08NumbaEnv8__main__11max_g_valueB2v1B92cw51cXTLSUwv1sDUaKthqqNgoKmjgOR3W3CwAkMXLaJtQYkOIgxJU0gCqOkEJoHkbttqdVhoqlspQGNFHSgJ5BnXagIAE13StructWrapper13StructWrapper;

.visible .func  (.param .align 4 .b8 func_retval0[12]) max_g_value(
        .param .b32 max_g_value_param_0,
        .param .b32 max_g_value_param_1,
        .param .b32 max_g_value_param_2,
        .param .b32 max_g_value_param_3,
        .param .b32 max_g_value_param_4,
        .param .b32 max_g_value_param_5
)
{
        .reg .pred      %p<2>;
        .reg .b32       %r<10>;


        ld.param.u32    %r1, [max_g_value_param_0];
        ld.param.u32    %r2, [max_g_value_param_1];
        ld.param.u32    %r3, [max_g_value_param_2];
        ld.param.u32    %r4, [max_g_value_param_3];
        ld.param.u32    %r5, [max_g_value_param_4];
        ld.param.u32    %r6, [max_g_value_param_5];
        setp.gt.s32     %p1, %r2, %r5;
        selp.b32        %r7, %r3, %r6, %p1;
        max.s32         %r8, %r2, %r5;
        selp.b32        %r9, %r1, %r4, %p1;
        st.param.b32    [func_retval0+0], %r9;
        st.param.b32    [func_retval0+4], %r8;
        st.param.b32    [func_retval0+8], %r7;
        ret;

}

(for posterity, the PTX generated can be viewed by compiling the code for output type "ptx" in this function and running the unit test introduced in this PR).

Checklist

  • New or existing tests cover these changes.
  • The documentation is up to date with these changes.

Copy link

copy-pr-bot bot commented Dec 25, 2024

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@shwina
Copy link
Contributor Author

shwina commented Dec 25, 2024

/ok to test

1 similar comment
@shwina
Copy link
Contributor Author

shwina commented Dec 25, 2024

/ok to test

self.ltoir, _ = cuda.compile(
op, sig=value_type(value_type, value_type), output="ltoir"
)
# if h_init is a struct, wrap it in a Record type:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# if h_init is a struct, wrap it in a Record type:
# if h_init is a struct, wrap it in a custom numba struct-like type:


def wrap_struct(dtype: np.dtype) -> numba.types.Type:
"""
Wrap the given numpy structure dtype in a numba type.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: explain this better

Copy link
Contributor

🟩 CI finished in 24m 05s: Pass: 100%/1 | Total: 24m 05s | Avg: 24m 05s | Max: 24m 05s
  • 🟩 python: Pass: 100%/1 | Total: 24m 05s | Avg: 24m 05s | Max: 24m 05s

    🟩 cpu
      🟩 amd64              Pass: 100%/1   | Total: 24m 05s | Avg: 24m 05s | Max: 24m 05s
    🟩 ctk
      🟩 12.6               Pass: 100%/1   | Total: 24m 05s | Avg: 24m 05s | Max: 24m 05s
    🟩 cudacxx
      🟩 nvcc12.6           Pass: 100%/1   | Total: 24m 05s | Avg: 24m 05s | Max: 24m 05s
    🟩 cudacxx_family
      🟩 nvcc               Pass: 100%/1   | Total: 24m 05s | Avg: 24m 05s | Max: 24m 05s
    🟩 cxx
      🟩 GCC13              Pass: 100%/1   | Total: 24m 05s | Avg: 24m 05s | Max: 24m 05s
    🟩 cxx_family
      🟩 GCC                Pass: 100%/1   | Total: 24m 05s | Avg: 24m 05s | Max: 24m 05s
    🟩 gpu
      🟩 v100               Pass: 100%/1   | Total: 24m 05s | Avg: 24m 05s | Max: 24m 05s
    🟩 jobs
      🟩 Test               Pass: 100%/1   | Total: 24m 05s | Avg: 24m 05s | Max: 24m 05s
    

👃 Inspect Changes

Modifications in project?

Project
CCCL Infrastructure
libcu++
CUB
Thrust
CUDA Experimental
+/- python
CCCL C Parallel Library
Catch2Helper

Modifications in project or dependencies?

Project
CCCL Infrastructure
libcu++
CUB
Thrust
CUDA Experimental
+/- python
CCCL C Parallel Library
Catch2Helper

🏃‍ Runner counts (total jobs: 1)

# Runner
1 linux-amd64-gpu-v100-latest-1

@shwina shwina force-pushed the cuda-parallel-support-structured-dtypes branch from 7fab84d to 2bd35e3 Compare January 1, 2025 12:40
@shwina shwina force-pushed the cuda-parallel-support-structured-dtypes branch from 2bd35e3 to ac9cc55 Compare January 4, 2025 15:40
@shwina shwina force-pushed the cuda-parallel-support-structured-dtypes branch from ac9cc55 to a78a187 Compare January 6, 2025 11:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
Status: In Progress
Development

Successfully merging this pull request may close these issues.

[FEA]: Enable using custom data types with cuda.parallel
1 participant