Skip to content

Commit

Permalink
Remove project section check outside with_project_definition decorator (
Browse files Browse the repository at this point in the history
#1276)

* Remove project section check outside with_project_definition decorator

* Add util function for checking project type
  • Loading branch information
sfc-gh-melnacouzi authored Jul 11, 2024
1 parent a3aad2b commit 8938645
Show file tree
Hide file tree
Showing 8 changed files with 108 additions and 73 deletions.
6 changes: 2 additions & 4 deletions src/snowflake/cli/api/commands/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,7 @@ def global_options_with_connection(func: Callable):
)


def with_project_definition(
project_name: Optional[str] = None, is_optional: bool = False
):
def with_project_definition(is_optional: bool = False):
def _decorator(func: Callable):

return _options_decorator_factory(
Expand All @@ -86,7 +84,7 @@ def _decorator(func: Callable):
"project_definition",
inspect.Parameter.KEYWORD_ONLY,
annotation=Optional[str],
default=project_definition_option(project_name, is_optional),
default=project_definition_option(is_optional),
),
inspect.Parameter(
"env_overrides",
Expand Down
31 changes: 5 additions & 26 deletions src/snowflake/cli/api/commands/flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from snowflake.cli.api.cli_global_context import cli_context_manager
from snowflake.cli.api.commands.typer_pre_execute import register_pre_execute_command
from snowflake.cli.api.console import cli_console
from snowflake.cli.api.exceptions import MissingConfiguration, NoProjectDefinitionError
from snowflake.cli.api.exceptions import MissingConfiguration
from snowflake.cli.api.output.formats import OutputFormat
from snowflake.cli.api.project.definition_manager import DefinitionManager
from snowflake.cli.api.rendering.jinja import CONTEXT_KEY
Expand Down Expand Up @@ -516,7 +516,7 @@ def execution_identifier_argument(sf_object: str, example: str) -> typer.Argumen
)


def register_project_definition(project_name: Optional[str], is_optional: bool) -> None:
def register_project_definition(is_optional: bool) -> None:
project_path = cli_context_manager.project_path_arg
env_overrides_args = cli_context_manager.project_env_overrides_args

Expand All @@ -530,42 +530,21 @@ def register_project_definition(project_name: Optional[str], is_optional: bool)
"Cannot find project definition (snowflake.yml). Please provide a path to the project or run this command in a valid project directory."
)

if project_name is not None and not getattr(project_definition, project_name, None):
raise NoProjectDefinitionError(
project_type=project_name, project_file=project_path
)

cli_context_manager.set_project_definition(project_definition)
cli_context_manager.set_project_root(project_root)
cli_context_manager.set_template_context(template_context)


def _get_project_long_name(project_short_name: Optional[str]) -> str:
if project_short_name is None:
return "Snowflake"

if project_short_name == "native_app":
project_long_name = "Snowflake Native App"
elif project_short_name == "streamlit":
project_long_name = "Streamlit app"
else:
project_long_name = project_short_name.replace("_", " ").capitalize()

return f"the {project_long_name}"


def project_definition_option(project_name: Optional[str], is_optional: bool):
def project_definition_option(is_optional: bool):
def project_definition_callback(project_path: str) -> None:
cli_context_manager.set_project_path_arg(project_path)
register_pre_execute_command(
lambda: register_project_definition(project_name, is_optional)
)
register_pre_execute_command(lambda: register_project_definition(is_optional))

