Skip to content

Commit

Permalink
Clean Upvote and Downvote data (#3611)
Browse files Browse the repository at this point in the history
  • Loading branch information
CodingWithTim authored Nov 13, 2024
1 parent 185e1a9 commit 5ac9372
Showing 1 changed file with 146 additions and 83 deletions.
229 changes: 146 additions & 83 deletions fastchat/serve/monitor/clean_chat_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,16 @@
python3 clean_chat_data.py
"""
import argparse
import datetime
import json
import os
import hashlib
from pytz import timezone
import time

from functools import partial
from math import ceil
from datetime import datetime, timedelta
from tqdm import tqdm
import time
import multiprocessing as mp

from fastchat.serve.monitor.basic_stats import NUM_SERVERS
from fastchat.serve.monitor.clean_battle_data import (
Expand All @@ -26,12 +29,20 @@
)


def get_log_files(max_num_files=None):
dates = []
for month in range(4, 12):
for day in range(1, 33):
dates.append(f"2023-{month:02d}-{day:02d}")
def date_range(start="2023-04-01"):
start_date = datetime.strptime(start, "%Y-%m-%d").date()
end_date = datetime.now().date()
delta = end_date - start_date
dates = [
(start_date + timedelta(days=d)).strftime("%Y-%m-%d")
for d in range(delta.days + 2)
]

return dates


def get_log_files(max_num_files=None):
dates = date_range()
filenames = []
for d in dates:
for i in range(NUM_SERVERS):
Expand All @@ -44,90 +55,141 @@ def get_log_files(max_num_files=None):
return filenames


def clean_chat_data(log_files, action_type):
def get_action_type_data(filename, action_type):
for _ in range(5):
try:
lines = open(filename).readlines()
break
except FileNotFoundError:
time.sleep(2)

rows = []
for l in lines:
row = json.loads(l)
if row["type"] == action_type:
rows.append(row)
return rows


def process_data(row, action_type):
try:
if action_type in ["chat", "upvote", "downvote"]:
state = row["state"]
model = row["model"]
elif action_type == "leftvote":
state = row["states"][0]
model = row["states"][0]["model_name"]
elif action_type == "rightvote":
state = row["states"][1]
model = row["states"][1]["model_name"]
conversation_id = state["conv_id"]
except KeyError:
return {
"ct_invalid_conv_id": 1,
}

if conversation_id is None:
return {
"ct_invalid_conv_id": 1,
}

conversation = to_openai_format(state["messages"][state["offset"] :])
if not isinstance(model, str):
return {
"ct_invalid": 1,
}
model = replace_model_name(model, row["tstamp"])

try:
lang_code = detect_language(state["messages"][state["offset"]][1])
except IndexError:
return {
"ct_invalid": 1,
}

if not all(isinstance(x["content"], str) for x in conversation):
return {
"ct_invalid": 1,
}

messages = "".join([x["content"] for x in conversation]).lower()
if NETWORK_ERROR_MSG in messages:
return {
"ct_network_error": 1,
}
user_id = hashlib.md5(row["ip"].encode()).hexdigest()

# Prepare the result data
result = dict(
conversation_id=conversation_id,
model=model,
conversation=conversation,
turn=len(conversation) // 2,
language=lang_code,
user_id=user_id,
tstamp=row["tstamp"],
)

return {
"result": result,
"model": model,
}


def clean_chat_data(log_files, action_type, num_parallel):
with mp.Pool(num_parallel) as pool:
# Use partial to pass action_type to get_action_type_data
func = partial(get_action_type_data, action_type=action_type)
file_data = list(
tqdm(
pool.imap(
func, log_files, chunksize=ceil(len(log_files) / len(pool._pool))
),
total=len(log_files),
desc="Processing Log Files",
)
)
# filter out Nones as some files may not contain any data belong to action_type
raw_data = []
for filename in tqdm(log_files, desc="read files"):
for retry in range(5):
try:
lines = open(filename).readlines()
break
except FileNotFoundError:
time.sleep(2)

for l in lines:
row = json.loads(l)
if row["type"] == action_type:
raw_data.append(row)
for data in file_data:
raw_data.extend(data)
raw_data = [r for r in raw_data if not (r is None)]

# Use the multiprocessing Pool
with mp.Pool(num_parallel) as pool:
func = partial(process_data, action_type=action_type)
results = list(
tqdm(
pool.imap(
func, raw_data, chunksize=ceil(len(raw_data) / len(pool._pool))
),
total=len(raw_data),
desc="Processing Raw Data",
)
)

all_models = set()
all_ips = dict()
chats = []
# Aggregate results from child processes
ct_invalid_conv_id = 0
ct_invalid = 0
ct_network_error = 0
for row in raw_data:
try:
if action_type in ["chat", "upvote", "downvote"]:
state = row["state"]
model = row["model"]
elif action_type == "leftvote":
state = row["states"][0]
model = row["states"][0]["model_name"]
elif action_type == "rightvote":
state = row["states"][1]
model = row["states"][1]["model_name"]
conversation_id = state["conv_id"]
except KeyError:
ct_invalid_conv_id += 1
continue

if conversation_id is None:
ct_invalid_conv_id += 1
continue

conversation = to_openai_format(state["messages"][state["offset"] :])
if not isinstance(model, str):
ct_invalid += 1
continue
model = replace_model_name(model, row["tstamp"])

try:
lang_code = detect_language(state["messages"][state["offset"]][1])
except IndexError:
ct_invalid += 1
all_models = set()
chats = []
for data in tqdm(results):
if "ct_invalid_conv_id" in data:
ct_invalid_conv_id += data["ct_invalid_conv_id"]
continue

if not all(isinstance(x["content"], str) for x in conversation):
ct_invalid += 1
if "ct_invalid" in data:
ct_invalid += data["ct_invalid"]
continue

messages = "".join([x["content"] for x in conversation]).lower()
if NETWORK_ERROR_MSG in messages:
ct_network_error += 1
if "ct_network_error" in data:
ct_network_error += data["ct_network_error"]
continue

ip = row["ip"]
if ip not in all_ips:
all_ips[ip] = len(all_ips)
user_id = all_ips[ip]

chats.append(
dict(
conversation_id=conversation_id,
model=model,
conversation=conversation,
turn=len(conversation) // 2,
language=lang_code,
user_id=user_id,
tstamp=row["tstamp"],
)
)

all_models.update([model])
all_models.update([data["model"]])
chats.append(data["result"])

chats.sort(key=lambda x: x["tstamp"])
last_updated_tstamp = chats[-1]["tstamp"]
last_updated_datetime = datetime.datetime.fromtimestamp(
last_updated_datetime = datetime.fromtimestamp(
last_updated_tstamp, tz=timezone("US/Pacific")
).strftime("%Y-%m-%d %H:%M:%S %Z")

Expand Down Expand Up @@ -156,12 +218,13 @@ def clean_chat_data(log_files, action_type):
parser = argparse.ArgumentParser()
parser.add_argument("--action-type", type=str, default="chat")
parser.add_argument("--max-num-files", type=int)
parser.add_argument("--num-parallel", type=int, default=16)
args = parser.parse_args()

log_files = get_log_files(args.max_num_files)
chats = clean_chat_data(log_files, args.action_type)
chats = clean_chat_data(log_files, args.action_type, args.num_parallel)
last_updated_tstamp = chats[-1]["tstamp"]
cutoff_date = datetime.datetime.fromtimestamp(
cutoff_date = datetime.fromtimestamp(
last_updated_tstamp, tz=timezone("US/Pacific")
).strftime("%Y%m%d")

Expand Down

0 comments on commit 5ac9372

Please sign in to comment.