-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlocustfile.py
92 lines (75 loc) · 2.31 KB
/
locustfile.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
"""
GLiClass Locust load test script.
"""
import random
import time
from textwrap import dedent
from datasets import load_dataset
from locust import HttpUser, between, task
from loguru import logger
PORT = 8080
HOST = f"http://fastapi:{PORT}"
# https://huggingface.co/datasets/gsm8k
DATASET_ID = "gsm8k"
def load_data() -> dict:
logger.info("Loading dataset...")
DATA = load_dataset(DATASET_ID, "main")
logger.info("Dataset loaded.")
return DATA
def get_random_sample() -> str:
idx = random.randint(0, len(DATA["test"]))
sample = DATA["test"][idx]
return dedent(
f"""
{sample['question']}
{sample['answer']}
"""
)
DATA = load_data()
class GLiClassLoadTest(HttpUser):
"""Locust load test for GLiClass API."""
host = HOST
wait_time = between(1, 5)
def get_classification(self) -> None:
"""Send a chat completion request to the GLiClass API."""
endpoint = "/predict"
headers = {"Content-Type": "application/json"}
request_count = 0
payload = {
"inputs": [get_random_sample()],
"labels": [
"Positive",
"Negative",
"Neutral",
],
"classification_type": "single-label",
}
resp = self.client.post(endpoint, headers=headers, json=payload)
if resp.status_code not in (200, 202):
logger.error(f"Request failed: {resp.text}")
return
request_count += 1
resp = resp.json()
task_id = resp.get("task_id", None)
if not task_id:
return
time.sleep(2)
endpoint = f"/result/{task_id}"
resp = self.client.get(endpoint, headers=headers)
request_count += 1
resp = resp.json()
status = resp.get("status", None)
while status == "Processing":
time.sleep(5)
resp = self.client.get(endpoint, headers=headers)
request_count += 1
resp = resp.json()
status = resp.get("status", None)
if status == "Success":
logger.info("Task completed.")
break
logger.info(f"Request count: {request_count}")
@task
def execute_task(self) -> None:
"""Execute tasks."""
self.get_classification()