Skip to content

Commit

Permalink
update vlm test
Browse files Browse the repository at this point in the history
  • Loading branch information
xyyimian committed Aug 21, 2024
1 parent c9f3d4f commit 5bc20f8
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 24 deletions.
20 changes: 10 additions & 10 deletions nexa/gguf/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# from .nexa_inference_image import NexaImageInference
# from .nexa_inference_text import NexaTextInference
# from .nexa_inference_vlm import NexaVLMInference
# from .nexa_inference_voice import NexaVoiceInference
from .nexa_inference_image import NexaImageInference
from .nexa_inference_text import NexaTextInference
from .nexa_inference_vlm import NexaVLMInference
from .nexa_inference_voice import NexaVoiceInference

# __all__ = [
# "NexaImageInference",
# "NexaTextInference",
# "NexaVLMInference",
# "NexaVoiceInference",
# ]
__all__ = [
"NexaImageInference",
"NexaTextInference",
"NexaVLMInference",
"NexaVoiceInference",
]
37 changes: 36 additions & 1 deletion nexa/gguf/nexa_inference_vlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,14 @@ def run(self):
logging.error(f"Error during generation: {e}", exc_info=True)
print("\n")

def create_chat_completion(self, messages, temperature, max_tokens, top_k, top_p, stream, stop):
def create_chat_completion(self,
messages,
max_tokens:int = 2048,
temperature: float = 0.2,
top_p: float = 0.95,
top_k: int = 40,
stream=False,
stop=[]):
"""
Generate text completion for a given chat prompt.
Expand All @@ -231,6 +238,34 @@ def create_chat_completion(self, messages, temperature, max_tokens, top_k, top_p
top_p (float): Top-p sampling parameter.
stream (bool): Stream the output.
stop (list): List of stop words for early stopping.
Returns:
Iterator: An iterator of the generated text completion
return format:
{
"choices": [
{
"finish_reason": "stop",
"index": 0,
"message": {
"content": "The 2020 World Series was played in Texas at Globe Life Field in Arlington.",
"role": "assistant"
},
"logprobs": null
}
],
"created": 1677664795,
"id": "chatcmpl-7QyqpwdfhqwajicIEznoc6Q47XAyW",
"model": "gpt-4o-mini",
"object": "chat.completion",
"usage": {
"completion_tokens": 17,
"prompt_tokens": 57,
"total_tokens": 74
}
}
usage: message = completion.choices[0].message.content
"""
return self.model.create_chat_completion(
messages=messages,
Expand Down
17 changes: 4 additions & 13 deletions tests/test_vlm.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import base64
import os

from nexa.gguf.llama import llama
from nexa.gguf.llama.llama_chat_format import NanoLlavaChatHandler
from nexa.gguf import NexaVLMInference
from tests.utils import download_model
from nexa.gguf.lib_utils import is_gpu_available
import tempfile
Expand All @@ -23,18 +22,10 @@ def test_image_generation():
model_url = "https://nexa-model-hub-bucket.s3.us-west-1.amazonaws.com/public/nanoLLaVA/model-fp16.gguf"
mmproj_url = "https://nexa-model-hub-bucket.s3.us-west-1.amazonaws.com/public/nanoLLaVA/projector-fp16.gguf"

model_path = download_model(model_url, temp_dir)
mmproj_path = download_model(mmproj_url, temp_dir)
chat_handler = NanoLlavaChatHandler(clip_model_path=mmproj_path)

llm = llama.Llama(
model_path=model_path,
chat_handler=chat_handler,
n_ctx=2048, # n_ctx should be increased to accommodate the image embedding
n_gpu_layers=-1 if is_gpu_available() else 0, # Uncomment to use GPU acceleration
verbose=False,
model = NexaVLMInference(
model_path="nanollava",
)
output = llm.create_chat_completion(
output = model.create_chat_completion(
messages=[
{
"role": "system",
Expand Down

0 comments on commit 5bc20f8

Please sign in to comment.