-
Notifications
You must be signed in to change notification settings - Fork 282
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FSDPv1] Only perform cat() during last microbatch backward() within FlattenParamsWrapper #1180
base: ngoyal_changes_for_pp_fp8
Are you sure you want to change the base?
[FSDPv1] Only perform cat() during last microbatch backward() within FlattenParamsWrapper #1180
Conversation
…t for last microbatch
… flatten_parameter.unsharded_main_grad in last microbatch backward()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This approach makes sense to me!
If True, only let backward pass propagate to self.params, which will | ||
invoke the _post_backward_hook() and concat() op, when self._require_backward_grad_sync | ||
is True (e.g. last microbatch) | ||
NOTE: this likely will incur more GPU memory usage |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you explain why there will be more GPU memory usage?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hi @awgu, currently by testing results it shows the GPU memory overhead could be non-trivial (20% of 80G), we will follow up on reducing the memory usage
if self.fp32_grads[param_index] is None: | ||
self.fp32_grads[param_index] = grad.to(torch.float32) | ||
else: | ||
self.fp32_grads[param_index].add_(grad.data) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I think grad.data
can just be grad
(save one aten.detach
call)
* Changed to only run reshard hook if all gradients computed * Fix decreasing it/s with multi-grad hook
Co-authored-by: Jie Wang <[email protected]>
If optimize_backward_concat is set to be True, only let the backward() pass propagate to FSDP.flat_params, which will
invoke the FSDP. _post_backward_hook() and concat() op, when FSDP._require_backward_grad_sync
is True (e.g. last microbatch)
Trace comparison
trace before change (SplitWithSizesBackward triggered every microbatch per FSDP module):
https://fburl.com/perfdoctor/qdt32ibh
trace with applied change (SplitWithSizesBackward triggered only in last microbatch per FSDP module):
https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree/traces/trace.229652302632210.json.gz&bucket=acadia
numerics verification
local run with deterministic mode
TP=2, PP=2, num_layers_per_virtual_pipeline_stage=4, 8 GPUs, batch_size 2, DP = 2, fp8 (no 1F1B) (loss bitwise on par)
baseline
https://www.internalfb.com/intern/paste/P1363180533/
test
https://www.internalfb.com/intern/paste/P1363177870/
TP=2, GPU=8, DP = 4, BF16, non-PP microbatching (loss bitwise on par)
baseline:
https://www.internalfb.com/intern/paste/P1322976356/
test :
https://www.internalfb.com/intern/paste/P1322871976/
TP=2, PP=2, num_layers_per_virtual_pipeline_stage=4, 8 GPUs, batch_size 2, DP = 2, BF16 (no 1F1B) (loss bitwise on par)
baseline
https://www.internalfb.com/intern/paste/P1358660231/
test
https://www.internalfb.com/intern/paste/P1358659328/
TP=2, PP=2, num_layers_per_virtual_pipeline_stage=4, 8 GPUs, batch_size 4, DP = 2, BF16 (1F1B) (loss bitwise on par)
baseline
https://www.internalfb.com/intern/paste/P1358780690
test
https://www.internalfb.com/intern/paste/P1358786994/
E2E MAST tests:
model = small, TP = 2, PP = 2, DP = 2 (loss on par)
baseline:
https://www.internalfb.com/mlhub/pipelines/runs/mast/conda-xlformers-tl66r0qd
test:
https://www.internalfb.com/mlhub/pipelines/runs/mast/conda-xlformers-km46966
Perf evaluation
model= llama3_kv8_balance2_ffn12, n_layers = 1, non-PP microbatching, bs = 128, fp8, TP 4, CP = 8
baseline:
e2e TFLOPS/s: 339.53
comp TFLOPS/s: 625.64
https://www.internalfb.com/mlhub/pipelines/runs/mast/conda-xlformers-f7cdn9q
trace: https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree/traces/trace.27299292624533.json.gz&bucket=acadia
test:
e2e TFLOPS/s: 387.98 (~15%)
comp TFLOPS/s: 817.5 (~30%)
https://www.internalfb.com/mlhub/pipelines/runs/mast/conda-xlformers-t56xpf
trace: https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree/traces/trace.71951644521316.json.gz&bucket=acadia