diff --git a/deepspeed/checkpoint/deepspeed_checkpoint.py b/deepspeed/checkpoint/deepspeed_checkpoint.py index 9a368b7a0a25..32228d69faa4 100644 --- a/deepspeed/checkpoint/deepspeed_checkpoint.py +++ b/deepspeed/checkpoint/deepspeed_checkpoint.py @@ -120,11 +120,12 @@ def _build_global_state(self): self.global_state[ITERATION_KEY] = sd.get(ITERATION_KEY, 0) self.global_state[ARGS_KEY] = sd.get(ARGS_KEY, None) - def get_zero_checkpoint_state(self, pp_index, tp_index, dp_index) -> dict: + def get_zero_checkpoint_state(self, pp_index, tp_index, dp_index, strip_tensor_paddings: bool = True): return self.zero_checkpoint.get_state_for_rank(pp_index=pp_index, tp_index=tp_index, dp_index=dp_index, - keys_to_ignore=[PARAM_SHAPES]) + keys_to_ignore=[PARAM_SHAPES], + strip_tensor_paddings=strip_tensor_paddings) def get_zero_files(self, pp_index, tp_index, dp_index) -> list: return self.zero_checkpoint.get_files_for_rank(pp_index=pp_index, tp_index=tp_index, dp_index=dp_index) diff --git a/deepspeed/checkpoint/ds_to_universal.py b/deepspeed/checkpoint/ds_to_universal.py index f7b75eee66d0..2b5dc81b2f66 100755 --- a/deepspeed/checkpoint/ds_to_universal.py +++ b/deepspeed/checkpoint/ds_to_universal.py @@ -111,7 +111,10 @@ def _save_checkpoint(file_path, chkpt_sd): def extract_zero_shards(dir, ds_checkpoint, indices_3D): pp_index, tp_index, dp_index = indices_3D - sd = ds_checkpoint.get_zero_checkpoint_state(pp_index=pp_index, tp_index=tp_index, dp_index=dp_index) + sd = ds_checkpoint.get_zero_checkpoint_state(pp_index=pp_index, + tp_index=tp_index, + dp_index=dp_index, + strip_tensor_paddings=False) # pprint(f"Processing {dp_index=} {pp_index=}, {tp_index=}")