diff --git a/app_utils/shared_utils.py b/app_utils/shared_utils.py index 5cf19ede..c3515877 100644 --- a/app_utils/shared_utils.py +++ b/app_utils/shared_utils.py @@ -201,8 +201,9 @@ def get_available_stages(schema: str) -> List[str]: """ return fetch_stages_in_schema(get_snowflake_connection(), schema) + @st.cache_resource(show_spinner=False) -def validate_table_columns(table: str, columns_must_exist: Tuple[str]) -> bool: +def validate_table_columns(table: str, columns_must_exist: Tuple[str,...]) -> bool: """ Fetches the available stages from the Snowflake account. @@ -215,8 +216,9 @@ def validate_table_columns(table: str, columns_must_exist: Tuple[str]) -> bool: return False return True + @st.cache_resource(show_spinner=False) -def validate_table_exist(schema: str, table_name) -> bool: +def validate_table_exist(schema: str, table_name:str) -> bool: """ Validate table exist in the Snowflake account. @@ -228,9 +230,11 @@ def validate_table_exist(schema: str, table_name) -> bool: if table_name.upper() in table_names: return True return False - -def schema_selector_container(db_selector:Dict[str,str], schema_selector:Dict[str,str]) -> Optional[str]: + +def schema_selector_container( + db_selector: Dict[str, str], schema_selector: Dict[str, str] +) -> List[str]: """ Common component that encapsulates db/schema/table selection for the admin app. When a db/schema/table is selected, it is saved to the session state for reading elsewhere. @@ -271,7 +275,12 @@ def schema_selector_container(db_selector:Dict[str,str], schema_selector:Dict[st return available_tables -def table_selector_container(db_selector:Dict[str,str], schema_selector:Dict[str,str],table_selector:Dict[str,str]) -> Optional[str]: + +def table_selector_container( + db_selector: Dict[str, str], + schema_selector: Dict[str, str], + table_selector: Dict[str, str], +) -> Optional[str]: """ Common component that encapsulates db/schema/table selection for the admin app. When a db/schema/table is selected, it is saved to the session state for reading elsewhere. @@ -1112,7 +1121,7 @@ def upload_yaml(file_name: str) -> None: with tempfile.TemporaryDirectory() as temp_dir: tmp_file_path = os.path.join(temp_dir, f"{file_name}.yaml") - with open(tmp_file_path, "w", encoding='utf-8') as temp_file: + with open(tmp_file_path, "w", encoding="utf-8") as temp_file: temp_file.write(yaml) st.session_state.session.file.put( @@ -1168,12 +1177,10 @@ def download_yaml(file_name: str, stage_name: str) -> str: with tempfile.TemporaryDirectory() as temp_dir: # Downloads the YAML to {temp_dir}/{file_name}. - st.session_state.session.file.get( - f"@{stage_name}/{file_name}", temp_dir - ) + st.session_state.session.file.get(f"@{stage_name}/{file_name}", temp_dir) tmp_file_path = os.path.join(temp_dir, f"{file_name}") - with open(tmp_file_path, "r", encoding='utf-8') as temp_file: + with open(tmp_file_path, "r", encoding="utf-8") as temp_file: # Read the raw contents from {temp_dir}/{file_name} and return it as a string. yaml_str = temp_file.read() return yaml_str @@ -1379,7 +1386,7 @@ def model(self) -> Optional[str]: return st.session_state.semantic_model.name # type: ignore return None - def to_dict(self) -> dict[str, Union[str,None]]: + def to_dict(self) -> dict[str, Union[str, None]]: return { "User": self.user, "Stage": self.stage, @@ -1408,6 +1415,7 @@ def to_dict(self) -> dict[str, str]: "Stage": self.stage_name, } + @dataclass class SnowflakeTable: table_database: str @@ -1419,4 +1427,4 @@ def to_dict(self) -> dict[str, str]: "Database": self.table_database, "Schema": self.table_schema, "Table": self.table_name, - } \ No newline at end of file + } diff --git a/journeys/iteration.py b/journeys/iteration.py index f0c304af..35c08af5 100644 --- a/journeys/iteration.py +++ b/journeys/iteration.py @@ -1,4 +1,5 @@ -from streamlit import config +from streamlit import config + # Set minCachedMessageSize to 500 MB to disable forward message cache: # st.set_config would trigger an error, only the set_config from config module works config.set_option("global.minCachedMessageSize", 500 * 1e6) @@ -188,9 +189,9 @@ def edit_verified_query( st.session_state["successful_sql"] = True except Exception as e: - st.session_state[ - "error_state" - ] = f"Edited SQL not compatible with semantic model provided, please double check: {e}" + st.session_state["error_state"] = ( + f"Edited SQL not compatible with semantic model provided, please double check: {e}" + ) if st.session_state["error_state"] is not None: st.error(st.session_state["error_state"]) @@ -384,9 +385,10 @@ def evaluation_data_dialog() -> None: evaluation_table_columns = ["ID", "QUERY", "GOLD_SQL"] st.markdown("Please select evaluation table") table_selector_container( - db_selector={"key": "selected_eval_database","label":"Eval database"}, - schema_selector={"key": "selected_eval_schema","label":"Eval schema"}, - table_selector={"key": "selected_eval_table","label":"Eval table"},) + db_selector={"key": "selected_eval_database", "label": "Eval database"}, + schema_selector={"key": "selected_eval_schema", "label": "Eval schema"}, + table_selector={"key": "selected_eval_table", "label": "Eval table"}, + ) if st.button("Use Table"): if ( not st.session_state["selected_eval_database"] @@ -395,8 +397,10 @@ def evaluation_data_dialog() -> None: ): st.error("Please fill in all fields.") return - - if not validate_table_columns(st.session_state["selected_eval_table"], tuple(evaluation_table_columns)): + + if not validate_table_columns( + st.session_state["selected_eval_table"], tuple(evaluation_table_columns) + ): st.error("Table must have columns {evaluation_table_columns}.") return @@ -407,21 +411,34 @@ def evaluation_data_dialog() -> None: ) st.rerun() + @st.experimental_dialog("Evaluation Data", width="large") def evaluation_results_data_dialog() -> None: - results_table_columns = {"ID":"VARCHAR", "QUERY":"VARCHAR", "GOLD_SQL":"VARCHAR","PREDICTED_SQL":"VARCHAR"} + results_table_columns = { + "ID": "VARCHAR", + "QUERY": "VARCHAR", + "GOLD_SQL": "VARCHAR", + "PREDICTED_SQL": "VARCHAR", + } st.markdown("Please select results table") eval_results_existing_table = st.checkbox("Use existing table") if not eval_results_existing_table: schema_selector_container( - db_selector={"key": "selected_results_eval_database","label":"Results database"}, - schema_selector={"key": "selected_results_eval_schema","label":"Results schema"},) - + db_selector={ + "key": "selected_results_eval_database", + "label": "Results database", + }, + schema_selector={ + "key": "selected_results_eval_schema", + "label": "Results schema", + }, + ) + new_table_name = st.text_input( - key="selected_eval_results_table_name", - label="Enter the table name to upload evaluation results", - ) + key="selected_eval_results_table_name", + label="Enter the table name to upload evaluation results", + ) if st.button("Create Table"): if ( not st.session_state["selected_results_eval_database"] @@ -430,28 +447,38 @@ def evaluation_results_data_dialog() -> None: ): st.error("Please fill in all fields.") return - + if ( st.session_state["selected_results_eval_database"] and st.session_state["selected_results_eval_schema"] - and validate_table_exist(st.session_state["selected_results_eval_schema"],new_table_name) + and validate_table_exist( + st.session_state["selected_results_eval_schema"], new_table_name + ) ): st.error("Table already exists") return - with st.spinner("Creating table..."): - success = create_table_in_schema(conn = get_snowflake_connection(), - schema_name=st.session_state["selected_results_eval_schema"], - table_name=new_table_name, - columns_schema={f"{k} {v}" for k,v in results_table_columns.items()}) + success = create_table_in_schema( + conn=get_snowflake_connection(), + schema_name=st.session_state["selected_results_eval_schema"], + table_name=new_table_name, + columns_schema=[ + f"{k} {v}" for k, v in results_table_columns.items() + ], + ) if success: st.success(f"Table {new_table_name} created successfully!") else: st.error(f"Failed to create table {new_table_name}") return - - fqn_table_name = ".".join([st.session_state["selected_results_eval_schema"],new_table_name.upper()]) + + fqn_table_name = ".".join( + [ + st.session_state["selected_results_eval_schema"], + new_table_name.upper(), + ] + ) st.session_state["eval_results_table"] = SnowflakeTable( table_database=st.session_state["selected_results_eval_database"], @@ -463,9 +490,19 @@ def evaluation_results_data_dialog() -> None: else: table_selector_container( - db_selector={"key": "selected_results_eval_database","label":"Results database"}, - schema_selector={"key": "selected_results_eval_schema","label":"Results schema"}, - table_selector={"key": "selected_results_eval_table","label":"Results table"},) + db_selector={ + "key": "selected_results_eval_database", + "label": "Results database", + }, + schema_selector={ + "key": "selected_results_eval_schema", + "label": "Results schema", + }, + table_selector={ + "key": "selected_results_eval_table", + "label": "Results table", + }, + ) if st.button("Use Table"): if ( not st.session_state["selected_results_eval_database"] @@ -474,9 +511,14 @@ def evaluation_results_data_dialog() -> None: ): st.error("Please fill in all fields.") return - - if not validate_table_columns(st.session_state["selected_results_eval_table"], tuple(results_table_columns.keys())): - st.error(f"Table must have columns {list(results_table_columns.keys())}.") + + if not validate_table_columns( + st.session_state["selected_results_eval_table"], + tuple(results_table_columns.keys()), + ): + st.error( + f"Table must have columns {list(results_table_columns.keys())}." + ) return st.session_state["eval_results_table"] = SnowflakeTable( @@ -487,8 +529,6 @@ def evaluation_results_data_dialog() -> None: st.rerun() - - @st.experimental_dialog("Upload", width="small") def upload_dialog(content: str) -> None: def upload_handler(file_name: str) -> None: @@ -612,7 +652,6 @@ def yaml_editor(yaml_str: str) -> None: "Evaluation Mode", ) - # Style text_area to mirror st.code with stylable_container(key="customized_text_area", css_styles=css_yaml_editor): content = st.text_area( @@ -799,35 +838,37 @@ def chat_settings_dialog() -> None: Note that the Cortex Analyst semantic model must be validated before integrating partner semantics.""" - def evaluation_mode_show() -> None: - header_row = row([0.7, 0.3,0.3], vertical_align="center") + header_row = row([0.7, 0.3, 0.3], vertical_align="center") header_row.markdown("**Evaluation**") if header_row.button("Select Eval Table"): evaluation_data_dialog() if header_row.button("Select Result Table"): evaluation_results_data_dialog() - + if "validated" in st.session_state and not st.session_state["validated"]: st.error("Please validate your semantic model before evaluating.") return - + if "eval_table" not in st.session_state: st.error("Please select evaluation tables.") return - + if "eval_results_table" not in st.session_state: st.error("Please select evaluation results tables.") return # TODO Replace with actual evaluation code probably from seperate file if "eval_table" in st.session_state: - st.write(f'Using this table as eval table {st.session_state["eval_table"].to_dict()}') + st.write( + f'Using this table as eval table {st.session_state["eval_table"].to_dict()}' + ) if "eval_results_table" in st.session_state: - st.write(f'Using this table as eval results table {st.session_state["eval_results_table"].to_dict()}') + st.write( + f'Using this table as eval results table {st.session_state["eval_results_table"].to_dict()}' + ) if st.session_state.validated: st.write("Model validated") - def show() -> None: @@ -871,7 +912,7 @@ def show() -> None: st.session_state.working_yml, language="yaml", line_numbers=True ) elif st.session_state.eval_mode: - + evaluation_mode_show() else: header_row = row([0.85, 0.15], vertical_align="center") diff --git a/partner/dbt.py b/partner/dbt.py index a7a9f76d..46ff5da4 100644 --- a/partner/dbt.py +++ b/partner/dbt.py @@ -69,8 +69,9 @@ def upload_dbt_semantic() -> None: stage_files = st.multiselect("Staged files", options=available_files) if stage_files: for staged_file in stage_files: - file_content = download_yaml(staged_file, - st.session_state["selected_iteration_stage"]) + file_content = download_yaml( + staged_file, st.session_state["selected_iteration_stage"] + ) uploaded_files.append(file_content) else: uploaded_files = st.file_uploader( # type: ignore @@ -80,7 +81,7 @@ def upload_dbt_semantic() -> None: key="dbt_files", ) if uploaded_files: - partner_semantic: list[Union[None,DBTSemanticModel]] = [] + partner_semantic: list[Union[None, DBTSemanticModel]] = [] for file in uploaded_files: partner_semantic.extend(read_dbt_yaml(file)) # type: ignore diff --git a/partner/looker.py b/partner/looker.py index f728a0b7..9a9bacce 100644 --- a/partner/looker.py +++ b/partner/looker.py @@ -523,7 +523,7 @@ def render_looker_explore_as_table( target_lag: Optional[int] = 20, target_lag_unit: Optional[str] = "minutes", warehouse: Optional[str] = None, -) -> Union[None,dict[str, dict[str, str]]]: +) -> Union[None, dict[str, dict[str, str]]]: """ Creates materialized table corresponding to Looker Explore. Args: diff --git a/partner/partner_utils.py b/partner/partner_utils.py index 24d7710d..576ad4e1 100644 --- a/partner/partner_utils.py +++ b/partner/partner_utils.py @@ -306,9 +306,7 @@ def integrate_partner_semantics() -> None: index=0, help=COMPARE_SEMANTICS_HELP, ) - orphan_label, orphan_col1, orphan_col2 = st.columns( - 3, gap="small" - ) + orphan_label, orphan_col1, orphan_col2 = st.columns(3, gap="small") with orphan_label: st.write("Retain unmatched fields:") with orphan_col1: diff --git a/semantic_model_generator/snowflake_utils/snowflake_connector.py b/semantic_model_generator/snowflake_utils/snowflake_connector.py index 761274eb..84eb7b06 100644 --- a/semantic_model_generator/snowflake_utils/snowflake_connector.py +++ b/semantic_model_generator/snowflake_utils/snowflake_connector.py @@ -345,7 +345,10 @@ def fetch_stages_in_schema(conn: SnowflakeConnection, schema_name: str) -> list[ return [f"{result[2]}.{result[3]}.{result[1]}" for result in stages] -def fetch_columns_names_in_table(conn: SnowflakeConnection, table_fqn: str) -> list[str]: + +def fetch_columns_names_in_table( + conn: SnowflakeConnection, table_fqn: str +) -> list[str]: """ Fetches all columns that the current user has access to in the current table Args: @@ -360,6 +363,7 @@ def fetch_columns_names_in_table(conn: SnowflakeConnection, table_fqn: str) -> l columns = cursor.fetchall() return [result[0] for result in columns] + def fetch_yaml_names_in_stage( conn: SnowflakeConnection, stage_name: str, include_yml: bool = False ) -> list[str]: @@ -383,8 +387,12 @@ def fetch_yaml_names_in_stage( # The file name is prefixed with "@{stage_name}/", so we need to remove that prefix. return [result[0].split("/")[-1] for result in yaml_files] + def create_table_in_schema( - conn: SnowflakeConnection, table_name: str, schema_name: str, columns_schema: List[str] + conn: SnowflakeConnection, + table_name: str, + schema_name: str, + columns_schema: List[str], ) -> bool: """ Creates a table in the specified schema with the specified columns @@ -411,7 +419,7 @@ def create_table_in_schema( except ProgrammingError as e: logger.error(f"Error creating table: {e}") return False - + def get_valid_schemas_tables_columns_df( conn: SnowflakeConnection,