diff --git a/fastchat/serve/monitor/clean_chat_data.py b/fastchat/serve/monitor/clean_chat_data.py index 26899d829..c8921ac0c 100644 --- a/fastchat/serve/monitor/clean_chat_data.py +++ b/fastchat/serve/monitor/clean_chat_data.py @@ -7,14 +7,13 @@ import argparse import json import os +import hashlib from pytz import timezone from functools import partial from datetime import datetime, timedelta import time import multiprocessing as mp -from tqdm import tqdm - from fastchat.serve.monitor.basic_stats import NUM_SERVERS from fastchat.serve.monitor.clean_battle_data import ( to_openai_format, @@ -64,18 +63,15 @@ def get_action_type_data(filename, action_type): except FileNotFoundError: time.sleep(2) + rows = [] for l in lines: row = json.loads(l) if row["type"] == action_type: - return row - + rows.append(row) + return rows -def process_data(row, action_type, all_ips): - # Initialize local counters - ct_invalid_conv_id = 0 - ct_invalid = 0 - ct_network_error = 0 +def process_data(row, action_type): try: if action_type in ["chat", "upvote", "downvote"]: state = row["state"] @@ -88,40 +84,64 @@ def process_data(row, action_type, all_ips): model = row["states"][1]["model_name"] conversation_id = state["conv_id"] except KeyError: - ct_invalid_conv_id += 1 - return None, ct_invalid_conv_id, ct_invalid, ct_network_error, None + return { + "result": None, + "ct_invalid_conv_id": 1, + "ct_invalid": 0, + "ct_network_error": 0, + "model": None, + } if conversation_id is None: - ct_invalid_conv_id += 1 - return None, ct_invalid_conv_id, ct_invalid, ct_network_error, None + return { + "result": None, + "ct_invalid_conv_id": 1, + "ct_invalid": 0, + "ct_network_error": 0, + "model": None, + } conversation = to_openai_format(state["messages"][state["offset"] :]) if not isinstance(model, str): - ct_invalid += 1 - return None, ct_invalid_conv_id, ct_invalid, ct_network_error, None + return { + "result": None, + "ct_invalid_conv_id": 0, + "ct_invalid": 1, + "ct_network_error": 0, + "model": None, + } model = replace_model_name(model, row["tstamp"]) try: lang_code = detect_language(state["messages"][state["offset"]][1]) except IndexError: - ct_invalid += 1 - return None, ct_invalid_conv_id, ct_invalid, ct_network_error, None + return { + "result": None, + "ct_invalid_conv_id": 0, + "ct_invalid": 1, + "ct_network_error": 0, + "model": None, + } if not all(isinstance(x["content"], str) for x in conversation): - ct_invalid += 1 - return None, ct_invalid_conv_id, ct_invalid, ct_network_error, None + return { + "result": None, + "ct_invalid_conv_id": 0, + "ct_invalid": 1, + "ct_network_error": 0, + "model": None, + } messages = "".join([x["content"] for x in conversation]).lower() if NETWORK_ERROR_MSG in messages: - ct_network_error += 1 - return None, ct_invalid_conv_id, ct_invalid, ct_network_error, None - - ip = row["ip"] - # Synchronize access to all_ips using the lock - with LOCK: - if ip not in all_ips: - all_ips[ip] = len(all_ips) - user_id = all_ips[ip] + return { + "result": None, + "ct_invalid_conv_id": 0, + "ct_invalid": 0, + "ct_network_error": 1, + "model": None, + } + user_id = hashlib.md5(row["ip"].encode()).hexdigest() # Prepare the result data result = dict( @@ -134,43 +154,37 @@ def process_data(row, action_type, all_ips): tstamp=row["tstamp"], ) - return result, ct_invalid_conv_id, ct_invalid, ct_network_error, model + return { + "result": result, + "ct_invalid_conv_id": 0, + "ct_invalid": 0, + "ct_network_error": 0, + "model": model, + } def clean_chat_data(log_files, action_type): with mp.Pool() as pool: # Use partial to pass action_type to get_action_type_data func = partial(get_action_type_data, action_type=action_type) - raw_data = pool.map(func, log_files, chunksize=1) - + file_data = pool.map(func, log_files, chunksize=1) # filter out Nones as some files may not contain any data belong to action_type + raw_data = [] + for data in file_data: + raw_data.extend(data) raw_data = [r for r in raw_data if r is not None] - all_ips = MANAGER.dict() # Use the multiprocessing Pool with mp.Pool() as pool: - func = partial(process_data, action_type=action_type, all_ips=all_ips) + func = partial(process_data, action_type=action_type) results = pool.map(func, raw_data, chunksize=1) - # Initialize counters and collections in the parent process - ct_invalid_conv_id = 0 - ct_invalid = 0 - ct_network_error = 0 - all_models = set() - chats = [] - # Aggregate results from child processes - for res in results: - if res is None: - continue - data, inv_conv_id, inv, net_err, model = res - ct_invalid_conv_id += inv_conv_id - ct_invalid += inv - ct_network_error += net_err - if data: - chats.append(data) - if model: - all_models.add(model) + ct_invalid_conv_id = sum([data["ct_invalid_conv_id"] for data in results]) + ct_invalid = sum([data["ct_invalid"] for data in results]) + ct_network_error = sum([data["ct_network_error"] for data in results]) + all_models = set([data["model"] for data in results if not (data["model"] is None)]) + chats = [data["result"] for data in results if not (data["result"] is None)] chats.sort(key=lambda x: x["tstamp"]) last_updated_tstamp = chats[-1]["tstamp"]