diff --git a/test_gpt2.cu b/test_gpt2.cu index 1fbae8197..d1d5c64c8 100644 --- a/test_gpt2.cu +++ b/test_gpt2.cu @@ -109,6 +109,13 @@ int main(int argc, char *argv[]) { size_t L = model.config.num_layers; size_t C = model.config.channels; + for (int i = 1; i < argc; i+=2) { + if (i + 1 >= argc) { exit(EXIT_FAILURE); } // must have arg after flag + if (argv[i][0] != '-') { exit(EXIT_FAILURE); } // must start with dash + if (argv[i][1] == 'w') { model.use_master_weights = atoi(argv[i+1]); } + else if (argv[i][1] == 'r') { model.recompute = atoi(argv[i+1]); } + } + // load additional information that we will use for debugging and error checking FILE *state_file = fopenCheck("gpt2_124M_debug_state.bin", "rb"); int state_header[256]; @@ -320,5 +327,5 @@ int main(int argc, char *argv[]) { free(expected_grads_memory); free(grads_memory_cpu); free(grads_memory_cpu_float); - return 0; + return allok ? EXIT_SUCCESS : EXIT_FAILURE; } diff --git a/train_gpt2.cu b/train_gpt2.cu index 1a1293864..286afbb6d 100644 --- a/train_gpt2.cu +++ b/train_gpt2.cu @@ -797,7 +797,7 @@ void gpt2_forward(GPT2 *model, const int* inputs, const int* targets, size_t B, encoder_forward(acts.encoded, model->inputs, params.wte, params.wpe, B, T, C, main_stream); // encoding goes into residual[0] // first layernorm isn't fused - layernorm_forward(acts.ln1, acts.ln1_mean, acts.ln1_rstd, acts.encoded, params.ln1w, params.ln1b, B, T, C, main_stream); + layernorm_forward((model->recompute < 2) ? acts.ln1 : acts.lnf, acts.ln1_mean, acts.ln1_rstd, acts.encoded, params.ln1w, params.ln1b, B, T, C, main_stream); for (int l = 0; l < L; l++) { NvtxRange layer_range("Layer", l);