Skip to content

Commit

Permalink
SigLip Service
Browse files Browse the repository at this point in the history
  • Loading branch information
JoyboyBrian committed Dec 17, 2024
1 parent cbcab3a commit 2577104
Show file tree
Hide file tree
Showing 5 changed files with 107 additions and 15 deletions.
18 changes: 18 additions & 0 deletions nexa/cli/entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,15 @@ def run_eval_tasks(args):
print("Please run: pip install 'nexaai[eval]'")
return

def run_siglip_server(args):
from nexa.siglip.nexa_siglip_server import run_nexa_ai_siglip_service
run_nexa_ai_siglip_service(
image_dir=args.image_dir,
host=args.host,
port=args.port,
reload=args.reload
)

def run_embedding_generation(args):
kwargs = {k: v for k, v in vars(args).items() if v is not None}
model_path = kwargs.pop("model_path")
Expand Down Expand Up @@ -599,6 +608,13 @@ def main():
perf_eval_group.add_argument("--device", type=str, help="Device to run performance evaluation on, choose from 'cpu', 'cuda', 'mps'", default="cpu")
perf_eval_group.add_argument("--new_tokens", type=int, help="Number of new tokens to evaluate", default=100)

# Siglip Server
siglip_parser = subparsers.add_parser("siglip", help="Run the Nexa AI SigLIP Service")
siglip_parser.add_argument("--image_dir", type=str, help="Directory of images to load")
siglip_parser.add_argument("--host", type=str, default="localhost", help="Host to bind the server to")
siglip_parser.add_argument("--port", type=int, default=8100, help="Port to bind the server to")
siglip_parser.add_argument("--reload", action="store_true", help="Enable automatic reloading on code changes")

args = parser.parse_args()

if args.command == "run":
Expand Down Expand Up @@ -627,6 +643,8 @@ def main():
run_onnx_inference(args)
elif args.command == "eval":
run_eval_tasks(args)
elif args.command == "siglip":
run_siglip_server(args)
elif args.command == "embed":
run_embedding_generation(args)
elif args.command == "pull":
Expand Down
Empty file added nexa/siglip/__init__.py
Empty file.
91 changes: 77 additions & 14 deletions nexa/siglip/nexa_siglip_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,24 @@
from pydantic import BaseModel
from fastapi import Request
from fastapi.responses import HTMLResponse
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
import os
import socket
import time
import argparse
from PIL import Image
import torch
from transformers import AutoProcessor, AutoModel

app = FastAPI(title="Nexa AI SigLIP Image-Text Matching Service")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Allows all origins
allow_credentials=True,
allow_methods=["*"], # Allows all methods
allow_headers=["*"], # Allows all headers
)

# Global variables
hostname = socket.gethostname()
Expand All @@ -23,6 +33,7 @@ class ImagePathRequest(BaseModel):
class SearchResponse(BaseModel):
image_path: str
similarity_score: float
latency: float

def init_model():
"""Initialize SigLIP model and processor"""
Expand Down Expand Up @@ -53,37 +64,65 @@ def load_images_from_directory(image_dir, valid_extensions=('.jpg', '.jpeg', '.p

@app.on_event("startup")
async def startup_event():
"""Initialize model when service starts"""
"""Initialize model and load images when service starts"""
init_model()
# Add image loading if image_dir is provided
if hasattr(app, "image_dir") and app.image_dir:
global images_dict
try:
images_dict = load_images_from_directory(app.image_dir)
print(f"Successfully loaded {len(images_dict)} images from {app.image_dir}")
except Exception as e:
print(f"Failed to load images: {str(e)}")

@app.get("/", response_class=HTMLResponse, tags=["Root"])
async def read_root(request: Request):
return HTMLResponse(
content=f"<h1>Welcome to Nexa AI SigLIP Image-Text Matching Service</h1><p>Hostname: {hostname}</p>"
)

@app.get("/v1/list_images")
async def list_images():
"""Return current image directory path and loaded images"""
current_dir = getattr(app, "image_dir", None)
return {
"image_dir": current_dir,
"images_count": len(images_dict),
"images": list(images_dict.keys()),
"status": "active" if current_dir and images_dict else "no_images_loaded"
}

@app.post("/v1/load_images")
async def load_images(request: ImagePathRequest):
"""Load images from specified directory"""
"""Load images from specified directory, replacing any previously loaded images"""
global images_dict
try:
images_dict = load_images_from_directory(request.image_dir)
temp_images = load_images_from_directory(request.image_dir)

if not temp_images:
raise ValueError("No valid images found in the specified directory")

images_dict.clear()
images_dict.update(temp_images)
app.image_dir = request.image_dir

return {
"message": f"Successfully loaded {len(images_dict)} images",
"message": f"Successfully loaded {len(images_dict)} images from {request.image_dir}",
"images": list(images_dict.keys())
}
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
current_count = len(images_dict)
error_message = f"Failed to load images: {str(e)}. Keeping existing {current_count} images."
raise HTTPException(status_code=400, detail=error_message)

@app.post("v1/find_similar", response_model=SearchResponse)
@app.post("/v1/find_similar", response_model=SearchResponse)
async def find_similar(text: str):
"""Find image most similar to input text"""
if not images_dict:
raise HTTPException(status_code=400, detail="No images available, please load images first")

try:

start_time = time.time()
image_paths = list(images_dict.keys())
images = list(images_dict.values())

Expand All @@ -99,17 +138,41 @@ async def find_similar(text: str):

return SearchResponse(
image_path=image_paths[max_prob_index],
similarity_score=max_prob
similarity_score=max_prob,
latency = round(time.time() - start_time, 3)
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error processing request: {str(e)}")

@app.get("/v1/list_images")
async def list_images():
"""List all loaded images"""
return {"images": list(images_dict.keys())}


def run_nexa_ai_siglip_service(**kwargs):
host = kwargs.get("host", "localhost")
port = kwargs.get("port", 8100)
reload = kwargs.get("reload", False)
if kwargs.get("image_dir"):
app.image_dir = kwargs.get("image_dir")
uvicorn.run(app, host=host, port=port, reload=reload)

if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)
parser = argparse.ArgumentParser(
description="Run the Nexa AI SigLIP Service"
)
parser.add_argument(
"--image_dir", type=str, help="Directory of images to load"
)
parser.add_argument(
"--host", type=str, default="localhost", help="Host to bind the server to"
)
parser.add_argument(
"--port", type=int, default=8100, help="Port to bind the server to"
)
parser.add_argument(
"--reload", type=bool, default=False, help="Reload the server on code changes"
)
args = parser.parse_args()
run_nexa_ai_siglip_service(
image_dir=args.image_dir,
host=args.host,
port=args.port,
reload=args.reload
)
6 changes: 6 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,12 @@ convert = [
"nexa-gguf",
]

siglip = [
"torch",
"transformers",
"sentencepiece",
]

[project.urls]
Homepage = "https://github.com/NexaAI/nexa-sdk"
Issues = "https://github.com/NexaAI/nexa-sdk/issues"
Expand Down
7 changes: 6 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,9 @@ pytablewriter
sacrebleu
langdetect
rouge_score
immutabledict
immutabledict

# For SigLIP
torch
transformers
sentencepiece

0 comments on commit 2577104

Please sign in to comment.