Skip to content

Commit

Permalink
Allow flow parameter schema generation when dependencies are missing (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
desertaxle committed May 29, 2024
1 parent 7dbb320 commit de7ef82
Show file tree
Hide file tree
Showing 7 changed files with 1,183 additions and 21 deletions.
30 changes: 13 additions & 17 deletions src/prefect/cli/deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,15 @@
from prefect.deployments.steps.core import run_steps
from prefect.events import DeploymentTriggerTypes, TriggerTypes
from prefect.exceptions import ObjectNotFound, PrefectHTTPStatusError
from prefect.flows import load_flow_from_entrypoint
from prefect.flows import load_flow_argument_from_entrypoint
from prefect.settings import (
PREFECT_DEFAULT_WORK_POOL_NAME,
PREFECT_UI_URL,
)
from prefect.utilities.annotations import NotSet
from prefect.utilities.asyncutils import run_sync_in_worker_thread
from prefect.utilities.callables import parameter_schema
from prefect.utilities.callables import (
parameter_schema_from_entrypoint,
)
from prefect.utilities.collections import get_from_dict
from prefect.utilities.slugify import slugify
from prefect.utilities.templating import (
Expand Down Expand Up @@ -474,26 +475,20 @@ async def _run_single_deploy(
"You can also provide an entrypoint in a prefect.yaml file."
)
deploy_config["entrypoint"] = await prompt_entrypoint(app.console)
if deploy_config.get("flow_name") and deploy_config.get("entrypoint"):
raise ValueError(
"Received an entrypoint and a flow name for this deployment. Please provide"
" either an entrypoint or a flow name."
)

# entrypoint logic
if deploy_config.get("entrypoint"):
flow = await run_sync_in_worker_thread(
load_flow_from_entrypoint, deploy_config["entrypoint"]
)
deploy_config["flow_name"] = flow.name
deploy_config["flow_name"] = load_flow_argument_from_entrypoint(
deploy_config["entrypoint"], arg="name"
)

deployment_name = deploy_config.get("name")
if not deployment_name:
if not is_interactive():
raise ValueError("A deployment name must be provided.")
deploy_config["name"] = prompt("Deployment name", default="default")

deploy_config["parameter_openapi_schema"] = parameter_schema(flow)
deploy_config["parameter_openapi_schema"] = parameter_schema_from_entrypoint(
deploy_config["entrypoint"]
)

deploy_config["schedules"] = _construct_schedules(
deploy_config,
Expand Down Expand Up @@ -674,8 +669,9 @@ async def _run_single_deploy(
deploy_config["work_pool"]["job_variables"]["image"] = "{{ build-image.image }}"

if not deploy_config.get("description"):
deploy_config["description"] = flow.description

deploy_config["description"] = load_flow_argument_from_entrypoint(
deploy_config["entrypoint"], arg="description"
)
# save deploy_config before templating
deploy_config_before_templating = deepcopy(deploy_config)
## apply templating from build and push steps to the final deployment spec
Expand Down
60 changes: 59 additions & 1 deletion src/prefect/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
# This file requires type-checking with pyright because mypy does not yet support PEP612
# See https://github.com/python/mypy/issues/8645

import ast
import datetime
import importlib.util
import inspect
import os
import tempfile
Expand Down Expand Up @@ -1640,7 +1642,9 @@ def load_flow_from_script(path: str, flow_name: str = None) -> Flow:
)


def load_flow_from_entrypoint(entrypoint: str) -> Flow:
def load_flow_from_entrypoint(
entrypoint: str,
) -> Flow:
"""
Extract a flow object from a script at an entrypoint by running all of the code in the file.
Expand Down Expand Up @@ -1793,3 +1797,57 @@ def my_other_flow(name):
)

await runner.start()


def load_flow_argument_from_entrypoint(
entrypoint: str, arg: str = "name"
) -> Optional[str]:
"""
Extract a flow argument from an entrypoint string.
Loads the source code of the entrypoint and extracts the flow argument from the
`flow` decorator.
Args:
entrypoint: a string in the format `<path_to_script>:<flow_func_name>` or a module path
to a flow function
Returns:
The flow argument value
"""
if ":" in entrypoint:
# split by the last colon once to handle Windows paths with drive letters i.e C:\path\to\file.py:do_stuff
path, func_name = entrypoint.rsplit(":", maxsplit=1)
source_code = Path(path).read_text()
else:
path, func_name = entrypoint.rsplit(".", maxsplit=1)
spec = importlib.util.find_spec(path)
if not spec or not spec.origin:
raise ValueError(f"Could not find module {path!r}")
source_code = Path(spec.origin).read_text()
parsed_code = ast.parse(source_code)
func_def = next(
(
node
for node in ast.walk(parsed_code)
if isinstance(node, ast.FunctionDef) and node.name == func_name
),
None,
)
if not func_def:
raise ValueError(f"Could not find flow {func_name!r} in {path!r}")
for decorator in func_def.decorator_list:
if (
isinstance(decorator, ast.Call)
and getattr(decorator.func, "id", "") == "flow"
):
for keyword in decorator.keywords:
if keyword.arg == arg:
return (
keyword.value.value
) # Return the string value of the argument

if arg == "name":
return func_name.replace(
"_", "-"
) # If no matching decorator or keyword argument is found
192 changes: 190 additions & 2 deletions src/prefect/utilities/callables.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@
Utilities for working with Python callables.
"""

import ast
import importlib.util
import inspect
from functools import partial
from pathlib import Path
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple

import cloudpickle
Expand Down Expand Up @@ -31,7 +34,10 @@
ReservedArgumentError,
SignatureMismatchError,
)
from prefect.logging.loggers import disable_logger
from prefect.logging.loggers import disable_logger, get_logger
from prefect.utilities.importtools import safe_load_namespace

logger = get_logger(__name__)


def get_call_parameters(
Expand Down Expand Up @@ -318,9 +324,62 @@ def parameter_schema(fn: Callable) -> ParameterSchema:
# `eval_str` is not available in Python < 3.10
signature = inspect.signature(fn)

docstrings = parameter_docstrings(inspect.getdoc(fn))

return generate_parameter_schema(signature, docstrings)


def parameter_schema_from_entrypoint(entrypoint: str) -> ParameterSchema:
"""
Generate a parameter schema from an entrypoint string.
Will load the source code of the function and extract the signature and docstring
to generate the schema.
Useful for generating a schema for a function when instantiating the function may
not be possible due to missing imports or other issues.
Args:
entrypoint: A string representing the entrypoint to a function. The string
should be in the format of `module.path.to.function:do_stuff`.
Returns:
ParameterSchema: The parameter schema for the function.
"""
if ":" in entrypoint:
# split by the last colon once to handle Windows paths with drive letters i.e C:\path\to\file.py:do_stuff
path, func_name = entrypoint.rsplit(":", maxsplit=1)
source_code = Path(path).read_text()
else:
path, func_name = entrypoint.rsplit(".", maxsplit=1)
spec = importlib.util.find_spec(path)
if not spec or not spec.origin:
raise ValueError(f"Could not find module {path!r}")
source_code = Path(spec.origin).read_text()
signature = _generate_signature_from_source(source_code, func_name)
docstring = _get_docstring_from_source(source_code, func_name)
return generate_parameter_schema(signature, parameter_docstrings(docstring))


def generate_parameter_schema(
signature: inspect.Signature, docstrings: Dict[str, str]
) -> ParameterSchema:
"""
Generate a parameter schema from a function signature and docstrings.
To get a signature from a function, use `inspect.signature(fn)` or
`_generate_signature_from_source(source_code, func_name)`.
Args:
signature: The function signature.
docstrings: A dictionary mapping parameter names to docstrings.
Returns:
ParameterSchema: The parameter schema.
"""

model_fields = {}
aliases = {}
docstrings = parameter_docstrings(inspect.getdoc(fn))

class ModelConfig:
arbitrary_types_allowed = True
Expand Down Expand Up @@ -362,3 +421,132 @@ def raise_for_reserved_arguments(fn: Callable, reserved_arguments: Iterable[str]
raise ReservedArgumentError(
f"{argument!r} is a reserved argument name and cannot be used."
)


def _generate_signature_from_source(
source_code: str, func_name: str
) -> inspect.Signature:
"""
Extract the signature of a function from its source code.
Will ignore missing imports and exceptions while loading local class definitions.
Args:
source_code: The source code where the function named `func_name` is declared.
func_name: The name of the function.
Returns:
The signature of the function.
"""
# Load the namespace from the source code. Missing imports and exceptions while
# loading local class definitions are ignored.
namespace = safe_load_namespace(source_code)
# Parse the source code into an AST
parsed_code = ast.parse(source_code)

func_def = next(
(
node
for node in ast.walk(parsed_code)
if isinstance(node, ast.FunctionDef) and node.name == func_name
),
None,
)
if func_def is None:
raise ValueError(f"Function {func_name} not found in source code")
parameters = []

for arg in func_def.args.args:
name = arg.arg
annotation = arg.annotation
if annotation is not None:
try:
# Compile and evaluate the annotation
ann_code = compile(ast.Expression(annotation), "<string>", "eval")
annotation = eval(ann_code, namespace)
except Exception as e:
# Don't raise an error if the annotation evaluation fails. Set the
# annotation to `inspect.Parameter.empty` instead which is equivalent to
# not having an annotation.
logger.debug("Failed to evaluate annotation for %s: %s", name, e)
annotation = inspect.Parameter.empty
else:
annotation = inspect.Parameter.empty

param = inspect.Parameter(
name, inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=annotation
)
parameters.append(param)

defaults = [None] * (
len(func_def.args.args) - len(func_def.args.defaults)
) + func_def.args.defaults
for param, default in zip(parameters, defaults):
if default is not None:
try:
def_code = compile(ast.Expression(default), "<string>", "eval")
default = eval(def_code, namespace)
except Exception as e:
logger.debug(
"Failed to evaluate default value for %s: %s", param.name, e
)
default = None # Set to None if evaluation fails
parameters[parameters.index(param)] = param.replace(default=default)

if func_def.args.vararg:
parameters.append(
inspect.Parameter(
func_def.args.vararg.arg, inspect.Parameter.VAR_POSITIONAL
)
)
if func_def.args.kwarg:
parameters.append(
inspect.Parameter(func_def.args.kwarg.arg, inspect.Parameter.VAR_KEYWORD)
)

# Handle return annotation
return_annotation = func_def.returns
if return_annotation is not None:
try:
ret_ann_code = compile(
ast.Expression(return_annotation), "<string>", "eval"
)
return_annotation = eval(ret_ann_code, namespace)
except Exception as e:
logger.debug("Failed to evaluate return annotation: %s", e)
return_annotation = inspect.Signature.empty

return inspect.Signature(parameters, return_annotation=return_annotation)


def _get_docstring_from_source(source_code: str, func_name: str) -> Optional[str]:
"""
Extract the docstring of a function from its source code.
Args:
source_code (str): The source code of the function.
func_name (str): The name of the function.
Returns:
The docstring of the function. If the function has no docstring, returns None.
"""
parsed_code = ast.parse(source_code)

func_def = next(
(
node
for node in ast.walk(parsed_code)
if isinstance(node, ast.FunctionDef) and node.name == func_name
),
None,
)
if func_def is None:
raise ValueError(f"Function {func_name} not found in source code")

if (
func_def.body
and isinstance(func_def.body[0], ast.Expr)
and isinstance(func_def.body[0].value, ast.Constant)
):
return func_def.body[0].value.value
return None
Loading

0 comments on commit de7ef82

Please sign in to comment.