Skip to content

Commit

Permalink
universal-ckp: support megatron-deepspeed llama model (microsoft#4666)
Browse files Browse the repository at this point in the history
Megatron-DeepSpeed's llama implementation of swiglu allocates a single
ColumnParallelLinear layer L, but effectively this parameter is a
container of two Linear layers L1, L2 used for silu(L1(x)) * L2(x)).
This requires special handling in ds_to_universal to create a
representation of L parameter where the slices of L1 and L2 are first
concatenated and then L is created by concatenating L1 and L2.

Signed-off-by: Moshe Island <[email protected]>
Co-authored-by: Moshe Island <[email protected]>
Co-authored-by: Olatunji Ruwase <[email protected]>
Co-authored-by: Logan Adams <[email protected]>
  • Loading branch information
4 people authored Nov 15, 2023
1 parent 00e7dc5 commit ce5e56a
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 5 deletions.
9 changes: 9 additions & 0 deletions deepspeed/checkpoint/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,19 @@
# Parameter splitting/merging
PARAM_SLICE_MAPPINGS = 'param_slice_mappings'
CAT_DIM = "cat_dim"
# Following is a special case where a parameter effectively contains sub parameters.
# As an example, consider Megatron-DeepSpeed GPT SWIGLU implementation (mlp.h_to_4h).
# In this case, a single parameter ia allocated contiguously, but used as separate parameters.
# When using universal checkpoint, we have to normalize the representation of the full parameter.
# We normalize it by concatenating all slices of the sub params and then concatenating the sub params.
# All concat operations are done on CAT_DIM (currently, no support for different concat dims sub params and TP slicing).
# Similarly, load_hp_checkpoint_state has to take the needed actions when loading from universal.
PARAM_N_SUB_PARAMS = "param_n_sub_params"

# Regex list of parameters that require special handling
VOCABULARY_PARAMETER_PATTERNS = 'vocabulary_parameter_patterns'
PIPELINE_REPLICATED_PARAMETER_PATTERNS = 'pipeline_replicated_parameter_patterns'
PARAMETER_TO_AVERAGE_PATTERNS = 'parameter_to_average_patterns'
PARAMETER_WITH_ROW_PARALLELISM_PATTERNS = 'parameter_with_row_parallelism_patterns'
TP_REPLICATED_PARAMETER_PATTERNS = 'tp_replicated_parameter_patterns'
PARAMETER_WITH_2_SUB_PARAMS_CAT_DIM_0 = 'parameter_with_2_sub_params_cat_dim_0'
14 changes: 12 additions & 2 deletions deepspeed/checkpoint/ds_to_universal.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,15 @@
PARAM_SHAPES,
PARAM,
CAT_DIM,
PARAM_N_SUB_PARAMS,
VOCAB_TENSOR,
UNIVERSAL_CHECKPOINT_INFO,
VOCABULARY_PARAMETER_PATTERNS,
PIPELINE_REPLICATED_PARAMETER_PATTERNS,
TP_REPLICATED_PARAMETER_PATTERNS,
PARAMETER_TO_AVERAGE_PATTERNS,
PARAMETER_WITH_ROW_PARALLELISM_PATTERNS,
PARAMETER_WITH_2_SUB_PARAMS_CAT_DIM_0,
)


Expand Down Expand Up @@ -148,7 +150,6 @@ def _merge_zero_shards(param_base_path, state, tp_degree, slice_shape):
shards = [torch.load(p) for p in paths]
slice = torch.cat(shards, dim=0).reshape(slice_shape)
slices.append(slice)

return slices


Expand All @@ -163,8 +164,9 @@ def merge_tp_slices(ds_checkpoint, dir, slice_dir, tp_degree, name_and_shape):
parameters_to_average = universal_checkpoint_info.get(PARAMETER_TO_AVERAGE_PATTERNS, [])
parameters_with_row_parallelism = universal_checkpoint_info.get(PARAMETER_WITH_ROW_PARALLELISM_PATTERNS, [])
vocabulary_parameters = universal_checkpoint_info.get(VOCABULARY_PARAMETER_PATTERNS, [])
parameters_with_2_sub_params_cat_dim_0 = universal_checkpoint_info.get(PARAMETER_WITH_2_SUB_PARAMS_CAT_DIM_0, [])
unmatched_patterns = set(replicated_parameters + parameters_to_average + parameters_with_row_parallelism +
vocabulary_parameters)
vocabulary_parameters + parameters_with_2_sub_params_cat_dim_0)

def get_matched_pattern(patterns_, name_):
matched_ = [pattern_ for pattern_ in patterns_ if re.match(pattern_, name_)]
Expand All @@ -190,6 +192,14 @@ def get_matched_pattern(patterns_, name_):
elif get_matched_pattern(parameters_to_average, name):
param = sum(slices) / len(slices)
# print(f'merge {name} using average')
elif get_matched_pattern(parameters_with_2_sub_params_cat_dim_0, name):
cat_dim = 0
chunked_slices = [torch.chunk(s, 2, dim=cat_dim) for s in slices]
merged_chunks_0 = torch.cat([s[0] for s in chunked_slices], dim=cat_dim)
merged_chunks_1 = torch.cat([s[1] for s in chunked_slices], dim=cat_dim)
param = torch.cat([merged_chunks_0, merged_chunks_1], dim=cat_dim)
ckpt_dict[CAT_DIM] = cat_dim
ckpt_dict[PARAM_N_SUB_PARAMS] = 2
else:
cat_dim = 1 if get_matched_pattern(parameters_with_row_parallelism, name) else 0
# print(f"merge {name} with CAT DIM: {cat_dim}")
Expand Down
14 changes: 11 additions & 3 deletions deepspeed/checkpoint/universal_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import os
import torch
import types
from .constants import (FP32_WEIGHT_KEY, PARAM, VOCAB_TENSOR, CAT_DIM)
from .constants import (FP32_WEIGHT_KEY, PARAM, VOCAB_TENSOR, CAT_DIM, PARAM_N_SUB_PARAMS)


def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size):
Expand Down Expand Up @@ -68,10 +68,18 @@ def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size):
# print(f"{dst_tensor.shape=} {dst_tensor.numel()=}{folder=}")

# since when we do many to 1 on tp we cat sometimes on dim=0 and other times on dim=1 we have to do exactly the same in reverse
# special case is when a single parameter is effectively a container for multiple sub parameters
# (more details at PARAM_N_SUB_PARAMS definition)
chunk_dim = ckpt_dict.get(CAT_DIM, 0)
n_sub_params = ckpt_dict.get(PARAM_N_SUB_PARAMS, 1)
if n_sub_params > 1:
sub_params = full_hp_param.chunk(n_sub_params, dim=chunk_dim)
sub_params_tp_slice = [p.chunk(tp_world_size, dim=chunk_dim)[tp_rank] for p in sub_params]
tp_hp_slice = torch.cat(sub_params_tp_slice, dim=chunk_dim)
else:
# this performs the opposite of cat when merging TP slices
tp_hp_slice = full_hp_param.chunk(tp_world_size, chunk_dim)[tp_rank]

# this performs the opposite of cat when merging TP slices
tp_hp_slice = full_hp_param.chunk(tp_world_size, chunk_dim)[tp_rank]
tp_hp_slice = tp_hp_slice.flatten()

lp_frag_address = hp_mapping.lp_fragment_address
Expand Down

0 comments on commit ce5e56a

Please sign in to comment.