From a93b025cba72401b5be5f4a8bca98c1d83ac5227 Mon Sep 17 00:00:00 2001 From: Junteng Jia Date: Mon, 6 Nov 2023 23:18:09 -0800 Subject: [PATCH 1/3] Extend CheckpointFunction to track all tensor input/output --- .../nn/checkpoint/checkpoint_activations.py | 127 +++++++++++++----- 1 file changed, 90 insertions(+), 37 deletions(-) diff --git a/fairscale/nn/checkpoint/checkpoint_activations.py b/fairscale/nn/checkpoint/checkpoint_activations.py index 520747f36..923832e1e 100644 --- a/fairscale/nn/checkpoint/checkpoint_activations.py +++ b/fairscale/nn/checkpoint/checkpoint_activations.py @@ -161,6 +161,62 @@ def checkpoint_wrapper( return module +def dfs_simplified(entity): + if isinstance(entity, tuple): + return tuple(dfs_simplified(value) for value in entity) + elif isinstance(entity, list): + return [dfs_simplified(value) for value in entity] + elif isinstance(entity, dict): + return {key: dfs_simplified(value) for key, value in entity.items()} + elif isinstance(entity, torch.Tensor): + return entity.shape + else: + return entity + + +SimpleEntity = collections.namedtuple("SimpleEntity", ["is_tensor", "value"]) + + +def serialize_tensors(inputs: Any) -> Tuple[Tuple[torch.Tensor], Any]: + tensors = [] + + def dfs(entity): + if isinstance(entity, tuple): + return tuple(dfs(value) for value in entity) + elif isinstance(entity, list): + return [dfs(value) for value in entity] + elif isinstance(entity, dict): + return {key: dfs(value) for key, value in entity.items()} + elif isinstance(entity, torch.Tensor): + tensors.append(entity) + return SimpleEntity(True, len(tensors)-1) + else: + return SimpleEntity(False, entity) + + non_tensors = dfs(inputs) + + return tuple(tensors), non_tensors + + +def deserialize_tensors(tensors: Tuple[torch.Tensor], non_tensors: Any) -> Any: + def dfs(entity): + if isinstance(entity, SimpleEntity): + if entity.is_tensor: + return tensors[entity.value] + else: + return entity.value + elif isinstance(entity, tuple): + return tuple(dfs(value) for value in entity) + elif isinstance(entity, list): + return [dfs(value) for value in entity] + elif isinstance(entity, dict): + return {key: dfs(value) for key, value in entity.items()} + else: + raise RuntimeError(f"Unexpected type {type(entity)}") + + return dfs(non_tensors) + + def _checkpointed_forward( original_forward: Any, weak_self: Any, offload_to_cpu: bool, *args: Any, **kwargs: Any ) -> Any: @@ -173,8 +229,8 @@ def _checkpointed_forward( # Autograd Functions in PyTorch work best with positional args, since # the backward must return gradients (or None) for every input argument. # We can flatten keyword arguments to make this easier. - args = (module,) + args - kwarg_keys, flat_args = pack_kwargs(*args, **kwargs) + tensor_inputs, non_tensor_inputs = serialize_tensors((module, args, kwargs)) + parent_ctx_dict: Dict[str, Any] = { "offload": offload_to_cpu, } @@ -189,7 +245,7 @@ def _checkpointed_forward( # We get around this by saving the desired requires_grad value in output and # detaching the output if needed. output = CheckpointFunction.apply( - torch.tensor([], requires_grad=True), original_forward, parent_ctx_dict, kwarg_keys, *flat_args + torch.tensor([], requires_grad=True), original_forward, parent_ctx_dict, non_tensor_inputs, *tensor_inputs ) output_requires_grad = parent_ctx_dict["output_requires_grad"] if not isinstance(output, torch.Tensor): @@ -198,10 +254,9 @@ def _checkpointed_forward( # requires_grad output = [x.detach() if not output_requires_grad else x for x in output] - packed_non_tensor_outputs = parent_ctx_dict["packed_non_tensor_outputs"] - if packed_non_tensor_outputs: - output = unpack_non_tensors(output, packed_non_tensor_outputs) - + non_tensor_outputs = parent_ctx_dict["non_tensor_outputs"] + if non_tensor_outputs: + output = deserialize_tensors(output, non_tensor_outputs) else: # If output should not require grad, then detach it, since otherwise it will # always have requires_grad = True due to our dummy tensor input above that @@ -256,32 +311,28 @@ def forward( # type: ignore dummy_tensor_requires_grad: torch.Tensor, run_function: Any, parent_ctx_dict: Dict[str, Any], - kwarg_keys: Tuple[str, ...], - *args: Any, - **kwargs: Any + non_tensor_inputs: Tuple[Any], + *tensor_inputs: torch.Tensor, ) -> Any: - torch_checkpoint.check_backward_validity(args) + torch_checkpoint.check_backward_validity(tensor_inputs) ctx.run_function = run_function - ctx.kwarg_keys = kwarg_keys + ctx.non_tensor_inputs = non_tensor_inputs ctx.fwd_rng_state = get_rng_state() ctx.had_autocast_in_fwd = is_autocast_enabled() - tensor_inputs, packed_non_tensor_inputs = split_non_tensors(args) if parent_ctx_dict["offload"]: ctx.fwd_device = tuple(x.device for x in tensor_inputs) ctx.grad_requirements = tuple(x.requires_grad for x in tensor_inputs) - tensor_inputs = tuple(x.to("cpu", non_blocking=True) for x in tensor_inputs) + ctx.save_for_backward(*(x.to("cpu", non_blocking=True) for x in tensor_inputs)) else: - ctx.fwd_device, ctx.grad_requirements = None, None - - ctx.save_for_backward(*tensor_inputs) - ctx.packed_non_tensor_inputs = packed_non_tensor_inputs + ctx.fwd_device = None + ctx.grad_requirements = None + ctx.save_for_backward(*tensor_inputs) with torch.no_grad(), enable_checkpointing(): - unpacked_args, unpacked_kwargs = unpack_kwargs(kwarg_keys, args) - outputs = run_function(*unpacked_args, **unpacked_kwargs) - the_module = unpacked_args[0] + the_module, args, kwargs = deserialize_tensors(tensor_inputs, non_tensor_inputs) + outputs = run_function(the_module, *args, **kwargs) # Because we run with torch.no_grad(), we can't actually access # outputs.requires_grad. Instead, we manually compute it by @@ -303,13 +354,14 @@ def forward( # type: ignore # Autograd Functions don't like non-Tensor outputs. We can split the # non-Tensor and Tensor outputs, returning the former by reference # through *parent_ctx_dict* and returning the latter directly. - outputs, packed_non_tensor_outputs = split_non_tensors(outputs) - parent_ctx_dict["packed_non_tensor_outputs"] = packed_non_tensor_outputs - - return outputs + tensor_outputs, non_tensor_outputs = serialize_tensors(outputs) + parent_ctx_dict["non_tensor_outputs"] = non_tensor_outputs + return tensor_outputs + else: + return outputs @staticmethod - def backward(ctx: Any, *args: Any) -> Tuple[Optional[Tensor], ...]: + def backward(ctx: Any, *grad_outputs: Any) -> Tuple[Optional[Tensor], ...]: if not torch.autograd._is_checkpoint_valid(): raise RuntimeError("Checkpointing is not compatible with .grad(), please use .backward() if possible") @@ -319,7 +371,7 @@ def backward(ctx: Any, *args: Any) -> Tuple[Optional[Tensor], ...]: tensor_inputs = tuple(t.to(ctx.fwd_device[i], non_blocking=True) for i, t in enumerate(tensor_inputs)) for i, need_grad in enumerate(ctx.grad_requirements): tensor_inputs[i].requires_grad = need_grad - inputs = unpack_non_tensors(tensor_inputs, ctx.packed_non_tensor_inputs) + non_tensor_inputs = ctx.non_tensor_inputs # Store the current states. bwd_rng_state = get_rng_state() @@ -328,26 +380,27 @@ def backward(ctx: Any, *args: Any) -> Tuple[Optional[Tensor], ...]: set_rng_state(ctx.fwd_rng_state) with torch.enable_grad(), enable_recomputing(), autocast(ctx.had_autocast_in_fwd): - unpacked_args, unpacked_kwargs = unpack_kwargs(ctx.kwarg_keys, inputs) - outputs = ctx.run_function(*unpacked_args, **unpacked_kwargs) - tensor_outputs, _ = split_non_tensors(outputs) + the_module, args, kwargs = deserialize_tensors(tensor_inputs, non_tensor_inputs) + outputs = ctx.run_function(the_module, *args, **kwargs) + tensor_outputs, _ = serialize_tensors(outputs) # Set the states back to what it was at the start of this function. set_rng_state(bwd_rng_state) # Run backward() with only Tensors that require grad - outputs_with_grad = [] - args_with_grad = [] + assert len(tensor_outputs) == len(grad_outputs) + tensor_outputs_with_grad = [] + grad_outputs_with_grad = [] for i in range(len(tensor_outputs)): if tensor_outputs[i].requires_grad: - outputs_with_grad.append(tensor_outputs[i]) - args_with_grad.append(args[i]) + tensor_outputs_with_grad.append(tensor_outputs[i]) + grad_outputs_with_grad.append(grad_outputs[i]) - if len(outputs_with_grad) == 0: + if len(tensor_outputs_with_grad) == 0: raise RuntimeError("None of the outputs have requires_grad=True, " "this checkpoint() is not necessary") - torch.autograd.backward(outputs_with_grad, args_with_grad) + torch.autograd.backward(tensor_outputs_with_grad, grad_outputs_with_grad) - grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None for inp in inputs) + grads = tuple(inp.grad for inp in tensor_inputs) return (None, None, None, None) + grads From 55738b5c09463429a44b1453275b53971f790888 Mon Sep 17 00:00:00 2001 From: Junteng Jia Date: Tue, 7 Nov 2023 10:08:03 -0800 Subject: [PATCH 2/3] Extend CheckpointFunction to track all tensor input/output (add comments) --- fairscale/nn/checkpoint/checkpoint_activations.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/fairscale/nn/checkpoint/checkpoint_activations.py b/fairscale/nn/checkpoint/checkpoint_activations.py index 923832e1e..8ee9a0660 100644 --- a/fairscale/nn/checkpoint/checkpoint_activations.py +++ b/fairscale/nn/checkpoint/checkpoint_activations.py @@ -162,6 +162,10 @@ def checkpoint_wrapper( def dfs_simplified(entity): + """ + a helper function that takes a python container (tuple, list, dict) and replace + any tensor with its shape; the main purpose is for printing and debugging + """ if isinstance(entity, tuple): return tuple(dfs_simplified(value) for value in entity) elif isinstance(entity, list): @@ -178,6 +182,11 @@ def dfs_simplified(entity): def serialize_tensors(inputs: Any) -> Tuple[Tuple[torch.Tensor], Any]: + """ + given a python container inputs (tuple, list, dict), which may contain tensors + this function extract the tensors in the container as a tuple, while returning + another container with the tensors replaced with the indices in the tuple + """ tensors = [] def dfs(entity): @@ -199,7 +208,11 @@ def dfs(entity): def deserialize_tensors(tensors: Tuple[torch.Tensor], non_tensors: Any) -> Any: + """ + the reverse function of the serialize_tensors + """ def dfs(entity): + # check SimpleEntity first, since it is a subclass of Tuple if isinstance(entity, SimpleEntity): if entity.is_tensor: return tensors[entity.value] From 8710f0349d148d2664afc7bd9824478df98fb563 Mon Sep 17 00:00:00 2001 From: Junteng Jia Date: Tue, 7 Nov 2023 10:17:52 -0800 Subject: [PATCH 3/3] Extend CheckpointFunction to track all tensor input/output (add comments) --- fairscale/nn/checkpoint/checkpoint_activations.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/fairscale/nn/checkpoint/checkpoint_activations.py b/fairscale/nn/checkpoint/checkpoint_activations.py index 8ee9a0660..0188082b5 100644 --- a/fairscale/nn/checkpoint/checkpoint_activations.py +++ b/fairscale/nn/checkpoint/checkpoint_activations.py @@ -209,7 +209,9 @@ def dfs(entity): def deserialize_tensors(tensors: Tuple[torch.Tensor], non_tensors: Any) -> Any: """ - the reverse function of the serialize_tensors + the reverse function of the serialize_tensors, given a tuple of tensors and + a container with tensor index, it returns a container with the tensor index + replaced with the corresponding tensor """ def dfs(entity): # check SimpleEntity first, since it is a subclass of Tuple