diff --git a/fastchat/serve/monitor/clean_chat_data.py b/fastchat/serve/monitor/clean_chat_data.py index 2bda0e2c3..ec6da4a65 100644 --- a/fastchat/serve/monitor/clean_chat_data.py +++ b/fastchat/serve/monitor/clean_chat_data.py @@ -5,13 +5,16 @@ python3 clean_chat_data.py """ import argparse -import datetime import json import os +import hashlib from pytz import timezone -import time - +from functools import partial +from math import ceil +from datetime import datetime, timedelta from tqdm import tqdm +import time +import multiprocessing as mp from fastchat.serve.monitor.basic_stats import NUM_SERVERS from fastchat.serve.monitor.clean_battle_data import ( @@ -26,12 +29,20 @@ ) -def get_log_files(max_num_files=None): - dates = [] - for month in range(4, 12): - for day in range(1, 33): - dates.append(f"2023-{month:02d}-{day:02d}") +def date_range(start="2023-04-01"): + start_date = datetime.strptime(start, "%Y-%m-%d").date() + end_date = datetime.now().date() + delta = end_date - start_date + dates = [ + (start_date + timedelta(days=d)).strftime("%Y-%m-%d") + for d in range(delta.days + 2) + ] + return dates + + +def get_log_files(max_num_files=None): + dates = date_range() filenames = [] for d in dates: for i in range(NUM_SERVERS): @@ -44,90 +55,141 @@ def get_log_files(max_num_files=None): return filenames -def clean_chat_data(log_files, action_type): +def get_action_type_data(filename, action_type): + for _ in range(5): + try: + lines = open(filename).readlines() + break + except FileNotFoundError: + time.sleep(2) + + rows = [] + for l in lines: + row = json.loads(l) + if row["type"] == action_type: + rows.append(row) + return rows + + +def process_data(row, action_type): + try: + if action_type in ["chat", "upvote", "downvote"]: + state = row["state"] + model = row["model"] + elif action_type == "leftvote": + state = row["states"][0] + model = row["states"][0]["model_name"] + elif action_type == "rightvote": + state = row["states"][1] + model = row["states"][1]["model_name"] + conversation_id = state["conv_id"] + except KeyError: + return { + "ct_invalid_conv_id": 1, + } + + if conversation_id is None: + return { + "ct_invalid_conv_id": 1, + } + + conversation = to_openai_format(state["messages"][state["offset"] :]) + if not isinstance(model, str): + return { + "ct_invalid": 1, + } + model = replace_model_name(model, row["tstamp"]) + + try: + lang_code = detect_language(state["messages"][state["offset"]][1]) + except IndexError: + return { + "ct_invalid": 1, + } + + if not all(isinstance(x["content"], str) for x in conversation): + return { + "ct_invalid": 1, + } + + messages = "".join([x["content"] for x in conversation]).lower() + if NETWORK_ERROR_MSG in messages: + return { + "ct_network_error": 1, + } + user_id = hashlib.md5(row["ip"].encode()).hexdigest() + + # Prepare the result data + result = dict( + conversation_id=conversation_id, + model=model, + conversation=conversation, + turn=len(conversation) // 2, + language=lang_code, + user_id=user_id, + tstamp=row["tstamp"], + ) + + return { + "result": result, + "model": model, + } + + +def clean_chat_data(log_files, action_type, num_parallel): + with mp.Pool(num_parallel) as pool: + # Use partial to pass action_type to get_action_type_data + func = partial(get_action_type_data, action_type=action_type) + file_data = list( + tqdm( + pool.imap( + func, log_files, chunksize=ceil(len(log_files) / len(pool._pool)) + ), + total=len(log_files), + desc="Processing Log Files", + ) + ) + # filter out Nones as some files may not contain any data belong to action_type raw_data = [] - for filename in tqdm(log_files, desc="read files"): - for retry in range(5): - try: - lines = open(filename).readlines() - break - except FileNotFoundError: - time.sleep(2) - - for l in lines: - row = json.loads(l) - if row["type"] == action_type: - raw_data.append(row) + for data in file_data: + raw_data.extend(data) + raw_data = [r for r in raw_data if not (r is None)] + + # Use the multiprocessing Pool + with mp.Pool(num_parallel) as pool: + func = partial(process_data, action_type=action_type) + results = list( + tqdm( + pool.imap( + func, raw_data, chunksize=ceil(len(raw_data) / len(pool._pool)) + ), + total=len(raw_data), + desc="Processing Raw Data", + ) + ) - all_models = set() - all_ips = dict() - chats = [] + # Aggregate results from child processes ct_invalid_conv_id = 0 ct_invalid = 0 ct_network_error = 0 - for row in raw_data: - try: - if action_type in ["chat", "upvote", "downvote"]: - state = row["state"] - model = row["model"] - elif action_type == "leftvote": - state = row["states"][0] - model = row["states"][0]["model_name"] - elif action_type == "rightvote": - state = row["states"][1] - model = row["states"][1]["model_name"] - conversation_id = state["conv_id"] - except KeyError: - ct_invalid_conv_id += 1 - continue - - if conversation_id is None: - ct_invalid_conv_id += 1 - continue - - conversation = to_openai_format(state["messages"][state["offset"] :]) - if not isinstance(model, str): - ct_invalid += 1 - continue - model = replace_model_name(model, row["tstamp"]) - - try: - lang_code = detect_language(state["messages"][state["offset"]][1]) - except IndexError: - ct_invalid += 1 + all_models = set() + chats = [] + for data in tqdm(results): + if "ct_invalid_conv_id" in data: + ct_invalid_conv_id += data["ct_invalid_conv_id"] continue - - if not all(isinstance(x["content"], str) for x in conversation): - ct_invalid += 1 + if "ct_invalid" in data: + ct_invalid += data["ct_invalid"] continue - - messages = "".join([x["content"] for x in conversation]).lower() - if NETWORK_ERROR_MSG in messages: - ct_network_error += 1 + if "ct_network_error" in data: + ct_network_error += data["ct_network_error"] continue - - ip = row["ip"] - if ip not in all_ips: - all_ips[ip] = len(all_ips) - user_id = all_ips[ip] - - chats.append( - dict( - conversation_id=conversation_id, - model=model, - conversation=conversation, - turn=len(conversation) // 2, - language=lang_code, - user_id=user_id, - tstamp=row["tstamp"], - ) - ) - - all_models.update([model]) + all_models.update([data["model"]]) + chats.append(data["result"]) chats.sort(key=lambda x: x["tstamp"]) last_updated_tstamp = chats[-1]["tstamp"] - last_updated_datetime = datetime.datetime.fromtimestamp( + last_updated_datetime = datetime.fromtimestamp( last_updated_tstamp, tz=timezone("US/Pacific") ).strftime("%Y-%m-%d %H:%M:%S %Z") @@ -156,12 +218,13 @@ def clean_chat_data(log_files, action_type): parser = argparse.ArgumentParser() parser.add_argument("--action-type", type=str, default="chat") parser.add_argument("--max-num-files", type=int) + parser.add_argument("--num-parallel", type=int, default=16) args = parser.parse_args() log_files = get_log_files(args.max_num_files) - chats = clean_chat_data(log_files, args.action_type) + chats = clean_chat_data(log_files, args.action_type, args.num_parallel) last_updated_tstamp = chats[-1]["tstamp"] - cutoff_date = datetime.datetime.fromtimestamp( + cutoff_date = datetime.fromtimestamp( last_updated_tstamp, tz=timezone("US/Pacific") ).strftime("%Y%m%d")