Skip to content

Commit

Permalink
Merge pull request karpathy#523 from ngc92/testability
Browse files Browse the repository at this point in the history
Testability
  • Loading branch information
karpathy authored Jun 9, 2024
2 parents d2482ff + 02e30fa commit 615ec0b
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 2 deletions.
9 changes: 8 additions & 1 deletion test_gpt2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,13 @@ int main(int argc, char *argv[]) {
size_t L = model.config.num_layers;
size_t C = model.config.channels;

for (int i = 1; i < argc; i+=2) {
if (i + 1 >= argc) { exit(EXIT_FAILURE); } // must have arg after flag
if (argv[i][0] != '-') { exit(EXIT_FAILURE); } // must start with dash
if (argv[i][1] == 'w') { model.use_master_weights = atoi(argv[i+1]); }
else if (argv[i][1] == 'r') { model.recompute = atoi(argv[i+1]); }
}

// load additional information that we will use for debugging and error checking
FILE *state_file = fopenCheck("gpt2_124M_debug_state.bin", "rb");
int state_header[256];
Expand Down Expand Up @@ -320,5 +327,5 @@ int main(int argc, char *argv[]) {
free(expected_grads_memory);
free(grads_memory_cpu);
free(grads_memory_cpu_float);
return 0;
return allok ? EXIT_SUCCESS : EXIT_FAILURE;
}
2 changes: 1 addition & 1 deletion train_gpt2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -797,7 +797,7 @@ void gpt2_forward(GPT2 *model, const int* inputs, const int* targets, size_t B,
encoder_forward(acts.encoded, model->inputs, params.wte, params.wpe, B, T, C, main_stream); // encoding goes into residual[0]

// first layernorm isn't fused
layernorm_forward(acts.ln1, acts.ln1_mean, acts.ln1_rstd, acts.encoded, params.ln1w, params.ln1b, B, T, C, main_stream);
layernorm_forward((model->recompute < 2) ? acts.ln1 : acts.lnf, acts.ln1_mean, acts.ln1_rstd, acts.encoded, params.ln1w, params.ln1b, B, T, C, main_stream);

for (int l = 0; l < L; l++) {
NvtxRange layer_range("Layer", l);
Expand Down

0 comments on commit 615ec0b

Please sign in to comment.