Skip to content

Commit

Permalink
[MII] catch error wrt HF version and Mistral (microsoft#4634)
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffra authored Nov 13, 2023
1 parent 0abf4df commit 0a6095f
Showing 1 changed file with 5 additions and 0 deletions.
5 changes: 5 additions & 0 deletions deepspeed/inference/v2/engine_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import logging
from typing import Any
from packaging import version

from .engine_v2 import InferenceEngineV2
from .config_v2 import RaggedInferenceEngineConfig
Expand Down Expand Up @@ -39,6 +40,10 @@ def build_hf_engine(path: str,
policy = Llama2Policy(checkpoint_engine, model_config)
elif model_config.model_type == "mistral":
from .model_implementations.mistral.policy import MistralPolicy
# Ensure we're using the correct version of transformers for mistral
import transformers
assert version.parse(transformers.__version__) >= version.parse("4.34.0"), \
f"Mistral requires transformers >= 4.34.0, you have version {transformers.__version__}"
policy = MistralPolicy(checkpoint_engine, model_config)
else:
raise ValueError(f"Unsupported model type {model_config.model_type}")
Expand Down

0 comments on commit 0a6095f

Please sign in to comment.