Skip to content

Commit

Permalink
Generalized names for processor compiler and context
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-bdufour committed Dec 9, 2024
1 parent 9cb87b3 commit 5a47d22
Show file tree
Hide file tree
Showing 12 changed files with 112 additions and 78 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@


@dataclass
class BundleContext:
class ArtifactProcessorContext:
package_name: str
artifacts: List[PathMapping]
project_root: Path
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
from typing import Optional

from click import ClickException
from snowflake.cli._plugins.nativeapp.bundle_context import BundleContext
from snowflake.cli._plugins.nativeapp.artifact_processor_context import (
ArtifactProcessorContext,
)
from snowflake.cli.api.project.schemas.entities.common import (
PathMapping,
ProcessorMapping,
Expand Down Expand Up @@ -74,9 +76,9 @@ def __exit__(self, exc_type, exc_val, exc_tb):
class ArtifactProcessor(ABC):
def __init__(
self,
bundle_ctx: BundleContext,
processor_ctx: ArtifactProcessorContext,
) -> None:
self._bundle_ctx = bundle_ctx
self._processor_ctx = processor_ctx

@abstractmethod
def process(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
from typing import Dict, Optional

from click import ClickException
from snowflake.cli._plugins.nativeapp.bundle_context import BundleContext
from snowflake.cli._plugins.nativeapp.artifact_processor_context import (
ArtifactProcessorContext,
)
from snowflake.cli._plugins.nativeapp.codegen.artifact_processor import (
ArtifactProcessor,
UnsupportedArtifactProcessorError,
Expand All @@ -41,22 +43,20 @@
ProcessorClassType = type[ArtifactProcessor]


class NativeAppCompiler:
class ArtifactProcessorRegistrar:
"""
Compiler class to perform custom processing on all relevant Native Apps artifacts (specified in the project definition file)
before an application package can be created from those artifacts.
Keeps track of registered artifact processors, and invokes them on a set of artifacts.
An artifact can have more than one processor specified for itself, and this class will execute those processors in that order.
The class also maintains a dictionary of processors it creates in order to reuse them across artifacts, since processor initialization
is independent of the artifact to process.
"""

def __init__(
self,
bundle_ctx: BundleContext,
processor_ctx: ArtifactProcessorContext,
):
self._assert_absolute_paths(bundle_ctx)
self._assert_absolute_paths(processor_ctx)
self._processor_classes_by_name: Dict[str, ProcessorClassType] = {}
self._bundle_ctx = bundle_ctx
self._processor_ctx = processor_ctx
# dictionary of all processors created and shared between different artifact objects.
self.cached_processors: Dict[str, ArtifactProcessor] = {}

Expand All @@ -78,9 +78,9 @@ def register(self, processor_cls: ProcessorClassType):
self._processor_classes_by_name[str(name)] = processor_cls

@staticmethod
def _assert_absolute_paths(bundle_ctx: BundleContext):
def _assert_absolute_paths(processor_ctx: ArtifactProcessorContext):
for name in ["Project", "Deploy", "Bundle", "Generated"]:
path = getattr(bundle_ctx, f"{name.lower()}_root")
path = getattr(processor_ctx, f"{name.lower()}_root")
assert path.is_absolute(), f"{name} root {path} must be an absolute path."

def compile_artifacts(self):
Expand All @@ -99,12 +99,12 @@ def compile_artifacts(self):
cc.phase("Invoking artifact processors"),
get_cli_context().metrics.span("artifact_processors"),
):
if self._bundle_ctx.generated_root.exists():
if self._processor_ctx.generated_root.exists():
raise ClickException(
f"Path {self._bundle_ctx.generated_root} already exists. Please choose a different name for your generated directory in the project definition file."
f"Path {self._processor_ctx.generated_root} already exists. Please choose a different name for your generated directory in the project definition file."
)

for artifact in self._bundle_ctx.artifacts:
for artifact in self._processor_ctx.artifacts:
for processor in artifact.processors:
if self._is_enabled(processor):
artifact_processor = self._try_create_processor(
Expand Down Expand Up @@ -140,21 +140,21 @@ def _try_create_processor(
# No registered processor with the specified name
return None

processor_ctx = copy.copy(self._bundle_ctx)
processor_ctx = copy.copy(self._processor_ctx)
processor_subdirectory = re.sub(r"[^a-zA-Z0-9_$]", "_", processor_name)
processor_ctx.bundle_root = (
self._bundle_ctx.bundle_root / processor_subdirectory
self._processor_ctx.bundle_root / processor_subdirectory
)
processor_ctx.generated_root = (
self._bundle_ctx.generated_root / processor_subdirectory
self._processor_ctx.generated_root / processor_subdirectory
)
current_processor = processor_cls(processor_ctx)
self.cached_processors[processor_name] = current_processor

return current_processor

def _should_invoke_processors(self):
for artifact in self._bundle_ctx.artifacts:
for artifact in self._processor_ctx.artifacts:
for processor in artifact.processors:
if self._is_enabled(processor):
return True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ def process(
Processes a Python setup script and generates the corresponding SQL commands.
"""
bundle_map = BundleMap(
project_root=self._bundle_ctx.project_root,
deploy_root=self._bundle_ctx.deploy_root,
project_root=self._processor_ctx.project_root,
deploy_root=self._processor_ctx.deploy_root,
)
bundle_map.add(artifact_to_process)

Expand All @@ -108,7 +108,7 @@ def process(
absolute=True, expand_directories=True, predicate=is_python_file_artifact
):
cc.message(
f"Found Python setup file: {src_file.relative_to(self._bundle_ctx.project_root)}"
f"Found Python setup file: {src_file.relative_to(self._processor_ctx.project_root)}"
)
files_to_process.append(src_file)

Expand Down Expand Up @@ -150,15 +150,15 @@ def _execute_in_sandbox(self, py_files: List[Path]) -> dict:
file_count = len(py_files)
cc.step(f"Processing {file_count} setup file{'s' if file_count > 1 else ''}")

manifest_path = find_manifest_file(deploy_root=self._bundle_ctx.deploy_root)
manifest_path = find_manifest_file(deploy_root=self._processor_ctx.deploy_root)

generated_root = self._bundle_ctx.generated_root
generated_root = self._processor_ctx.generated_root
generated_root.mkdir(exist_ok=True, parents=True)

env_vars = {
"_SNOWFLAKE_CLI_PROJECT_PATH": str(self._bundle_ctx.project_root),
"_SNOWFLAKE_CLI_PROJECT_PATH": str(self._processor_ctx.project_root),
"_SNOWFLAKE_CLI_SETUP_FILES": os.pathsep.join(map(str, py_files)),
"_SNOWFLAKE_CLI_APP_NAME": str(self._bundle_ctx.package_name),
"_SNOWFLAKE_CLI_APP_NAME": str(self._processor_ctx.package_name),
"_SNOWFLAKE_CLI_SQL_DEST_DIR": str(generated_root),
"_SNOWFLAKE_CLI_MANIFEST_PATH": str(manifest_path),
}
Expand All @@ -167,7 +167,7 @@ def _execute_in_sandbox(self, py_files: List[Path]) -> dict:
result = execute_script_in_sandbox(
script_source=DRIVER_PATH.read_text(),
env_type=ExecutionEnvironmentType.VENV,
cwd=self._bundle_ctx.bundle_root,
cwd=self._processor_ctx.bundle_root,
timeout=DEFAULT_TIMEOUT,
path=self.sandbox_root,
env_vars=env_vars,
Expand All @@ -187,7 +187,7 @@ def _execute_in_sandbox(self, py_files: List[Path]) -> dict:
def _edit_setup_sql(self, modifications: List[dict]) -> None:
cc.step("Patching setup script")
setup_file_path = find_setup_script_file(
deploy_root=self._bundle_ctx.deploy_root
deploy_root=self._processor_ctx.deploy_root
)

with self.edit_file(setup_file_path) as f:
Expand All @@ -208,7 +208,7 @@ def _edit_setup_sql(self, modifications: List[dict]) -> None:

def _edit_manifest(self, modifications: List[dict]) -> None:
cc.step("Patching manifest")
manifest_path = find_manifest_file(deploy_root=self._bundle_ctx.deploy_root)
manifest_path = find_manifest_file(deploy_root=self._processor_ctx.deploy_root)

with self.edit_file(manifest_path) as f:
manifest = yaml.safe_load(f.contents)
Expand All @@ -232,14 +232,14 @@ def _setup_mod_instruction_to_sql(self, mod_inst: dict) -> str:
if payload_type == "execute immediate":
file_path = payload.get("file_path")
if file_path:
sql_file_path = self._bundle_ctx.generated_root / file_path
return f"EXECUTE IMMEDIATE FROM '/{to_stage_path(sql_file_path.relative_to(self._bundle_ctx.deploy_root))}';"
sql_file_path = self._processor_ctx.generated_root / file_path
return f"EXECUTE IMMEDIATE FROM '/{to_stage_path(sql_file_path.relative_to(self._processor_ctx.deploy_root))}';"

raise ClickException(f"Unsupported instruction type received: {payload_type}")

@property
def sandbox_root(self):
return self._bundle_ctx.bundle_root / "venv"
return self._processor_ctx.bundle_root / "venv"

def _create_or_update_sandbox(self):
sandbox_root = self.sandbox_root
Expand All @@ -248,7 +248,7 @@ def _create_or_update_sandbox(self):
cc.step("Virtual environment found")
else:
cc.step(
f"Creating virtual environment in {sandbox_root.relative_to(self._bundle_ctx.project_root)}"
f"Creating virtual environment in {sandbox_root.relative_to(self._processor_ctx.project_root)}"
)
env_builder.ensure_created()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,8 +184,8 @@ def process(
get_cli_context().metrics.set_counter(CLICounterField.SNOWPARK_PROCESSOR, 1)

bundle_map = BundleMap(
project_root=self._bundle_ctx.project_root,
deploy_root=self._bundle_ctx.deploy_root,
project_root=self._processor_ctx.project_root,
deploy_root=self._processor_ctx.deploy_root,
)
bundle_map.add(artifact_to_process)

Expand Down Expand Up @@ -233,7 +233,7 @@ def process(
edit_setup_script_with_exec_imm_sql(
collected_sql_files=collected_sql_files,
deploy_root=bundle_map.deploy_root(),
generated_root=self._bundle_ctx.generated_root,
generated_root=self._processor_ctx.generated_root,
)

def _normalize_imports(
Expand Down Expand Up @@ -314,7 +314,7 @@ def collect_extension_functions(
self, bundle_map: BundleMap, processor_mapping: Optional[ProcessorMapping]
) -> Dict[Path, List[NativeAppExtensionFunction]]:
kwargs = (
_determine_virtual_env(self._bundle_ctx.project_root, processor_mapping)
_determine_virtual_env(self._processor_ctx.project_root, processor_mapping)
if processor_mapping is not None
else {}
)
Expand All @@ -330,11 +330,11 @@ def collect_extension_functions(
predicate=is_python_file_artifact,
)
):
src_file_name = src_file.relative_to(self._bundle_ctx.project_root)
src_file_name = src_file.relative_to(self._processor_ctx.project_root)
cc.step(f"Processing Snowpark annotations from {src_file_name}")
collected_extension_function_json = _execute_in_sandbox(
py_file=str(dest_file.resolve()),
deploy_root=self._bundle_ctx.deploy_root,
deploy_root=self._processor_ctx.deploy_root,
kwargs=kwargs,
)

Expand Down Expand Up @@ -365,9 +365,9 @@ def generate_new_sql_file_name(self, py_file: Path) -> Path:
"""
Generates a SQL filename for the generated root from the Python file, and creates its parent directories.
"""
relative_py_file = py_file.relative_to(self._bundle_ctx.deploy_root)
relative_py_file = py_file.relative_to(self._processor_ctx.deploy_root)
sql_file = Path(
self._bundle_ctx.generated_root, relative_py_file.with_suffix(".sql")
self._processor_ctx.generated_root, relative_py_file.with_suffix(".sql")
)
if sql_file.exists():
cc.warning(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def expand_templates_in_file(
if src.is_dir():
return

src_file_name = src.relative_to(self._bundle_ctx.project_root)
src_file_name = str(src.relative_to(self._processor_ctx.project_root))

try:
with self.edit_file(dest) as file:
Expand Down Expand Up @@ -114,8 +114,8 @@ def process(
get_cli_context().metrics.set_counter(CLICounterField.TEMPLATES_PROCESSOR, 1)

bundle_map = BundleMap(
project_root=self._bundle_ctx.project_root,
deploy_root=self._bundle_ctx.deploy_root,
project_root=self._processor_ctx.project_root,
deploy_root=self._processor_ctx.deploy_root,
)
bundle_map.add(artifact_to_process)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,17 @@
from click import BadOptionUsage, ClickException
from pydantic import Field, field_validator
from snowflake.cli._plugins.connection.util import UIParameter
from snowflake.cli._plugins.nativeapp.artifact_processor_context import (
ArtifactProcessorContext,
)
from snowflake.cli._plugins.nativeapp.artifacts import (
VersionInfo,
build_bundle,
find_version_info_in_manifest_file,
)
from snowflake.cli._plugins.nativeapp.bundle_context import BundleContext
from snowflake.cli._plugins.nativeapp.codegen.compiler import NativeAppCompiler
from snowflake.cli._plugins.nativeapp.codegen.artifact_processor_registrar import (
ArtifactProcessorRegistrar,
)
from snowflake.cli._plugins.nativeapp.constants import (
ALLOWED_SPECIAL_COMMENTS,
COMMENT_COL,
Expand Down Expand Up @@ -522,15 +526,15 @@ def action_version_drop(
def _bundle(self):
model = self._entity_model
bundle_map = build_bundle(self.project_root, self.deploy_root, model.artifacts)
bundle_context = BundleContext(
bundle_context = ArtifactProcessorContext(
package_name=self.name,
artifacts=model.artifacts,
project_root=self.project_root,
bundle_root=self.bundle_root,
deploy_root=self.deploy_root,
generated_root=self.generated_root,
)
compiler = NativeAppCompiler(bundle_context)
compiler = ArtifactProcessorRegistrar(bundle_context)
compiler.compile_artifacts()
return bundle_map

Expand Down
8 changes: 5 additions & 3 deletions src/snowflake/cli/api/project/definition_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@
from typing import Any, Dict, Literal, Optional

from click import ClickException
from snowflake.cli._plugins.nativeapp.artifact_processor_context import (
ArtifactProcessorContext,
)
from snowflake.cli._plugins.nativeapp.artifacts import (
bundle_artifacts,
)
from snowflake.cli._plugins.nativeapp.bundle_context import BundleContext
from snowflake.cli._plugins.nativeapp.codegen.templates.templates_processor import (
TemplatesProcessor,
)
Expand Down Expand Up @@ -467,7 +469,7 @@ def _convert_templates_in_files(
# files on disk outside of the artifacts we want to convert
with tempfile.TemporaryDirectory() as d:
deploy_root = Path(d)
bundle_ctx = BundleContext(
processor_ctx = ArtifactProcessorContext(
package_name=pkg_model.identifier,
artifacts=pkg_model.artifacts,
project_root=project_root,
Expand All @@ -477,7 +479,7 @@ def _convert_templates_in_files(
project_root / deploy_root / pkg_model.generated_root
),
)
template_processor = TemplatesProcessor(bundle_ctx)
template_processor = TemplatesProcessor(processor_ctx)
bundle_map = bundle_artifacts(
project_root, deploy_root, artifacts_to_template
)
Expand Down
6 changes: 4 additions & 2 deletions tests/nativeapp/codegen/snowpark/test_python_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@
from unittest import mock

import pytest
from snowflake.cli._plugins.nativeapp.bundle_context import BundleContext
from snowflake.cli._plugins.nativeapp.artifact_processor_context import (
ArtifactProcessorContext,
)
from snowflake.cli._plugins.nativeapp.codegen.sandbox import (
ExecutionEnvironmentType,
SandboxExecutionError,
Expand Down Expand Up @@ -54,7 +56,7 @@ def _get_bundle_context(
pkg_model: ApplicationPackageEntityModel, project_root: Path | None = None
):
project_root = project_root or Path().resolve()
return BundleContext(
return ArtifactProcessorContext(
package_name=pkg_model.fqn.name,
artifacts=pkg_model.artifacts,
project_root=project_root,
Expand Down
Loading

0 comments on commit 5a47d22

Please sign in to comment.