-
Notifications
You must be signed in to change notification settings - Fork 19
/
server.py
48 lines (39 loc) · 1.47 KB
/
server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
from fastapi import FastAPI
from typing import List, Dict, Any, Optional
from pydantic import BaseModel
import uuid
import time
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from inference import generate
import uvicorn
import argparse
app = FastAPI()
class ChatInput(BaseModel):
messages: List[Dict[str, Any]]
functions: Optional[List[Dict[str, Any]]]
temperature: float = 0.7 # set a default value
@app.post("/v1/chat/completions")
async def chat_endpoint(chat_input: ChatInput):
generated_message = generate(model, tokenizer, chat_input.messages, chat_input.functions, chat_input.temperature)
return {
'id': str(uuid.uuid4()),
'object': 'chat.completion',
'created': int(time.time()),
'model': model_name,
'choices': [
{
'message': generated_message,
'finish_reason': 'stop',
'index': 0
}
]
}
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Functionary API Server")
parser.add_argument('--model', type=str, default='musabgultekin/functionary-7b-v1', help='The model name to be used.')
args = parser.parse_args()
model_name = args.model
model = AutoModelForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=True, torch_dtype=torch.float16).to("cuda:0")
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
uvicorn.run(app, host="0.0.0.0", port=8000)