Skip to content

Commit

Permalink
Update load test, use threads instead of asyncio (#3628)
Browse files Browse the repository at this point in the history
  • Loading branch information
BabyChouSr authored Nov 20, 2024
1 parent 691be9d commit 1cd4b74
Showing 1 changed file with 47 additions and 37 deletions.
84 changes: 47 additions & 37 deletions tests/load_test.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,15 @@
import argparse
import time, asyncio
from openai import AsyncOpenAI, AsyncAzureOpenAI
import time
import threading
from concurrent.futures import ThreadPoolExecutor
import uuid
import traceback
import numpy as np
from transformers import AutoTokenizer
from litellm import completion

# base_url - litellm proxy endpoint
# api_key - litellm proxy api-key, is created proxy with auth
litellm_client = None


async def litellm_completion(args, tokenizer, image_url=None):
# Your existing code for litellm_completion goes here
def litellm_completion(args, tokenizer, image_url=None):
try:
if image_url:
messages = [
Expand All @@ -30,16 +27,24 @@ async def litellm_completion(args, tokenizer, image_url=None):
]

start = time.time()
response = await litellm_client.chat.completions.create(

additional_api_kwargs = {}
if args.api_key:
additional_api_kwargs["api_key"] = args.api_key
if args.api_base:
additional_api_kwargs["api_base"] = args.api_base

response = completion(
model=args.model,
messages=messages,
stream=True,
**additional_api_kwargs,
)
ttft = None

itl_list = []
content = ""
async for chunk in response:
for chunk in response:
if chunk.choices[0].delta.content:
end_time = time.time()
if ttft is None:
Expand All @@ -52,43 +57,48 @@ async def litellm_completion(args, tokenizer, image_url=None):
return content, ttft, itl_list

except Exception as e:
# If there's an exception, log the error message
print(e)
with open("error_log.txt", "a") as error_log:
error_log.write(f"Error during completion: {str(e)}\n")
return str(e)


async def main(args):
def main(args):
n = args.num_total_responses
batch_size = args.req_per_sec # Requests per second
start = time.time()

all_tasks = []
all_results = []
tokenizer = AutoTokenizer.from_pretrained("gpt2")
for i in range(0, n, batch_size):
batch = range(i, min(i + batch_size, n))
for _ in batch:
if args.include_image:
# Generate a random dimension for the image
if args.randomize_image_dimensions:
y_dimension = np.random.randint(100, 1025)

with ThreadPoolExecutor(max_workers=batch_size) as executor:
for i in range(0, n, batch_size):
batch_futures = []
batch = range(i, min(i + batch_size, n))

for _ in batch:
if args.include_image:
if args.randomize_image_dimensions:
y_dimension = np.random.randint(100, 1025)
else:
y_dimension = 512
image_url = f"https://placehold.co/1024x{y_dimension}/png"
future = executor.submit(
litellm_completion, args, tokenizer, image_url
)
else:
y_dimension = 512
image_url = f"https://placehold.co/1024x{y_dimension}/png"
task = asyncio.create_task(
litellm_completion(args, tokenizer, image_url)
)
else:
task = asyncio.create_task(litellm_completion(args, tokenizer))
all_tasks.append(task)
if i + batch_size < n:
await asyncio.sleep(1) # Wait 1 second before the next batch

all_completions = await asyncio.gather(*all_tasks)
future = executor.submit(litellm_completion, args, tokenizer)
batch_futures.append(future)

# Wait for batch to complete
for future in batch_futures:
all_results.append(future.result())

if i + batch_size < n:
time.sleep(1) # Wait 1 second before next batch

successful_completions = [
c for c in all_completions if isinstance(c, tuple) and len(c) == 3
c for c in all_results if isinstance(c, tuple) and len(c) == 3
]
ttft_list = np.array([float(c[1]) for c in successful_completions])
itl_list_flattened = np.array(
Expand All @@ -101,7 +111,7 @@ async def main(args):

# Write errors to error_log.txt
with open("load_test_errors.log", "a") as error_log:
for completion in all_completions:
for completion in all_results:
if isinstance(completion, str):
error_log.write(completion + "\n")

Expand All @@ -115,15 +125,15 @@ async def main(args):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, default="azure-gpt-3.5")
parser.add_argument("--server-address", type=str, default="http://0.0.0.0:9094")
parser.add_argument("--api-base", type=str, default=None)
parser.add_argument("--api-key", type=str, default=None)
parser.add_argument("--num-total-responses", type=int, default=50)
parser.add_argument("--req-per-sec", type=int, default=5)
parser.add_argument("--include-image", action="store_true")
parser.add_argument("--randomize-image-dimensions", action="store_true")
args = parser.parse_args()

litellm_client = AsyncOpenAI(base_url=args.server_address, api_key="sk-1234")
# Blank out contents of error_log.txt
open("load_test_errors.log", "w").close()

asyncio.run(main(args))
main(args)

0 comments on commit 1cd4b74

Please sign in to comment.