Skip to content

Commit

Permalink
🎨: improve cache module (#97)
Browse files Browse the repository at this point in the history
* improve cache module

* update version
  • Loading branch information
KenyonY authored Nov 21, 2023
1 parent a6f9a2c commit 3d760a0
Show file tree
Hide file tree
Showing 8 changed files with 139 additions and 112 deletions.
4 changes: 2 additions & 2 deletions .env
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
# `LOG_CHAT`: Whether to log the chat
LOG_CHAT=true

CACHE_CHAT_COMPLETION=true

# `CACHE_BACKEND`: Options (MEMORY, LMDB, LevelDB)
CACHE_BACKEND=LMDB
CACHE_CHAT_COMPLETION=true
DEFAULT_REQUEST_CACHING_VALUE=false

#LOG_CACHE_DB_INFO=true

Expand Down
55 changes: 29 additions & 26 deletions Examples/chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
# debug = True
debug = False

# is_function_call = True
is_function_call = False
is_tool_calls = True
# is_tool_call = False
caching = True

max_tokens = None
Expand All @@ -32,22 +32,24 @@

mt = MeasureTime().start()

# function_call
if is_function_call:
functions = [
if is_tool_calls:
tools = [
{
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
"required": ["location"],
},
"required": ["location"],
},
}
]
Expand All @@ -56,10 +58,10 @@
messages=[
{"role": "user", "content": "What's the weather like in Boston today?"}
],
functions=functions,
function_call="auto", # auto is default, but we'll be explicit
tools=tools,
tool_choice="auto", # auto is default, but we'll be explicit
stream=stream,
timeout=30,
extra_body={"caching": caching},
)

else:
Expand All @@ -84,20 +86,21 @@
for idx, chunk in enumerate(resp):
chunk_message = chunk.choices[0].delta or ""
if idx == 0:
if is_function_call:
function_call = chunk_message.function_call or ""
name = function_call.name
if is_tool_calls:
function = chunk_message.tool_calls[0].function
name = function.name
print(f"{chunk_message.role}: \n{name}: ")
else:
print(f"{chunk_message.role}: ")
continue

content = ""
if is_function_call:
function_call = chunk_message.function_call or ""
if function_call:
content = function_call.arguments or ""

if is_tool_calls:
tool_calls = chunk_message.tool_calls
if tool_calls:
function = tool_calls[0].function
if function:
content = function.arguments or ""
else:
content = chunk_message.content or ""
print(content, end="")
Expand Down
8 changes: 4 additions & 4 deletions Examples/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
base_url=config['api_base'],
)

caching = True
# caching = False
# extra_body={"caching": True}
extra_body = {}
stream = True

json_obj_case = True
Expand All @@ -30,7 +30,7 @@
{"role": "user", "content": "Who won the world series in 2020?"},
],
stream=stream,
extra_body={"caching": caching},
extra_body=extra_body,
)
if stream:
for chunk in response:
Expand Down Expand Up @@ -66,7 +66,7 @@
tools=tools,
tool_choice="auto",
stream=stream,
extra_body={"caching": caching},
extra_body=extra_body,
)

if stream:
Expand Down
2 changes: 1 addition & 1 deletion openai_forward/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.6.6"
__version__ = "0.6.7"

from dotenv import load_dotenv

Expand Down
171 changes: 94 additions & 77 deletions openai_forward/cache/chat_completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,71 +116,92 @@ class ChatCompletionsResponse:
sentences = cycle(corpus)


# todo: refactor this
# @async_token_rate_limit(token_interval_conf)
# async def stream_generate(
# model: str, texts, function_call_name: str | None, request: Request
# ):
# created = int(time.time())
# id = f"chatcmpl-{get_unique_id()}"
#
# delta = DeltaMessage()
#
# choice_data = ChatCompletionResponseStreamChoice(
# index=0, delta=delta, finish_reason=None
# )
# chunk = ChatCompletionsResponse(
# id=id,
# model=model,
# choices=[choice_data],
# object="chat.completion.chunk",
# created=created,
# )
#
# def serialize_delta(
# role=None, content=None, tool_calls=None, finish_reason=None
# ):
# if role:
# delta.role = role
# if content:
# delta.content = content
# if tool_calls:
# delta.tool_calls = tool_calls
#
# choice_data.finish_reason = finish_reason
# choice_data.delta = delta
#
# chunk.choices = [choice_data]
#
# return (
# b'data: '
# + orjson.dumps(
# attrs.asdict(chunk, filter=attrs.filters.exclude(type(None)))
# )
# + b'\n\n'
# )
#
# if function_call_name:
# yield serialize_delta(
# role="assistant",
# tool_calls={"name": function_call_name, "arguments": ""},
# )
# else:
# yield serialize_delta(role="assistant", content="")
#
# delta = DeltaMessage()
# for text in texts:
# if function_call_name:
# yield serialize_delta(tool_calls={"arguments": text})
# else:
# yield serialize_delta(content=text)
#
# delta = DeltaMessage()
# yield serialize_delta(
# finish_reason="function_call" if function_call_name else "stop"
# )
#
# yield b'data: [DONE]\n\n'
@async_token_rate_limit(token_interval_conf)
async def stream_generate(
model: str, content: str | None, tool_calls: list | None, request: Request
):
created = int(time.time())
id = f"chatcmpl-{get_unique_id()}"

if tool_calls:
function_name = tool_calls[0]['function']['name']
function_arguments = tool_calls[0]['function']['arguments']
texts = encode_as_pieces(function_arguments)
else:
texts = encode_as_pieces(content)

delta = DeltaMessage()

choice_data = ChatCompletionResponseStreamChoice(
index=0, delta=delta, finish_reason=None
)
chunk = ChatCompletionsResponse(
id=id,
model=model,
choices=[choice_data],
object="chat.completion.chunk",
created=created,
system_fingerprint="fp_0123456789",
)

def serialize_delta(
role=None, content=None, delta_tool_calls=None, finish_reason=None
):
if role:
delta.role = role
if content:
delta.content = content
if tool_calls:
delta.tool_calls = delta_tool_calls

choice_data.finish_reason = finish_reason
choice_data.delta = delta

chunk.choices = [choice_data]

return (
b'data: '
+ orjson.dumps(
attrs.asdict(chunk, filter=attrs.filters.exclude(type(None)))
)
+ b'\n\n'
)

if tool_calls:
yield serialize_delta(
role="assistant",
delta_tool_calls=[
{
'index': 0,
'id': f"call_{get_unique_id()}",
'function': {"name": function_name, "arguments": ""},
'type': 'function',
}
],
)
else:
yield serialize_delta(role="assistant", content="")

delta = DeltaMessage()
for content in texts:
if tool_calls:
yield serialize_delta(
delta_tool_calls=[
{
'index': 0,
'id': None,
'function': {"name": None, "arguments": content},
'type': None,
}
],
)
else:
yield serialize_delta(content=content)

delta = DeltaMessage()
yield serialize_delta(finish_reason="tool_calls" if tool_calls else "stop")

yield b'data: [DONE]\n\n'


@async_token_rate_limit(token_interval_conf)
Expand All @@ -192,7 +213,7 @@ async def stream_generate_efficient(
model (str): The model to use.
content (str): content.
tool_calls (list | None): tool_calls list.
request (Request): A FastAPI request object.
request (Request): A FastAPI request object. For rate limit.
"""
created = int(time.time())
id = f"chatcmpl-{get_unique_id()}"
Expand Down Expand Up @@ -298,23 +319,19 @@ def generate(model: str, content: str | None, tool_calls: list | None, usage: di

@attrs.define(slots=True)
class ModelInferResult:
texts: List[str]
content: str
usage: dict


def model_inference(model: str, messages: List, stream: bool):
def model_inference(model: str, messages: List):
sentence = next(sentences)

if TIKTOKEN_VALID:
usage = count_tokens(messages, sentence, 'gpt-3.5-turbo')
else:
usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": -1}
usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}

if stream:
texts = encode_as_pieces(sentence)
else:
texts = [sentence]
return ModelInferResult(texts=texts, usage=usage)
return ModelInferResult(content=sentence, usage=usage)


@async_random_sleep(min_time=0, max_time=1)
Expand All @@ -324,15 +341,15 @@ async def chat_completions_benchmark(request: Request):
stream = payload.get("stream", False)
messages = payload.get("messages", [])

model_result = model_inference(model, messages, stream)
model_result = model_inference(model, messages)

if stream:
return StreamingResponse(
stream_generate_efficient(model, model_result.texts, None, request),
stream_generate_efficient(model, model_result.content, None, request),
media_type="text/event-stream",
)
else:
return Response(
content=generate(model, model_result.texts[0], None, model_result.usage),
content=generate(model, model_result.content, None, model_result.usage),
media_type="application/json",
)
5 changes: 4 additions & 1 deletion openai_forward/content/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from orjson import JSONDecodeError

from ..helper import get_client_ip, get_unique_id, route_prefix_to_str
from ..settings import DEFAULT_REQUEST_CACHING_VALUE
from .helper import markdown_print, parse_sse_buffer, print


Expand Down Expand Up @@ -150,7 +151,9 @@ async def parse_payload(request: Request):
"tool_choice": payload.get("tool_choice", None),
"ip": get_client_ip(request) or "",
"uid": uid,
"caching": payload.pop("caching", False), # pop caching
"caching": payload.pop(
"caching", DEFAULT_REQUEST_CACHING_VALUE
), # pop caching
"datetime": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),
}
)
Expand Down
1 change: 0 additions & 1 deletion openai_forward/forward/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,6 @@ async def aiter_bytes(
target_info = self._handle_result(
chunk, uid, route_path, request.method
)
print(f"{target_info=}")
if target_info and CACHE_CHAT_COMPLETION and cache_key is not None:
cached_value = db_dict.get(cache_key, [])
cached_value.append(target_info["assistant"])
Expand Down
5 changes: 5 additions & 0 deletions openai_forward/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,13 @@
os.environ.get("LOG_CACHE_DB_INFO", "false").strip().lower() == "true"
)
CACHE_BACKEND = os.environ.get("CACHE_BACKEND", "MEMORY").strip()
DEFAULT_REQUEST_CACHING_VALUE = False
if CACHE_CHAT_COMPLETION:
additional_start_info["cache_backend"] = CACHE_BACKEND
DEFAULT_REQUEST_CACHING_VALUE = (
os.environ.get("DEFAULT_REQUEST_CACHING_VALUE", "false").strip().lower()
== "true"
)

IP_WHITELIST = env2list("IP_WHITELIST", sep=ENV_VAR_SEP)
IP_BLACKLIST = env2list("IP_BLACKLIST", sep=ENV_VAR_SEP)
Expand Down

0 comments on commit 3d760a0

Please sign in to comment.