Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Kschmaus port eval logic #217

Merged
merged 13 commits into from
Dec 10, 2024
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ pyvenv

# VSCode
.vscode/settings.json
.vscode/launch.json
.vscode/.ropeproject
.vscode/*.log

Expand Down
6 changes: 3 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ run_mypy: ## Run mypy
mypy --config-file=mypy.ini .

run_flake8: ## Run flake8
flake8 --ignore=E203,E501,W503 --exclude=venv,pyvenv,tmp,*_pb2.py,*_pb2.pyi,images/*/src .
flake8 --ignore=E203,E501,W503 --exclude=venv,.venv,pyvenv,tmp,*_pb2.py,*_pb2.pyi,images/*/src .

check_black: ## Check to see if files would be updated with black.
# Exclude pyvenv and all generated protobuf code.
Expand All @@ -49,10 +49,10 @@ run_black: ## Run black to format files.
black --exclude="venv|pyvenv|tmp|.*_pb2.py|.*_pb2.pyi" .

check_isort: ## Check if files would be updated with isort.
isort --profile black --check --skip=venv --skip=pyvenv --skip-glob='*_pb2.py*' .
isort --profile black --check --skip=venv --skip=pyvenv --skip=.venv --skip-glob='*_pb2.py*' .

run_isort: ## Run isort to update imports.
isort --profile black --skip=pyvenv --skip=venv --skip=tmp --skip-glob='*_pb2.py*' .
isort --profile black --skip=pyvenv --skip=venv --skip=tmp --skip=.venv --skip-glob='*_pb2.py*' .


fmt_lint: shell ## lint/fmt in current python environment
Expand Down
8 changes: 4 additions & 4 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@
from app_utils.shared_utils import ( # noqa: E402
GeneratorAppScreen,
get_snowflake_connection,
set_sit_query_tag,
set_account_name,
set_host_name,
set_user_name,
set_streamlit_location,
set_sit_query_tag,
set_snowpark_session,
set_streamlit_location,
set_user_name,
)
from semantic_model_generator.snowflake_utils.env_vars import ( # noqa: E402
SNOWFLAKE_ACCOUNT_LOCATOR,
Expand All @@ -28,7 +28,7 @@ def failed_connection_popup() -> None:
Renders a dialog box detailing that the credentials provided could not be used to connect to Snowflake.
"""
st.markdown(
f"""It looks like the credentials provided could not be used to connect to the account."""
"""It looks like the credentials provided could not be used to connect to the account."""
)
st.stop()

Expand Down
4 changes: 2 additions & 2 deletions app_utils/chat.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
import re
from typing import Dict, Any
from typing import Any, Dict

import requests
import streamlit as st
Expand Down Expand Up @@ -32,7 +32,7 @@ def send_message(

resp = _snowflake.send_snow_api_request( # type: ignore
"POST",
f"/api/v2/cortex/analyst/message",
"/api/v2/cortex/analyst/message",
{},
{},
request_body,
Expand Down
56 changes: 18 additions & 38 deletions app_utils/shared_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,19 @@

import json
import os
import time
import tempfile
import time
from dataclasses import dataclass
from datetime import datetime
from enum import Enum
from io import StringIO
from typing import Any, Optional, List, Union, Dict, Tuple
from typing import Any, Dict, List, Optional, Union

import pandas as pd
import streamlit as st
from snowflake.snowpark import Session
from PIL import Image
from snowflake.connector import ProgrammingError
from snowflake.connector.connection import SnowflakeConnection
from snowflake.snowpark import Session

from semantic_model_generator.data_processing.proto_utils import (
proto_to_yaml,
Expand All @@ -27,22 +26,18 @@
)
from semantic_model_generator.protos import semantic_model_pb2
from semantic_model_generator.protos.semantic_model_pb2 import Dimension, Table
from semantic_model_generator.snowflake_utils.env_vars import ( # noqa: E402
assert_required_env_vars,
)
from semantic_model_generator.snowflake_utils.snowflake_connector import (
SnowflakeConnector,
fetch_databases,
fetch_schemas_in_database,
fetch_stages_in_schema,
fetch_table_schema,
fetch_tables_views_in_schema,
fetch_warehouses,
fetch_stages_in_schema,
fetch_yaml_names_in_stage,
fetch_columns_names_in_table,
)

from semantic_model_generator.snowflake_utils.env_vars import ( # noqa: E402
SNOWFLAKE_ACCOUNT_LOCATOR,
SNOWFLAKE_HOST,
SNOWFLAKE_USER,
assert_required_env_vars,
)

SNOWFLAKE_ACCOUNT = os.environ.get("SNOWFLAKE_ACCOUNT_LOCATOR", "")
Expand Down Expand Up @@ -105,6 +100,7 @@ def get_snowflake_connection() -> SnowflakeConnection:
if st.session_state["sis"]:
# Import SiS-required modules
import sys

from snowflake.snowpark.context import get_active_session

# Non-Anaconda supported packages must be added to path to import from stage
Expand Down Expand Up @@ -203,22 +199,20 @@ def get_available_stages(schema: str) -> List[str]:


@st.cache_resource(show_spinner=False)
def validate_table_columns(table: str, columns_must_exist: Tuple[str,...]) -> bool:
"""
Fetches the available stages from the Snowflake account.

Returns:
List[str]: A list of available stages.
"""
columns_names = fetch_columns_names_in_table(get_snowflake_connection(), table)
for col in columns_must_exist:
if col not in columns_names:
def validate_table_schema(table: str, schema: Dict[str, str]) -> bool:
table_schema = fetch_table_schema(get_snowflake_connection(), table)
# validate columns names
if set(schema.keys()) != set(table_schema.keys()):
return False
# validate column types
for col_name, col_type in table_schema.items():
if not (schema[col_name] in col_type):
return False
return True


@st.cache_resource(show_spinner=False)
def validate_table_exist(schema: str, table_name:str) -> bool:
def validate_table_exist(schema: str, table_name: str) -> bool:
"""
Validate table exist in the Snowflake account.

Expand Down Expand Up @@ -1414,17 +1408,3 @@ def to_dict(self) -> dict[str, str]:
"Schema": self.stage_schema,
"Stage": self.stage_name,
}


@dataclass
class SnowflakeTable:
table_database: str
table_schema: str
table_name: str

def to_dict(self) -> dict[str, str]:
return {
"Database": self.table_database,
"Schema": self.table_schema,
"Table": self.table_name,
}
Loading
Loading