[Not for merge] fp8allgather debug #1147
Open
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
With @jspark1105 's commits enabling FP8 allgather, we can run
test_te.py
and also local training without PP.However, if enabling PP, there are some issues with FP8 allgather that need to be fixed. This diff copies the changes from @jspark1105 's PR and includes the fixes we need in fairscale.
The fixes we need are as following (most fixes are naive and need better implementations):
When run
model_chunk._rebuild_full_params_recursive()
inxlformers/src/model_parallel_core/pipeline_parallel/fwd_bwd_schedules.py
, we need to pass the FP8 training related settings into the context. All changes in xlformers are included in this commit.In
TransformerEngine
, we don't need to return the weight gradients for FP8 training since the gradients will be accumulated in.main_grad
. All changes in TE are included in this commit.The
FlattenParamsWrapper
creates the view of the parameters every forward pass. It is unnecessary if we are not doing resharding after forward. Also, it creates a problem for FP8 allgather + PP because we create.main_grad
in the beginning of the forward, and we can only access the last view of parameters. The earlier views of parameters are no longer accessable.We should not free the
. _free_fp16_param_shard
in the_post_backward_hook
. The FP16 shard needs to be kept since each backward pass needs to use it.