Skip to content

Commit

Permalink
added results table
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-tzayats committed Dec 5, 2024
1 parent b7bff16 commit a7ff990
Show file tree
Hide file tree
Showing 3 changed files with 208 additions and 20 deletions.
78 changes: 64 additions & 14 deletions app_utils/shared_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
import json
import os
import time
import tempfile
from dataclasses import dataclass
from datetime import datetime
from enum import Enum
from io import StringIO
from typing import Any, Optional, List, Union
from typing import Any, Optional, List, Union, Dict, Tuple

import pandas as pd
import streamlit as st
Expand Down Expand Up @@ -201,7 +202,7 @@ def get_available_stages(schema: str) -> List[str]:
return fetch_stages_in_schema(get_snowflake_connection(), schema)

@st.cache_resource(show_spinner=False)
def validate_table_columns(table: str, columns_must_exist) -> bool:
def validate_table_columns(table: str, columns_must_exist: Tuple[str]) -> bool:
"""
Fetches the available stages from the Snowflake account.
Expand All @@ -214,9 +215,22 @@ def validate_table_columns(table: str, columns_must_exist) -> bool:
return False
return True

@st.cache_resource(show_spinner=False)
def validate_table_exist(schema: str, table_name) -> bool:
"""
Validate table exist in the Snowflake account.
Returns:
List[str]: A list of available stages.
"""
table_names = fetch_tables_views_in_schema(get_snowflake_connection(), schema)
table_names = [table.split(".")[2] for table in table_names]
if table_name.upper() in table_names:
return True
return False


def table_selector_container() -> Optional[str]:
def schema_selector_container(db_selector:Dict[str,str], schema_selector:Dict[str,str]) -> Optional[str]:
"""
Common component that encapsulates db/schema/table selection for the admin app.
When a db/schema/table is selected, it is saved to the session state for reading elsewhere.
Expand All @@ -227,10 +241,10 @@ def table_selector_container() -> Optional[str]:

# First, retrieve all databases that the user has access to.
eval_database = st.selectbox(
"Eval database",
db_selector["label"],
options=get_available_databases(),
index=None,
key="selected_eval_database",
key=db_selector["key"],
)
if eval_database:
# When a valid database is selected, fetch the available schemas in that database.
Expand All @@ -241,10 +255,51 @@ def table_selector_container() -> Optional[str]:
st.stop()

eval_schema = st.selectbox(
"Eval schema",
schema_selector["label"],
options=available_schemas,
index=None,
key="selected_eval_schema",
key=schema_selector["key"],
format_func=lambda x: format_snowflake_context(x, -1),
)
if eval_schema:
# When a valid schema is selected, fetch the available tables in that schema.
try:
available_tables = get_available_tables(eval_schema)
except (ValueError, ProgrammingError):
st.error("Insufficient permissions to read from the selected schema.")
st.stop()

return available_tables

def table_selector_container(db_selector:Dict[str,str], schema_selector:Dict[str,str],table_selector:Dict[str,str]) -> Optional[str]:
"""
Common component that encapsulates db/schema/table selection for the admin app.
When a db/schema/table is selected, it is saved to the session state for reading elsewhere.
Returns: None
"""
available_schemas = []
available_tables = []

# First, retrieve all databases that the user has access to.
eval_database = st.selectbox(
db_selector["label"],
options=get_available_databases(),
index=None,
key=db_selector["key"],
)
if eval_database:
# When a valid database is selected, fetch the available schemas in that database.
try:
available_schemas = get_available_schemas(eval_database)
except (ValueError, ProgrammingError):
st.error("Insufficient permissions to read from the selected database.")
st.stop()

eval_schema = st.selectbox(
schema_selector["label"],
options=available_schemas,
index=None,
key=schema_selector["key"],
format_func=lambda x: format_snowflake_context(x, -1),
)
if eval_schema:
Expand All @@ -256,10 +311,10 @@ def table_selector_container() -> Optional[str]:
st.stop()

tables = st.selectbox(
"Table name",
table_selector["label"],
options=available_tables,
index=None,
key="selected_eval_table",
key=table_selector["key"],
format_func=lambda x: format_snowflake_context(x, -1),
)

Expand Down Expand Up @@ -1052,9 +1107,6 @@ def show_yaml_in_dialog() -> None:

def upload_yaml(file_name: str) -> None:
"""util to upload the semantic model."""
import os
import tempfile

yaml = proto_to_yaml(st.session_state.semantic_model)

with tempfile.TemporaryDirectory() as temp_dir:
Expand Down Expand Up @@ -1113,8 +1165,6 @@ def model_is_validated() -> bool:

def download_yaml(file_name: str, stage_name: str) -> str:
"""util to download a semantic YAML from a stage."""
import os
import tempfile

with tempfile.TemporaryDirectory() as temp_dir:
# Downloads the YAML to {temp_dir}/{file_name}.
Expand Down
121 changes: 115 additions & 6 deletions journeys/iteration.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@
return_home_button,
stage_selector_container,
table_selector_container,
schema_selector_container,
validate_table_columns,
validate_table_exist,
upload_yaml,
validate_and_upload_tmp_yaml,
)
Expand All @@ -46,6 +48,10 @@
from semantic_model_generator.protos import semantic_model_pb2
from semantic_model_generator.validate_model import validate

from semantic_model_generator.snowflake_utils.snowflake_connector import (
create_table_in_schema,
)


def get_file_name() -> str:
return st.session_state.file_name # type: ignore
Expand Down Expand Up @@ -376,8 +382,11 @@ def chat_and_edit_vqr(_conn: SnowflakeConnection) -> None:
@st.experimental_dialog("Evaluation Data", width="large")
def evaluation_data_dialog() -> None:
evaluation_table_columns = ["ID", "QUERY", "GOLD_SQL"]
st.markdown("Please enter evaluation select table")
table_selector_container()
st.markdown("Please select evaluation table")
table_selector_container(
db_selector={"key": "selected_eval_database","label":"Eval database"},
schema_selector={"key": "selected_eval_schema","label":"Eval schema"},
table_selector={"key": "selected_eval_table","label":"Eval table"},)
if st.button("Use Table"):
if (
not st.session_state["selected_eval_database"]
Expand All @@ -387,8 +396,8 @@ def evaluation_data_dialog() -> None:
st.error("Please fill in all fields.")
return

if not validate_table_columns(st.session_state["selected_eval_table"], evaluation_table_columns):
st.error("Table must have columns {evaluation_table_columns} to be used in Evaluation")
if not validate_table_columns(st.session_state["selected_eval_table"], tuple(evaluation_table_columns)):
st.error("Table must have columns {evaluation_table_columns}.")
return

st.session_state["eval_table"] = SnowflakeTable(
Expand All @@ -397,7 +406,85 @@ def evaluation_data_dialog() -> None:
table_name=st.session_state["selected_eval_table"],
)
st.rerun()

@st.experimental_dialog("Evaluation Data", width="large")
def evaluation_results_data_dialog() -> None:
results_table_columns = {"ID":"VARCHAR", "QUERY":"VARCHAR", "GOLD_SQL":"VARCHAR","PREDICTED_SQL":"VARCHAR"}
st.markdown("Please select results table")
eval_results_existing_table = st.checkbox("Use existing table")

if not eval_results_existing_table:
schema_selector_container(
db_selector={"key": "selected_results_eval_database","label":"Results database"},
schema_selector={"key": "selected_results_eval_schema","label":"Results schema"},)

new_table_name = st.text_input(
key="selected_eval_results_table_name",
label="Enter the table name to upload evaluation results",
)
if st.button("Create Table"):
if (
not st.session_state["selected_results_eval_database"]
or not st.session_state["selected_results_eval_schema"]
or not new_table_name
):
st.error("Please fill in all fields.")
return

if (
st.session_state["selected_results_eval_database"]
and st.session_state["selected_results_eval_schema"]
and validate_table_exist(st.session_state["selected_results_eval_schema"],new_table_name)
):
st.error("Table already exists")
return


with st.spinner("Creating table..."):
success = create_table_in_schema(conn = get_snowflake_connection(),
schema_name=st.session_state["selected_results_eval_schema"],
table_name=new_table_name,
columns_schema={f"{k} {v}" for k,v in results_table_columns.items()})
if success:
st.success(f"Table {new_table_name} created successfully!")
else:
st.error(f"Failed to create table {new_table_name}")
return

fqn_table_name = ".".join([st.session_state["selected_results_eval_schema"],new_table_name.upper()])

st.session_state["eval_results_table"] = SnowflakeTable(
table_database=st.session_state["selected_results_eval_database"],
table_schema=st.session_state["selected_results_eval_schema"],
table_name=fqn_table_name,
)

st.rerun()

else:
table_selector_container(
db_selector={"key": "selected_results_eval_database","label":"Results database"},
schema_selector={"key": "selected_results_eval_schema","label":"Results schema"},
table_selector={"key": "selected_results_eval_table","label":"Results table"},)
if st.button("Use Table"):
if (
not st.session_state["selected_results_eval_database"]
or not st.session_state["selected_results_eval_schema"]
or not st.session_state["selected_results_eval_table"]
):
st.error("Please fill in all fields.")
return

if not validate_table_columns(st.session_state["selected_results_eval_table"], tuple(results_table_columns.keys())):
st.error(f"Table must have columns {list(results_table_columns.keys())}.")
return

st.session_state["eval_results_table"] = SnowflakeTable(
table_database=st.session_state["selected_results_eval_database"],
table_schema=st.session_state["selected_results_eval_schema"],
table_name=st.session_state["selected_results_eval_table"],
)
st.rerun()



Expand Down Expand Up @@ -712,13 +799,35 @@ def chat_settings_dialog() -> None:
Note that the Cortex Analyst semantic model must be validated before integrating partner semantics."""



def evaluation_mode_show() -> None:
header_row = row([0.65, 0.15], vertical_align="center")
header_row = row([0.7, 0.3,0.3], vertical_align="center")
header_row.markdown("**Evaluation**")
if header_row.button("Select Eval Data"):
if header_row.button("Select Eval Table"):
evaluation_data_dialog()
if header_row.button("Select Result Table"):
evaluation_results_data_dialog()

if "validated" in st.session_state and not st.session_state["validated"]:
st.error("Please validate your semantic model before evaluating.")
return

if "eval_table" not in st.session_state:
st.error("Please select evaluation tables.")
return

if "eval_results_table" not in st.session_state:
st.error("Please select evaluation results tables.")
return

# TODO Replace with actual evaluation code probably from seperate file
if "eval_table" in st.session_state:
st.write(f'Using this table as eval table {st.session_state["eval_table"].to_dict()}')
if "eval_results_table" in st.session_state:
st.write(f'Using this table as eval results table {st.session_state["eval_results_table"].to_dict()}')
if st.session_state.validated:
st.write("Model validated")



def show() -> None:
Expand Down
29 changes: 29 additions & 0 deletions semantic_model_generator/snowflake_utils/snowflake_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,35 @@ def fetch_yaml_names_in_stage(
# The file name is prefixed with "@{stage_name}/", so we need to remove that prefix.
return [result[0].split("/")[-1] for result in yaml_files]

def create_table_in_schema(
conn: SnowflakeConnection, table_name: str, schema_name: str, columns_schema: List[str]
) -> bool:
"""
Creates a table in the specified schema with the specified columns
Args:
conn: SnowflakeConnection to run the query
table_name: The name of the table to create
schema_name: The name of the schema to create the table in
columns: A list of Column objects representing the columns of the table
Returns: True if the table was created successfully, False otherwise
"""
# Construct the create table query
create_table_query = f"""
CREATE TABLE IF NOT EXISTS {schema_name}.{table_name} (
{', '.join(columns_schema)}
)
"""

# Execute the query
cursor = conn.cursor()
try:
cursor.execute(create_table_query)
return True
except ProgrammingError as e:
logger.error(f"Error creating table: {e}")
return False


def get_valid_schemas_tables_columns_df(
conn: SnowflakeConnection,
Expand Down

0 comments on commit a7ff990

Please sign in to comment.