Skip to content

Commit

Permalink
Inference Checkpoints in V2 (microsoft#4664)
Browse files Browse the repository at this point in the history
Add capability to snapshot an engine and resume from it, reducing load
times for large models. Includes new unit tests to validate this
pipeline on a small scale.

---------

Co-authored-by: Jeff Rasley <[email protected]>
Co-authored-by: Michael Wyatt <[email protected]>
Co-authored-by: Ammar Ahmad Awan <[email protected]>
Co-authored-by: Masahiro Tanaka <[email protected]>
Co-authored-by: Logan Adams <[email protected]>
Co-authored-by: Reza Yazdani <[email protected]>
Co-authored-by: Reza Yazdani <[email protected]>
  • Loading branch information
8 people authored Nov 14, 2023
1 parent c1ba6a1 commit 5411030
Show file tree
Hide file tree
Showing 98 changed files with 928 additions and 4,601 deletions.
1 change: 1 addition & 0 deletions .github/workflows/nv-accelerate-v100.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ on:
- 'docs/**'
- 'blogs/**'
- 'deepspeed/inference/v2/**'
- "tests/unit/inference/v2/**"
merge_group:
branches: [ master ]
schedule:
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/nv-inference.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ on:
- 'docs/**'
- 'blogs/**'
- 'deepspeed/inference/v2/**'
- "tests/unit/inference/v2/**"
merge_group:
branches: [ master ]
schedule:
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/nv-lightning-v100.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ on:
- 'docs/**'
- 'blogs/**'
- 'deepspeed/inference/v2/**'
- "tests/unit/inference/v2/**"
merge_group:
branches: [ master ]
schedule:
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/nv-megatron.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ on:
- 'docs/**'
- 'blogs/**'
- 'deepspeed/inference/v2/**'
- "tests/unit/inference/v2/**"
merge_group:
branches: [ master ]
schedule:
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/nv-pre-compile-ops.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ on:
- 'docs/**'
- 'blogs/**'
- 'deepspeed/inference/v2/**'
- "tests/unit/inference/v2/**"
merge_group:
branches: [ master ]
schedule:
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/nv-torch-latest-cpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ on:
- 'docs/**'
- 'blogs/**'
- 'deepspeed/inference/v2/**'
- "tests/unit/inference/v2/**"
merge_group:
branches: [ master ]
schedule:
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/nv-torch-latest-v100.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ on:
- 'docs/**'
- 'blogs/**'
- 'deepspeed/inference/v2/**'
- "tests/unit/inference/v2/**"
merge_group:
branches: [ master ]
schedule:
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/nv-transformers-v100.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ on:
- 'docs/**'
- 'blogs/**'
- 'deepspeed/inference/v2/**'
- "tests/unit/inference/v2/**"
merge_group:
branches: [ master ]
schedule:
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/inference/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@
# DeepSpeed Team
from .v2 import RaggedInferenceEngineConfig, DeepSpeedTPConfig
from .v2.engine_v2 import InferenceEngineV2
from .v2 import build_hf_engine
from .v2 import build_hf_engine, build_engine_from_ds_checkpoint
2 changes: 1 addition & 1 deletion deepspeed/inference/v2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@
# DeepSpeed Team
from .config_v2 import RaggedInferenceEngineConfig, DeepSpeedTPConfig
from .engine_v2 import InferenceEngineV2
from .engine_factory import build_hf_engine
from .engine_factory import build_hf_engine, build_engine_from_ds_checkpoint
2 changes: 1 addition & 1 deletion deepspeed/inference/v2/allocator.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def on_device(method) -> torch.Tensor:
def wrapped(self, *args, **kwargs):
tensor = method(self, *args, **kwargs)
if isinstance(tensor, torch.Tensor):
return tensor.to(get_accelerator().current_device()).contiguous()
return tensor.to(get_accelerator().current_device())
return tensor

return wrapped
2 changes: 2 additions & 0 deletions deepspeed/inference/v2/checkpoint/huggingface_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ def parameters(self) -> Iterable[Tuple[str, torch.Tensor]]:
param = checkpoint_sd[param_name]
yield param_name, param

del checkpoint_sd


if __name__ == "__main__":
# To test, add your auth_token here and run `python huggingface_engine.py`
Expand Down
115 changes: 85 additions & 30 deletions deepspeed/inference/v2/engine_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,49 +3,104 @@

# DeepSpeed Team

import json
import logging
from typing import Any
import os
import pickle
from packaging import version

from .engine_v2 import InferenceEngineV2
from .config_v2 import RaggedInferenceEngineConfig
from .checkpoint import HuggingFaceCheckpointEngine
from .logging import inference_logger
from .model_implementations import (
OPTPolicy,
Llama2Policy,
MistralPolicy,
)
from .model_implementations.inference_policy_base import POLICIES, InferenceV2Policy
from .model_implementations.flat_model_helpers import make_metadata_filename, ModelMetadata


def build_engine_from_ds_checkpoint(path: str,
engine_config: RaggedInferenceEngineConfig,
debug_level: int = logging.INFO) -> InferenceEngineV2:
"""
Creates an engine from a checkpoint saved by ``InferenceEngineV2``.
Arguments:
path: Path to the checkpoint. This does not need to point to any files in particular,
just the directory containing the checkpoint.
engine_config: Engine configuration. See ``RaggedInferenceEngineConfig`` for details.
debug_level: Logging level to use. Unless you are actively seeing issues, the recommended
value is ``logging.INFO``.
Returns:
Fully initialized inference engine ready to serve queries.
"""

inference_logger(level=debug_level)
# Load metadata, for grabbing the policy name we'll have all ranks just check for
# rank 0.
metadata_filename = make_metadata_filename(path, 0, engine_config.tensor_parallel.tp_size)
metadata = json.load(open(metadata_filename, "r"))
metadata = ModelMetadata.parse_raw(metadata)

# Get the policy
try:
policy_cls: InferenceV2Policy = POLICIES[metadata.policy]
except KeyError:
raise ValueError(f"Unknown policy {metadata.policy} for model {path}")

# Load the model config
model_config = pickle.load(open(os.path.join(path, "ds_model_config.pkl"), "rb"))
policy = policy_cls(model_config, inf_checkpoint_path=path)

return InferenceEngineV2(policy, engine_config)


def build_hf_engine(path: str,
engine_config: RaggedInferenceEngineConfig,
debug_level: int = logging.INFO,
random_weights_config: Any = None,
fill_random: bool = False) -> InferenceEngineV2:
debug_level: int = logging.INFO) -> InferenceEngineV2:
"""
Build an InferenceV2 engine for HuggingFace models.
Build an InferenceV2 engine for HuggingFace models. This can accept both a HuggingFace
model name or a path to an Inference-V2 checkpoint.
Arguments:
path: Path to the checkpoint. This does not need to point to any files in particular,
just the directory containing the checkpoint.
engine_config: Engine configuration. See ``RaggedInferenceEngineConfig`` for details.
debug_level: Logging level to use. Unless you are actively seeing issues, the recommended
value is ``logging.INFO``.
Returns:
Fully initialized inference engine ready to serve queries.
"""
# Set up logging
inference_logger(level=debug_level)

