Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compared with the native Torch cross-entropy, the gradient differences of the classifier are very large. #14

Open
BangguWu opened this issue Dec 17, 2024 · 4 comments

Comments

@BangguWu
Copy link

BangguWu commented Dec 17, 2024

I have implemented a toy code

import torch
import copy

from cut_cross_entropy import linear_cross_entropy
torch.manual_seed(42)

model = torch.nn.Linear(128, 128, dtype=torch.bfloat16).cuda()
classifier_weights = torch.nn.Linear(128, 32, dtype=torch.bfloat16).cuda()

model_c = copy.deepcopy(model)
classifier_weights_c = copy.deepcopy(classifier_weights)

input = torch.randn(8192, 128, dtype=torch.bfloat16).cuda()
input.requires_grad = True
input.retain_grad()

input_c = input.clone()

labels = torch.randint(0, 32, (8192,)).cuda()
labels_c = labels.clone()

embeddings = model(input)

shift_embeddings = embeddings[:-1, :]
shift_labels = labels[1:]

manual_shift_loss = linear_cross_entropy(shift_embeddings, classifier_weights.weight, shift_labels)
manual_shift_loss.backward()

embeddings_c = model_c(input_c)
logits = classifier_weights_c(embeddings_c)
shift_logits_c = logits[:-1, :]
shift_labels_c = labels_c[1:]
manual_shift_loss_c = torch.nn.functional.cross_entropy(shift_logits_c, shift_labels_c)
manual_shift_loss_c.backward()

print(f"model grad: {model.weight.grad.mean()}")
print(f"model_c grad: {model_c.weight.grad.mean()}")

print(f"classifier_weights grad: {classifier_weights.weight.grad.mean()}")
print(f"classifier_weights_c grad: {classifier_weights_c.weight.grad.mean()}")

and the output is:

model grad: -4.589557647705078e-06
model_c grad: -4.76837158203125e-06
classifier_weights grad: 2.2631138563156128e-07
classifier_weights_c grad: -4.307366907596588e-08

the gradient of classifier looks like very large.

Also I have tried to train a llm model using gpt2 arch, the loss gap is about 0.06 when training 100B tokens.

any wrong usage is there?

@zhixuan-lin
Copy link

Probably related to this, I observe that the loss increases during training using linear_cross_entropy with long-context training, while some other non-fused implementation works fine. So there is very likely something wrong with the implementation of linear_cross_entropy (probably numerical precision).

@zhixuan-lin
Copy link

My guess is dc (here, the accumulator for the gradients of the linear layer weights) is kept in low precision. So when the sequence is long there could be huge numerical errors. This is also mentioned in the flash-linear-attention repository (see this commit).

@erikwijmans
Copy link
Collaborator

This is something we've noticed too, specifically of training models from scratch.

@BangguWu is the loss gap on train or val? I have seldom seen train loss be different (except in the case of triton bugs that we haven't worked around yet), but I have seen val loss be different when training from scratch and the validation set has tokens that aren't present in the train set.

@erikwijmans
Copy link
Collaborator

We have been working on some updates on this branch. You can it via pip install "cut-cross-entropy @ git+https://github.com/apple/ml-cross-entropy.git@ewijmans/kahan-updates".

That adds two options to linear_cross_entropy and patch.

  1. An impl="cce_exact" option that should match PyTorch exactly. That implementation is slower as it disables gradient filtering, but in our testing it isn't significantly slower in full model training and can be faster due the ability to increase the batch size.
  2. A use_kahan flag that uses Kahan summation to deal with dc and de being kept in lower precision. This uses a bit more memory for temporary buffers. Setting impl="cce_exact" will automatically set use_kahan=True.

@BangguWu @zhixuan-lin If you are up for some beta testing, feel free to try these out and let me know how it goes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants