Skip to content

Commit

Permalink
Merge pull request #93 from NexaAI/david/bugfix
Browse files Browse the repository at this point in the history
correctly apply chat_format
  • Loading branch information
zhiyuan8 authored Sep 16, 2024
2 parents 353087d + ef01f95 commit ac900e0
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 9 deletions.
16 changes: 16 additions & 0 deletions nexa/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,12 +186,28 @@ class ModelType(Enum):

NEXA_RUN_CHAT_TEMPLATE_MAP = {
"llama2": "llama-2",
"llama-2-7b-chat": "llama-2",
"llama3": "llama-3",
"meta-llama-3-8b-instruct": "llama-3",
"llama3.1": "llama-3",
"meta-llama-3.1-8b-instruct": "llama-3",
"gemma": "gemma",
"gemma-1.1-2b-instruct": "gemma",
"gemma-1.1-7b-instruct": "gemma",
"gemma-2b-instruct": "gemma",
"gemma-7b-instruct": "gemma",
"gemma-2-2b-instruct": "gemma",
"gemma-2-9b-instruct": "gemma",
"qwen1.5": "qwen",
"qwen1.5-7b-instruct": "qwen",
"codeqwen1.5-7b-instruct": "qwen",
"qwen2": "qwen",
"qwen2-0.5b-instruct": "qwen",
"qwen2-1.5b-instruct": "qwen",
"qwen2-7b-instruct": "qwen",
"mistral": "mistral-instruct",
"mistral-7b-instruct-v0.3": "mistral-instruct",
"mistral-7b-instruct-v0.2": "mistral-instruct",
}

NEXA_RUN_COMPLETION_TEMPLATE_MAP = {
Expand Down
2 changes: 1 addition & 1 deletion nexa/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def pull_model_from_official(model_path):


def get_run_type_from_model_path(model_path):
model_name, model_version = model_path.split(":")
model_name, _ = model_path.split(":")
return NEXA_OFFICIAL_MODELS_TYPE.get(model_name, ModelType.NLP).value


Expand Down
12 changes: 4 additions & 8 deletions nexa/gguf/nexa_inference_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,16 +64,12 @@ def __init__(self, model_path, local_path=None, stop_words=None, **kwargs):
exc_info=True,
)
exit(1)

self.stop_words = (
stop_words if stop_words else NEXA_STOP_WORDS_MAP.get(model_path, [])
)
self.profiling = kwargs.get("profiling", False)

self.chat_format = NEXA_RUN_CHAT_TEMPLATE_MAP.get(model_path, None)
self.completion_template = NEXA_RUN_COMPLETION_TEMPLATE_MAP.get(
model_path, None
)
model_name = model_path.split(":")[0].lower()
self.stop_words = (stop_words if stop_words else NEXA_STOP_WORDS_MAP.get(model_name, []))
self.chat_format = NEXA_RUN_CHAT_TEMPLATE_MAP.get(model_name, None)
self.completion_template = NEXA_RUN_COMPLETION_TEMPLATE_MAP.get(model_name, None)

if not kwargs.get("streamlit", False):
self._load_model()
Expand Down

0 comments on commit ac900e0

Please sign in to comment.