Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DeepSpeedCheckpoint: support custom final ln idx #5506

Merged
merged 7 commits into from
May 29, 2024
Merged
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 20 additions & 10 deletions deepspeed/checkpoint/deepspeed_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# DeepSpeed Team

import os
import re
from typing import Dict
import torch

Expand All @@ -21,6 +22,7 @@
ARGS_KEY = 'args'
CHECKPOINT_INFO_KEY = 'checkpoint_info'
ITERATION_KEY = 'iteration'
LAYER_FILE_PREFIX_PATTERN = r'layer_(\d+)-model_.*'

SEQUENTIAL_LAYERS = [
'input_layernorm.weight', 'input_layernorm.bias', 'self_attention.dense.bias', 'post_attention_layernorm.weight',
Expand All @@ -32,7 +34,13 @@

class DeepSpeedCheckpoint(object):

def __init__(self, dir, tp_degree=None, pp_degree=None, dp_degree=None):
def __init__(self,
dir,
tp_degree=None,
pp_degree=None,
dp_degree=None,
final_layer_norm_idx=FINAL_LAYER_NORM_INDEX):
self.final_layer_norm_idx = final_layer_norm_idx
self.dir = dir

pipeline_parallel = len(get_files_with_prefix(get_files(dir), LAYER_FILE_PREFIX)) > 0
Expand Down Expand Up @@ -73,7 +81,7 @@ def __init__(self, dir, tp_degree=None, pp_degree=None, dp_degree=None):
self.pp_to_transformer_map = self._build_pp_transformer_map()
self.transformer_file_map = self._build_transformer_file_map()
self.tp_to_embedding_map = self._build_tp_other_layer_map(EMBEDDING_LAYER_INDEX)
self.tp_to_final_norm_map = self._build_tp_other_layer_map(FINAL_LAYER_NORM_INDEX)
self.tp_to_final_norm_map = self._build_tp_other_layer_map(self.final_layer_norm_idx)
self._build_global_state()

def is_change_tp_degree(self):
Expand Down Expand Up @@ -125,7 +133,7 @@ def get_embedding_layer_id(self):
return self.layer_keys[EMBEDDING_LAYER_INDEX]

def get_final_norm_layer_id(self):
return self.layer_keys[FINAL_LAYER_NORM_INDEX]
return self.layer_keys[self.final_layer_norm_idx]

def get_iteration(self):
if not ITERATION_KEY in self.global_state:
Expand Down Expand Up @@ -214,7 +222,7 @@ def get_2d_parallel_files(self, tp_index: int, pp_index: int) -> list:
def _build_pp_transformer_map(self):
data_map = {}
if self.pp_degree > 0:
transformer_layers = self.layer_keys[1:-1]
transformer_layers = self.layer_keys[1:self.final_layer_norm_idx]
layers_per_pp = len(transformer_layers) // self.pp_degree
data_map = {
i: transformer_layers[i * layers_per_pp:(i + 1) * layers_per_pp]
Expand All @@ -229,7 +237,7 @@ def _dump_mapping(self, data_map, map_tag=None):
print(f'{k} = {v}')

def _build_transformer_file_map(self):
transformer_layer_keys = self.layer_keys[1:-1]
transformer_layer_keys = self.layer_keys[1:self.final_layer_norm_idx]
file_map = {}
# XXX: this is not guaranteed
layers_per_pp = 1
Expand All @@ -238,7 +246,7 @@ def _build_transformer_file_map(self):
#print(f"{transformer_layer_keys} {layers_per_pp}")
for key_index, layer_key in enumerate(transformer_layer_keys):
pp_index = key_index // layers_per_pp
layer_files = get_files_with_prefix(self.layer_files, layer_key)
layer_files = get_files_with_prefix(self.layer_files, layer_key + '-')
layer_file_partitions = partition_data(layer_files, self.tp_degree)
for tp_index in range(self.tp_degree):
map_key = (tp_index, pp_index)
Expand All @@ -263,11 +271,13 @@ def validate_files(self):

def _get_layer_keys(self):
key_set = set()
key_len = len(LAYER_FILE_PREFIX) + 2
for file_path in self.layer_files:
_, fname = os.path.split(file_path)
key_set.add(fname[:key_len])
return sorted(list(key_set))
_, fname = os.path.split(file_path)
layer_id = re.search(LAYER_FILE_PREFIX_PATTERN, fname).group(1)
key_set.add(layer_id)
sorted_ids = sorted(list(key_set), key=int)
layer_keys = [LAYER_FILE_PREFIX + str(layer_id) for layer_id in sorted_ids]
return layer_keys

def _merge_state_dicts(self, sd_list):
merged_sd = {}
Expand Down
Loading