diff --git a/test_gpt2.cu b/test_gpt2.cu index 690bca9e5..d1d5c64c8 100644 --- a/test_gpt2.cu +++ b/test_gpt2.cu @@ -109,6 +109,13 @@ int main(int argc, char *argv[]) { size_t L = model.config.num_layers; size_t C = model.config.channels; + for (int i = 1; i < argc; i+=2) { + if (i + 1 >= argc) { exit(EXIT_FAILURE); } // must have arg after flag + if (argv[i][0] != '-') { exit(EXIT_FAILURE); } // must start with dash + if (argv[i][1] == 'w') { model.use_master_weights = atoi(argv[i+1]); } + else if (argv[i][1] == 'r') { model.recompute = atoi(argv[i+1]); } + } + // load additional information that we will use for debugging and error checking FILE *state_file = fopenCheck("gpt2_124M_debug_state.bin", "rb"); int state_header[256];