From 5a3e36ee378c955a6079f37dbcd376698c5942e9 Mon Sep 17 00:00:00 2001 From: Guy Bloom Date: Thu, 12 Dec 2024 10:51:52 -0500 Subject: [PATCH] POC: Add child entities to application package (#1856) * add child entities * children_artifacts_dir * unit tests * sanitize dir name docstring * error message on child directory collision --- .../cli/_plugins/nativeapp/commands.py | 5 +- .../nativeapp/entities/application_package.py | 181 ++++++++++++++++-- .../application_package_child_interface.py | 43 +++++ .../cli/_plugins/nativeapp/feature_flags.py | 1 + src/snowflake/cli/_plugins/nativeapp/utils.py | 11 ++ .../nativeapp/v2_conversions/compat.py | 6 +- .../_plugins/streamlit/streamlit_entity.py | 64 ++++++- .../cli/_plugins/workspace/manager.py | 12 +- src/snowflake/cli/api/entities/common.py | 4 + .../api/project/schemas/project_definition.py | 33 +++- .../test_application_package_entity.py | 4 +- tests/nativeapp/test_children.py | 152 +++++++++++++++ tests/nativeapp/test_manager.py | 7 +- tests/streamlit/test_streamlit_entity.py | 53 +++++ .../projects/napp_children/app/README.md | 1 + .../projects/napp_children/app/manifest.yml | 7 + .../napp_children/app/setup_script.sql | 3 + .../projects/napp_children/snowflake.yml | 21 ++ .../projects/napp_children/streamlit_app.py | 20 ++ 19 files changed, 598 insertions(+), 30 deletions(-) create mode 100644 src/snowflake/cli/_plugins/nativeapp/entities/application_package_child_interface.py create mode 100644 tests/nativeapp/test_children.py create mode 100644 tests/streamlit/test_streamlit_entity.py create mode 100644 tests/test_data/projects/napp_children/app/README.md create mode 100644 tests/test_data/projects/napp_children/app/manifest.yml create mode 100644 tests/test_data/projects/napp_children/app/setup_script.sql create mode 100644 tests/test_data/projects/napp_children/snowflake.yml create mode 100644 tests/test_data/projects/napp_children/streamlit_app.py diff --git a/src/snowflake/cli/_plugins/nativeapp/commands.py b/src/snowflake/cli/_plugins/nativeapp/commands.py index 2d5bcbf901..411067ac5e 100644 --- a/src/snowflake/cli/_plugins/nativeapp/commands.py +++ b/src/snowflake/cli/_plugins/nativeapp/commands.py @@ -362,7 +362,10 @@ def app_validate( if cli_context.output_format == OutputFormat.JSON: return ObjectResult( package.get_validation_result( - use_scratch_stage=True, interactive=False, force=True + action_ctx=ws.action_ctx, + use_scratch_stage=True, + interactive=False, + force=True, ) ) diff --git a/src/snowflake/cli/_plugins/nativeapp/entities/application_package.py b/src/snowflake/cli/_plugins/nativeapp/entities/application_package.py index 54c643a628..8f76c7ac4c 100644 --- a/src/snowflake/cli/_plugins/nativeapp/entities/application_package.py +++ b/src/snowflake/cli/_plugins/nativeapp/entities/application_package.py @@ -1,10 +1,11 @@ from __future__ import annotations import json +import os import re from pathlib import Path from textwrap import dedent -from typing import Any, List, Literal, Optional, Union +from typing import Any, List, Literal, Optional, Set, Union import typer from click import BadOptionUsage, ClickException @@ -14,6 +15,7 @@ BundleMap, VersionInfo, build_bundle, + find_setup_script_file, find_version_info_in_manifest_file, ) from snowflake.cli._plugins.nativeapp.bundle_context import BundleContext @@ -30,6 +32,9 @@ PATCH_COL, VERSION_COL, ) +from snowflake.cli._plugins.nativeapp.entities.application_package_child_interface import ( + ApplicationPackageChildInterface, +) from snowflake.cli._plugins.nativeapp.exceptions import ( ApplicationPackageAlreadyExistsError, ApplicationPackageDoesNotExistError, @@ -48,9 +53,16 @@ from snowflake.cli._plugins.nativeapp.sf_facade_exceptions import ( InsufficientPrivilegesError, ) -from snowflake.cli._plugins.nativeapp.utils import needs_confirmation +from snowflake.cli._plugins.nativeapp.utils import needs_confirmation, sanitize_dir_name +from snowflake.cli._plugins.snowpark.snowpark_entity_model import ( + FunctionEntityModel, + ProcedureEntityModel, +) from snowflake.cli._plugins.stage.diff import DiffResult from snowflake.cli._plugins.stage.manager import StageManager +from snowflake.cli._plugins.streamlit.streamlit_entity_model import ( + StreamlitEntityModel, +) from snowflake.cli._plugins.workspace.context import ActionContext from snowflake.cli.api.cli_global_context import span from snowflake.cli.api.entities.common import ( @@ -75,6 +87,7 @@ from snowflake.cli.api.project.schemas.updatable_model import ( DiscriminatorField, IdentifierField, + UpdatableModel, ) from snowflake.cli.api.project.schemas.v1.native_app.package import DistributionOptions from snowflake.cli.api.project.schemas.v1.native_app.path_mapping import PathMapping @@ -94,6 +107,43 @@ from snowflake.connector import DictCursor, ProgrammingError from snowflake.connector.cursor import SnowflakeCursor +ApplicationPackageChildrenTypes = ( + StreamlitEntityModel | FunctionEntityModel | ProcedureEntityModel +) + + +class ApplicationPackageChildIdentifier(UpdatableModel): + schema_: Optional[str] = Field( + title="Child entity schema", alias="schema", default=None + ) + + +class EnsureUsableByField(UpdatableModel): + application_roles: Optional[Union[str, Set[str]]] = Field( + title="One or more application roles to be granted with the required privileges", + default=None, + ) + + @field_validator("application_roles") + @classmethod + def ensure_app_roles_is_a_set( + cls, application_roles: Optional[Union[str, Set[str]]] + ) -> Optional[Union[Set[str]]]: + if isinstance(application_roles, str): + return set([application_roles]) + return application_roles + + +class ApplicationPackageChildField(UpdatableModel): + target: str = Field(title="The key of the entity to include in this package") + ensure_usable_by: Optional[EnsureUsableByField] = Field( + title="Automatically grant the required privileges on the child object and its schema", + default=None, + ) + identifier: ApplicationPackageChildIdentifier = Field( + title="Entity identifier", default=None + ) + class ApplicationPackageEntityModel(EntityModelBase): type: Literal["application package"] = DiscriminatorField() # noqa: A003 @@ -101,23 +151,27 @@ class ApplicationPackageEntityModel(EntityModelBase): title="List of paths or file source/destination pairs to add to the deploy root", ) bundle_root: Optional[str] = Field( - title="Folder at the root of your project where artifacts necessary to perform the bundle step are stored.", + title="Folder at the root of your project where artifacts necessary to perform the bundle step are stored", default="output/bundle/", ) deploy_root: Optional[str] = Field( title="Folder at the root of your project where the build step copies the artifacts", default="output/deploy/", ) + children_artifacts_dir: Optional[str] = Field( + title="Folder under deploy_root where the child artifacts will be stored", + default="_children/", + ) generated_root: Optional[str] = Field( - title="Subdirectory of the deploy root where files generated by the Snowflake CLI will be written.", + title="Subdirectory of the deploy root where files generated by the Snowflake CLI will be written", default="__generated/", ) stage: Optional[str] = IdentifierField( - title="Identifier of the stage that stores the application artifacts.", + title="Identifier of the stage that stores the application artifacts", default="app_src.stage", ) scratch_stage: Optional[str] = IdentifierField( - title="Identifier of the stage that stores temporary scratch data used by the Snowflake CLI.", + title="Identifier of the stage that stores temporary scratch data used by the Snowflake CLI", default="app_src.stage_snowflake_cli_scratch", ) distribution: Optional[DistributionOptions] = Field( @@ -128,6 +182,19 @@ class ApplicationPackageEntityModel(EntityModelBase): title="Path to manifest.yml. Unused and deprecated starting with Snowflake CLI 3.2", default="", ) + children: Optional[List[ApplicationPackageChildField]] = Field( + title="Entities that will be bundled and deployed as part of this application package", + default=[], + ) + + @field_validator("children") + @classmethod + def verify_children_behind_flag( + cls, input_value: Optional[List[ApplicationPackageChildField]] + ) -> Optional[List[ApplicationPackageChildField]]: + if input_value and not FeatureFlag.ENABLE_NATIVE_APP_CHILDREN.is_enabled(): + raise AttributeError("Application package children are not supported yet") + return input_value @field_validator("identifier") @classmethod @@ -183,6 +250,10 @@ def project_root(self) -> Path: def deploy_root(self) -> Path: return self.project_root / self._entity_model.deploy_root + @property + def children_artifacts_deploy_root(self) -> Path: + return self.deploy_root / self._entity_model.children_artifacts_dir + @property def bundle_root(self) -> Path: return self.project_root / self._entity_model.bundle_root @@ -221,7 +292,7 @@ def post_deploy_hooks(self) -> list[PostDeployHook] | None: return model.meta and model.meta.post_deploy def action_bundle(self, action_ctx: ActionContext, *args, **kwargs): - return self._bundle() + return self._bundle(action_ctx) def action_deploy( self, @@ -237,6 +308,7 @@ def action_deploy( **kwargs, ): return self._deploy( + action_ctx=action_ctx, bundle_map=None, prune=prune, recursive=recursive, @@ -336,6 +408,7 @@ def action_validate( **kwargs, ): self.validate_setup_script( + action_ctx=action_ctx, use_scratch_stage=use_scratch_stage, interactive=interactive, force=force, @@ -390,7 +463,7 @@ def action_version_create( else: git_policy = AllowAlwaysPolicy() - bundle_map = self._bundle() + bundle_map = self._bundle(action_ctx) resolved_version, resolved_patch, resolved_label = self.resolve_version_info( version=version, patch=patch, @@ -404,6 +477,7 @@ def action_version_create( self.check_index_changes_in_git_repo(policy=policy, interactive=interactive) self._deploy( + action_ctx=action_ctx, bundle_map=bundle_map, prune=True, recursive=True, @@ -507,7 +581,7 @@ def action_version_drop( """ ) ) - self._bundle() + self._bundle(action_ctx) version_info = find_version_info_in_manifest_file(self.deploy_root) version = version_info.version_name if not version: @@ -692,7 +766,7 @@ def action_release_directive_unset( role=self.role, ) - def _bundle(self): + def _bundle(self, action_ctx: ActionContext = None): model = self._entity_model bundle_map = build_bundle(self.project_root, self.deploy_root, model.artifacts) bundle_context = BundleContext( @@ -705,10 +779,80 @@ def _bundle(self): ) compiler = NativeAppCompiler(bundle_context) compiler.compile_artifacts() + + if self._entity_model.children: + # Bundle children and append their SQL to setup script + # TODO Consider re-writing the logic below as a processor + children_sql = self._bundle_children(action_ctx=action_ctx) + setup_file_path = find_setup_script_file(deploy_root=self.deploy_root) + with open(setup_file_path, "r", encoding="utf-8") as file: + existing_setup_script = file.read() + if setup_file_path.is_symlink(): + setup_file_path.unlink() + with open(setup_file_path, "w", encoding="utf-8") as file: + file.write(existing_setup_script) + file.write("\n-- AUTO GENERATED CHILDREN SECTION\n") + file.write("\n".join(children_sql)) + file.write("\n") + return bundle_map + def _bundle_children(self, action_ctx: ActionContext) -> List[str]: + # Create _children directory + children_artifacts_dir = self.children_artifacts_deploy_root + os.makedirs(children_artifacts_dir) + children_sql = [] + for child in self._entity_model.children: + # Create child sub directory + child_artifacts_dir = children_artifacts_dir / sanitize_dir_name( + child.target + ) + try: + os.makedirs(child_artifacts_dir) + except FileExistsError: + raise ClickException( + f"Could not create sub-directory at {child_artifacts_dir}. Make sure child entity names do not collide with each other." + ) + child_entity: ApplicationPackageChildInterface = action_ctx.get_entity( + child.target + ) + child_entity.bundle(child_artifacts_dir) + app_role = ( + to_identifier( + child.ensure_usable_by.application_roles.pop() # TODO Support more than one application role + ) + if child.ensure_usable_by and child.ensure_usable_by.application_roles + else None + ) + child_schema = ( + to_identifier(child.identifier.schema_) + if child.identifier and child.identifier.schema_ + else None + ) + children_sql.append( + child_entity.get_deploy_sql( + artifacts_dir=child_artifacts_dir.relative_to(self.deploy_root), + schema=child_schema, + ) + ) + if app_role: + children_sql.append( + f"CREATE APPLICATION ROLE IF NOT EXISTS {app_role};" + ) + if child_schema: + children_sql.append( + f"GRANT USAGE ON SCHEMA {child_schema} TO APPLICATION ROLE {app_role};" + ) + children_sql.append( + child_entity.get_usage_grant_sql( + app_role=app_role, schema=child_schema + ) + ) + return children_sql + def _deploy( self, + action_ctx: ActionContext, bundle_map: BundleMap | None, prune: bool, recursive: bool, @@ -733,7 +877,7 @@ def _deploy( stage_fqn = stage_fqn or self.stage_fqn # 1. Create a bundle if one wasn't passed in - bundle_map = bundle_map or self._bundle() + bundle_map = bundle_map or self._bundle(action_ctx) # 2. Create an empty application package, if none exists try: @@ -765,6 +909,7 @@ def _deploy( if validate: self.validate_setup_script( + action_ctx=action_ctx, use_scratch_stage=False, interactive=interactive, force=force, @@ -1054,7 +1199,11 @@ def execute_post_deploy_hooks(self): ) def validate_setup_script( - self, use_scratch_stage: bool, interactive: bool, force: bool + self, + action_ctx: ActionContext, + use_scratch_stage: bool, + interactive: bool, + force: bool, ): workspace_ctx = self._workspace_ctx console = workspace_ctx.console @@ -1062,6 +1211,7 @@ def validate_setup_script( """Validates Native App setup script SQL.""" with console.phase(f"Validating Snowflake Native App setup script."): validation_result = self.get_validation_result( + action_ctx=action_ctx, use_scratch_stage=use_scratch_stage, force=force, interactive=interactive, @@ -1083,13 +1233,18 @@ def validate_setup_script( @span("validate_setup_script") def get_validation_result( - self, use_scratch_stage: bool, interactive: bool, force: bool + self, + action_ctx: ActionContext, + use_scratch_stage: bool, + interactive: bool, + force: bool, ): """Call system$validate_native_app_setup() to validate deployed Native App setup script.""" stage_fqn = self.stage_fqn if use_scratch_stage: stage_fqn = self.scratch_stage_fqn self._deploy( + action_ctx=action_ctx, bundle_map=None, prune=True, recursive=True, diff --git a/src/snowflake/cli/_plugins/nativeapp/entities/application_package_child_interface.py b/src/snowflake/cli/_plugins/nativeapp/entities/application_package_child_interface.py new file mode 100644 index 0000000000..c4f13871e4 --- /dev/null +++ b/src/snowflake/cli/_plugins/nativeapp/entities/application_package_child_interface.py @@ -0,0 +1,43 @@ +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Optional + + +class ApplicationPackageChildInterface(ABC): + @abstractmethod + def bundle(self, bundle_root=Path, *args, **kwargs) -> None: + """ + Bundles the entity artifacts into the provided root directory. Must not have any side-effects, such as deploying the artifacts into a stage, etc. + @param bundle_root: The directory where the bundle contents should be put. + """ + pass + + @abstractmethod + def get_deploy_sql( + self, + artifacts_dir: Path, + schema: Optional[str], + *args, + **kwargs, + ) -> str: + """ + Returns the SQL that would create the entity object. Must not execute the SQL or have any other side-effects. + @param artifacts_dir: Path to the child entity artifacts directory relative to the deploy root. + @param [Optional] schema: Schema to use when creating the object. + """ + pass + + @abstractmethod + def get_usage_grant_sql( + self, + app_role: str, + schema: Optional[str], + *args, + **kwargs, + ) -> str: + """ + Returns the SQL that would grant the required USAGE privilege to the provided application role on the entity object. Must not execute the SQL or have any other side-effects. + @param app_role: The application role to grant the privileges to. + @param [Optional] schema: The schema where the object was created. + """ + pass diff --git a/src/snowflake/cli/_plugins/nativeapp/feature_flags.py b/src/snowflake/cli/_plugins/nativeapp/feature_flags.py index dbc47e7483..dc7e93bf51 100644 --- a/src/snowflake/cli/_plugins/nativeapp/feature_flags.py +++ b/src/snowflake/cli/_plugins/nativeapp/feature_flags.py @@ -22,4 +22,5 @@ class FeatureFlag(FeatureFlagMixin): ENABLE_NATIVE_APP_PYTHON_SETUP = BooleanFlag( "ENABLE_NATIVE_APP_PYTHON_SETUP", False ) + ENABLE_NATIVE_APP_CHILDREN = BooleanFlag("ENABLE_NATIVE_APP_CHILDREN", False) ENABLE_RELEASE_CHANNELS = BooleanFlag("ENABLE_RELEASE_CHANNELS", None) diff --git a/src/snowflake/cli/_plugins/nativeapp/utils.py b/src/snowflake/cli/_plugins/nativeapp/utils.py index 87fa989d2a..fa2a4cebd5 100644 --- a/src/snowflake/cli/_plugins/nativeapp/utils.py +++ b/src/snowflake/cli/_plugins/nativeapp/utils.py @@ -96,3 +96,14 @@ def verify_no_directories(paths_to_sync: Iterable[Path]): def verify_exists(path: Path): if not path.exists(): raise ClickException(f"The following path does not exist: {path}") + + +def sanitize_dir_name(dir_name: str) -> str: + """ + Returns a string that is safe to use as a directory name. + For simplicity, this function is over restricitive: it strips non alphanumeric characters, + unless listed in the allow list. Additional characters can be allowed in the future, but + we need to be careful to consider both Unix/Windows directory naming rules. + """ + allowed_chars = [" ", "_"] + return "".join(char for char in dir_name if char in allowed_chars or char.isalnum()) diff --git a/src/snowflake/cli/_plugins/nativeapp/v2_conversions/compat.py b/src/snowflake/cli/_plugins/nativeapp/v2_conversions/compat.py index 93d60c2e2b..a72a12f68d 100644 --- a/src/snowflake/cli/_plugins/nativeapp/v2_conversions/compat.py +++ b/src/snowflake/cli/_plugins/nativeapp/v2_conversions/compat.py @@ -217,7 +217,11 @@ def wrapper(*args, **kwargs): entities_to_keep.add(app_definition.entity_id) kwargs["app_entity_id"] = app_definition.entity_id for entity_id in list(original_pdf.entities): - if entity_id not in entities_to_keep: + entity_type = original_pdf.entities[entity_id].type.lower() + if ( + entity_type in ["application", "application package"] + and entity_id not in entities_to_keep + ): # This happens after templates are rendered, # so we can safely remove the entity del original_pdf.entities[entity_id] diff --git a/src/snowflake/cli/_plugins/streamlit/streamlit_entity.py b/src/snowflake/cli/_plugins/streamlit/streamlit_entity.py index 6def772525..6b187ba54b 100644 --- a/src/snowflake/cli/_plugins/streamlit/streamlit_entity.py +++ b/src/snowflake/cli/_plugins/streamlit/streamlit_entity.py @@ -1,12 +1,72 @@ +from pathlib import Path +from typing import Optional + +from snowflake.cli._plugins.nativeapp.artifacts import build_bundle +from snowflake.cli._plugins.nativeapp.entities.application_package_child_interface import ( + ApplicationPackageChildInterface, +) +from snowflake.cli._plugins.nativeapp.feature_flags import FeatureFlag from snowflake.cli._plugins.streamlit.streamlit_entity_model import ( StreamlitEntityModel, ) from snowflake.cli.api.entities.common import EntityBase +from snowflake.cli.api.project.schemas.v1.native_app.path_mapping import PathMapping -class StreamlitEntity(EntityBase[StreamlitEntityModel]): +# WARNING: This entity is not implemented yet. The logic below is only for demonstrating the +# required interfaces for composability (used by ApplicationPackageEntity behind a feature flag). +class StreamlitEntity( + EntityBase[StreamlitEntityModel], ApplicationPackageChildInterface +): """ A Streamlit app. """ - pass + def __init__(self, *args, **kwargs): + if not FeatureFlag.ENABLE_NATIVE_APP_CHILDREN.is_enabled(): + raise NotImplementedError("Streamlit entity is not implemented yet") + super().__init__(*args, **kwargs) + + @property + def project_root(self) -> Path: + return self._workspace_ctx.project_root + + @property + def deploy_root(self) -> Path: + return self.project_root / "output" / "deploy" + + def action_bundle( + self, + *args, + **kwargs, + ): + return self.bundle() + + def bundle(self, bundle_root=None): + return build_bundle( + self.project_root, + bundle_root or self.deploy_root, + [ + PathMapping(src=str(artifact)) + for artifact in self._entity_model.artifacts + ], + ) + + def get_deploy_sql( + self, + artifacts_dir: Optional[Path] = None, + schema: Optional[str] = None, + ): + entity_id = self.entity_id + if artifacts_dir: + streamlit_name = f"{schema}.{entity_id}" if schema else entity_id + return f"CREATE OR REPLACE STREAMLIT {streamlit_name} FROM '{artifacts_dir}' MAIN_FILE='{self._entity_model.main_file}';" + else: + return f"CREATE OR REPLACE STREAMLIT {entity_id} MAIN_FILE='{self._entity_model.main_file}';" + + def get_usage_grant_sql(self, app_role: str, schema: Optional[str] = None): + entity_id = self.entity_id + streamlit_name = f"{schema}.{entity_id}" if schema else entity_id + return ( + f"GRANT USAGE ON STREAMLIT {streamlit_name} TO APPLICATION ROLE {app_role};" + ) diff --git a/src/snowflake/cli/_plugins/workspace/manager.py b/src/snowflake/cli/_plugins/workspace/manager.py index 25b56d542f..10d7fef9c7 100644 --- a/src/snowflake/cli/_plugins/workspace/manager.py +++ b/src/snowflake/cli/_plugins/workspace/manager.py @@ -1,3 +1,4 @@ +from functools import cached_property from pathlib import Path from typing import Dict @@ -58,10 +59,7 @@ def perform_action(self, entity_id: str, action: EntityActions, *args, **kwargs) """ entity = self.get_entity(entity_id) if entity.supports(action): - action_ctx = ActionContext( - get_entity=self.get_entity, - ) - return entity.perform(action, action_ctx, *args, **kwargs) + return entity.perform(action, self.action_ctx, *args, **kwargs) else: raise ValueError(f'This entity type does not support "{action.value}"') @@ -69,6 +67,12 @@ def perform_action(self, entity_id: str, action: EntityActions, *args, **kwargs) def project_root(self) -> Path: return self._project_root + @cached_property + def action_ctx(self) -> ActionContext: + return ActionContext( + get_entity=self.get_entity, + ) + def _get_default_role() -> str: role = default_role() diff --git a/src/snowflake/cli/api/entities/common.py b/src/snowflake/cli/api/entities/common.py index c7bd6bfb0f..c444dc0897 100644 --- a/src/snowflake/cli/api/entities/common.py +++ b/src/snowflake/cli/api/entities/common.py @@ -63,6 +63,10 @@ def __init__(self, entity_model: T, workspace_ctx: WorkspaceContext): self._entity_model = entity_model self._workspace_ctx = workspace_ctx + @property + def entity_id(self): + return self._entity_model.entity_id + @classmethod def get_entity_model_type(cls) -> Type[T]: """ diff --git a/src/snowflake/cli/api/project/schemas/project_definition.py b/src/snowflake/cli/api/project/schemas/project_definition.py index cda6ecd8eb..2b0f4f5cf0 100644 --- a/src/snowflake/cli/api/project/schemas/project_definition.py +++ b/src/snowflake/cli/api/project/schemas/project_definition.py @@ -15,12 +15,17 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Union +from types import UnionType +from typing import Any, Dict, List, Optional, Union, get_args, get_origin from packaging.version import Version from pydantic import Field, ValidationError, field_validator, model_validator from pydantic_core.core_schema import ValidationInfo from snowflake.cli._plugins.nativeapp.entities.application import ApplicationEntityModel +from snowflake.cli._plugins.nativeapp.entities.application_package import ( + ApplicationPackageChildrenTypes, + ApplicationPackageEntityModel, +) from snowflake.cli.api.project.errors import SchemaValidationError from snowflake.cli.api.project.schemas.entities.common import ( TargetField, @@ -159,6 +164,12 @@ def _validate_single_entity( target_object = entity.from_ target_type = target_object.get_type() cls._validate_target_field(target_key, target_type, entities) + elif entity.type == ApplicationPackageEntityModel.get_type(): + for child_entity in entity.children: + target_key = child_entity.target + cls._validate_target_field( + target_key, ApplicationPackageChildrenTypes, entities + ) @classmethod def _validate_target_field( @@ -168,11 +179,20 @@ def _validate_target_field( raise ValueError(f"No such target: {target_key}") # Validate the target type - actual_target_type = entities[target_key].__class__ - if target_type and target_type is not actual_target_type: - raise ValueError( - f"Target type mismatch. Expected {target_type.__name__}, got {actual_target_type.__name__}" - ) + if target_type: + actual_target_type = entities[target_key].__class__ + if get_origin(target_type) in (Union, UnionType): + if actual_target_type not in get_args(target_type): + expected_types_str = ", ".join( + [t.__name__ for t in get_args(target_type)] + ) + raise ValueError( + f"Target type mismatch. Expected one of [{expected_types_str}], got {actual_target_type.__name__}" + ) + elif target_type is not actual_target_type: + raise ValueError( + f"Target type mismatch. Expected {target_type.__name__}, got {actual_target_type.__name__}" + ) @model_validator(mode="before") @classmethod @@ -200,6 +220,7 @@ def apply_mixins(cls, data: Dict, info: ValidationInfo) -> Dict: mixin_defs=data["mixins"], ) entities[entity_name] = merged_values + return data @classmethod diff --git a/tests/nativeapp/test_application_package_entity.py b/tests/nativeapp/test_application_package_entity.py index 2a0e632a6d..0772a5ada0 100644 --- a/tests/nativeapp/test_application_package_entity.py +++ b/tests/nativeapp/test_application_package_entity.py @@ -45,8 +45,8 @@ ) -def _get_app_pkg_entity(project_directory): - with project_directory("workspaces_simple") as project_root: +def _get_app_pkg_entity(project_directory, test_dir="workspaces_simple"): + with project_directory(test_dir) as project_root: with Path(project_root / "snowflake.yml").open() as definition_file_path: project_definition = yaml.safe_load(definition_file_path) model = ApplicationPackageEntityModel( diff --git a/tests/nativeapp/test_children.py b/tests/nativeapp/test_children.py new file mode 100644 index 0000000000..fca85666e3 --- /dev/null +++ b/tests/nativeapp/test_children.py @@ -0,0 +1,152 @@ +from __future__ import annotations + +from pathlib import Path +from textwrap import dedent + +import pytest +import yaml +from snowflake.cli._plugins.nativeapp.entities.application_package import ( + ApplicationPackageEntityModel, +) +from snowflake.cli._plugins.nativeapp.feature_flags import FeatureFlag +from snowflake.cli._plugins.streamlit.streamlit_entity import StreamlitEntity +from snowflake.cli._plugins.workspace.context import ActionContext +from snowflake.cli._plugins.workspace.manager import WorkspaceManager +from snowflake.cli.api.project.errors import SchemaValidationError +from snowflake.cli.api.project.schemas.project_definition import ( + DefinitionV20, +) + +from tests.testing_utils.mock_config import mock_config_key + + +def _get_app_pkg_entity(project_directory): + with project_directory("napp_children") as project_root: + with Path(project_root / "snowflake.yml").open() as definition_file_path: + project_definition = DefinitionV20(**yaml.safe_load(definition_file_path)) + wm = WorkspaceManager( + project_definition=project_definition, + project_root=project_root, + ) + pkg_entity = wm.get_entity("pkg") + streamlit_entity = wm.get_entity("my_streamlit") + action_ctx = ActionContext( + get_entity=lambda entity_id: streamlit_entity, + ) + return ( + pkg_entity, + action_ctx, + ) + + +def test_children_feature_flag_is_disabled(): + assert FeatureFlag.ENABLE_NATIVE_APP_CHILDREN.is_enabled() == False + with pytest.raises(AttributeError) as err: + ApplicationPackageEntityModel( + **{"type": "application package", "children": [{"target": "some_child"}]} + ) + assert str(err.value) == "Application package children are not supported yet" + + +def test_invalid_children_type(): + with mock_config_key("enable_native_app_children", True): + definition_input = { + "definition_version": "2", + "entities": { + "pkg": { + "type": "application package", + "artifacts": [], + "children": [ + { + # packages cannot contain other packages as children + "target": "pkg2" + } + ], + }, + "pkg2": { + "type": "application package", + "artifacts": [], + }, + }, + } + with pytest.raises(SchemaValidationError) as err: + DefinitionV20(**definition_input) + assert "Target type mismatch" in str(err.value) + + +def test_invalid_children_target(): + with mock_config_key("enable_native_app_children", True): + definition_input = { + "definition_version": "2", + "entities": { + "pkg": { + "type": "application package", + "artifacts": [], + "children": [ + { + # no such entity + "target": "sl" + } + ], + }, + }, + } + with pytest.raises(SchemaValidationError) as err: + DefinitionV20(**definition_input) + assert "No such target: sl" in str(err.value) + + +def test_valid_children(): + with mock_config_key("enable_native_app_children", True): + definition_input = { + "definition_version": "2", + "entities": { + "pkg": { + "type": "application package", + "artifacts": [], + "children": [{"target": "sl"}], + }, + "sl": {"type": "streamlit", "identifier": "my_streamlit"}, + }, + } + project_definition = DefinitionV20(**definition_input) + wm = WorkspaceManager( + project_definition=project_definition, + project_root="", + ) + child_entity_id = project_definition.entities["pkg"].children[0] + child_entity = wm.get_entity(child_entity_id.target) + assert child_entity.__class__ == StreamlitEntity + + +def test_children_bundle_with_custom_dir(project_directory): + with mock_config_key("enable_native_app_children", True): + app_pkg, action_ctx = _get_app_pkg_entity(project_directory) + bundle_result = app_pkg.action_bundle(action_ctx) + deploy_root = bundle_result.deploy_root() + + # Application package artifacts + assert (deploy_root / "README.md").exists() + assert (deploy_root / "manifest.yml").exists() + assert (deploy_root / "setup_script.sql").exists() + + # Child artifacts + assert ( + deploy_root / "_entities" / "my_streamlit" / "streamlit_app.py" + ).exists() + + # Generated setup script section + with open(deploy_root / "setup_script.sql", "r") as f: + setup_script_content = f.read() + custom_dir_path = Path("_entities", "my_streamlit") + assert setup_script_content.endswith( + dedent( + f""" + -- AUTO GENERATED CHILDREN SECTION + CREATE OR REPLACE STREAMLIT v_schema.my_streamlit FROM '{custom_dir_path}' MAIN_FILE='streamlit_app.py'; + CREATE APPLICATION ROLE IF NOT EXISTS my_app_role; + GRANT USAGE ON SCHEMA v_schema TO APPLICATION ROLE my_app_role; + GRANT USAGE ON STREAMLIT v_schema.my_streamlit TO APPLICATION ROLE my_app_role; + """ + ) + ) diff --git a/tests/nativeapp/test_manager.py b/tests/nativeapp/test_manager.py index 57cafe7a07..c61467c044 100644 --- a/tests/nativeapp/test_manager.py +++ b/tests/nativeapp/test_manager.py @@ -1376,6 +1376,7 @@ def test_validate_use_scratch_stage(mock_execute, mock_deploy, temp_dir, mock_cu pd = wm._project_definition # noqa: SLF001 pkg_model: ApplicationPackageEntityModel = pd.entities["app_pkg"] mock_deploy.assert_called_with( + action_ctx=wm.action_ctx, bundle_map=None, prune=True, recursive=True, @@ -1452,6 +1453,7 @@ def test_validate_failing_drops_scratch_stage( pd = wm._project_definition # noqa: SLF001 pkg_model: ApplicationPackageEntityModel = pd.entities["app_pkg"] mock_deploy.assert_called_with( + action_ctx=wm.action_ctx, bundle_map=None, prune=True, recursive=True, @@ -1511,7 +1513,10 @@ def test_validate_raw_returns_data(mock_execute, temp_dir, mock_cursor): pkg = wm.get_entity("app_pkg") assert ( pkg.get_validation_result( - use_scratch_stage=False, interactive=False, force=True + action_ctx=wm.action_ctx, + use_scratch_stage=False, + interactive=False, + force=True, ) == failure_data ) diff --git a/tests/streamlit/test_streamlit_entity.py b/tests/streamlit/test_streamlit_entity.py new file mode 100644 index 0000000000..315e34b8e5 --- /dev/null +++ b/tests/streamlit/test_streamlit_entity.py @@ -0,0 +1,53 @@ +from __future__ import annotations + +from pathlib import Path + +import pytest +from snowflake.cli._plugins.streamlit.streamlit_entity import ( + StreamlitEntity, +) +from snowflake.cli._plugins.streamlit.streamlit_entity_model import ( + StreamlitEntityModel, +) +from snowflake.cli._plugins.workspace.context import WorkspaceContext +from snowflake.cli.api.console import cli_console as cc +from snowflake.cli.api.project.definition_manager import DefinitionManager + +from tests.testing_utils.mock_config import mock_config_key + + +def test_cannot_instantiate_without_feature_flag(): + with pytest.raises(NotImplementedError) as err: + StreamlitEntity() + assert str(err.value) == "Streamlit entity is not implemented yet" + + +def test_nativeapp_children_interface(temp_dir): + with mock_config_key("enable_native_app_children", True): + dm = DefinitionManager() + ctx = WorkspaceContext( + console=cc, + project_root=dm.project_root, + get_default_role=lambda: "mock_role", + get_default_warehouse=lambda: "mock_warehouse", + ) + main_file = "main.py" + (Path(temp_dir) / main_file).touch() + model = StreamlitEntityModel( + type="streamlit", + main_file=main_file, + artifacts=[main_file], + ) + sl = StreamlitEntity(model, ctx) + + sl.bundle() + bundle_artifact = Path(temp_dir) / "output" / "deploy" / main_file + deploy_sql_str = sl.get_deploy_sql() + grant_sql_str = sl.get_usage_grant_sql(app_role="app_role") + + assert bundle_artifact.exists() + assert deploy_sql_str == "CREATE OR REPLACE STREAMLIT None MAIN_FILE='main.py';" + assert ( + grant_sql_str + == "GRANT USAGE ON STREAMLIT None TO APPLICATION ROLE app_role;" + ) diff --git a/tests/test_data/projects/napp_children/app/README.md b/tests/test_data/projects/napp_children/app/README.md new file mode 100644 index 0000000000..7e59600739 --- /dev/null +++ b/tests/test_data/projects/napp_children/app/README.md @@ -0,0 +1 @@ +# README diff --git a/tests/test_data/projects/napp_children/app/manifest.yml b/tests/test_data/projects/napp_children/app/manifest.yml new file mode 100644 index 0000000000..0b8b9b892c --- /dev/null +++ b/tests/test_data/projects/napp_children/app/manifest.yml @@ -0,0 +1,7 @@ +# This is the v2 version of the napp_init_v1 project + +manifest_version: 1 + +artifacts: + setup_script: setup_script.sql + readme: README.md diff --git a/tests/test_data/projects/napp_children/app/setup_script.sql b/tests/test_data/projects/napp_children/app/setup_script.sql new file mode 100644 index 0000000000..ade6eccbd6 --- /dev/null +++ b/tests/test_data/projects/napp_children/app/setup_script.sql @@ -0,0 +1,3 @@ +CREATE OR ALTER VERSIONED SCHEMA v_schema; +CREATE APPLICATION ROLE IF NOT EXISTS my_app_role; +GRANT USAGE ON SCHEMA v_schema TO APPLICATION ROLE my_app_role; diff --git a/tests/test_data/projects/napp_children/snowflake.yml b/tests/test_data/projects/napp_children/snowflake.yml new file mode 100644 index 0000000000..52667820df --- /dev/null +++ b/tests/test_data/projects/napp_children/snowflake.yml @@ -0,0 +1,21 @@ +definition_version: 2 +entities: + pkg: + type: application package + identifier: my_pkg + artifacts: + - src: app/* + dest: ./ + children_artifacts_dir: _entities + children: + - target: my_streamlit + identifier: + schema: v_schema + ensure_usable_by: + application_roles: ["my_app_role"] + + my_streamlit: + type: streamlit + main_file: streamlit_app.py + artifacts: + - streamlit_app.py diff --git a/tests/test_data/projects/napp_children/streamlit_app.py b/tests/test_data/projects/napp_children/streamlit_app.py new file mode 100644 index 0000000000..45c8ad3822 --- /dev/null +++ b/tests/test_data/projects/napp_children/streamlit_app.py @@ -0,0 +1,20 @@ +from http.client import HTTPSConnection + +import _snowflake +import streamlit as st + + +def get_secret_value(): + return _snowflake.get_generic_secret_string("generic_secret") + + +def send_request(): + host = "docs.snowflake.com" + conn = HTTPSConnection(host) + conn.request("GET", "/") + response = conn.getresponse() + st.success(f"Response status: {response.status}") + + +st.title(f"Example streamlit app.") +st.button("Send request", on_click=send_request)