Skip to content

Commit

Permalink
feat: allow ChatLlamaCpp offline use
Browse files Browse the repository at this point in the history
  • Loading branch information
Benjoyo committed Apr 22, 2024
1 parent 66dffbe commit 9cc5f06
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 15 deletions.
1 change: 0 additions & 1 deletion bpm_ai_inference/llm/llama_cpp/_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,3 @@
DEFAULT_QUANT_SMALL = "*Q2_K.gguf"
DEFAULT_TEMPERATURE = 0.0
DEFAULT_MAX_RETRIES = 8
a
29 changes: 20 additions & 9 deletions bpm_ai_inference/llm/llama_cpp/llama_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,23 @@
from typing import Dict, Any, Optional, List

from bpm_ai_core.llm.common.llm import LLM
from bpm_ai_core.llm.common.message import ChatMessage, AssistantMessage, SystemMessage, ToolCallMessage
from bpm_ai_core.llm.common.message import ChatMessage, AssistantMessage, ToolCallMessage
from bpm_ai_core.llm.common.tool import Tool
from bpm_ai_core.llm.openai_chat.util import messages_to_openai_dicts
from bpm_ai_core.prompt.prompt import Prompt
from bpm_ai_core.tracing.tracing import Tracing
from bpm_ai_core.util.json_schema import expand_simplified_json_schema
from llama_cpp.llama_grammar import json_schema_to_gbnf, LlamaGrammar

from bpm_ai_inference.llm.llama_cpp._constants import DEFAULT_MODEL, DEFAULT_TEMPERATURE, DEFAULT_MAX_RETRIES, \
DEFAULT_QUANT_BALANCED
from bpm_ai_inference.llm.llama_cpp.util import messages_to_llama_dicts
from bpm_ai_inference.util.files import find_file
from bpm_ai_inference.util.hf import hf_home

logger = logging.getLogger(__name__)

try:
from llama_cpp import Llama, CreateChatCompletionResponse, llama_grammar
from llama_cpp.llama_grammar import json_schema_to_gbnf, LlamaGrammar

has_llama_cpp_python = True
except ImportError:
Expand All @@ -39,6 +40,7 @@ def __init__(
filename: str = DEFAULT_QUANT_BALANCED,
temperature: float = DEFAULT_TEMPERATURE,
max_retries: int = DEFAULT_MAX_RETRIES,
force_offline: bool = False
):
if not has_llama_cpp_python:
raise ImportError('llama-cpp-python is not installed')
Expand All @@ -48,12 +50,21 @@ def __init__(
max_retries=max_retries,
retryable_exceptions=[]
)
self.llm = Llama.from_pretrained(
repo_id=model,
filename=filename,
n_ctx=4096,
verbose=False
)
n_ctx = 4096
if force_offline:
self.llm = Llama(
model_path=find_file(hf_home() + "hub/models--" + model.replace("/", "--"), filename),
n_ctx=n_ctx,
verbose=False
)
else:
self.llm = Llama.from_pretrained(
repo_id=model,
filename=filename,
n_ctx=n_ctx,
verbose=False
)


async def _generate_message(
self,
Expand Down
15 changes: 15 additions & 0 deletions bpm_ai_inference/util/files.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import glob
import os


def find_file(path: str, filename_pattern: str):
"""
Find a file with the pattern below the given path.
Return the first occurrence if there are multiple matches.
Returns:
str: The path to the first matching file, or None if no file is found.
"""
search_pattern = os.path.join(path, "**", filename_pattern)
matching_files = glob.glob(search_pattern, recursive=True)
return matching_files[0] if matching_files else None
5 changes: 5 additions & 0 deletions bpm_ai_inference/util/hf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
import os


def hf_home():
return os.getenv('HF_HOME', os.path.join(os.path.expanduser("~"), ".cache", "huggingface"))
8 changes: 3 additions & 5 deletions bpm_ai_inference/util/optimum.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from optimum.onnxruntime.configuration import OptimizationConfig, AutoQuantizationConfig, AutoOptimizationConfig
from transformers import AutoTokenizer

from bpm_ai_inference.util.hf import hf_home

logger = logging.getLogger(__name__)

FILENAME_ONNX = "model.onnx"
Expand All @@ -38,7 +40,7 @@ def _holisticon_onnx_repository_id(model_name: str) -> str:

def get_optimized_model(model: str, task: str, optimization_level: int = None, push_to_hub: bool = False):
model_name = model
model_dir = _hf_home() + "/onnx/" + model.replace("/", "--")
model_dir = hf_home() + "/onnx/" + model.replace("/", "--")
tokenizer = AutoTokenizer.from_pretrained(model)

optimization_level = optimization_level or int(os.getenv("OPTIMIZATION_LEVEL", "2"))
Expand All @@ -62,10 +64,6 @@ def get_optimized_model(model: str, task: str, optimization_level: int = None, p
return model, tokenizer


def _hf_home():
return os.getenv('HF_HOME', os.path.join(os.path.expanduser("~"), ".cache", "huggingface"))


def _check_exists_on_hub(repository_id: str, filename: str) -> str | None:
fs = HfFileSystem()
if not fs.exists(repository_id):
Expand Down

0 comments on commit 9cc5f06

Please sign in to comment.