From 8bd6f27fa6e555da254c045a8c3666138dfd9dea Mon Sep 17 00:00:00 2001 From: "K.Y" Date: Fri, 12 Apr 2024 00:34:01 +0800 Subject: [PATCH] feat: Distinguish between streaming and non streaming internally (#121) * feat: Distinguish between streaming and non streaming internally * clean up --- openai_forward/forward/core.py | 66 +++++++++++++++++++++------------- 1 file changed, 42 insertions(+), 24 deletions(-) diff --git a/openai_forward/forward/core.py b/openai_forward/forward/core.py index ebf20cb..a39b748 100644 --- a/openai_forward/forward/core.py +++ b/openai_forward/forward/core.py @@ -495,6 +495,7 @@ async def aiter_bytes( route_path: str, uid: str, cache_key: str | None = None, + stream: bool | None = None, ): """ Asynchronously iterates through the bytes in the given aiohttp.ClientResponse object @@ -506,38 +507,54 @@ async def aiter_bytes( route_path (str): The API route path. uid (str): Unique identifier for the request. cache_key (bytes): The cache key. + stream (bool): Whether the response is a stream. Returns: AsyncGenerator[bytes]: Each chunk of bytes from the server's response. """ - queue_is_complete = False - - queue = Queue() - # todo: - task = asyncio.create_task(self.read_chunks(r, queue)) + yield_completed = False chunk_list = [] - try: - while True: - chunk = await queue.get() - if not isinstance(chunk, bytes): - queue.task_done() - queue_is_complete = True - break - if CACHE_OPENAI: - chunk_list.append(chunk) + chunk = None + if stream: + queue = Queue() + # todo: + task = asyncio.create_task(self.read_chunks(r, queue)) + try: + while True: + chunk = await queue.get() + if not isinstance(chunk, bytes): + queue.task_done() + yield_completed = True + break + if CACHE_OPENAI: + chunk_list.append(chunk) + yield chunk + except Exception as e: + logger.warning( + f"aiter_bytes error:{e}\nhost:{request.client.host} method:{request.method}: " + f"{traceback.format_exc()}" + ) + finally: + if not task.done(): + task.cancel() + else: + try: + chunk = await r.read() yield chunk - except Exception: - logger.warning( - f"aiter_bytes error:\nhost:{request.client.host} method:{request.method}: {traceback.format_exc()}" - ) - finally: - if not task.done(): - task.cancel() - r.release() + chunk_list.append(chunk) + chunk = bytearray(chunk) + yield_completed = True + except Exception as e: + logger.warning( + f"aiter_bytes error:{e}\nhost:{request.client.host} method:{request.method}: " + f"{traceback.format_exc()}" + ) + + r.release() if uid: - if r.ok and queue_is_complete: + if r.ok and yield_completed: target_info = self._handle_result( chunk, uid, route_path, request.method ) @@ -595,6 +612,7 @@ async def reverse_proxy(self, request: Request): request, route_path, model_set ) uid = payload_info["uid"] + stream = payload_info.get('stream', None) cached_response, cache_key = get_cached_response( payload, @@ -610,7 +628,7 @@ async def reverse_proxy(self, request: Request): r = await self.send(client_config, data=payload) return StreamingResponse( - self.aiter_bytes(r, request, route_path, uid, cache_key), + self.aiter_bytes(r, request, route_path, uid, cache_key, stream), status_code=r.status, media_type=r.headers.get("content-type"), )