Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add main grad before fwd pass #1142

Draft
wants to merge 5 commits into
base: ngoyal_changes_for_pp_fp8
Choose a base branch
from

Conversation

vedanuj
Copy link

@vedanuj vedanuj commented Oct 4, 2023

Adds main_grad before FWD pass to FlatParameter

to be used with https://github.com/fairinternal/xlformers/pull/1418

@vedanuj vedanuj changed the base branch from main to ngoyal_changes_for_pp_fp8 October 4, 2023 15:32
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 4, 2023
@vedanuj vedanuj requested review from jspark1105, tmarkstrum, ngoyal2707, awgu and jianyuh and removed request for tmarkstrum October 4, 2023 15:32
@vedanuj vedanuj mentioned this pull request Oct 4, 2023
10 tasks
assert param.grad is not None, param.shape
if param.grad.requires_grad:
raise RuntimeError("FSDP only works with gradients that don't require gradients")
# assert param.grad is not None, param.shape
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps some check is needed to make sure parameters are not shared (as would be the case with weights tying)?

param.grad = None
if param.main_grad is not None:
grad = param.main_grad
param.main_grad = None
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't .main_grad need to be restored somewhere before next forward?

@awgu
Copy link

awgu commented Oct 4, 2023

If we construct flat_param.main_grad before forward and set individual param.main_grad as view into flat_param.main_grad before forward, then we hold the flat_param.main_grad in memory for all FSDP instances going into backward, which may increase peak memory.

An alternative option is to construct flat_param.main_grad (as zeros or empty depending on if TE adds or copies to the memory) in the pre-backward hook and separately set param.main_grad as views into flat_param.main_grad only in the pre-backward hook.

def _pre_backward_hook(*unused: Any) -> None:

One option could be to add this logic to _prep_grads_for_backward():

param.grad.data = param.grad.data.float()
if param.grad is not None:
if param.main_grad is not None:
param.main_grad.copy_(param.grad.float())
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: torch can upcast and copy in one kernel:

Suggested change
param.main_grad.copy_(param.grad.float())
param.main_grad.copy_(param.grad)
Existing

Upcast kernel + copy kernel
Screenshot 2023-10-04 at 12 27 52 PM

New

Only upcast kernel

Screenshot 2023-10-04 at 12 29 32 PM

Correctness Example
>>> t_fp32 = torch.empty((4,))
>>> t_bf16 = torch.randn((4,), dtype=torch.bfloat16)
>>> t_fp32
tensor([-8.3762e-20,  3.0801e-41, -1.3043e-16,  3.0801e-41])
>>> t_bf16
tensor([-1.3516, -0.5156, -0.6055,  0.3535], dtype=torch.bfloat16)
>>> t_fp32.copy_(t_bf16)
tensor([-1.3516, -0.5156, -0.6055,  0.3535])
>>> t_fp32
tensor([-1.3516, -0.5156, -0.6055,  0.3535])

@vedanuj
Copy link
Author

vedanuj commented Oct 5, 2023

@awgu It seems from my testing that the changes are still necessary in FlattenParamsWrapper otherwise it complains that .main_grad is not there for parameter.

@jspark1105 I have borrowed some changes from your PR #1136 to update the view when reallocating the zero buffers for main_grad.

@@ -1721,35 +1722,48 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None:
# reductions in post_backward stream.
self._streams["post_backward"].wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self._streams["post_backward"]):
orig_grad_data = param.grad.data
if param.main_grad is not None and not param.main_grad.eq(0).all():
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we concerned that this param.main_grad.eq(0).all() might be a CPU sync? Perhaps, it is not so much a concern if we already have CPU syncs for rate limiting FSDP.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there another way I can check if main_grad is non zero without doing a CPU sync?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are checking if this is all zeros to skip modules that didn't use main_grad?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes .. because all parameters have .main_grad, so not sure how to make sure we are not using the ones that do not have the grads stored in .main_grad

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants