Skip to content

Commit

Permalink
Add DB/Schema/Table selector to builder flow (#120)
Browse files Browse the repository at this point in the history
When using the builder flow, it's a bit of a pain to type in the exact
db/schema/table identifiers every time. This PR adds dropdown menus that
allows the user to select the relevant fields instead of having to type.

It would also be good to do this for the iteration flow as well,
providing db/schema/stage selectors for the YAML file.

Partially addresses
#94.

## Testing
Run the builder flow and verify that the db/schema/table selector is
functional and replaces the text input UI.
  • Loading branch information
sfc-gh-cnivera authored Aug 1, 2024
1 parent 0e63d27 commit 888eef4
Show file tree
Hide file tree
Showing 2 changed files with 198 additions and 57 deletions.
201 changes: 149 additions & 52 deletions admin_apps/journeys/builder.py
Original file line number Diff line number Diff line change
@@ -1,78 +1,175 @@
import streamlit as st
from loguru import logger
from snowflake.connector import ProgrammingError

from admin_apps.shared_utils import GeneratorAppScreen, get_snowflake_connection
from semantic_model_generator.generate_model import generate_model_str_from_snowflake
from semantic_model_generator.snowflake_utils.snowflake_connector import (
fetch_table_names,
fetch_databases,
fetch_schemas_in_database,
fetch_tables_views_in_schema,
)


@st.cache_resource(show_spinner=False)
def get_available_tables() -> list[str]:
def get_available_tables(schema: str) -> list[str]:
"""
Simple wrapper around fetch_table_names to cache the results.
Returns: list of fully qualified table names
"""

return fetch_table_names(get_snowflake_connection())
return fetch_tables_views_in_schema(get_snowflake_connection(), schema)


@st.experimental_dialog("Selecting your tables", width="large")
def table_selector_dialog() -> None:
@st.cache_resource(show_spinner=False)
def get_available_schemas(db: str) -> list[str]:
"""
Simple wrapper around fetch_schemas to cache the results.
Returns: list of schema names
"""

return fetch_schemas_in_database(get_snowflake_connection(), db)


@st.cache_resource(show_spinner=False)
def get_available_databases() -> list[str]:
"""
Simple wrapper around fetch_databases to cache the results.
Returns: list of database names
"""

return fetch_databases(get_snowflake_connection())


def update_schemas_and_tables() -> None:
"""
Renders a dialog box for the user to input the tables they want to use in their semantic model.
Callback to run when the selected databases change. Ensures that if a database is deselected, the corresponding
schemas and tables are also deselected.
Returns: None
"""
databases = st.session_state["selected_databases"]

# Fetch the available schemas for the selected databases
schemas = []
for db in databases:
try:
schemas.extend(get_available_schemas(db))
except ProgrammingError:
logger.info(
f"Insufficient permissions to read from database {db}, skipping"
)

st.session_state["available_schemas"] = schemas

# Enforce that the previously selected schemas are still valid
valid_selected_schemas = [
schema for schema in st.session_state["selected_schemas"] if schema in schemas
]
st.session_state["selected_schemas"] = valid_selected_schemas
update_tables()


def update_tables() -> None:
"""
Callback to run when the selected schemas change. Ensures that if a schema is deselected, the corresponding
tables are also deselected.
"""
schemas = st.session_state["selected_schemas"]

# Fetch the available tables for the selected schemas
tables = []
for schema in schemas:
try:
tables.extend(get_available_tables(schema))
except ProgrammingError:
logger.info(
f"Insufficient permissions to read from schema {schema}, skipping"
)
st.session_state["available_tables"] = tables

# Enforce that the previously selected tables are still valid
valid_selected_tables = [
table for table in st.session_state["selected_tables"] if table in tables
]
st.session_state["selected_tables"] = valid_selected_tables


@st.experimental_dialog("Selecting your tables", width="large")
def table_selector_dialog() -> None:
st.write(
"Please fill out the following fields to start building your semantic model."
)
with st.form("table_selector_form"):
model_name = st.text_input(
"Semantic Model Name (no .yaml suffix)",
help="The name of the semantic model you are creating. This is separate from the filename, which we will set later.",
)
sample_values = st.selectbox(
"Maximum number of sample values per column",
list(range(1, 40)),
index=0,
help="NOTE: For dimensions, time measures, and measures, we enforce a minimum of 25, 3, and 3 sample values respectively.",
)
st.markdown("")

if "available_tables" not in st.session_state:
with st.spinner("Loading table definitions..."):
st.session_state["available_tables"] = get_available_tables()

tables = st.multiselect(
label="Tables",
options=st.session_state["available_tables"],
placeholder="Select the tables you'd like to include in your semantic model.",
)
st.markdown("<div style='margin: 240px;'></div>", unsafe_allow_html=True)
submit = st.form_submit_button(
"Submit", use_container_width=True, type="primary"
)
if submit:
if not model_name:
st.error("Please provide a name for your semantic model.")
elif not tables:
st.error("Please select at least one table to proceed.")
else:
with st.spinner("Generating model..."):
yaml_str = generate_model_str_from_snowflake(
base_tables=tables,
snowflake_account=st.session_state["account_name"],
semantic_model_name=model_name,
n_sample_values=sample_values, # type: ignore
conn=get_snowflake_connection(),
)

# Set the YAML session state so that the iteration app has access to the generated contents,
# then proceed to the iteration screen.
st.session_state["yaml"] = yaml_str
st.session_state["page"] = GeneratorAppScreen.ITERATION
st.rerun()
model_name = st.text_input(
"Semantic Model Name (no .yaml suffix)",
help="The name of the semantic model you are creating. This is separate from the filename, which we will set later.",
)
sample_values = st.selectbox(
"Maximum number of sample values per column",
list(range(1, 40)),
index=0,
help="NOTE: For dimensions, time measures, and measures, we enforce a minimum of 25, 3, and 3 sample values respectively.",
)
st.markdown("")

if "selected_databases" not in st.session_state:
st.session_state["selected_databases"] = []

if "selected_schemas" not in st.session_state:
st.session_state["selected_schemas"] = []

if "selected_tables" not in st.session_state:
st.session_state["selected_tables"] = []

with st.spinner("Loading databases..."):
available_databases = get_available_databases()

st.multiselect(
label="Databases",
options=available_databases,
placeholder="Select the databases that contain the tables you'd like to include in your semantic model.",
on_change=update_schemas_and_tables,
key="selected_databases",
)

st.multiselect(
label="Schemas",
options=st.session_state.get("available_schemas", []),
placeholder="Select the schemas that contain the tables you'd like to include in your semantic model.",
on_change=update_tables,
key="selected_schemas",
)

st.multiselect(
label="Tables",
options=st.session_state.get("available_tables", []),
placeholder="Select the tables you'd like to include in your semantic model.",
key="selected_tables",
)

st.markdown("<div style='margin: 240px;'></div>", unsafe_allow_html=True)
submit = st.button("Submit", use_container_width=True, type="primary")
if submit:
if not model_name:
st.error("Please provide a name for your semantic model.")
elif not st.session_state["selected_tables"]:
st.error("Please select at least one table to proceed.")
else:
with st.spinner("Generating model. This may take a minute or two..."):
yaml_str = generate_model_str_from_snowflake(
base_tables=st.session_state["selected_tables"],
snowflake_account=st.session_state["account_name"],
semantic_model_name=model_name,
n_sample_values=sample_values, # type: ignore
conn=get_snowflake_connection(),
)

st.session_state["yaml"] = yaml_str
st.session_state["page"] = GeneratorAppScreen.ITERATION
st.rerun()


def show() -> None:
Expand Down
54 changes: 49 additions & 5 deletions semantic_model_generator/snowflake_utils/snowflake_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,21 +244,65 @@ def _get_df(query: str) -> pd.DataFrame:
return pd.concat([tables, views], axis=0)


def fetch_table_names(conn: SnowflakeConnection) -> list[str]:
def fetch_databases(conn: SnowflakeConnection) -> List[str]:
"""
Fetches all tables that the current user has access to, throughout all db/schema.
Fetches all databases that the current user has access to
Args:
conn: SnowflakeConnection to run the query
Returns: a list of fully qualified table names.
Returns: a list of database names
"""
query = "show databases;"
cursor = conn.cursor()
cursor.execute(query)
results = cursor.fetchall()
return [result[1] for result in results]


query = "show tables;"
def fetch_schemas_in_database(conn: SnowflakeConnection, db_name: str) -> List[str]:
"""
Fetches all schemas that the current user has access to in the current database
Args:
conn: SnowflakeConnection to run the query
db_name: The name of the database to connect to.
Returns: a list of qualified schema names (db.schema)
"""
query = f"show schemas in database {db_name};"
cursor = conn.cursor()
cursor.execute(query)
results = cursor.fetchall()
return [f"{result[4]}.{result[1]}" for result in results]


def fetch_tables_views_in_schema(
conn: SnowflakeConnection, schema_name: str
) -> list[str]:
"""
Fetches all tables and views that the current user has access to in the current schema
Args:
conn: SnowflakeConnection to run the query
schema_name: The name of the schema to connect to.
Returns: a list of fully qualified table names.
"""
query = f"show tables in schema {schema_name};"
cursor = conn.cursor()
cursor.execute(query)
tables = cursor.fetchall()
# Each row in the result has columns (created_on, table_name, database_name, schema_name, ...)
return [f"{result[2]}.{result[3]}.{result[1]}" for result in results]
results = [f"{result[2]}.{result[3]}.{result[1]}" for result in tables]

query = f"show views in schema {schema_name};"
cursor = conn.cursor()
cursor.execute(query)
views = cursor.fetchall()
# Each row in the result has columns (created_on, view_name, reserved, database_name, schema_name, ...)
results += [f"{result[3]}.{result[4]}.{result[1]}" for result in views]

return results


def get_valid_schemas_tables_columns_df(
Expand Down

0 comments on commit 888eef4

Please sign in to comment.