Skip to content

Commit

Permalink
improve
Browse files Browse the repository at this point in the history
  • Loading branch information
CodingWithTim committed Nov 10, 2024
1 parent 234e3b0 commit f21c56f
Showing 1 changed file with 65 additions and 51 deletions.
116 changes: 65 additions & 51 deletions fastchat/serve/monitor/clean_chat_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,13 @@
import argparse
import json
import os
import hashlib
from pytz import timezone
from functools import partial
from datetime import datetime, timedelta
import time
import multiprocessing as mp

from tqdm import tqdm

from fastchat.serve.monitor.basic_stats import NUM_SERVERS
from fastchat.serve.monitor.clean_battle_data import (
to_openai_format,
Expand Down Expand Up @@ -64,18 +63,15 @@ def get_action_type_data(filename, action_type):
except FileNotFoundError:
time.sleep(2)

rows = []
for l in lines:
row = json.loads(l)
if row["type"] == action_type:
return row

rows.append(row)
return rows

def process_data(row, action_type, all_ips):
# Initialize local counters
ct_invalid_conv_id = 0
ct_invalid = 0
ct_network_error = 0

def process_data(row, action_type):
try:
if action_type in ["chat", "upvote", "downvote"]:
state = row["state"]
Expand All @@ -88,40 +84,64 @@ def process_data(row, action_type, all_ips):
model = row["states"][1]["model_name"]
conversation_id = state["conv_id"]
except KeyError:
ct_invalid_conv_id += 1
return None, ct_invalid_conv_id, ct_invalid, ct_network_error, None
return {
"result": None,
"ct_invalid_conv_id": 1,
"ct_invalid": 0,
"ct_network_error": 0,
"model": None,
}

if conversation_id is None:
ct_invalid_conv_id += 1
return None, ct_invalid_conv_id, ct_invalid, ct_network_error, None
return {
"result": None,
"ct_invalid_conv_id": 1,
"ct_invalid": 0,
"ct_network_error": 0,
"model": None,
}

conversation = to_openai_format(state["messages"][state["offset"] :])
if not isinstance(model, str):
ct_invalid += 1
return None, ct_invalid_conv_id, ct_invalid, ct_network_error, None
return {
"result": None,
"ct_invalid_conv_id": 0,
"ct_invalid": 1,
"ct_network_error": 0,
"model": None,
}
model = replace_model_name(model, row["tstamp"])

try:
lang_code = detect_language(state["messages"][state["offset"]][1])
except IndexError:
ct_invalid += 1
return None, ct_invalid_conv_id, ct_invalid, ct_network_error, None
return {
"result": None,
"ct_invalid_conv_id": 0,
"ct_invalid": 1,
"ct_network_error": 0,
"model": None,
}

if not all(isinstance(x["content"], str) for x in conversation):
ct_invalid += 1
return None, ct_invalid_conv_id, ct_invalid, ct_network_error, None
return {
"result": None,
"ct_invalid_conv_id": 0,
"ct_invalid": 1,
"ct_network_error": 0,
"model": None,
}

messages = "".join([x["content"] for x in conversation]).lower()
if NETWORK_ERROR_MSG in messages:
ct_network_error += 1
return None, ct_invalid_conv_id, ct_invalid, ct_network_error, None

ip = row["ip"]
# Synchronize access to all_ips using the lock
with LOCK:
if ip not in all_ips:
all_ips[ip] = len(all_ips)
user_id = all_ips[ip]
return {
"result": None,
"ct_invalid_conv_id": 0,
"ct_invalid": 0,
"ct_network_error": 1,
"model": None,
}
user_id = hashlib.md5(row["ip"].encode()).hexdigest()

# Prepare the result data
result = dict(
Expand All @@ -134,43 +154,37 @@ def process_data(row, action_type, all_ips):
tstamp=row["tstamp"],
)

return result, ct_invalid_conv_id, ct_invalid, ct_network_error, model
return {
"result": result,
"ct_invalid_conv_id": 0,
"ct_invalid": 0,
"ct_network_error": 0,
"model": model,
}


def clean_chat_data(log_files, action_type):
with mp.Pool() as pool:
# Use partial to pass action_type to get_action_type_data
func = partial(get_action_type_data, action_type=action_type)
raw_data = pool.map(func, log_files, chunksize=1)

file_data = pool.map(func, log_files, chunksize=1)
# filter out Nones as some files may not contain any data belong to action_type
raw_data = []
for data in file_data:
raw_data.extend(data)
raw_data = [r for r in raw_data if r is not None]
all_ips = MANAGER.dict()

# Use the multiprocessing Pool
with mp.Pool() as pool:
func = partial(process_data, action_type=action_type, all_ips=all_ips)
func = partial(process_data, action_type=action_type)
results = pool.map(func, raw_data, chunksize=1)

# Initialize counters and collections in the parent process
ct_invalid_conv_id = 0
ct_invalid = 0
ct_network_error = 0
all_models = set()
chats = []

# Aggregate results from child processes
for res in results:
if res is None:
continue
data, inv_conv_id, inv, net_err, model = res
ct_invalid_conv_id += inv_conv_id
ct_invalid += inv
ct_network_error += net_err
if data:
chats.append(data)
if model:
all_models.add(model)
ct_invalid_conv_id = sum([data["ct_invalid_conv_id"] for data in results])
ct_invalid = sum([data["ct_invalid"] for data in results])
ct_network_error = sum([data["ct_network_error"] for data in results])
all_models = set([data["model"] for data in results if not (data["model"] is None)])
chats = [data["result"] for data in results if not (data["result"] is None)]

chats.sort(key=lambda x: x["tstamp"])
last_updated_tstamp = chats[-1]["tstamp"]
Expand Down

0 comments on commit f21c56f

Please sign in to comment.