Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean Upvote and Downvote data #3611

Merged
merged 10 commits into from
Nov 13, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
199 changes: 123 additions & 76 deletions fastchat/serve/monitor/clean_chat_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
python3 clean_chat_data.py
"""
import argparse
import datetime
import json
import os
from pytz import timezone
from functools import partial
from datetime import datetime, timedelta
import time
import multiprocessing as mp

from tqdm import tqdm

Expand All @@ -24,14 +26,24 @@
NETWORK_ERROR_MSG = (
"NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.".lower()
)
MANAGER = mp.Manager()
LOCK = MANAGER.Lock()


def get_log_files(max_num_files=None):
dates = []
for month in range(4, 12):
for day in range(1, 33):
dates.append(f"2023-{month:02d}-{day:02d}")
def date_range(start="2023-04-01"):
start_date = datetime.strptime(start, "%Y-%m-%d").date()
end_date = datetime.now().date()
delta = end_date - start_date
dates = [
(start_date + timedelta(days=d)).strftime("%Y-%m-%d")
for d in range(delta.days + 2)
]

return dates


def get_log_files(max_num_files=None):
dates = date_range()
filenames = []
for d in dates:
for i in range(NUM_SERVERS):
Expand All @@ -44,90 +56,125 @@ def get_log_files(max_num_files=None):
return filenames


def clean_chat_data(log_files, action_type):
raw_data = []
for filename in tqdm(log_files, desc="read files"):
for retry in range(5):
try:
lines = open(filename).readlines()
break
except FileNotFoundError:
time.sleep(2)

for l in lines:
row = json.loads(l)
if row["type"] == action_type:
raw_data.append(row)
def get_action_type_data(filename, action_type):
for _ in range(5):
try:
lines = open(filename).readlines()
break
except FileNotFoundError:
time.sleep(2)

all_models = set()
all_ips = dict()
chats = []
for l in lines:
row = json.loads(l)
if row["type"] == action_type:
return row


def process_data(row, action_type, all_ips):
# Initialize local counters
ct_invalid_conv_id = 0
ct_invalid = 0
ct_network_error = 0
for row in raw_data:
try:
if action_type in ["chat", "upvote", "downvote"]:
state = row["state"]
model = row["model"]
elif action_type == "leftvote":
state = row["states"][0]
model = row["states"][0]["model_name"]
elif action_type == "rightvote":
state = row["states"][1]
model = row["states"][1]["model_name"]
conversation_id = state["conv_id"]
except KeyError:
ct_invalid_conv_id += 1
continue

if conversation_id is None:
ct_invalid_conv_id += 1
continue
try:
if action_type in ["chat", "upvote", "downvote"]:
state = row["state"]
model = row["model"]
elif action_type == "leftvote":
state = row["states"][0]
model = row["states"][0]["model_name"]
elif action_type == "rightvote":
state = row["states"][1]
model = row["states"][1]["model_name"]
conversation_id = state["conv_id"]
except KeyError:
ct_invalid_conv_id += 1
return None, ct_invalid_conv_id, ct_invalid, ct_network_error, None
CodingWithTim marked this conversation as resolved.
Show resolved Hide resolved

if conversation_id is None:
ct_invalid_conv_id += 1
return None, ct_invalid_conv_id, ct_invalid, ct_network_error, None

conversation = to_openai_format(state["messages"][state["offset"] :])
if not isinstance(model, str):
ct_invalid += 1
return None, ct_invalid_conv_id, ct_invalid, ct_network_error, None
model = replace_model_name(model, row["tstamp"])

try:
lang_code = detect_language(state["messages"][state["offset"]][1])
except IndexError:
ct_invalid += 1
return None, ct_invalid_conv_id, ct_invalid, ct_network_error, None

if not all(isinstance(x["content"], str) for x in conversation):
ct_invalid += 1
return None, ct_invalid_conv_id, ct_invalid, ct_network_error, None

messages = "".join([x["content"] for x in conversation]).lower()
if NETWORK_ERROR_MSG in messages:
ct_network_error += 1
return None, ct_invalid_conv_id, ct_invalid, ct_network_error, None

ip = row["ip"]
# Synchronize access to all_ips using the lock
with LOCK:
CodingWithTim marked this conversation as resolved.
Show resolved Hide resolved
if ip not in all_ips:
all_ips[ip] = len(all_ips)
user_id = all_ips[ip]

conversation = to_openai_format(state["messages"][state["offset"] :])
if not isinstance(model, str):
ct_invalid += 1
continue
model = replace_model_name(model, row["tstamp"])
# Prepare the result data
result = dict(
conversation_id=conversation_id,
model=model,
conversation=conversation,
turn=len(conversation) // 2,
language=lang_code,
user_id=user_id,
tstamp=row["tstamp"],
)

try:
lang_code = detect_language(state["messages"][state["offset"]][1])
except IndexError:
ct_invalid += 1
continue
return result, ct_invalid_conv_id, ct_invalid, ct_network_error, model

if not all(isinstance(x["content"], str) for x in conversation):
ct_invalid += 1
continue

messages = "".join([x["content"] for x in conversation]).lower()
if NETWORK_ERROR_MSG in messages:
ct_network_error += 1
continue
def clean_chat_data(log_files, action_type):
with mp.Pool() as pool:
# Use partial to pass action_type to get_action_type_data
func = partial(get_action_type_data, action_type=action_type)
raw_data = pool.map(func, log_files, chunksize=1)

ip = row["ip"]
if ip not in all_ips:
all_ips[ip] = len(all_ips)
user_id = all_ips[ip]
# filter out Nones as some files may not contain any data belong to action_type
raw_data = [r for r in raw_data if r is not None]
all_ips = MANAGER.dict()

# Use the multiprocessing Pool
with mp.Pool() as pool:
func = partial(process_data, action_type=action_type, all_ips=all_ips)
results = pool.map(func, raw_data, chunksize=1)
CodingWithTim marked this conversation as resolved.
Show resolved Hide resolved

chats.append(
dict(
conversation_id=conversation_id,
model=model,
conversation=conversation,
turn=len(conversation) // 2,
language=lang_code,
user_id=user_id,
tstamp=row["tstamp"],
)
)
# Initialize counters and collections in the parent process
ct_invalid_conv_id = 0
ct_invalid = 0
ct_network_error = 0
all_models = set()
chats = []

all_models.update([model])
# Aggregate results from child processes
for res in results:
if res is None:
continue
data, inv_conv_id, inv, net_err, model = res
ct_invalid_conv_id += inv_conv_id
ct_invalid += inv
ct_network_error += net_err
if data:
chats.append(data)
if model:
all_models.add(model)

chats.sort(key=lambda x: x["tstamp"])
last_updated_tstamp = chats[-1]["tstamp"]
last_updated_datetime = datetime.datetime.fromtimestamp(
last_updated_datetime = datetime.fromtimestamp(
last_updated_tstamp, tz=timezone("US/Pacific")
).strftime("%Y-%m-%d %H:%M:%S %Z")

Expand Down Expand Up @@ -161,7 +208,7 @@ def clean_chat_data(log_files, action_type):
log_files = get_log_files(args.max_num_files)
chats = clean_chat_data(log_files, args.action_type)
last_updated_tstamp = chats[-1]["tstamp"]
cutoff_date = datetime.datetime.fromtimestamp(
cutoff_date = datetime.fromtimestamp(
last_updated_tstamp, tz=timezone("US/Pacific")
).strftime("%Y%m%d")

Expand Down
Loading