Skip to content

Commit

Permalink
fmt
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-tzayats committed Dec 6, 2024
1 parent a7ff990 commit ed4d939
Show file tree
Hide file tree
Showing 6 changed files with 121 additions and 65 deletions.
32 changes: 20 additions & 12 deletions app_utils/shared_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,8 +201,9 @@ def get_available_stages(schema: str) -> List[str]:
"""
return fetch_stages_in_schema(get_snowflake_connection(), schema)


@st.cache_resource(show_spinner=False)
def validate_table_columns(table: str, columns_must_exist: Tuple[str]) -> bool:
def validate_table_columns(table: str, columns_must_exist: Tuple[str,...]) -> bool:
"""
Fetches the available stages from the Snowflake account.
Expand All @@ -215,8 +216,9 @@ def validate_table_columns(table: str, columns_must_exist: Tuple[str]) -> bool:
return False
return True


@st.cache_resource(show_spinner=False)
def validate_table_exist(schema: str, table_name) -> bool:
def validate_table_exist(schema: str, table_name:str) -> bool:
"""
Validate table exist in the Snowflake account.
Expand All @@ -228,9 +230,11 @@ def validate_table_exist(schema: str, table_name) -> bool:
if table_name.upper() in table_names:
return True
return False


def schema_selector_container(db_selector:Dict[str,str], schema_selector:Dict[str,str]) -> Optional[str]:

def schema_selector_container(
db_selector: Dict[str, str], schema_selector: Dict[str, str]
) -> List[str]:
"""
Common component that encapsulates db/schema/table selection for the admin app.
When a db/schema/table is selected, it is saved to the session state for reading elsewhere.
Expand Down Expand Up @@ -271,7 +275,12 @@ def schema_selector_container(db_selector:Dict[str,str], schema_selector:Dict[st

return available_tables

def table_selector_container(db_selector:Dict[str,str], schema_selector:Dict[str,str],table_selector:Dict[str,str]) -> Optional[str]:

def table_selector_container(
db_selector: Dict[str, str],
schema_selector: Dict[str, str],
table_selector: Dict[str, str],
) -> Optional[str]:
"""
Common component that encapsulates db/schema/table selection for the admin app.
When a db/schema/table is selected, it is saved to the session state for reading elsewhere.
Expand Down Expand Up @@ -1112,7 +1121,7 @@ def upload_yaml(file_name: str) -> None:
with tempfile.TemporaryDirectory() as temp_dir:
tmp_file_path = os.path.join(temp_dir, f"{file_name}.yaml")

with open(tmp_file_path, "w", encoding='utf-8') as temp_file:
with open(tmp_file_path, "w", encoding="utf-8") as temp_file:
temp_file.write(yaml)

st.session_state.session.file.put(
Expand Down Expand Up @@ -1168,12 +1177,10 @@ def download_yaml(file_name: str, stage_name: str) -> str:

with tempfile.TemporaryDirectory() as temp_dir:
# Downloads the YAML to {temp_dir}/{file_name}.
st.session_state.session.file.get(
f"@{stage_name}/{file_name}", temp_dir
)
st.session_state.session.file.get(f"@{stage_name}/{file_name}", temp_dir)

tmp_file_path = os.path.join(temp_dir, f"{file_name}")
with open(tmp_file_path, "r", encoding='utf-8') as temp_file:
with open(tmp_file_path, "r", encoding="utf-8") as temp_file:
# Read the raw contents from {temp_dir}/{file_name} and return it as a string.
yaml_str = temp_file.read()
return yaml_str
Expand Down Expand Up @@ -1379,7 +1386,7 @@ def model(self) -> Optional[str]:
return st.session_state.semantic_model.name # type: ignore
return None

def to_dict(self) -> dict[str, Union[str,None]]:
def to_dict(self) -> dict[str, Union[str, None]]:
return {
"User": self.user,
"Stage": self.stage,
Expand Down Expand Up @@ -1408,6 +1415,7 @@ def to_dict(self) -> dict[str, str]:
"Stage": self.stage_name,
}


@dataclass
class SnowflakeTable:
table_database: str
Expand All @@ -1419,4 +1427,4 @@ def to_dict(self) -> dict[str, str]:
"Database": self.table_database,
"Schema": self.table_schema,
"Table": self.table_name,
}
}
127 changes: 84 additions & 43 deletions journeys/iteration.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from streamlit import config
from streamlit import config

# Set minCachedMessageSize to 500 MB to disable forward message cache:
# st.set_config would trigger an error, only the set_config from config module works
config.set_option("global.minCachedMessageSize", 500 * 1e6)
Expand Down Expand Up @@ -188,9 +189,9 @@ def edit_verified_query(
st.session_state["successful_sql"] = True

except Exception as e:
st.session_state[
"error_state"
] = f"Edited SQL not compatible with semantic model provided, please double check: {e}"
st.session_state["error_state"] = (
f"Edited SQL not compatible with semantic model provided, please double check: {e}"
)

if st.session_state["error_state"] is not None:
st.error(st.session_state["error_state"])
Expand Down Expand Up @@ -384,9 +385,10 @@ def evaluation_data_dialog() -> None:
evaluation_table_columns = ["ID", "QUERY", "GOLD_SQL"]
st.markdown("Please select evaluation table")
table_selector_container(
db_selector={"key": "selected_eval_database","label":"Eval database"},
schema_selector={"key": "selected_eval_schema","label":"Eval schema"},
table_selector={"key": "selected_eval_table","label":"Eval table"},)
db_selector={"key": "selected_eval_database", "label": "Eval database"},
schema_selector={"key": "selected_eval_schema", "label": "Eval schema"},
table_selector={"key": "selected_eval_table", "label": "Eval table"},
)
if st.button("Use Table"):
if (
not st.session_state["selected_eval_database"]
Expand All @@ -395,8 +397,10 @@ def evaluation_data_dialog() -> None:
):
st.error("Please fill in all fields.")
return

if not validate_table_columns(st.session_state["selected_eval_table"], tuple(evaluation_table_columns)):

if not validate_table_columns(
st.session_state["selected_eval_table"], tuple(evaluation_table_columns)
):
st.error("Table must have columns {evaluation_table_columns}.")
return

Expand All @@ -407,21 +411,34 @@ def evaluation_data_dialog() -> None:
)
st.rerun()


@st.experimental_dialog("Evaluation Data", width="large")
def evaluation_results_data_dialog() -> None:
results_table_columns = {"ID":"VARCHAR", "QUERY":"VARCHAR", "GOLD_SQL":"VARCHAR","PREDICTED_SQL":"VARCHAR"}
results_table_columns = {
"ID": "VARCHAR",
"QUERY": "VARCHAR",
"GOLD_SQL": "VARCHAR",
"PREDICTED_SQL": "VARCHAR",
}
st.markdown("Please select results table")
eval_results_existing_table = st.checkbox("Use existing table")

if not eval_results_existing_table:
schema_selector_container(
db_selector={"key": "selected_results_eval_database","label":"Results database"},
schema_selector={"key": "selected_results_eval_schema","label":"Results schema"},)

db_selector={
"key": "selected_results_eval_database",
"label": "Results database",
},
schema_selector={
"key": "selected_results_eval_schema",
"label": "Results schema",
},
)

new_table_name = st.text_input(
key="selected_eval_results_table_name",
label="Enter the table name to upload evaluation results",
)
key="selected_eval_results_table_name",
label="Enter the table name to upload evaluation results",
)
if st.button("Create Table"):
if (
not st.session_state["selected_results_eval_database"]
Expand All @@ -430,28 +447,38 @@ def evaluation_results_data_dialog() -> None:
):
st.error("Please fill in all fields.")
return

if (
st.session_state["selected_results_eval_database"]
and st.session_state["selected_results_eval_schema"]
and validate_table_exist(st.session_state["selected_results_eval_schema"],new_table_name)
and validate_table_exist(
st.session_state["selected_results_eval_schema"], new_table_name
)
):
st.error("Table already exists")
return


with st.spinner("Creating table..."):
success = create_table_in_schema(conn = get_snowflake_connection(),
schema_name=st.session_state["selected_results_eval_schema"],
table_name=new_table_name,
columns_schema={f"{k} {v}" for k,v in results_table_columns.items()})
success = create_table_in_schema(
conn=get_snowflake_connection(),
schema_name=st.session_state["selected_results_eval_schema"],
table_name=new_table_name,
columns_schema=[
f"{k} {v}" for k, v in results_table_columns.items()
],
)
if success:
st.success(f"Table {new_table_name} created successfully!")
else:
st.error(f"Failed to create table {new_table_name}")
return

fqn_table_name = ".".join([st.session_state["selected_results_eval_schema"],new_table_name.upper()])

fqn_table_name = ".".join(
[
st.session_state["selected_results_eval_schema"],
new_table_name.upper(),
]
)

st.session_state["eval_results_table"] = SnowflakeTable(
table_database=st.session_state["selected_results_eval_database"],
Expand All @@ -463,9 +490,19 @@ def evaluation_results_data_dialog() -> None:

else:
table_selector_container(
db_selector={"key": "selected_results_eval_database","label":"Results database"},
schema_selector={"key": "selected_results_eval_schema","label":"Results schema"},
table_selector={"key": "selected_results_eval_table","label":"Results table"},)
db_selector={
"key": "selected_results_eval_database",
"label": "Results database",
},
schema_selector={
"key": "selected_results_eval_schema",
"label": "Results schema",
},
table_selector={
"key": "selected_results_eval_table",
"label": "Results table",
},
)
if st.button("Use Table"):
if (
not st.session_state["selected_results_eval_database"]
Expand All @@ -474,9 +511,14 @@ def evaluation_results_data_dialog() -> None:
):
st.error("Please fill in all fields.")
return

if not validate_table_columns(st.session_state["selected_results_eval_table"], tuple(results_table_columns.keys())):
st.error(f"Table must have columns {list(results_table_columns.keys())}.")

if not validate_table_columns(
st.session_state["selected_results_eval_table"],
tuple(results_table_columns.keys()),
):
st.error(
f"Table must have columns {list(results_table_columns.keys())}."
)
return

st.session_state["eval_results_table"] = SnowflakeTable(
Expand All @@ -487,8 +529,6 @@ def evaluation_results_data_dialog() -> None:
st.rerun()




@st.experimental_dialog("Upload", width="small")
def upload_dialog(content: str) -> None:
def upload_handler(file_name: str) -> None:
Expand Down Expand Up @@ -612,7 +652,6 @@ def yaml_editor(yaml_str: str) -> None:
"Evaluation Mode",
)


# Style text_area to mirror st.code
with stylable_container(key="customized_text_area", css_styles=css_yaml_editor):
content = st.text_area(
Expand Down Expand Up @@ -799,35 +838,37 @@ def chat_settings_dialog() -> None:
Note that the Cortex Analyst semantic model must be validated before integrating partner semantics."""



def evaluation_mode_show() -> None:
header_row = row([0.7, 0.3,0.3], vertical_align="center")
header_row = row([0.7, 0.3, 0.3], vertical_align="center")
header_row.markdown("**Evaluation**")
if header_row.button("Select Eval Table"):
evaluation_data_dialog()
if header_row.button("Select Result Table"):
evaluation_results_data_dialog()

if "validated" in st.session_state and not st.session_state["validated"]:
st.error("Please validate your semantic model before evaluating.")
return

if "eval_table" not in st.session_state:
st.error("Please select evaluation tables.")
return

if "eval_results_table" not in st.session_state:
st.error("Please select evaluation results tables.")
return

# TODO Replace with actual evaluation code probably from seperate file
if "eval_table" in st.session_state:
st.write(f'Using this table as eval table {st.session_state["eval_table"].to_dict()}')
st.write(
f'Using this table as eval table {st.session_state["eval_table"].to_dict()}'
)
if "eval_results_table" in st.session_state:
st.write(f'Using this table as eval results table {st.session_state["eval_results_table"].to_dict()}')
st.write(
f'Using this table as eval results table {st.session_state["eval_results_table"].to_dict()}'
)
if st.session_state.validated:
st.write("Model validated")



def show() -> None:
Expand Down Expand Up @@ -871,7 +912,7 @@ def show() -> None:
st.session_state.working_yml, language="yaml", line_numbers=True
)
elif st.session_state.eval_mode:

evaluation_mode_show()
else:
header_row = row([0.85, 0.15], vertical_align="center")
Expand Down
7 changes: 4 additions & 3 deletions partner/dbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,9 @@ def upload_dbt_semantic() -> None:
stage_files = st.multiselect("Staged files", options=available_files)
if stage_files:
for staged_file in stage_files:
file_content = download_yaml(staged_file,
st.session_state["selected_iteration_stage"])
file_content = download_yaml(
staged_file, st.session_state["selected_iteration_stage"]
)
uploaded_files.append(file_content)
else:
uploaded_files = st.file_uploader( # type: ignore
Expand All @@ -80,7 +81,7 @@ def upload_dbt_semantic() -> None:
key="dbt_files",
)
if uploaded_files:
partner_semantic: list[Union[None,DBTSemanticModel]] = []
partner_semantic: list[Union[None, DBTSemanticModel]] = []
for file in uploaded_files:
partner_semantic.extend(read_dbt_yaml(file)) # type: ignore

Expand Down
2 changes: 1 addition & 1 deletion partner/looker.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,7 +523,7 @@ def render_looker_explore_as_table(
target_lag: Optional[int] = 20,
target_lag_unit: Optional[str] = "minutes",
warehouse: Optional[str] = None,
) -> Union[None,dict[str, dict[str, str]]]:
) -> Union[None, dict[str, dict[str, str]]]:
"""
Creates materialized table corresponding to Looker Explore.
Args:
Expand Down
4 changes: 1 addition & 3 deletions partner/partner_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,9 +306,7 @@ def integrate_partner_semantics() -> None:
index=0,
help=COMPARE_SEMANTICS_HELP,
)
orphan_label, orphan_col1, orphan_col2 = st.columns(
3, gap="small"
)
orphan_label, orphan_col1, orphan_col2 = st.columns(3, gap="small")
with orphan_label:
st.write("Retain unmatched fields:")
with orphan_col1:
Expand Down
Loading

0 comments on commit ed4d939

Please sign in to comment.