Skip to content

Commit

Permalink
bugfix for LN recomputation
Browse files Browse the repository at this point in the history
  • Loading branch information
ngc92 committed Jun 8, 2024
1 parent 8fbfba2 commit 02e30fa
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion train_gpt2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -797,7 +797,7 @@ void gpt2_forward(GPT2 *model, const int* inputs, const int* targets, size_t B,
encoder_forward(acts.encoded, model->inputs, params.wte, params.wpe, B, T, C, main_stream); // encoding goes into residual[0]

// first layernorm isn't fused
layernorm_forward(acts.ln1, acts.ln1_mean, acts.ln1_rstd, acts.encoded, params.ln1w, params.ln1b, B, T, C, main_stream);
layernorm_forward((model->recompute < 2) ? acts.ln1 : acts.lnf, acts.ln1_mean, acts.ln1_rstd, acts.encoded, params.ln1w, params.ln1b, B, T, C, main_stream);

for (int l = 0; l < L; l++) {
NvtxRange layer_range("Layer", l);
Expand Down

0 comments on commit 02e30fa

Please sign in to comment.