Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FSDPv1] Only perform cat() during last microbatch backward() within FlattenParamsWrapper #1180

Draft
wants to merge 27 commits into
base: ngoyal_changes_for_pp_fp8
Choose a base branch
from

Conversation

chrisxcai
Copy link

@chrisxcai chrisxcai commented Apr 29, 2024

If optimize_backward_concat is set to be True, only let the backward() pass propagate to FSDP.flat_params, which will
invoke the FSDP. _post_backward_hook() and concat() op, when FSDP._require_backward_grad_sync
is True (e.g. last microbatch)

Trace comparison

trace before change (SplitWithSizesBackward triggered every microbatch per FSDP module):
https://fburl.com/perfdoctor/qdt32ibh

trace with applied change (SplitWithSizesBackward triggered only in last microbatch per FSDP module):
https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree/traces/trace.229652302632210.json.gz&bucket=acadia

numerics verification

local run with deterministic mode

TP=2, PP=2, num_layers_per_virtual_pipeline_stage=4, 8 GPUs, batch_size 2, DP = 2, fp8 (no 1F1B) (loss bitwise on par)

baseline

NVTE_DISABLE_NVRTC=1 CUDA_LAUNCH_BLOCKING=1 PYTHONPATH=~/benchmark/fairscale_repos/fairscale/ CRYPTOGRAPHY_OPENSSL_NO_LEGACY=1 torchrun --master_port 1024 --nproc_per_node=8 train.py --dump_dir /tmp/chriscai/xldumps --model_parallel_size 2 --pipeline_parallel_size 2 --num_layers_per_virtual_pipeline_stage=4  --seq_len=1024 --gpu_check_level=-1 --steps=10 --log_all_steps=True --profile_freq=10 --dump_profile_traces=True --profile_with_stack=True --model.n_layers=8 --reshard_after_forward=False --batch_size=4 --model.efficient_attn=cutlass --model.attn_bias_type=causal --model.layer_ckpt=none --model=small --model.sequence_parallel=True --mem_snapshot_stop_step 5 --log_all_steps=True --enable_deterministic_training=True --log_freq=1 --model.use_te_layers=True --optim.use_fp32_copy_optim=True --model.benchmark_perf=False --model.use_fp8=True --model.fp8_wgrad=True --optimize_backward_concat=False

https://www.internalfb.com/intern/paste/P1363180533/

test

NVTE_DISABLE_NVRTC=1 CUDA_LAUNCH_BLOCKING=1 PYTHONPATH=~/benchmark/fairscale_repos/fairscale/ CRYPTOGRAPHY_OPENSSL_NO_LEGACY=1 torchrun --master_port 1024 --nproc_per_node=8 train.py --dump_dir /tmp/chriscai/xldumps --model_parallel_size 2 --pipeline_parallel_size 2 --num_layers_per_virtual_pipeline_stage=4  --seq_len=1024 --gpu_check_level=-1 --steps=10 --log_all_steps=True --profile_freq=10 --dump_profile_traces=True --profile_with_stack=True --model.n_layers=8 --reshard_after_forward=False --batch_size=4 --model.efficient_attn=cutlass --model.attn_bias_type=causal --model.layer_ckpt=none --model=small --model.sequence_parallel=True --mem_snapshot_stop_step 5 --log_all_steps=True --enable_deterministic_training=True --log_freq=1 --model.use_te_layers=True --optim.use_fp32_copy_optim=True --model.benchmark_perf=False --model.use_fp8=True --model.fp8_wgrad=True --optimize_backward_concat=True

https://www.internalfb.com/intern/paste/P1363177870/

TP=2, GPU=8, DP = 4, BF16, non-PP microbatching (loss bitwise on par)

baseline:
https://www.internalfb.com/intern/paste/P1322976356/
test :
https://www.internalfb.com/intern/paste/P1322871976/

TP=2, PP=2, num_layers_per_virtual_pipeline_stage=4, 8 GPUs, batch_size 2, DP = 2, BF16 (no 1F1B) (loss bitwise on par)

baseline
https://www.internalfb.com/intern/paste/P1358660231/

test
https://www.internalfb.com/intern/paste/P1358659328/

TP=2, PP=2, num_layers_per_virtual_pipeline_stage=4, 8 GPUs, batch_size 4, DP = 2, BF16 (1F1B) (loss bitwise on par)

baseline
https://www.internalfb.com/intern/paste/P1358780690

test
https://www.internalfb.com/intern/paste/P1358786994/

E2E MAST tests:

model = small, TP = 2, PP = 2, DP = 2 (loss on par)

baseline:
https://www.internalfb.com/mlhub/pipelines/runs/mast/conda-xlformers-tl66r0qd

test:
https://www.internalfb.com/mlhub/pipelines/runs/mast/conda-xlformers-km46966

loss

Perf evaluation

model= llama3_kv8_balance2_ffn12, n_layers = 1, non-PP microbatching, bs = 128, fp8, TP 4, CP = 8

baseline:
e2e TFLOPS/s: 339.53
comp TFLOPS/s: 625.64

https://www.internalfb.com/mlhub/pipelines/runs/mast/conda-xlformers-f7cdn9q
trace: https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree/traces/trace.27299292624533.json.gz&bucket=acadia

test:
e2e TFLOPS/s: 387.98 (~15%)
comp TFLOPS/s: 817.5 (~30%)

https://www.internalfb.com/mlhub/pipelines/runs/mast/conda-xlformers-t56xpf
trace: https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree/traces/trace.71951644521316.json.gz&bucket=acadia

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 29, 2024
@chrisxcai chrisxcai requested review from GD06 and yuchenhao April 29, 2024 22:36
@chrisxcai chrisxcai changed the title [WIP] Make FSDPv1 only perform cat() during last microbatch backward() within FlattenParamsWrapper [FSDPv1] Only perform cat() during last microbatch backward() within FlattenParamsWrapper May 15, 2024
@chrisxcai chrisxcai requested a review from awgu May 15, 2024 07:46
Copy link

@awgu awgu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This approach makes sense to me!

If True, only let backward pass propagate to self.params, which will
invoke the _post_backward_hook() and concat() op, when self._require_backward_grad_sync
is True (e.g. last microbatch)
NOTE: this likely will incur more GPU memory usage
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain why there will be more GPU memory usage?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hi @awgu, currently by testing results it shows the GPU memory overhead could be non-trivial (20% of 80G), we will follow up on reducing the memory usage
Screenshot 2024-05-15 at 10 39 19 AM
Screenshot 2024-05-15 at 10 40 18 AM

if self.fp32_grads[param_index] is None:
self.fp32_grads[param_index] = grad.to(torch.float32)
else:
self.fp32_grads[param_index].add_(grad.data)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I think grad.data can just be grad (save one aten.detach call)

@chrisxcai chrisxcai changed the base branch from ngoyal_changes_for_pp_fp8 to ngoyal_changes_for_pp_fp8_jiecaoyu_free_fp16_shard May 15, 2024 21:51
awgu and others added 4 commits May 15, 2024 15:17
* Changed to only run reshard hook if all gradients computed

* Fix decreasing it/s with multi-grad hook
@chrisxcai chrisxcai changed the base branch from ngoyal_changes_for_pp_fp8_jiecaoyu_free_fp16_shard to ngoyal_changes_for_pp_fp8 May 15, 2024 23:32
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants