Skip to content

Commit

Permalink
delete copy paste code for gpt2 model init
Browse files Browse the repository at this point in the history
  • Loading branch information
karpathy committed Jun 3, 2024
1 parent c827dd2 commit 08fc45b
Showing 1 changed file with 34 additions and 38 deletions.
72 changes: 34 additions & 38 deletions train_gpt2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2151,13 +2151,43 @@ typedef struct {
floatX* cpu_losses; // CPU buffer to copy the losses to, allocated with cudaMallocHost
float* cpu_losses_fp32; // same but fp32
unsigned long long rng_state; // the RNG state for seeding stochastic rounding etc.
int use_master_weights;
int recompute;
int use_master_weights; // keep master weights copy in float for optim update? 0|1
int recompute; // recompute gelu | layernorm forward during model backward? 0|1|2
// todo - if other functions need cpu scratch buffers in the future, reuse as generic scratch?
int* workload_indices; // encoder_backward, B*T*num_c_groups (int)
int4* bucket_info; // encoder_backward, B*T*num_c_groups (int4) - size for worst case
} GPT2;

void gpt2_init_common(GPT2 *model) {
// common inits outside of the model weights
// the weights are initialized either in:
// - gpt2_build_from_checkpoint() if loading from a checkpoint
// - gpt2_build_from_random() if starting from scratch
// memory lazily initialized in forward()
model->acts_memory = NULL;
model->inputs = NULL;
model->targets = NULL;
model->cpu_losses = NULL;
model->cpu_losses_fp32 = NULL;
// the B,T params are determined and set, fixed on first batch in forward()
model->batch_size = 0;
model->seq_len = 0;
model->mean_loss = -1.0f; // -1.0f designates no loss, set at end of forward()
// memory lazily initialized in backward()
model->grads_memory = NULL;
model->grads_acts_memory = NULL;
model->workload_indices = NULL; // on cpu, for encoder_backward
model->bucket_info = NULL; // on cpu, for encoder_backward
// memory lazily initialized in update()
model->m_memory = NULL;
model->v_memory = NULL;
model->master_weights = NULL;
// other default settings
model->rng_state = 13371337; // used in stochastic rounding
model->use_master_weights = 1; // safe default: do keep master weights in fp32
model->recompute = 1; // good default: recompute gelu but not layernorm
}

void gpt2_write_to_checkpoint(GPT2 *model, const char* checkpoint_path) {
// write the model to a checkpoint file
printf0("Writing model to %s\n", checkpoint_path);
Expand Down Expand Up @@ -2247,25 +2277,7 @@ void gpt2_build_from_checkpoint(GPT2 *model, const char* checkpoint_path) {
free(params_memory_cpu);
fcloseCheck(model_file);

// other inits
model->acts_memory = NULL;
model->grads_memory = NULL;
model->m_memory = NULL;
model->v_memory = NULL;
model->master_weights = NULL;
model->grads_acts_memory = NULL;
model->inputs = NULL;
model->targets = NULL;
model->cpu_losses = NULL;
model->cpu_losses_fp32 = NULL;
model->workload_indices = NULL;
model->bucket_info = NULL;
model->batch_size = 0;
model->seq_len = 0;
model->mean_loss = -1.0f; // -1.0f will designate no loss
model->rng_state = 13371337;
model->use_master_weights = 1; // keep master weights copy in float for optim update?
model->recompute = 1; // default to recompute gelu during backward
gpt2_init_common(model);
}

void gpt2_build_from_random(GPT2 *model, int depth) {
Expand Down Expand Up @@ -2354,23 +2366,7 @@ void gpt2_build_from_random(GPT2 *model, int depth) {
cudaCheck(cudaMemcpy(model->params_memory, params_memory_cpu, model->num_parameters_bytes, cudaMemcpyHostToDevice));
free(params_memory_cpu);

// other inits and defaults
model->acts_memory = NULL;
model->grads_memory = NULL;
model->m_memory = NULL;
model->v_memory = NULL;
model->master_weights = NULL;
model->grads_acts_memory = NULL;
model->inputs = NULL;
model->targets = NULL;
model->cpu_losses = NULL;
model->cpu_losses_fp32 = NULL;
model->batch_size = 0;
model->seq_len = 0;
model->mean_loss = -1.0f; // -1.0f designates no loss
model->rng_state = 13371337;
model->use_master_weights = 1; // keep master weights copy in float for optim update?
model->recompute = 1; // default to recompute gelu during backward
gpt2_init_common(model);
}

void gpt2_forward(GPT2 *model, int* inputs, int* targets, size_t B, size_t T, int grad_accum_steps=1) {
Expand Down

0 comments on commit 08fc45b

Please sign in to comment.