Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Draft] Comparing two semantic models results #185

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion admin_apps/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def verify_environment_setup() -> None:


if __name__ == "__main__":
from admin_apps.journeys import builder, iteration, partner
from admin_apps.journeys import builder, comparator, iteration, partner

def onboarding_dialog() -> None:
"""
Expand Down Expand Up @@ -114,6 +114,14 @@ def onboarding_dialog() -> None:
action="start",
)
partner.show()
st.markdown("")
if st.button(
"**📝 Compare two semantic models**",
use_container_width=True,
type="primary",
):
comparator.init_dialog()
st.markdown("")

verify_environment_setup()

Expand All @@ -130,5 +138,7 @@ def onboarding_dialog() -> None:
# The builder flow is simply an intermediate dialog before the iteration flow.
if st.session_state["page"] == GeneratorAppScreen.ITERATION:
iteration.show()
elif st.session_state["page"] == GeneratorAppScreen.COMPARATOR:
comparator.show()
else:
onboarding_dialog()
244 changes: 244 additions & 0 deletions admin_apps/journeys/comparator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,244 @@
from typing import Any

import pandas as pd
import sqlglot
import streamlit as st
from loguru import logger
from snowflake.connector import SnowflakeConnection
from streamlit_monaco import st_monaco

from admin_apps.shared_utils import GeneratorAppScreen, return_home_button, send_message
from semantic_model_generator.data_processing.proto_utils import yaml_to_semantic_model
from semantic_model_generator.snowflake_utils.snowflake_connector import (
SnowflakeConnector,
)
from semantic_model_generator.validate_model import validate

MODEL1_PATH = "model1_path"
MODEL1_YAML = "model1_yaml"
MODEL2_PATH = "model2_path"
MODEL2_YAML = "model2_yaml"


def init_session_states() -> None:
st.session_state["page"] = GeneratorAppScreen.COMPARATOR


def comparator_app() -> None:
return_home_button()
st.write("## Compare two semantic models")
col1, col2 = st.columns(2)
with col1, st.container(border=True):
st.write(f"Model 1 from: `{st.session_state[MODEL1_PATH]}`")
content1 = st_monaco(
value=st.session_state[MODEL1_YAML],
height="400px",
language="yaml",
)

with col2, st.container(border=True):
st.write(f"Model 2 from: `{st.session_state[MODEL2_PATH]}`")
content2 = st_monaco(
value=st.session_state[MODEL2_YAML],
height="400px",
language="yaml",
)

if st.button("Validate models"):
with st.spinner(f"validating {st.session_state[MODEL1_PATH]}..."):
try:
validate(content1, st.session_state.account_name)
st.session_state["model1_valid"] = True
st.session_state[MODEL1_YAML] = content1
except Exception as e:
st.error(f"Validation failed on the first model with error: {e}")
st.session_state["model1_valid"] = False

with st.spinner(f"validating {st.session_state[MODEL2_PATH]}..."):
try:
validate(content2, st.session_state.account_name)
st.session_state["model2_valid"] = True
st.session_state[MODEL2_YAML] = content2
except Exception as e:
st.error(f"Validation failed on the second model with error: {e}")
st.session_state["model2_valid"] = False

if st.session_state.get("model1_valid", False) and st.session_state.get(
"model2_valid", False
):
st.success("Both models are correct.")
st.session_state["validated"] = True
else:
st.error("Please fix the models and try again.")
st.session_state["validated"] = False

if (
content1 != st.session_state[MODEL1_YAML]
or content2 != st.session_state[MODEL2_YAML]
):
st.info("Please validate the models again after making changes.")
st.session_state["validated"] = False

if not st.session_state.get("validated", False):
st.info("Please validate the models first.")
else:
prompt = st.text_input(
"What question would you like to ask the Cortex Analyst?"
)
if prompt:
st.write(f"Asking both models question: {prompt}")
user_message = [
{"role": "user", "content": [{"type": "text", "text": prompt}]}
]
connection = SnowflakeConnector(
account_name=st.session_state.account_name,
max_workers=1,
).open_connection(db_name="")

col1, col2 = st.columns(2)
ask_cortex_analyst(
user_message,
st.session_state[MODEL1_YAML],
connection,
col1,
"Model 1 is thinking...",
)
ask_cortex_analyst(
user_message,
st.session_state[MODEL2_YAML],
connection,
col2,
"Model 2 is thinking...",
)

# TODO:
# - Show the differences
# - Check if both models are pointing at the same table


def ask_cortex_analyst(
prompt: str,
semantic_model: str,
conn: SnowflakeConnection,
container: Any,
spinner_text: str,
) -> None:
"""Ask the Cortex Analyst a question and display the response.

