-
Notifications
You must be signed in to change notification settings - Fork 282
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
added option to do backward AG over smaller set of gpus instead of full DDP world #1125
base: ngoyal_bf16_changes
Are you sure you want to change the base?
added option to do backward AG over smaller set of gpus instead of full DDP world #1125
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This makes sense to me!
# Cases for when zero2 world size > 1 but less than zero3 size | ||
zero2_world_size = dist.get_world_size(self.zero2_process_group) | ||
zero2_rank = dist.get_rank(self.zero2_process_group) | ||
chunks = p._full_param_padded.chunk(zero2_world_size) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just want to mention that there is a divisibility assumption here (ZeRO-2 world size divides the ZeRO-3 world size), which should always hold in practice.
if wait_for_all_gather: | ||
torch.cuda.current_stream().wait_stream(self._streams["all_gather"]) | ||
return output_tensors | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like the only difference compared to _rebuild_full_params()
is no SSD offload, no CPU offload, and using p._zero2_fp16_shard
, self.zero2_process_group
, and self._free_zero2_param_shard()
-- this makes sense to me.
# free it until the work in the current stream completes. | ||
p._zero2_fp16_shard.record_stream(current_stream) | ||
free_storage_(p._zero2_fp16_shard) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like _zero2_fp16_shard
is allocated in the default stream (since _zero2_shard_to_smaller_group()
is called from forward()
without an explicit stream context manager):
fairscale/fairscale/nn/data_parallel/fully_sharded_data_parallel.py
Lines 1428 to 1430 in 0b77de4
if self.reshard_after_forward: | |
if self.zero2_process_group is not None: | |
self._zero2_shard_to_smaller_group() |
_zero2_fp16_shard
is consumed in the "all_gather"
stream, and this _free_zero2_param_shard()
is called from that "all_gather"
stream as well:fairscale/fairscale/nn/data_parallel/fully_sharded_data_parallel.py
Lines 2121 to 2136 in 0b77de4
# Fill output_tensor with (p.data for each shard in self.world_size) | |
if hasattr(dist, "_all_gather_base") and enable_nccl_base_collectives: | |
# New version of PyTorch has all_gather_base, which is faster than chunk and then all_gather. | |
dist._all_gather_base(output_tensor, p._zero2_fp16_shard , group=self.zero2_process_group) | |
else: | |
chunks = list(output_tensor.chunk(self.world_size)) | |
dist.all_gather(chunks, p._zero2_fp16_shard, group=self.zero2_process_group) | |
# Set p.data = output_tensor (with padding trimmed) | |
update_p_data(output_tensor) | |
if (self.mixed_precision or self.move_params_to_cpu) and not force_full_precision: | |
self._free_zero2_param_shard([p]) | |
if self.move_params_to_cpu and (self.params[0].dtype == self.compute_dtype): | |
self._free_zero2_param_shard([p]) |
In that case, I do agree this p._zero2_fp16_shard.record_stream(current_stream)
call is necessary to notify the caching allocator of the usage in the "all_gather"
stream. However, I think the comment can be changed to say that it was allocated in the default stream. Alternatively, you can do something like _cast_fp32_param_shards_to_fp16()
, but I am not sure if there is any actual overlap opportunity given the data dependencies.
fairscale/fairscale/nn/data_parallel/fully_sharded_data_parallel.py
Lines 2381 to 2391 in 0b77de4
with torch.cuda.stream(self._streams["fp32_to_fp16"]): | |
for p in params: | |
assert p._fp16_shard is not None | |
alloc_storage_(p._fp16_shard, size=p._fp32_shard.size()) | |
p._fp16_shard.copy_( | |
# If move_params_to_cpu is True, this will be non-blocking | |
# because _fp32_shard is pinned, otherwise it's a no-op. | |
p._fp32_shard.to(p._fp16_shard.device, non_blocking=True) | |
) | |
p.data = p._fp16_shard | |
torch.cuda.current_stream().wait_stream(self._streams["fp32_to_fp16"]) |
One more comment: I am not familiar with how the model checkpointing works in Fairscale FSDP, but one concern might be what happens if a user tries to checkpoint the model only after forward (and backward has not run yet). Will this work out of the box? Perhaps, this is not a major concern for your use case since may no one will save a state dict after only forward. |
No description provided.