Skip to content

Commit

Permalink
infV2 fix for OPT size variants (microsoft#4694)
Browse files Browse the repository at this point in the history
Co-authored-by: Jeff Rasley <[email protected]>
  • Loading branch information
mrwyattii and jeffra authored Nov 17, 2023
1 parent ce0ebda commit a3926bb
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 9 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/nv-a6000.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ jobs:
python -c "import torch; print('CUDA available:', torch.cuda.is_available())"
- name: Install transformers
run: |
git clone https://github.com/huggingface/transformers
git clone --depth=1 https://github.com/huggingface/transformers
cd transformers
git rev-parse --short HEAD
python -m pip install .
Expand All @@ -56,7 +56,7 @@ jobs:
python -m pytest --color=yes --durations=0 --verbose -rF -m 'inference_v2_ops' unit/ --torch_ver="2.0" --cuda_ver="12"
- name: MII unit tests
run: |
git clone https://github.com/microsoft/DeepSpeed-MII.git
git clone --depth=1 https://github.com/microsoft/DeepSpeed-MII.git
cd DeepSpeed-MII
pip install .[dev]
cd tests
Expand Down
4 changes: 4 additions & 0 deletions deepspeed/inference/v2/engine_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,10 @@ def build_hf_engine(path: str,
# get the policy
# TODO: generalize this to other models
if model_config.model_type == "opt":
if not model_config.do_layer_norm_before:
raise ValueError(
"Detected OPT-350m model. This model is not currently supported. If this is not the 350m model, please open an issue: https://github.com/microsoft/DeepSpeed-MII/issues"
)
policy = OPTPolicy(model_config, checkpoint_engine=checkpoint_engine)
elif model_config.model_type == "llama":
policy = Llama2Policy(model_config, checkpoint_engine=checkpoint_engine)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

# DeepSpeed Team

import re
from typing import Type

import torch
Expand Down Expand Up @@ -277,6 +278,30 @@ def set_dependency(self, dep_name: str, dep_value: torch.Tensor) -> None:
dep_name (str): The name of the dependency to set.
dep_value (torch.Tensor): The value to set the dependency to.
"""

def get_dep_name_target(dep_name: str) -> str:
"""
Helper method for getting the target name for a dependency from the
mapping params. Tries to match exact string first, then looks for
wildcards and attempts regex matching. Will return empty string if
no match found.
"""
if dep_name in self.mapping_params:
# If we have an exact match, it's a direct mapping and we can
# immediately set the value.
return self.mapping_params[dep_name]

matched_targets = []
for key, target in self.mapping_params.items():
regex_key = key.replace("*", ".*")
if re.match(regex_key, dep_name):
matched_targets.append(target)
if len(matched_targets) > 1:
raise ValueError(f"Multiple targets matched for dependency {dep_name}: {matched_targets}")
if matched_targets:
return matched_targets[0]
return ""

if dep_name in self.mapping_params:
# If we have an exact match, it's a direct mapping and we can immediately set
# the value.
Expand Down Expand Up @@ -309,6 +334,22 @@ def set_dependency(self, dep_name: str, dep_value: torch.Tensor) -> None:
target_dependency = getattr(target_param, target_dependency_name)
target_dependency[target_idx] = dep_value
return

# TODO: Refactor this with the help of cmikeh2
# We should be able to combine this with the wildcard matching above.
target = get_dep_name_target(dep_name)
if target:
# Convert single targets to a list for consistency
if isinstance(target, str):
target = [target]

for target_name in target:
# Double setting doesn't set the attribute correctly, so we do a getattr then setattr
target_param_name, target_dependency_name = target_name.split(".")
target_param = getattr(self, target_param_name)
setattr(target_param, target_dependency_name, dep_value)
return

raise ValueError(
"Could not find a mapping for dependency \"{}\". Check that it is included in the ``MAPPING_PARAMS``. See docstring for more on ``MAPPING_PARAMS``"
.format(dep_name))
9 changes: 4 additions & 5 deletions deepspeed/inference/v2/model_implementations/opt/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,8 @@ class OPTNonTransformerContainer(LayerContainer):
final_norm_b: NormParameter

PARAM_MAPPING = {
"model.decoder.embed_tokens.weight": "word_emb.params",
"model.decoder.embed_positions.weight": "word_emb_pos.params",
"model.decoder.final_layer_norm.weight": "final_norm_w.params",
"model.decoder.final_layer_norm.bias": "final_norm_b.params",
"lm_head.weight": "word_unembed.params",
"*decoder.embed_tokens.weight": ["word_emb.params", "word_unembed.params"],
"*decoder.embed_positions.weight": "word_emb_pos.params",
"*decoder.final_layer_norm.weight": "final_norm_w.params",
"*decoder.final_layer_norm.bias": "final_norm_b.params",
}
4 changes: 2 additions & 2 deletions deepspeed/inference/v2/model_implementations/opt/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ def build_container_map(self) -> ContainerMap:

transformer_containers = [OPTTransformerContainer(self.model) for _ in range(self.model.num_layers)]

map.set_transformer_params(['model.decoder.layers'], transformer_containers)
map.set_transformer_params(['model.decoder.layers', 'decoder.layers'], transformer_containers)

map.set_non_transformer_params(OPTNonTransformerContainer(self.model))

map.set_unmapped_params([])
map.set_unmapped_params(['lm_head.weight'])

return map

0 comments on commit a3926bb

Please sign in to comment.