Args:
prompt (str): The question to ask the Cortex Analyst.
semantic_model (str): The semantic model to use for the question.
conn (SnowflakeConnection): The Snowflake connection to use for the question.
container (st.DeltaGenerator): The streamlit container to display the response (e.g. st.columns()).
spinner_text (str): The text to display in the waiting spinner

Returns:
None

"""
with container, st.container(border=True), st.spinner(spinner_text):
json_resp = send_message(conn, prompt, yaml_to_semantic_model(semantic_model))
display_content(conn, json_resp["message"]["content"])
st.json(json_resp, expanded=False)


@st.cache_data(show_spinner=False)
def prettify_sql(sql: str) -> str:
"""
Prettify SQL using SQLGlot with an option to use the Snowflake dialect for syntax checks.

Args:
sql (str): SQL query string to be formatted.

Returns:
str: Formatted SQL string or input SQL if sqlglot failed to parse.
"""
try:
# Parse the SQL using SQLGlot
expression = sqlglot.parse_one(sql, dialect="snowflake")

# Generate formatted SQL, specifying the dialect if necessary for specific syntax transformations
formatted_sql: str = expression.sql(dialect="snowflake", pretty=True)
return formatted_sql
except Exception as e:
logger.debug(f"Failed to prettify SQL: {e}")
return sql


def display_content(
conn: SnowflakeConnection,
content: list[dict[str, Any]],
) -> None:
"""Displays a content item for a message from the Cortex Analyst."""
for item in content:
if item["type"] == "text":
st.markdown(item["text"])
elif item["type"] == "suggestions":
with st.expander("Suggestions", expanded=True):
for suggestion in item["suggestions"]:
st.markdown(f"- {suggestion}")
elif item["type"] == "sql":
with st.container(height=500, border=False):
sql = item["statement"]
sql = prettify_sql(sql)
with st.container(height=250, border=False):
st.code(item["statement"], language="sql")
try:
df = pd.read_sql(sql, conn)
st.dataframe(df, hide_index=True)
except Exception as e:
st.error(f"Failed to execute SQL: {e}")
else:
logger.warning(f"Unknown content type: {item['type']}")
st.write(item)


def is_session_state_initialized() -> bool:
return all(
[
MODEL1_YAML in st.session_state,
MODEL2_YAML in st.session_state,
MODEL1_PATH in st.session_state,
MODEL2_PATH in st.session_state,
]
)


@st.dialog("Welcome to the Cortex Analyst Annotation Workspace! 📝", width="large")
def init_dialog() -> None:
init_session_states()

st.write(
"Please choose the two semantic model files that you would like to compare."
)

model_1_file = st.file_uploader(
"Choose first semantic model file",
type=["yaml"],
help="Choose a local YAML file that contains semantic model.",
)
model_2_file = st.file_uploader(
"Choose second semantic model file",
type=["yaml"],
help="Choose a local YAML file that contains semantic model.",
)

if st.button("Compare"):
if model_1_file is None or model_2_file is None:
st.error("Please upload the both models first.")
else:
st.session_state[MODEL1_PATH] = model_1_file.name
st.session_state[MODEL1_YAML] = model_1_file.getvalue().decode("utf-8")
st.session_state[MODEL2_PATH] = model_2_file.name
st.session_state[MODEL2_YAML] = model_2_file.getvalue().decode("utf-8")
st.rerun()

return_home_button()


def show() -> None:
init_session_states()
if is_session_state_initialized():
comparator_app()
else:
init_dialog()
4 changes: 1 addition & 3 deletions admin_apps/journeys/iteration.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from admin_apps.journeys.joins import joins_dialog
from admin_apps.shared_utils import (
API_ENDPOINT,
GeneratorAppScreen,
SnowflakeStage,
changed_from_last_validated_model,
Expand Down Expand Up @@ -74,9 +75,6 @@ def pretty_print_sql(sql: str) -> str:
return formatted_sql


API_ENDPOINT = "https://{HOST}/api/v2/cortex/analyst/message"


@st.cache_data(ttl=60, show_spinner=False)
def send_message(
_conn: SnowflakeConnection, messages: list[dict[str, str]]
Expand Down
37 changes: 37 additions & 0 deletions admin_apps/shared_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from typing import Any, Optional

import pandas as pd
import requests
import streamlit as st
from PIL import Image
from snowflake.connector import SnowflakeConnection
Expand Down Expand Up @@ -43,6 +44,8 @@
"https://logos-world.net/wp-content/uploads/2022/11/Snowflake-Symbol.png"
)

API_ENDPOINT = "https://{HOST}/api/v2/cortex/analyst/message"


@st.cache_resource
def get_connector() -> SnowflakeConnector:
Expand Down Expand Up @@ -120,6 +123,7 @@ class GeneratorAppScreen(str, Enum):

ONBOARDING = "onboarding"
ITERATION = "iteration"
COMPARATOR = "comparator"


def return_home_button() -> None:
Expand Down Expand Up @@ -889,6 +893,39 @@ def download_yaml(file_name: str, conn: SnowflakeConnection) -> str:
return yaml_str


def send_message(
_conn: SnowflakeConnection,
messages: list[dict[str, str]],
semantic_model: semantic_model_pb2.SemanticModel,
) -> dict[str, Any]:
"""
Calls the REST API with a list of messages and returns the response.
Args:
_conn: SnowflakeConnection, used to grab the token for auth.
messages: list of chat messages to pass to the Analyst API.

Returns: The raw ChatMessage response from Analyst.
"""
request_body = {
"messages": messages,
"semantic_model": proto_to_yaml(semantic_model),
}
api_endpoint = API_ENDPOINT.format(HOST=st.session_state.host_name)
resp = requests.post(
api_endpoint,
json=request_body,
headers={
"Authorization": f'Snowflake Token="{_conn.rest.token}"', # type: ignore[union-attr]
"Content-Type": "application/json",
},
)
if resp.status_code < 400:
json_resp: dict[str, Any] = resp.json()
return json_resp
else:
raise Exception(f"Failed request with status {resp.status_code}: {resp.text}")


def get_sit_query_tag(
vendor: Optional[str] = None, action: Optional[str] = None
) -> str:
Expand Down
10 changes: 8 additions & 2 deletions semantic_model_generator/data_processing/cte_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,10 @@ def generate_select(
non_agg_cte
+ f"SELECT * FROM {logical_table_name(table_in_column_format)} LIMIT {limit}"
)
sqls_to_return.append(_convert_to_snowflake_sql(non_agg_sql))
# sqls_to_return.append(_convert_to_snowflake_sql(non_agg_sql))
sqls_to_return.append(
non_agg_sql
) # do not convert to snowflake sql for now, as sqlglot make mistakes sometimes, e.g. with TO_DATE()

# Generate select query for columns with aggregation exprs.
agg_cols = [
Expand All @@ -280,7 +283,10 @@ def generate_select(
agg_cte
+ f"SELECT * FROM {logical_table_name(table_in_column_format)} LIMIT {limit}"
)
sqls_to_return.append(_convert_to_snowflake_sql(agg_sql))
# sqls_to_return.append(_convert_to_snowflake_sql(agg_sql))
sqls_to_return.append(
agg_sql
) # do not convert to snowflake sql for now, as sqlglot make mistakes sometimes, e.g. with TO_DATE()
return sqls_to_return


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
ConnectionType = TypeVar("ConnectionType")
# Append this to the end of the auto-generated comments to indicate that the comment was auto-generated.
AUTOGEN_TOKEN = "__"
_autogen_model = "llama3-8b"
_autogen_model = "llama3.1-70b"

# This is the raw column name from snowflake information schema or desc table
_COMMENT_COL = "COMMENT"
Expand Down