Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question about Ulysses and loss agregation #6841

Open
pavelgein opened this issue Dec 9, 2024 · 7 comments
Open

Question about Ulysses and loss agregation #6841

pavelgein opened this issue Dec 9, 2024 · 7 comments

Comments

@pavelgein
Copy link

HI,
I am using Ulysses Attention and DeepSpeed Zero3 optimizer for DPO training.
My question what is the right way of loss aggregation?
When one trains model with CrossEntropyLoss, each rank yields loss for each own sequence subset.
But when on trains model with DPO loss, we have only loss for the whole example.
What is right way to deal with it?

@samadejacobs
Copy link
Contributor

Please take a look at sequence-parallel-aware cross entropy here.

@pavelgein
Copy link
Author

Yes, I have seen this before. As far as I understand, in this approach on each rank all the logits are stored and on the backward pass only required parts of gradients are taken into account.

I was trying to use approach when some reduction was done before communication inside sequence parallel group (I was trying to reduce the communication load), I am going to try to create something in that way

@pavelgein
Copy link
Author

I see gradient difference.
I have two ranks in the sequence parallel group (SPG), compute gradients of the loss and slicing it with respect to rank in SPG, therefore I have two gradients g_1 and g_2, one on each rank.

When I run the same setup without sequence parallelism, I have gradient g, and for all layers I have approximately g == g_1 + g_2.
So my question is whether DeepSpeed Zero3 optimizer handle this case correctly or it just take the average across all ranks?

I see the difference of gradient norm as well, I have around 3.24 on each rank with sequence parallelism and around 9.11 without sequence parallelism

@ronald-d-rogers
Copy link

@pavelgein, it does not do either. The way it is currently implemented it just returns the loss without any reduction. You are meant to do a loss.mean() or a loss.sum() yourself.

I am going to start testing the loss outputs soon and was wondering if this fixes your issue.

@pavelgein
Copy link
Author

@ronald-d-rogers I think, I didn't used the rights words in my question.

When we do DDP, we split the dataset across workers, compute gradient of outputs with respect input on every worker, and then we take the average of gradients across all the workers. Since output on each worker depends only on its input, it gives as a gradient estimation.

Now, when we use sequence parallelism, independence worker output from input of other worker does not hold.
If we consider the sequence parallel group as one worker, then we should take the sum of gradients inside this group, not the average.

@pavelgein
Copy link
Author

Here Zero3 optimizer reduce the number of workers by factor of sequence parallel group size

buffer_to_reduce.div_(world_sz / float(self.sequence_parallel_size))

@ronald-d-rogers
Copy link

ronald-d-rogers commented Jan 2, 2025

Ah I think I understand now. I actually tried the same thing as you did -- do the reduction in the sequence parallel group -- but gave up. I tried to modify the method he provided (VocabSequenceParallelCrossEntropy) to accept reduction as an arg, and both passing it in to nll_loss or returning loss_all.sum()/loss_all.mean() at the end. I think it is possible, but ran into issues.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants