From a7ff990dc50f7484a5f926f91fdb9b02a802a094 Mon Sep 17 00:00:00 2001 From: Tom Zayats Date: Thu, 5 Dec 2024 15:01:51 -0800 Subject: [PATCH] added results table --- app_utils/shared_utils.py | 78 +++++++++-- journeys/iteration.py | 121 +++++++++++++++++- .../snowflake_utils/snowflake_connector.py | 29 +++++ 3 files changed, 208 insertions(+), 20 deletions(-) diff --git a/app_utils/shared_utils.py b/app_utils/shared_utils.py index a13793aa..5cf19ede 100644 --- a/app_utils/shared_utils.py +++ b/app_utils/shared_utils.py @@ -3,11 +3,12 @@ import json import os import time +import tempfile from dataclasses import dataclass from datetime import datetime from enum import Enum from io import StringIO -from typing import Any, Optional, List, Union +from typing import Any, Optional, List, Union, Dict, Tuple import pandas as pd import streamlit as st @@ -201,7 +202,7 @@ def get_available_stages(schema: str) -> List[str]: return fetch_stages_in_schema(get_snowflake_connection(), schema) @st.cache_resource(show_spinner=False) -def validate_table_columns(table: str, columns_must_exist) -> bool: +def validate_table_columns(table: str, columns_must_exist: Tuple[str]) -> bool: """ Fetches the available stages from the Snowflake account. @@ -214,9 +215,22 @@ def validate_table_columns(table: str, columns_must_exist) -> bool: return False return True +@st.cache_resource(show_spinner=False) +def validate_table_exist(schema: str, table_name) -> bool: + """ + Validate table exist in the Snowflake account. + Returns: + List[str]: A list of available stages. + """ + table_names = fetch_tables_views_in_schema(get_snowflake_connection(), schema) + table_names = [table.split(".")[2] for table in table_names] + if table_name.upper() in table_names: + return True + return False + -def table_selector_container() -> Optional[str]: +def schema_selector_container(db_selector:Dict[str,str], schema_selector:Dict[str,str]) -> Optional[str]: """ Common component that encapsulates db/schema/table selection for the admin app. When a db/schema/table is selected, it is saved to the session state for reading elsewhere. @@ -227,10 +241,10 @@ def table_selector_container() -> Optional[str]: # First, retrieve all databases that the user has access to. eval_database = st.selectbox( - "Eval database", + db_selector["label"], options=get_available_databases(), index=None, - key="selected_eval_database", + key=db_selector["key"], ) if eval_database: # When a valid database is selected, fetch the available schemas in that database. @@ -241,10 +255,51 @@ def table_selector_container() -> Optional[str]: st.stop() eval_schema = st.selectbox( - "Eval schema", + schema_selector["label"], options=available_schemas, index=None, - key="selected_eval_schema", + key=schema_selector["key"], + format_func=lambda x: format_snowflake_context(x, -1), + ) + if eval_schema: + # When a valid schema is selected, fetch the available tables in that schema. + try: + available_tables = get_available_tables(eval_schema) + except (ValueError, ProgrammingError): + st.error("Insufficient permissions to read from the selected schema.") + st.stop() + + return available_tables + +def table_selector_container(db_selector:Dict[str,str], schema_selector:Dict[str,str],table_selector:Dict[str,str]) -> Optional[str]: + """ + Common component that encapsulates db/schema/table selection for the admin app. + When a db/schema/table is selected, it is saved to the session state for reading elsewhere. + Returns: None + """ + available_schemas = [] + available_tables = [] + + # First, retrieve all databases that the user has access to. + eval_database = st.selectbox( + db_selector["label"], + options=get_available_databases(), + index=None, + key=db_selector["key"], + ) + if eval_database: + # When a valid database is selected, fetch the available schemas in that database. + try: + available_schemas = get_available_schemas(eval_database) + except (ValueError, ProgrammingError): + st.error("Insufficient permissions to read from the selected database.") + st.stop() + + eval_schema = st.selectbox( + schema_selector["label"], + options=available_schemas, + index=None, + key=schema_selector["key"], format_func=lambda x: format_snowflake_context(x, -1), ) if eval_schema: @@ -256,10 +311,10 @@ def table_selector_container() -> Optional[str]: st.stop() tables = st.selectbox( - "Table name", + table_selector["label"], options=available_tables, index=None, - key="selected_eval_table", + key=table_selector["key"], format_func=lambda x: format_snowflake_context(x, -1), ) @@ -1052,9 +1107,6 @@ def show_yaml_in_dialog() -> None: def upload_yaml(file_name: str) -> None: """util to upload the semantic model.""" - import os - import tempfile - yaml = proto_to_yaml(st.session_state.semantic_model) with tempfile.TemporaryDirectory() as temp_dir: @@ -1113,8 +1165,6 @@ def model_is_validated() -> bool: def download_yaml(file_name: str, stage_name: str) -> str: """util to download a semantic YAML from a stage.""" - import os - import tempfile with tempfile.TemporaryDirectory() as temp_dir: # Downloads the YAML to {temp_dir}/{file_name}. diff --git a/journeys/iteration.py b/journeys/iteration.py index 92ddff72..f0c304af 100644 --- a/journeys/iteration.py +++ b/journeys/iteration.py @@ -28,7 +28,9 @@ return_home_button, stage_selector_container, table_selector_container, + schema_selector_container, validate_table_columns, + validate_table_exist, upload_yaml, validate_and_upload_tmp_yaml, ) @@ -46,6 +48,10 @@ from semantic_model_generator.protos import semantic_model_pb2 from semantic_model_generator.validate_model import validate +from semantic_model_generator.snowflake_utils.snowflake_connector import ( + create_table_in_schema, +) + def get_file_name() -> str: return st.session_state.file_name # type: ignore @@ -376,8 +382,11 @@ def chat_and_edit_vqr(_conn: SnowflakeConnection) -> None: @st.experimental_dialog("Evaluation Data", width="large") def evaluation_data_dialog() -> None: evaluation_table_columns = ["ID", "QUERY", "GOLD_SQL"] - st.markdown("Please enter evaluation select table") - table_selector_container() + st.markdown("Please select evaluation table") + table_selector_container( + db_selector={"key": "selected_eval_database","label":"Eval database"}, + schema_selector={"key": "selected_eval_schema","label":"Eval schema"}, + table_selector={"key": "selected_eval_table","label":"Eval table"},) if st.button("Use Table"): if ( not st.session_state["selected_eval_database"] @@ -387,8 +396,8 @@ def evaluation_data_dialog() -> None: st.error("Please fill in all fields.") return - if not validate_table_columns(st.session_state["selected_eval_table"], evaluation_table_columns): - st.error("Table must have columns {evaluation_table_columns} to be used in Evaluation") + if not validate_table_columns(st.session_state["selected_eval_table"], tuple(evaluation_table_columns)): + st.error("Table must have columns {evaluation_table_columns}.") return st.session_state["eval_table"] = SnowflakeTable( @@ -397,7 +406,85 @@ def evaluation_data_dialog() -> None: table_name=st.session_state["selected_eval_table"], ) st.rerun() + +@st.experimental_dialog("Evaluation Data", width="large") +def evaluation_results_data_dialog() -> None: + results_table_columns = {"ID":"VARCHAR", "QUERY":"VARCHAR", "GOLD_SQL":"VARCHAR","PREDICTED_SQL":"VARCHAR"} + st.markdown("Please select results table") + eval_results_existing_table = st.checkbox("Use existing table") + + if not eval_results_existing_table: + schema_selector_container( + db_selector={"key": "selected_results_eval_database","label":"Results database"}, + schema_selector={"key": "selected_results_eval_schema","label":"Results schema"},) + new_table_name = st.text_input( + key="selected_eval_results_table_name", + label="Enter the table name to upload evaluation results", + ) + if st.button("Create Table"): + if ( + not st.session_state["selected_results_eval_database"] + or not st.session_state["selected_results_eval_schema"] + or not new_table_name + ): + st.error("Please fill in all fields.") + return + + if ( + st.session_state["selected_results_eval_database"] + and st.session_state["selected_results_eval_schema"] + and validate_table_exist(st.session_state["selected_results_eval_schema"],new_table_name) + ): + st.error("Table already exists") + return + + + with st.spinner("Creating table..."): + success = create_table_in_schema(conn = get_snowflake_connection(), + schema_name=st.session_state["selected_results_eval_schema"], + table_name=new_table_name, + columns_schema={f"{k} {v}" for k,v in results_table_columns.items()}) + if success: + st.success(f"Table {new_table_name} created successfully!") + else: + st.error(f"Failed to create table {new_table_name}") + return + + fqn_table_name = ".".join([st.session_state["selected_results_eval_schema"],new_table_name.upper()]) + + st.session_state["eval_results_table"] = SnowflakeTable( + table_database=st.session_state["selected_results_eval_database"], + table_schema=st.session_state["selected_results_eval_schema"], + table_name=fqn_table_name, + ) + + st.rerun() + + else: + table_selector_container( + db_selector={"key": "selected_results_eval_database","label":"Results database"}, + schema_selector={"key": "selected_results_eval_schema","label":"Results schema"}, + table_selector={"key": "selected_results_eval_table","label":"Results table"},) + if st.button("Use Table"): + if ( + not st.session_state["selected_results_eval_database"] + or not st.session_state["selected_results_eval_schema"] + or not st.session_state["selected_results_eval_table"] + ): + st.error("Please fill in all fields.") + return + + if not validate_table_columns(st.session_state["selected_results_eval_table"], tuple(results_table_columns.keys())): + st.error(f"Table must have columns {list(results_table_columns.keys())}.") + return + + st.session_state["eval_results_table"] = SnowflakeTable( + table_database=st.session_state["selected_results_eval_database"], + table_schema=st.session_state["selected_results_eval_schema"], + table_name=st.session_state["selected_results_eval_table"], + ) + st.rerun() @@ -712,13 +799,35 @@ def chat_settings_dialog() -> None: Note that the Cortex Analyst semantic model must be validated before integrating partner semantics.""" + def evaluation_mode_show() -> None: - header_row = row([0.65, 0.15], vertical_align="center") + header_row = row([0.7, 0.3,0.3], vertical_align="center") header_row.markdown("**Evaluation**") - if header_row.button("Select Eval Data"): + if header_row.button("Select Eval Table"): evaluation_data_dialog() + if header_row.button("Select Result Table"): + evaluation_results_data_dialog() + + if "validated" in st.session_state and not st.session_state["validated"]: + st.error("Please validate your semantic model before evaluating.") + return + + if "eval_table" not in st.session_state: + st.error("Please select evaluation tables.") + return + + if "eval_results_table" not in st.session_state: + st.error("Please select evaluation results tables.") + return + + # TODO Replace with actual evaluation code probably from seperate file if "eval_table" in st.session_state: st.write(f'Using this table as eval table {st.session_state["eval_table"].to_dict()}') + if "eval_results_table" in st.session_state: + st.write(f'Using this table as eval results table {st.session_state["eval_results_table"].to_dict()}') + if st.session_state.validated: + st.write("Model validated") + def show() -> None: diff --git a/semantic_model_generator/snowflake_utils/snowflake_connector.py b/semantic_model_generator/snowflake_utils/snowflake_connector.py index 5f8c1951..761274eb 100644 --- a/semantic_model_generator/snowflake_utils/snowflake_connector.py +++ b/semantic_model_generator/snowflake_utils/snowflake_connector.py @@ -383,6 +383,35 @@ def fetch_yaml_names_in_stage( # The file name is prefixed with "@{stage_name}/", so we need to remove that prefix. return [result[0].split("/")[-1] for result in yaml_files] +def create_table_in_schema( + conn: SnowflakeConnection, table_name: str, schema_name: str, columns_schema: List[str] +) -> bool: + """ + Creates a table in the specified schema with the specified columns + Args: + conn: SnowflakeConnection to run the query + table_name: The name of the table to create + schema_name: The name of the schema to create the table in + columns: A list of Column objects representing the columns of the table + + Returns: True if the table was created successfully, False otherwise + """ + # Construct the create table query + create_table_query = f""" + CREATE TABLE IF NOT EXISTS {schema_name}.{table_name} ( + {', '.join(columns_schema)} + ) + """ + + # Execute the query + cursor = conn.cursor() + try: + cursor.execute(create_table_query) + return True + except ProgrammingError as e: + logger.error(f"Error creating table: {e}") + return False + def get_valid_schemas_tables_columns_df( conn: SnowflakeConnection,