Skip to content

Commit

Permalink
file paths + fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-cgorrie committed Dec 12, 2023
1 parent bc3e4b3 commit 9fc4731
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 55 deletions.
3 changes: 2 additions & 1 deletion src/snowcli/cli/appify/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from snowcli.cli.appify.metadata import MetadataDumper
from snowcli.cli.appify.generate import (
load_catalog,
modifications,
generate_setup_statements,
rewrite_stage_imports,
Expand Down Expand Up @@ -51,7 +52,7 @@ def appify(
dumper = MetadataDumper(db, project.path)
dumper.execute()

catalog = {} # load_catalog(dumper.catalog_path)
catalog = load_catalog(dumper.catalog_path)
rewrite_stage_imports(catalog, dumper.referenced_stage_ids, dumper.metadata_path)

# generate the setup script
Expand Down
85 changes: 63 additions & 22 deletions src/snowcli/cli/appify/generate.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from typing import Generator, List, Tuple

import re
import json
from textwrap import dedent
from contextlib import contextmanager
from pathlib import Path
from strictyaml import YAML, load
from click import ClickException

from snowcli.cli.appify.util import split_fqn_id

Expand All @@ -20,6 +23,16 @@
"streamlit": "usage on streamlit",
}

# FIXME: current streamlit get_ddl misses the final single quote
STREAMLIT_NAME = re.compile(r"^\s*create or replace streamlit (.+)$", re.MULTILINE)
STREAMLIT_ROOT_LOCATION = re.compile(r"^\s*root_location='(.+)$", re.MULTILINE)
STREAMLIT_MAIN_FILE = re.compile(r"^\s*main_file='(.+)'$", re.MULTILINE)


class MalformedStreamlitError(ClickException):
def __init__(self, property: str, path: Path):
super().__init__(f"Streamlit DDL is non-conforming for {property} at {path}")


@contextmanager
def modifications(path: Path) -> Generator[YAML, None, None]:
Expand All @@ -44,13 +57,6 @@ def get_ordering(catalog: dict) -> List[Tuple[str, str]]:
return []


def get_kind(catalog: dict, schema: str, object_name: str) -> str:
"""
Determine the kind of an object based on the metadata catalog.
"""
pass


def load_catalog(catalog_json: Path) -> dict:
"""
Returns the metadata catalog for the database, containing reference
Expand All @@ -64,25 +70,57 @@ def rewrite_stage_imports(
catalog: dict, stage_ids: List[str], metadata_path: Path
) -> None:
"""
Rewrite the "imports" part of callable DDL statements that reference stages we have
imported to be part of our application stage. Instead of referencing a different stage,
these will now reference a path inside our application stage.
Rewrite the "imports" part of callable / streamlit DDL statements as they now need to
reference paths inside our application stage. We re-write the streamlit DDL fully, as
there are missing features in NA (e.g. query_warehouse) and bugs in its get_ddl impl.
"""

def _rewrite_imports(s: str) -> str:
# FIXME: likely quoting is wrong here.
for stage_id in stage_ids:
(stage_db, stage_schema, stage_name) = split_fqn_id(stage_id)
needle = f"@{stage_id}/"
replacement = f"/stages/{stage_db}/{stage_schema}/{stage_name}/"
s = s.replace(needle, replacement)
return s

for id, object in catalog.items():
if object["kind"] in CALLABLE_KINDS:
(_db, schema, object_name) = split_fqn_id(id)
sql_path = metadata_path / schema / f"{object_name}.sql"
ddl_statement = sql_path.read_text()
ddl_statement = _rewrite_imports(sql_path.read_text())
sql_path.write_text(ddl_statement)

# FIXME: likely quoting is wrong here.
for stage_id in stage_ids:
(stage_db, stage_schema, stage_name) = split_fqn_id(stage_id)
needle = f"@{stage_id}/"
replacement = f"/stages/{stage_db}/{stage_schema}/{stage_name}/"
ddl_statement = ddl_statement.replace(needle, replacement)
elif object["kind"] == "streamlit":
(_db, schema, object_name) = split_fqn_id(id)
sql_path = metadata_path / schema / f"{object_name}.sql"
ddl_statement = sql_path.read_text()

ddl_statement = ddl_statement.replace()
sql_path.write_text(ddl_statement)
if match := STREAMLIT_NAME.match(ddl_statement):
name = match.group(1)
else:
raise MalformedStreamlitError("name", sql_path)

if match := STREAMLIT_MAIN_FILE.match(ddl_statement):
main_file = match.group(1)
else:
raise MalformedStreamlitError("main_file", sql_path)

if match := STREAMLIT_ROOT_LOCATION.match(ddl_statement):
root_location = match.group(1)
else:
raise MalformedStreamlitError("root_location", sql_path)

from_clause = _rewrite_imports(root_location)
sql_path.write_text(
dedent(
f"""
create or replace streamlit {name}
FROM '{from_clause}'
MAIN_FILE='{main_file};
"""
)
)


def generate_setup_statements(
Expand All @@ -100,11 +138,14 @@ def generate_setup_statements(
yield f"create or alter versioned schema {to_identifier(schema)};"
yield f"grant usage on schema {to_identifier(schema)} to application role {APP_PUBLIC};"

for schema, object_name in get_ordering(catalog):
kind = get_kind(catalog, schema, object_name)
for fqn in get_ordering(catalog):
(_db, schema, object_name) = split_fqn_id(fqn)
kind = catalog[fqn]["kind"]
yield f"use schema {to_identifier(schema)};"
# XXX: is this correct quoting?
yield f"execute immediate from './metadata/{schema}/{object_name}.sql';"
if kind in GRANT_BY_KIND:
# FIXME: need to refactor to split name + arguments so we can quote only the name
yield f"grant {GRANT_BY_KIND[kind]} {to_identifier(schema)}.{object_name};"
yield f"""
grant {GRANT_BY_KIND[kind]} {to_identifier(schema)}.{object_name} to application role {APP_PUBLIC};
""".strip()
36 changes: 17 additions & 19 deletions src/snowcli/cli/appify/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,7 @@

log = logging.getLogger(__name__)

REFERENCES_BY_NAME_JSON = "references_by_name.json"
REFERENCES_DOMAINS = ["function", "table", "view"]
REFERENCES_FILE_NAME = "references.json"
CATALOG_FILE_NAME = "catalog.json"

DOMAIN_TO_SHOW_COMMAND_NOUN = {
"function": "user functions",
Expand Down Expand Up @@ -103,6 +101,10 @@ def metadata_path(self) -> Path:
def stages_path(self) -> Path:
return self.project_path / "stages"

@cached_property
def catalog_path(self) -> Path:
return self.metadata_path / CATALOG_FILE_NAME

def get_stage_path(self, stage_id: str) -> Path:
(db, schema, stage_name) = split_fqn_id(stage_id)
return self.stages_path / db / schema / stage_name
Expand Down Expand Up @@ -158,8 +160,8 @@ def execute(self) -> None:
# functions, procedures, and streamlits appropriately.
for stage_id in self.referenced_stage_ids:
self.dump_stage(stage_id)
self.dump_references(self.metadata_path)

self.dump_references()
ordered_objects = self.get_ordering()

def process_schema(self, schema: str) -> None:
Expand Down Expand Up @@ -204,9 +206,9 @@ def process_schema(self, schema: str) -> None:
f.write(ddl)
self.update_references(schema_path, schema, object_name, domain)

def dump_references(self, path: str):
def dump_references(self):
# dump references
with open(path / REFERENCES_FILE_NAME, "w") as ref_file:
with open(self.catalog_path, "w") as ref_file:
json.dump(self.references, ref_file)

def dump_stage(self, stage_id: str) -> None:
Expand All @@ -220,26 +222,28 @@ def dump_stage(self, stage_id: str) -> None:
def update_references(
self, schema_path: str, schema: str, object_name: str, domain: str
) -> None:
log.info(f"grabbing references for object {schema}.{object_name} with domain {domain}")
log.info(
f"grabbing references for object {schema}.{object_name} with domain {domain}"
)
literal = self._object_literal(schema, object_name)
references_cursor = self._execute_query(
f"select system$GET_REFERENCES_BY_NAME_AS_OF_TIME({literal}, '{domain}')"
)
references_list = json.loads(references_cursor.fetchone()[0])
cleaned_up_ref_list = []
clean_ref_names = []
for reference in references_list:
for reference in references_list:
name = reference[0]
domain = reference[1]
if domain.upper() in ["FUNCTION"]:
cleaned_up_name = re.sub(r'^(.*\))(.*)$', r'\1', name) + "\""
if domain.upper() in ["FUNCTION"]:
cleaned_up_name = re.sub(r"^(.*\))(.*)$", r"\1", name) + '"'
cleaned_up_ref_list.append([cleaned_up_name, domain])
clean_ref_names.append(cleaned_up_name)
else:
cleaned_up_ref_list.append(reference)
clean_ref_names.append(name)
fqn = self.get_object_fully_qualified_name(schema, object_name)
self.references[self.get_object_fully_qualified_name(schema, object_name)] = {
self.references[self.get_object_fully_qualified_name(schema, object_name)] = {
"references": cleaned_up_ref_list,
"kind": domain,
"object_name": object_name,
Expand All @@ -251,11 +255,5 @@ def update_references(
def get_ordering(self) -> List[str]:
log.info(f"ordering graph {self.ordering_graph}")
ts = graphlib.TopologicalSorter(self.ordering_graph)
ordered_objects = list(ts.static_order())
ordered_objects = list(ts.static_order())
return ordered_objects






20 changes: 7 additions & 13 deletions src/snowcli/cli/appify/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
from click import ClickException
from typing import Callable, Optional, List, Tuple
from snowflake.connector.cursor import DictCursor
from snowcli.cli.project.util import DB_SCHEMA_AND_NAME, SCHEMA_AND_NAME
from snowcli.cli.project.util import DB_SCHEMA_AND_NAME, IDENTIFIER

DB_SCHEMA_NAME_ARGS = f"{DB_SCHEMA_AND_NAME}([(].+[)])?"
STAGE_IMPORT_REGEX = f"@({DB_SCHEMA_AND_NAME})/"


Expand Down Expand Up @@ -34,17 +35,10 @@ def split_fqn_id(id: str) -> Tuple[str, str, str]:
"""
Splits a fully-qualified identifier into its consituent parts.
Returns (database, schema, name); quoting carries over from the input.
Name can have arguments in it, e.g. for callable objects.
"""
if match := re.fullmatch(DB_SCHEMA_AND_NAME, id):
return (match.group(1), match.group(2), match.group(3))
raise NotAQualifiedNameError(id)


def split_schema_and_object_id(id: str) -> Tuple[str, str, str]:
"""
Splits a partially-qualified identifier into its consituent parts.
Returns (schema, name); quoting carries over from the input.
"""
if match := re.fullmatch(DB_SCHEMA_AND_NAME, id):
return (match.group(1), match.group(2))
if match := re.fullmatch(DB_SCHEMA_NAME_ARGS, id):
args = match.group(4)
name = match.group(3) if args is None else f"{match.group(3)}{args}"
return (match.group(1), match.group(2), name)
raise NotAQualifiedNameError(id)

0 comments on commit 9fc4731

Please sign in to comment.