Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Not for merge] fp8allgather debug #1147

Open
wants to merge 1 commit into
base: ngoyal_changes_for_pp_fp8
Choose a base branch
from

Conversation

jiecaoyu
Copy link

@jiecaoyu jiecaoyu commented Oct 29, 2023

With @jspark1105 's commits enabling FP8 allgather, we can run test_te.py and also local training without PP.

However, if enabling PP, there are some issues with FP8 allgather that need to be fixed. This diff copies the changes from @jspark1105 's PR and includes the fixes we need in fairscale.

The fixes we need are as following (most fixes are naive and need better implementations):

  • When run model_chunk._rebuild_full_params_recursive() in xlformers/src/model_parallel_core/pipeline_parallel/fwd_bwd_schedules.py, we need to pass the FP8 training related settings into the context. All changes in xlformers are included in this commit.
    image

  • In TransformerEngine, we don't need to return the weight gradients for FP8 training since the gradients will be accumulated in .main_grad. All changes in TE are included in this commit.
    image

  • The FlattenParamsWrapper creates the view of the parameters every forward pass. It is unnecessary if we are not doing resharding after forward. Also, it creates a problem for FP8 allgather + PP because we create .main_grad in the beginning of the forward, and we can only access the last view of parameters. The earlier views of parameters are no longer accessable.
    image

  • We should not free the . _free_fp16_param_shard in the _post_backward_hook. The FP16 shard needs to be kept since each backward pass needs to use it.
    image

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 29, 2023
@jiecaoyu jiecaoyu changed the base branch from main to ngoyal_changes_for_pp_fp8 October 29, 2023 05:43
FP8GlobalStateManager.copy_amax_from_global_buffer(
m.fp8_meta, forward=True
)
# FIXME update_weight_scale_inv is only True for the first micro-batch

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed it's OK to update weight_scale_inv multiple times but it can be still annoying if we see numerical differences btw no-PP and PP (actually just micro-batching). Actually I wonder we can check is_first_microbatch in kwargs to skip this.

@jiecaoyu jiecaoyu force-pushed the jiecaoyu_fp8allgather_debug branch from fb0b563 to 8bebf15 Compare November 9, 2023 05:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants