Skip to content

Commit

Permalink
feat: Caching chat completions result (#83)
Browse files Browse the repository at this point in the history
  • Loading branch information
KenyonY authored Oct 18, 2023
1 parent e7da8de commit 3402fe3
Show file tree
Hide file tree
Showing 20 changed files with 926 additions and 229 deletions.
18 changes: 10 additions & 8 deletions .env
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
# 示例与解释见 .env.example

# `LOG_CHAT`: 是否记录日志
#LOG_CHAT=false
#LOG_CHAT=true

#BENCHMARK_MODE=true
CACHE_CHAT_COMPLETION=true

# `CACHE_BACKEND`: MEMORY, LMDB, ROCKSDB, LevelDB
CACHE_BACKEND=MEMORY

#PRINT_CHAT=true
#BENCHMARK_MODE=true

# `OPENAI_BASE_URL`: 转发openai风格的任何服务地址,允许指定多个, 以逗号隔开。
# 如果指定超过一个,则任何OPENAI_ROUTE_PREFIX/EXTRA_ROUTE_PREFIX都不能为根路由/
OPENAI_BASE_URL=https://api.openai-forward.com
#OPENAI_BASE_URL=https://api.openai.com
#OPENAI_BASE_URL=https://api.openai-forward.com
OPENAI_BASE_URL=https://api.openai.com

# `OPENAI_ROUTE_PREFIX`: 可指定所有openai风格(为记录日志)服务的转发路由前缀
OPENAI_ROUTE_PREFIX=
Expand All @@ -29,8 +32,7 @@ EXTRA_ROUTE_PREFIX=
# `REQ_RATE_LIMIT`: i.e. 对指定路由的请求速率限制, 区分用户
# format: {route: ratelimit-string}
# ratelimit-string format [count] [per|/] [n (optional)] [second|minute|hour|day|month|year] :ref:`ratelimit-string`: https://limits.readthedocs.io/en/stable/quickstart.html#rate-limit-string-notation
#REQ_RATE_LIMIT={"/v1/chat/completions":"60/minute;600/hour", "/v1/completions":"60/minute;600/hour"}
REQ_RATE_LIMIT={"/benchmark/v1/chat/completions":"10/10second;100/2minutes"}
REQ_RATE_LIMIT={"/v1/chat/completions":"100/2minutes", "/v1/completions":"60/minute;600/hour"}

# rate limit后端: [memory, redis, memcached, ...] :ref: https://limits.readthedocs.io/en/stable/storage.html#
#REQ_RATE_LIMIT_BACKEND=redis://localhost:6379
Expand All @@ -43,7 +45,7 @@ GLOBAL_RATE_LIMIT=100/minute
RATE_LIMIT_STRATEGY=moving-window

# 返回的token速率限制
TOKEN_RATE_LIMIT={"/v1/chat/completions":"50/second","/v1/completions":"60/second"}
TOKEN_RATE_LIMIT={"/v1/chat/completions":"60/second","/v1/completions":"60/second"}


# TCP连接的超时时间(秒)
Expand Down
4 changes: 4 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@ LOG_CHAT=true

PRINT_CHAT=true

CACHE_CHAT_COMPLETION=true
# `CACHE_BACKEND`: MEMORY, LMDB, ROCKSDB
CACHE_BACKEND=MEMORY

BENCHMARK_MODE=true

# OPENAI_BASE_URL: 转发openai风格的任何服务地址,允许指定多个, 以逗号隔开。
Expand Down
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ chat_*.yaml
Log/
Log-caloi-top/
dist/
CACHE_DB/
CACHE_LMDB/
CACHE_ROCKSDB/
CACHE_LEVELDB/

.run/

Expand Down
2 changes: 1 addition & 1 deletion Examples/.gitignore
Original file line number Diff line number Diff line change
@@ -1 +1 @@
config.yaml
config.yaml
1 change: 1 addition & 0 deletions Examples/benchmark/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
db_performance_compare.py
12 changes: 11 additions & 1 deletion Examples/chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,22 @@
stream = True
# stream = False

n = 1

# debug = True
debug = False

# is_function_call = True
is_function_call = False
caching = True

max_tokens = None

user_content = """
用c实现目前已知最快平方根算法
"""
# user_content = "ni shi shei"
user_content = "ni hao"

mt = MeasureTime().start()

Expand Down Expand Up @@ -61,7 +68,10 @@
{"role": "user", "content": user_content},
],
stream=stream,
n=n,
max_tokens=max_tokens,
request_timeout=30,
caching=caching,
)

if stream:
Expand All @@ -71,7 +81,7 @@
else:
chunk_message = next(resp)['choices'][0]['delta']
if is_function_call:
function_call = chunk_message.get("function_call", "")
function_call = chunk_message["function_call"]
name = function_call["name"]
print(f"{chunk_message['role']}: \n{name}: ")
else:
Expand Down
1 change: 0 additions & 1 deletion Examples/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@


user_content = "现在让我们使用泰勒展开推导出牛顿法迭代公式: \n"
from sparrow import MeasureTime

resp = openai.Completion.create(
model="gpt-3.5-turbo-instruct",
Expand Down
134 changes: 78 additions & 56 deletions README.md

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions deploy.md
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,9 @@ Render应该算是所有部署中最简易的一种, 并且它生成的域名国
然后等待部署完成即可。
Render的免费计划: 每月750小时免费实例时间(意味着单个实例可以不间断运行)、100G带宽流量、500分钟构建时长.

注:默认render在15分钟内没有服务请求时会自动休眠(好处是休眠后不会占用750h的免费实例时间),休眠后下一次请求会被阻塞 5~10s
如果希望服务15分钟不自动休眠,可以使用定时脚本(如每14分钟)去请求服务进行保活。保活脚本参考`scripts/keep_render_alive.py`.
如果希望零停机部署可以在设置中设置`Health Check Path``/healthz`
注:默认render在15分钟内没有服务请求时会自动休眠(好处是休眠后不会占用750h的免费实例时间),休眠后下一次请求会被阻塞 ~15s
如果希望服务15分钟不自动休眠,可以使用定时脚本(如每14分钟)对render服务进行保活。保活脚本参考`scripts/keep_render_alive.py`.
如果希望零停机部署可以在render设置中配置 `Health Check Path``/healthz`

> https://render.openai-forward.com
> https://openai-forward.onrender.com
Expand Down
2 changes: 1 addition & 1 deletion openai_forward/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.6.1"
__version__ = "0.6.2"

from dotenv import load_dotenv

Expand Down
3 changes: 3 additions & 0 deletions openai_forward/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from slowapi.errors import RateLimitExceeded

from . import __version__, custom_slowapi
from .cache.database import db_dict
from .forward.extra import generic_objs
from .forward.openai import openai_objs
from .helper import normalize_route as normalize_route_path
Expand Down Expand Up @@ -65,6 +66,8 @@ def healthz(request: Request):

@app.on_event("shutdown")
async def shutdown():
if hasattr(db_dict, "close"):
db_dict.close()
for obj in openai_objs:
await obj.client.close()
for obj in generic_objs:
Expand Down
83 changes: 64 additions & 19 deletions openai_forward/cache/chat_completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@
@attrs.define(slots=True)
class ChatMessage:
role: Literal["user", "assistant", "system"]
content: str
content: Optional[str] = None
function_call: Optional[dict] = None


@attrs.define(slots=True)
class DeltaMessage:
role: Optional[Literal["user", "assistant", "system"]] = None
content: Optional[str] = None
function_call: Optional[dict] = None


@attrs.define(slots=True)
Expand All @@ -39,7 +41,7 @@ class ChatCompletionRequest:
class ChatCompletionResponseChoice:
index: int
message: ChatMessage
finish_reason: Literal["stop", "length"]
finish_reason: str


@attrs.define(slots=True)
Expand Down Expand Up @@ -83,7 +85,9 @@ class ChatCompletionsResponse:


@async_token_rate_limit(token_interval_conf)
async def stream_generate(model: str, texts, request: Request):
async def stream_generate(
model: str, texts, function_call_name: str | None, request: Request
):
created = int(time.time())
id = f"chatcmpl-{get_unique_id()}"

Expand All @@ -100,11 +104,15 @@ async def stream_generate(model: str, texts, request: Request):
created=created,
)

def serialize_delta(role=None, content=None, finish_reason=None):
def serialize_delta(
role=None, content=None, function_call=None, finish_reason=None
):
if role:
delta.role = role
if content:
delta.content = content
if function_call:
delta.function_call = function_call

choice_data.finish_reason = finish_reason
choice_data.delta = delta
Expand All @@ -119,21 +127,40 @@ def serialize_delta(role=None, content=None, finish_reason=None):
+ b'\n\n'
)

yield serialize_delta(role="assistant", content="")
if function_call_name:
yield serialize_delta(
role="assistant",
function_call={"name": function_call_name, "arguments": ""},
)
else:
yield serialize_delta(role="assistant", content="")

delta = DeltaMessage()
for text in texts:
yield serialize_delta(content=text)
if function_call_name:
yield serialize_delta(function_call={"arguments": text})
else:
yield serialize_delta(content=text)

delta = DeltaMessage()
yield serialize_delta(finish_reason="stop")
yield serialize_delta(
finish_reason="function_call" if function_call_name else "stop"
)

yield b'data: [DONE]\n\n'


@async_token_rate_limit(token_interval_conf)
async def stream_generate_efficient(model: str, texts, request: Request):
"""More efficient version of stream_generate"""
async def stream_generate_efficient(
model: str, texts, function_call_name: str | None, request: Request
):
"""More efficient (use dict) version of stream_generate
Args:
model (str): The model to use.
texts (List[str]): content or function_call['arguments'].
function_call_name (str | None): function_call['name'].
request (Request): A FastAPI request object.
"""
created = int(time.time())
id = f"chatcmpl-{get_unique_id()}"

Expand All @@ -147,11 +174,16 @@ async def stream_generate_efficient(model: str, texts, request: Request):
"created": created,
}

def serialize_delta(role=None, content=None, finish_reason=None):
def serialize_delta(
role=None, content=None, function_call=None, finish_reason=None
):
if role:
delta['role'] = role
if content:
delta['content'] = content
if function_call:
delta['function_call'] = function_call
delta['content'] = content

choice_data['finish_reason'] = finish_reason
choice_data['delta'] = delta
Expand All @@ -160,26 +192,39 @@ def serialize_delta(role=None, content=None, finish_reason=None):

return b'data: ' + orjson.dumps(chunk) + b'\n\n'

yield serialize_delta(role="assistant", content="")
if function_call_name:
yield serialize_delta(
role="assistant",
function_call={"name": function_call_name, "arguments": ""},
)
else:
yield serialize_delta(role="assistant", content="")

delta = {}
for text in texts:
yield serialize_delta(content=text)
if function_call_name:
yield serialize_delta(function_call={"arguments": text})
else:
yield serialize_delta(content=text)

delta = {}
yield serialize_delta(finish_reason="stop")
yield serialize_delta(
finish_reason="function_call" if function_call_name else "stop"
)

yield b'data: [DONE]\n\n'


def generate(model: str, sentence, usage):
def generate(model: str, sentence: str | None, function_call: dict | None, usage: dict):
created = int(time.time())
id = f"chatcmpl-{get_unique_id()}"

choice_data = ChatCompletionResponseChoice(
index=0,
message=ChatMessage(role="assistant", content=sentence),
finish_reason="stop",
message=ChatMessage(
role="assistant", content=sentence, function_call=function_call
),
finish_reason="function_call" if function_call else "stop",
)

data = ChatCompletionsResponse(
Expand All @@ -192,7 +237,7 @@ def generate(model: str, sentence, usage):
)

return orjson.dumps(
attrs.asdict(data, filter=attrs.filters.exclude(type(None))),
attrs.asdict(data),
option=orjson.OPT_APPEND_NEWLINE, # not necessary
)

Expand Down Expand Up @@ -229,12 +274,12 @@ async def chat_completions_benchmark(request: Request):

if stream:
return StreamingResponse(
stream_generate_efficient(model, model_result.texts, request),
stream_generate_efficient(model, model_result.texts, None, request),
# stream_generate(model, texts, request),
media_type="text/event-stream",
)
else:
return Response(
content=generate(model, model_result.texts[0], model_result.usage),
content=generate(model, model_result.texts[0], None, model_result.usage),
media_type="application/json",
)
Loading

0 comments on commit 3402fe3

Please sign in to comment.