# get HF checkpoint engine
checkpoint_engine = HuggingFaceCheckpointEngine(path)

# get model config from HF AutoConfig
model_config = checkpoint_engine.model_config

# get the policy
# TODO: generalize this to other models
if model_config.model_type == "opt":
from .model_implementations.opt.policy import OPTPolicy
policy = OPTPolicy(checkpoint_engine, model_config)
elif model_config.model_type == "llama":
from .model_implementations.llama_v2.llama_v2_policy import Llama2Policy
policy = Llama2Policy(checkpoint_engine, model_config)
elif model_config.model_type == "mistral":
from .model_implementations.mistral.policy import MistralPolicy
# Ensure we're using the correct version of transformers for mistral
import transformers
assert version.parse(transformers.__version__) >= version.parse("4.34.0"), \
f"Mistral requires transformers >= 4.34.0, you have version {transformers.__version__}"
policy = MistralPolicy(checkpoint_engine, model_config)
if os.path.exists(os.path.join(path, "ds_model_config.pkl")):
return build_engine_from_ds_checkpoint(path, engine_config, debug_level=debug_level)
else:
raise ValueError(f"Unsupported model type {model_config.model_type}")
# Set up logging
inference_logger(level=debug_level)
# get HF checkpoint engine
checkpoint_engine = HuggingFaceCheckpointEngine(path)

return InferenceEngineV2(policy, engine_config)
# get model config from HF AutoConfig
model_config = checkpoint_engine.model_config

# get the policy
# TODO: generalize this to other models
if model_config.model_type == "opt":
policy = OPTPolicy(model_config, checkpoint_engine=checkpoint_engine)
elif model_config.model_type == "llama":
policy = Llama2Policy(model_config, checkpoint_engine=checkpoint_engine)
elif model_config.model_type == "mistral":
# Ensure we're using the correct version of transformers for mistral
import transformers
assert version.parse(transformers.__version__) >= version.parse("4.34.0"), \
f"Mistral requires transformers >= 4.34.0, you have version {transformers.__version__}"
policy = MistralPolicy(model_config, checkpoint_engine=checkpoint_engine)
else:
raise ValueError(f"Unsupported model type {model_config.model_type}")

return InferenceEngineV2(policy, engine_config)
32 changes: 31 additions & 1 deletion deepspeed/inference/v2/engine_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# DeepSpeed Team

import os
import json
import pickle
from typing import Iterable, Tuple

