Skip to content

Commit

Permalink
Merge pull request karpathy#578 from ngc92/matmul
Browse files Browse the repository at this point in the history
Remove cublaslt from fp32cu versions
  • Loading branch information
karpathy authored Jun 12, 2024
2 parents 42dff87 + d679364 commit 95cef79
Show file tree
Hide file tree
Showing 4 changed files with 200 additions and 91 deletions.
102 changes: 101 additions & 1 deletion dev/cuda/matmul_forward.cu
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,88 @@ __global__ void add_bias(float* out, const float* bias, int B, int T, int OC) {
}
}

// kernel 4: semi-efficient handwritten kernel
// see trimat_forward.cu for some intermediate development steps
__device__ float4 ld_vec(const float* address) {
return *reinterpret_cast<const float4*>(address);
}

__device__ void st_vec(float* address, float4 val) {
*reinterpret_cast<float4*>(address) = val;
}

__global__ void __launch_bounds__(16*16) matmul_forward_kernel4(float* out,
const float* inp, const float* weight, const float* bias,
int C, int OC) {
// out is (B,T,OC). OC is short for "output channels", e.g. OC = 4 * C
// inp is (B,T,C), weight is (OC, C), bias is (OC)
// each thread handles 8x8 elements; each block 128 by 128 elements.
int oc = 8*(blockIdx.y * blockDim.y + threadIdx.y);

// buffers to cache chunks of the input matrices
__shared__ float lhs_s[128][32];
__shared__ float rhs_s[128][32];

// adjust our pointers for the current block
inp += 128 * blockIdx.x * C;
weight += 128 * blockIdx.y * C;
out += 128 * blockIdx.x * OC + 128 * blockIdx.y;

float vals[8][8] = {};
if(bias != NULL) {
for (int i = 0; i < 8; i++) {
for (int j = 0; j < 8; j += 4) {
float4 b = ld_vec(bias + oc + j);
vals[i][j+0] = b.x;
vals[i][j+1] = b.y;
vals[i][j+2] = b.z;
vals[i][j+3] = b.w;
}
}
}

int si_start = 4*(16 * threadIdx.y + threadIdx.x);
for (int so = 0; so < C; so += 32) {
__syncthreads();
int xmod8 = threadIdx.x % 8;
int xby8 = threadIdx.x / 8;
int xo = 4 * xmod8;
for(int y = 2 * threadIdx.y + xby8; y < 128; y += 32) {
st_vec(&lhs_s[y][xo], ld_vec(inp + y * C + so + xo));
st_vec(&rhs_s[y][xo], ld_vec(weight + y * C + so + xo));
}
__syncthreads();

for (int si = si_start; si < si_start + 32; si += 4) {
float4 rhs[8];
for (int u = 0; u < 8; ++u) {
rhs[u] = ld_vec(&rhs_s[u + 8 * threadIdx.y][si % 32]);
}

for (int ii = 0; ii < 8; ++ii) {
float4 lhs = ld_vec(&lhs_s[ii + 8 * threadIdx.x][si % 32]);
for (int ji = 0; ji < 8; ++ji) {
vals[ii][ji] += lhs.x * rhs[ji].x;
vals[ii][ji] += lhs.y * rhs[ji].y;
vals[ii][ji] += lhs.z * rhs[ji].z;
vals[ii][ji] += lhs.w * rhs[ji].w;
}
}
}
}

for (int i = 0; i < 8; ++i) {
for (int j = 0; j < 8; j += 4) {
float4 result;
result.x = vals[i][j + 0];
result.y = vals[i][j + 1];
result.z = vals[i][j + 2];
result.w = vals[i][j + 3];
st_vec(out + (8*threadIdx.x+i) * OC + 8*threadIdx.y + j, result);
}
}
}

// ----------------------------------------------------------------------------
// kernel launcher

Expand Down Expand Up @@ -218,6 +300,21 @@ void matmul_forward3(float* out,
cublasCheck(cublasLtMatrixLayoutDestroy(biasLayout));
}

// handwritten, relatively efficient non-tensorcore matmul kernel
void matmul_forward4(float* out,
const float* inp, const float* weight, const float* bias,
int B, int T, int C, int OC,
int sqrt_block_size) {
// out is (B,T,OC). OC is short for "output channels", e.g. OC = 4 * C
// inp is (B,T,C), weight is (OC, C), bias is (OC)
sqrt_block_size = 16;

dim3 gridDim(ceil_div(B * T, 8*sqrt_block_size), ceil_div(OC, 8*sqrt_block_size));
dim3 blockDim(sqrt_block_size, sqrt_block_size);
matmul_forward_kernel4<<<gridDim, blockDim>>>(out, inp, weight, bias, C, OC);
cudaCheck(cudaGetLastError());
}

// kernel version dispatch
void matmul_forward(int kernel_num,
float* out,
Expand All @@ -234,6 +331,9 @@ void matmul_forward(int kernel_num,
case 3:
matmul_forward3(out, inp, weight, bias, B, T, C, OC);
break;
case 4:
matmul_forward4(out, inp, weight, bias, B, T, C, OC, sqrt_block_size);
break;
default:
printf("Invalid kernel number\n");
exit(1);
Expand All @@ -245,7 +345,7 @@ void matmul_forward(int kernel_num,
int main(int argc, char **argv) {
srand(0);

int B = 8;
int B = 32;
int T = 1024;
int C = 768;
int OC = 768 * 4; // expansion of 4, e.g. in the MLP
Expand Down
2 changes: 1 addition & 1 deletion dev/cuda/trimat_forward.cu
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,7 @@ __device__ void matmul_tri3(float* p, int PS, const float* k, int KS, const floa
}

for (int i = 0; i < 8; ++i) {
// no need to keep lhs around for the i loop, its only reused in the j loop anyway.
// no need to keep lhs around for the i loop, it's only reused in the j loop anyway.
float4 lhs = ld_vec(q + i * QS + hs);
for (int j = 0; j < 8; ++j) {
vals[i][j] += lhs.x * rhs[j].x;
Expand Down
4 changes: 0 additions & 4 deletions test_gpt2_fp32.cu
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,13 @@ int main(int argc, char *argv[]) {

// setup cuBLAS and cuBLASLt
cublasCheck(cublasCreate(&cublas_handle));
cublasCheck(cublasLtCreate(&cublaslt_handle));
// TF32 precision is equivalent to torch.set_float32_matmul_precision('high')
int enable_tf32 = deviceProp.major >= 8 ? 1 : 0;
enable_tf32 = 0; // NOTE: disable TF32 for testing!!!
printf("enable_tf32: %d\n", enable_tf32);
cublas_compute_type = enable_tf32 ? CUBLAS_COMPUTE_32F_FAST_TF32 : CUBLAS_COMPUTE_32F;
cublasMath_t cublas_math_mode = enable_tf32 ? CUBLAS_TF32_TENSOR_OP_MATH : CUBLAS_DEFAULT_MATH;
cublasCheck(cublasSetMathMode(cublas_handle, cublas_math_mode));
cudaCheck(cudaMalloc(&cublaslt_workspace, cublaslt_workspace_size));

// build the GPT-2 model from a checkpoint
GPT2 model;
Expand Down Expand Up @@ -231,9 +229,7 @@ int main(int argc, char *argv[]) {
free(expected_grads_memory);
free(calculated_grads_memory);
gpt2_free(&model);
cudaCheck(cudaFree(cublaslt_workspace));
cublasCheck(cublasDestroy(cublas_handle));
cublasCheck(cublasLtDestroy(cublaslt_handle));

return 0;
}
Loading

0 comments on commit 95cef79

Please sign in to comment.