Skip to content

Commit

Permalink
Update add script to add a range and prioritise questions asked from …
Browse files Browse the repository at this point in the history
…9am-12pm and 8pm-10pm as well as questions asked during the weekend
  • Loading branch information
lickem22 committed Aug 30, 2024
1 parent 5c6ad01 commit 2ac97e3
Showing 1 changed file with 74 additions and 24 deletions.
98 changes: 74 additions & 24 deletions core_backend/add_new_data_to_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
QueryResponseContentDB,
ResponseFeedbackDB,
)
from app.urgency_detection.models import UrgencyQueryDB
from app.urgency_detection.models import UrgencyQueryDB, UrgencyResponseDB
from app.users.models import UserDB
from app.utils import get_key_hash
from litellm import completion
Expand All @@ -50,6 +50,7 @@
(ContentFeedbackDB, "feedback_datetime_utc"),
(QueryResponseContentDB, "created_datetime_utc"),
(UrgencyQueryDB, "message_datetime_utc"),
(UrgencyResponseDB, "response_datetime_utc"),
]

parser = argparse.ArgumentParser(
Expand All @@ -64,6 +65,7 @@
--api-key <API_KEY> \
--nb-workers 8 \
--start-date 01-08-23
--end-date 04-09-24
""",
)
Expand All @@ -82,6 +84,16 @@
help="Start date for the records in the format dd-mm-yy",
required=False,
)
parser.add_argument(
"--end-date",
help="End date for the records in the format dd-mm-yy",
required=False,
)
parser.add_argument(
"--subset",
help="Subset of the data to use for testing",
required=False,
)
args = parser.parse_args()


Expand Down Expand Up @@ -281,24 +293,60 @@ def process_urgency_detection(_id: int, text: str) -> tuple | None:
return None


def create_random_datetime_from_string(start_date: datetime) -> datetime:
def create_random_datetime(start_date: datetime, end_date: datetime) -> datetime:
"""
Create a random datetime from a date in the format "%d-%m-%y
to today
Create a random datetime from a date within a range
"""

time_difference = datetime.now() - start_date
time_difference = end_date - start_date
random_number_of_days = random.randint(0, time_difference.days)

random_number_of_seconds = random.randint(0, 86399) # Number of seconds in one day

random_number_of_seconds = random.randint(0, 86399)
random_datetime = start_date + timedelta(
days=random_number_of_days, seconds=random_number_of_seconds
)
return random_datetime


def update_date_of_records(models: list, random_dates: list, api_key: str) -> None:
def is_within_time_range(date: datetime) -> bool:
"""
Helper function to check if the date is within desired time range.
Prioritizing 9am-12pm and 8pm-10pm
"""
if 9 <= date.hour < 12 or 20 <= date.hour < 22:
return True
return False


def generate_distributed_dates(n: int, start: datetime, end: datetime) -> list:
"""
Generate dates with a specific distribution for the records
"""
dates: list[datetime] = []
while len(dates) < n:
date = create_random_datetime(start, end)

# More dates on weekends
if date.weekday() >= 5:

if (
is_within_time_range(date) or random.random() < 0.4
): # Within time range or 30% chance
dates.append(date)
else:
if random.random() < 0.6:
if is_within_time_range(date) or random.random() < 0.55:
dates.append(date)

return dates


def update_date_of_records(
models: list,
api_key: str,
start_date: datetime,
end_date: datetime,
) -> None:
"""
Update the date of the records in the database
"""
Expand All @@ -308,11 +356,7 @@ def update_date_of_records(models: list, random_dates: list, api_key: str) -> No
select(UserDB).where(UserDB.hashed_api_key == hashed_token)
).scalar_one()
queries = [c for c in session.query(QueryDB).all() if c.user_id == user.user_id]
if len(queries) > len(random_dates):
random_dates = random_dates + [
create_random_datetime_from_string(start_date)
for _ in range(len(queries) - len(random_dates))
]
random_dates = generate_distributed_dates(len(queries), start_date, end_date)
# Create a dictionary to map the query_id to the random date
date_map_dic = {queries[i].query_id: random_dates[i] for i in range(len(queries))}
for model in models:
Expand All @@ -323,8 +367,8 @@ def update_date_of_records(models: list, random_dates: list, api_key: str) -> No

for i, row in enumerate(rows):
# Set the date attribute to the random date
if hasattr(row, "query_id"):
date = date_map_dic[row.query_id]
if hasattr(row, "query_id") and model[0] != UrgencyQueryDB:
date = date_map_dic.get(row.query_id, None)
else:
date = random_dates[i]

Expand All @@ -351,17 +395,26 @@ def update_date_of_contents(date: datetime) -> None:
NB_WORKERS = int(args.nb_workers) if args.nb_workers else 8
API_KEY = args.api_key if args.api_key else ADMIN_API_KEY

date_string = args.start_date if args.start_date else "01-08-23"
start_date_string = args.start_date if args.start_date else "01-08-23"
end_date_string = args.end_date if args.end_date else None
date_format = "%d-%m-%y"
start_date = datetime.strptime(date_string, date_format)
start_date = datetime.strptime(start_date_string, date_format)
end_date = (
datetime.strptime(end_date_string, date_format)
if end_date_string
else datetime.now()
)
assert end_date, "Invalid end date. Please provide a valid date. Format is dd-mm-yy"
assert (
start_date and start_date < datetime.now()
), "Invalid start date. Please provide a valid start date."
start_date and start_date < end_date
), "Invalid start date. Please provide a valid start date. Format is dd-mm-yy"

subset = int(args.subset) if args.subset else None
path = args.csv
df = pd.read_csv(path)
df = pd.read_csv(path, nrows=subset)
saved_queries = defaultdict(list)
print("Processing search queries...")

# Using multithreading to speed up the process
with ThreadPoolExecutor(max_workers=NB_WORKERS) as executor:
future_to_text = {
Expand Down Expand Up @@ -444,11 +497,8 @@ def update_date_of_contents(date: datetime) -> None:
result = future.result()
print("Urgency Detection successfully processed")

random_dates = [
create_random_datetime_from_string(start_date) for _ in range(len(df))
]
print("Updating the date of the records...")
update_date_of_records(MODELS, random_dates, API_KEY)
update_date_of_records(MODELS, API_KEY, start_date, end_date)

print("Updating the date of the content records...")
update_date_of_contents(start_date)
Expand Down

0 comments on commit 2ac97e3

Please sign in to comment.