Skip to content

Commit

Permalink
copy and cast kernel move to cuda_utils, good general util
Browse files Browse the repository at this point in the history
  • Loading branch information
karpathy committed Jun 3, 2024
1 parent d697eae commit 45b0d1b
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 32 deletions.
31 changes: 31 additions & 0 deletions llmc/cuda_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,37 @@ __device__ void store128cg(ElementType* target, Packed128<ElementType> value) {
typedef Packed128<float> f128;
typedef Packed128<floatX> x128;

// ----------------------------------------------------------------------------
// Copy, cast functions

// device functions and the kernel to cast data between types
template<typename Td, typename Ts>
__device__ Td cast_value(Ts val);

template<>
__device__ float cast_value<float, float>(float val) {
return val;
}

template<>
__device__ float cast_value<float, half>(half val) {
return __half2float(val);
}

template<>
__device__ float cast_value<float, __nv_bfloat16>(__nv_bfloat16 val) {
return __bfloat162float(val);
}

template<typename Td, typename Ts>
__global__ void copy_and_cast_kernel(Td* dst, const Ts* src, size_t n) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
// need to try grid stride looping for more perf later
if (idx < n) {
dst[idx] = cast_value<Td, Ts>(src[idx]);
}
}

// ----------------------------------------------------------------------------
// Warp/Block communication primitives

Expand Down
33 changes: 1 addition & 32 deletions train_gpt2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ GPT-2 Transformer Neural Net training loop. See README.md for usage.
#include "llmc/cuda_common.h"
// defines:
// Packed128, f128, x128
// warpReduceSum, warpReduceMax, blockReduce
// warpReduceSum, warpReduceMax, blockReduce, copy_and_cast_kernel
#include "llmc/cuda_utils.cuh"
// defines: CUBLAS_LOWP, cublasCheck, cublaslt_workspace_size, cublaslt_workspace
// defines: cublas_compute, cublaslt_handle, cublas_handle
Expand Down Expand Up @@ -250,37 +250,6 @@ void set_zero_configs(MultiGpuConfig* multi_gpu_config, int zero_stage, size_t t
}
}

// ----------------------------------------------------------------------------
// Kernels

// device functions and the kernel to cast data between types
template<typename Td, typename Ts>
__device__ Td cast_value(Ts val);

template<>
__device__ float cast_value<float, float>(float val) {
return val;
}

template<>
__device__ float cast_value<float, half>(half val) {
return __half2float(val);
}

template<>
__device__ float cast_value<float, __nv_bfloat16>(__nv_bfloat16 val) {
return __bfloat162float(val);
}

template<typename Td, typename Ts>
__global__ void copy_and_cast_kernel(Td* dst, const Ts* src, size_t n) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
// need to try grid stride looping for more perf later
if (idx < n) {
dst[idx] = cast_value<Td, Ts>(src[idx]);
}
}

// ----------------------------------------------------------------------------
// GPT-2 model definition

Expand Down

0 comments on commit 45b0d1b

Please sign in to comment.