import torch
Expand All @@ -17,6 +19,8 @@
from .logging import inference_logger
from .ragged import DSStateManager, RaggedBatchWrapper, PlaceholderSequenceDescriptor
from .scheduling_utils import SchedulingError, SchedulingResult
from .model_implementations.flat_model_helpers import make_param_filename, make_metadata_filename
from .model_implementations.inference_model_base import DSInferenceModelBase

from .config_v2 import RaggedInferenceEngineConfig

Expand All @@ -30,7 +34,7 @@ class InferenceEngineV2:
Configuration of the inference engine.
"""

#_model: DSInferenceModelBase
_model: DSInferenceModelBase
"""
Inference model supporting ragged inference.
"""
Expand All @@ -47,6 +51,13 @@ def free_blocks(self) -> int:
"""
return self._state_manager.free_blocks

@property
def model(self) -> DSInferenceModelBase:
"""
The model implementation.
"""
return self._model

def __init__(self, policy: InferenceV2Policy, engine_config: RaggedInferenceEngineConfig) -> None:
"""
Create the Inference V2 engine.
Expand Down Expand Up @@ -215,3 +226,22 @@ def flush(self, uid: int) -> None:
uid (int): The UID of the sequence to flush.
"""
self._state_manager.flush_sequence(uid)

def serialize(self, save_path: str) -> None:
"""
Serialize the model to a file.
Arguments:
path (str): Path to the file to serialize to.
"""
param_file_name = make_param_filename(save_path, self._model.tp_rank, self._model.tp_size)
metadata_file_name = make_metadata_filename(save_path, self._model.tp_rank, self._model.tp_size)

# Save the flattened parameters

torch.save(self._model.flattened_params, param_file_name)

json.dump(self._model.flattened_param_metadata.json(), open(metadata_file_name, "w"))

if self._model.tp_rank == 0:
pickle.dump(self._model._config, open(os.path.join(save_path, "ds_model_config.pkl"), "wb"))
89 changes: 89 additions & 0 deletions deepspeed/inference/v2/inference_parameter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

from typing import Dict

import torch

CORE_PARAM = "_ds_core_param_key"

STR_TO_DTYPE = {
"torch.float32": torch.float32,
"torch.float64": torch.float64,
"torch.float16": torch.float16,
"torch.bfloat16": torch.bfloat16,
"torch.int64": torch.int64,
"torch.int32": torch.int32,
"torch.int16": torch.int16,
"torch.int8": torch.int8,
"torch.uint8": torch.uint8,
"torch.bool": torch.bool,
}


class InferenceParameter(torch.Tensor):
"""
An extension of the torch.Tensor class to support our inference focused features. One important
thing to note here is that an InferenceParam can be used a torch.Tensor, but outputs of
torch.Tensor operations will not be InferenceParams.
"""

@staticmethod
def __new__(cls, tensor, *args, **kwargs):
new_tensor = super().__new__(cls, tensor, *args, **kwargs)
if hasattr(tensor, "_aux_attrs"):
setattr(new_tensor, "_aux_attrs", tensor.aux_attrs)
return new_tensor

def to(self, *args, **kwargs):
new_tensor = super().to(*args, **kwargs)
if hasattr(self, "_aux_attrs"):
setattr(new_tensor, "_aux_attrs", self.aux_attrs)
try:
_ = torch.device(args[0])
for name, attr in new_tensor.aux_attrs.items():
new_attr = attr.to(*args, **kwargs)
setattr(new_tensor, name, new_attr)
new_tensor.aux_attrs[name] = new_attr
except:
pass

return new_tensor

@classmethod
def initialize(cls, core_param: torch.Tensor, **kwargs) -> 'InferenceParameter':
"""
Create the inference parameter.
"""
param = InferenceParameter(core_param)
setattr(param, "_aux_attrs", kwargs)

for attr_name, attr in kwargs.items():
if hasattr(param, attr_name):
raise ValueError(f"Attribute {attr_name} already exists on param.")

if not isinstance(attr, torch.Tensor):
raise ValueError(f"Attribute {attr_name} must be a tensor.")

setattr(param, attr_name, attr)

return param

@classmethod
def initialize_raw(self, **kwargs) -> 'InferenceParameter':
"""
All kwargs must be torch.Tensors and must include the core parameter.
"""
if CORE_PARAM not in kwargs:
raise ValueError(f"Must provide core parameter, with key {CORE_PARAM}.")

return InferenceParameter.initialize(kwargs[CORE_PARAM], **kwargs)

@property
def aux_attrs(self) -> Dict[str, torch.Tensor]:
"""
Dictionary of auxiliary attributes.
"""
return self._aux_attrs
5 changes: 5 additions & 0 deletions deepspeed/inference/v2/model_implementations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,8 @@
from .inference_transformer_base import DSTransformerModelBase, DSMoETransformerModelBase
from .inference_policy_base import InferenceV2Policy, ContainerMap
from .sharding import *

# Model Implementations
from .llama_v2 import *
from .opt import *
from .mistral import *
File renamed without changes.
Loading

0 comments on commit 5411030

Please sign in to comment.