return typer.Option(
None,
"-p",
"--project",
help=f"Path where {_get_project_long_name(project_name)} project resides. "
help=f"Path where Snowflake project resides. "
f"Defaults to current working directory.",
callback=_callback(lambda: project_definition_callback),
show_default=False,
Expand Down
23 changes: 23 additions & 0 deletions src/snowflake/cli/api/project/project_verification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Copyright (c) 2024 Snowflake Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from snowflake.cli.api.cli_global_context import cli_context
from snowflake.cli.api.exceptions import NoProjectDefinitionError


def assert_project_type(project_type: str):
if not getattr(cli_context.project_definition, project_type, None):
raise NoProjectDefinitionError(
project_type=project_type, project_file=cli_context.project_root
)
30 changes: 24 additions & 6 deletions src/snowflake/cli/plugins/nativeapp/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
MessageResult,
ObjectResult,
)
from snowflake.cli.api.project.project_verification import assert_project_type
from snowflake.cli.api.secure_path import SecurePath
from snowflake.cli.plugins.nativeapp.common_flags import (
ForceOption,
Expand Down Expand Up @@ -145,13 +146,16 @@ def app_list_templates(**options) -> CommandResult:


@app.command("bundle")
@with_project_definition("native_app")
@with_project_definition()
def app_bundle(
**options,
) -> CommandResult:
"""
Prepares a local folder with configured app artifacts.
"""

assert_project_type("native_app")

manager = NativeAppManager(
project_definition=cli_context.project_definition.native_app,
project_root=cli_context.project_root,
Expand All @@ -161,7 +165,7 @@ def app_bundle(


@app.command("run", requires_connection=True)
@with_project_definition("native_app")
@with_project_definition()
def app_run(
version: Optional[str] = typer.Option(
None,
Expand Down Expand Up @@ -191,6 +195,8 @@ def app_run(
then creates or upgrades an application object from the application package.
"""

assert_project_type("native_app")

is_interactive = False
if force:
policy = AllowAlwaysPolicy()
Expand Down Expand Up @@ -221,14 +227,17 @@ def app_run(


@app.command("open", requires_connection=True)
@with_project_definition("native_app")
@with_project_definition()
def app_open(
**options,
) -> CommandResult:
"""
Opens the Snowflake Native App inside of your browser,
once it has been installed in your account.
"""

assert_project_type("native_app")

manager = NativeAppManager(
project_definition=cli_context.project_definition.native_app,
project_root=cli_context.project_root,
Expand All @@ -243,7 +252,7 @@ def app_open(


@app.command("teardown", requires_connection=True)
@with_project_definition("native_app")
@with_project_definition()
def app_teardown(
force: Optional[bool] = ForceOption,
cascade: Optional[bool] = typer.Option(
Expand All @@ -257,6 +266,9 @@ def app_teardown(
"""
Attempts to drop both the application object and application package as defined in the project definition file.
"""

assert_project_type("native_app")

processor = NativeAppTeardownProcessor(
project_definition=cli_context.project_definition.native_app,
project_root=cli_context.project_root,
Expand All @@ -266,7 +278,7 @@ def app_teardown(


@app.command("deploy", requires_connection=True)
@with_project_definition("native_app")
@with_project_definition()
def app_deploy(
prune: Optional[bool] = typer.Option(
default=None,
Expand Down Expand Up @@ -296,6 +308,9 @@ def app_deploy(
Creates an application package in your Snowflake account and syncs the local changes to the stage without creating or updating the application.
Running this command with no arguments at all, as in `snow app deploy`, is a shorthand for `snow app deploy --prune --recursive`.
"""

assert_project_type("native_app")

has_paths = paths is not None and len(paths) > 0
if prune is None and recursive is None and not has_paths:
prune = True
Expand Down Expand Up @@ -329,11 +344,14 @@ def app_deploy(


@app.command("validate", requires_connection=True)
@with_project_definition("native_app")
@with_project_definition()
def app_validate(**options):
"""
Validates a deployed Snowflake Native App's setup script.
"""

assert_project_type("native_app")

manager = NativeAppManager(
project_definition=cli_context.project_definition.native_app,
project_root=cli_context.project_root,
Expand Down
16 changes: 13 additions & 3 deletions src/snowflake/cli/plugins/nativeapp/version/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
)
from snowflake.cli.api.commands.snow_typer import SnowTyperFactory
from snowflake.cli.api.output.types import CommandResult, MessageResult, QueryResult
from snowflake.cli.api.project.project_verification import assert_project_type
from snowflake.cli.plugins.nativeapp.common_flags import ForceOption, InteractiveOption
from snowflake.cli.plugins.nativeapp.policy import (
AllowAlwaysPolicy,
Expand All @@ -46,7 +47,7 @@


@app.command(requires_connection=True)
@with_project_definition("native_app")
@with_project_definition()
def create(
version: Optional[str] = typer.Argument(
None,
Expand All @@ -71,6 +72,9 @@ def create(
"""
Adds a new patch to the provided version defined in your application package. If the version does not exist, creates a version with patch 0.
"""

assert_project_type("native_app")

if version is None and patch is not None:
raise MissingParameter("Cannot provide a patch without version!")

Expand Down Expand Up @@ -107,13 +111,16 @@ def create(


@app.command("list", requires_connection=True)
@with_project_definition("native_app")
@with_project_definition()
def version_list(
**options,
) -> CommandResult:
"""
Lists all versions defined in an application package.
"""

assert_project_type("native_app")

processor = NativeAppRunProcessor(
project_definition=cli_context.project_definition.native_app,
project_root=cli_context.project_root,
Expand All @@ -123,7 +130,7 @@ def version_list(


@app.command(requires_connection=True)
@with_project_definition("native_app")
@with_project_definition()
def drop(
version: Optional[str] = typer.Argument(
None,
Expand All @@ -137,6 +144,9 @@ def drop(
Drops a version defined in your application package. Versions can either be passed in as an argument to the command or read from the `manifest.yml` file.
Dropping patches is not allowed.
"""

assert_project_type("native_app")

is_interactive = False
if force:
policy = AllowAlwaysPolicy()
Expand Down
15 changes: 10 additions & 5 deletions src/snowflake/cli/plugins/snowpark/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,15 @@
DEPLOYMENT_STAGE,
ObjectType,
)
from snowflake.cli.api.exceptions import (
SecretsWithoutExternalAccessIntegrationError,
)
from snowflake.cli.api.exceptions import SecretsWithoutExternalAccessIntegrationError
from snowflake.cli.api.identifiers import FQN
from snowflake.cli.api.output.types import (
CollectionResult,
CommandResult,
MessageResult,
SingleQueryResult,
)
from snowflake.cli.api.project.project_verification import assert_project_type
from snowflake.cli.api.project.schemas.snowpark.callable import (
FunctionSchema,
ProcedureSchema,
Expand Down Expand Up @@ -116,7 +115,7 @@


@app.command("deploy", requires_connection=True)
@with_project_definition("snowpark")
@with_project_definition()
def deploy(
replace: bool = ReplaceOption(
help="Replaces procedure or function, even if no detected changes to metadata"
Expand All @@ -128,6 +127,9 @@ def deploy(
By default, if any of the objects exist already the commands will fail unless `--replace` flag is provided.
All deployed objects use the same artifact which is deployed only once.
"""

assert_project_type("snowpark")

snowpark = cli_context.project_definition.snowpark
paths = SnowparkPackagePaths.for_snowpark_project(
project_root=SecurePath(cli_context.project_root),
Expand Down Expand Up @@ -379,7 +381,7 @@ def _read_snowflake_requrements_file(file_path: SecurePath):


@app.command("build", requires_connection=True)
@with_project_definition("snowpark")
@with_project_definition()
def build(
ignore_anaconda: bool = IgnoreAnacondaOption,
allow_shared_libraries: bool = AllowSharedLibrariesOption,
Expand All @@ -396,6 +398,9 @@ def build(
Builds the Snowpark project as a `.zip` archive that can be used by `deploy` command.
The archive is built using only the `src` directory specified in the project file.
"""

assert_project_type("snowpark")

if not deprecated_check_anaconda_for_pypi_deps:
ignore_anaconda = True
snowpark_paths = SnowparkPackagePaths.for_snowpark_project(
Expand Down
6 changes: 5 additions & 1 deletion src/snowflake/cli/plugins/streamlit/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
MessageResult,
SingleQueryResult,
)
from snowflake.cli.api.project.project_verification import assert_project_type
from snowflake.cli.api.project.schemas.streamlit.streamlit import Streamlit
from snowflake.cli.plugins.object.command_aliases import (
add_object_command_aliases,
Expand Down Expand Up @@ -118,7 +119,7 @@ def _check_file_exists_if_not_default(ctx: click.Context, value):


@app.command("deploy", requires_connection=True)
@with_project_definition("streamlit")
@with_project_definition()
@with_experimental_behaviour()
def streamlit_deploy(
replace: bool = ReplaceOption(
Expand All @@ -132,6 +133,9 @@ def streamlit_deploy(
environment.yml and any other pages or folders, if present. If you don’t specify a stage name, the `streamlit`
stage is used. If the specified stage does not exist, the command creates it.
"""

assert_project_type("streamlit")

streamlit: Streamlit = cli_context.project_definition.streamlit
if not streamlit:
return MessageResult("No streamlit were specified in project definition.")
Expand Down
Loading

0 comments on commit 8938645

Please sign in to comment.