Skip to content

Commit

Permalink
added logging session states
Browse files Browse the repository at this point in the history
  • Loading branch information
beingkk committed Jan 2, 2024
1 parent 2fd9037 commit bcb774b
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 400 deletions.
80 changes: 78 additions & 2 deletions signals_app.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import json
import os
import uuid

from datetime import datetime

import openai
import s3fs
import streamlit as st

from dotenv import load_dotenv
Expand Down Expand Up @@ -40,6 +44,10 @@
# Prompt: Following up on user's question
path_prompt_following_up = PROMPT_PATH + "04_follow_up.jsonl"

aws_key = os.environ["AWS_ACCESS_KEY_ID"]
aws_secret = os.environ["AWS_SECRET_ACCESS_KEY"]
s3_path = os.environ["S3_BUCKET"]


def auth_openai() -> None:
"""Authenticate with OpenAI."""
Expand Down Expand Up @@ -212,7 +220,7 @@ def predict_top_three_signals(user_message: str, allowed_signals: list) -> list:
return top_signals["prediction"]


def signals_bot(sidebar: bool = True) -> None:
def signals_bot() -> None:
"""Explain me a concept like I'm 3."""

# Define custom CSS
Expand Down Expand Up @@ -270,6 +278,17 @@ def signals_bot(sidebar: bool = True) -> None:
for m in intro_messages:
st.session_state.messages.append(m)
st.session_state["memory"].add_message(m)
# Keep count of the number of unique sessions
timestamp = current_time()
session_log = f"{timestamp}-{str(uuid.uuid4())}"
write_to_s3(
key=aws_key,
secret=aws_secret,
s3_path=f"{s3_path}/session-logs-signals",
filename="session_counter",
data={"session": session_log, "time": timestamp},
how="a",
)

# Display chat messages on app rerun
for message in st.session_state.messages:
Expand All @@ -292,10 +311,14 @@ def signals_bot(sidebar: bool = True) -> None:
intent = predict_intent(user_message, active_signal=st.session_state.active_signal)

if intent == "new_signal":
# Filter out signals that have already been covered
allowed_signals = [s for s in signals if s not in st.session_state.signals]
# Determine the most relevant signal to explain
signal_to_explain = predict_top_signal(user_message, allowed_signals)
# Keep track of already discussed signals
st.session_state.signals.append(signal_to_explain)
st.session_state.active_signal = signal_to_explain
# Generate a message about the signal
instruction = MessageTemplate.load(path_prompt_impact)
message_history = st.session_state["memory"].get_messages(max_tokens=3000) + [instruction]
with st.chat_message("assistant"):
Expand All @@ -310,11 +333,23 @@ def signals_bot(sidebar: bool = True) -> None:
)
st.session_state.messages.append({"role": "assistant", "content": full_response})
st.session_state["memory"].add_message({"role": "assistant", "content": full_response})
# Keep count of the number of signals
write_to_s3(
key=aws_key,
secret=aws_secret,
s3_path=f"{s3_path}/session-logs-signals",
filename="signal_counter",
data={"signal": signal_to_explain, "time": current_time()},
how="a",
)

elif intent == "more_signals":
# Filter out signals that have already been covered
allowed_signals = [s for s in signals if s not in st.session_state.signals]
# Determine the top three signals to explain
top_signals = predict_top_three_signals(st.session_state.user_info, allowed_signals)
top_signals_text = generate_signals_texts(signals_data, top_signals)
# Generate a message about the three signals
instruction = MessageTemplate.load(path_prompt_choice)
message_history = st.session_state["memory"].get_messages(max_tokens=3000) + [instruction]
with st.chat_message("assistant"):
Expand All @@ -328,6 +363,7 @@ def signals_bot(sidebar: bool = True) -> None:
st.session_state["memory"].add_message({"role": "assistant", "content": full_response})

elif intent == "following_up":
# Generate follow up message
instruction = MessageTemplate.load(path_prompt_following_up)
message_history = st.session_state["memory"].get_messages(max_tokens=3000) + [instruction]
with st.chat_message("assistant"):
Expand Down Expand Up @@ -363,11 +399,51 @@ def llm_call(selected_model: str, temperature: float, messages: MessageTemplate,
return full_response


def write_to_s3(key: str, secret: str, s3_path: str, filename: str, data: dict, how: str = "a") -> None:
"""Write data to a jsonl file in S3.
Parameters
----------
key
AWS access key ID.
secret
AWS secret access key.
s3_path
S3 bucket path.
filename
Name of the file to write to.
data
Data to write to the file.
how
How to write to the file. Default is "a" for append. Use "w" to overwrite.
"""
fs = s3fs.S3FileSystem(key=key, secret=secret)
with fs.open(f"{s3_path}/{filename}.jsonl", how) as f:
f.write(f"{json.dumps(data)}\n")


def current_time() -> str:
"""Return the current time as a string. Used as part of the session UUID."""
# Get current date and time
current_datetime = datetime.now()

# Convert to a long number format
datetime_string = current_datetime.strftime("%Y%m%d%H%M%S")

return datetime_string


def main() -> None:
"""Run the app."""
auth_openai()

signals_bot(sidebar=False)
signals_bot()


main()
Loading

0 comments on commit bcb774b

Please sign in to comment.