Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow flow parameter schema generation when dependencies are missing #13315

Merged
merged 10 commits into from
May 29, 2024
30 changes: 13 additions & 17 deletions src/prefect/cli/deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,15 @@
from prefect.deployments.steps.core import run_steps
from prefect.events import DeploymentTriggerTypes, TriggerTypes
from prefect.exceptions import ObjectNotFound, PrefectHTTPStatusError
from prefect.flows import load_flow_from_entrypoint
from prefect.flows import load_flow_argument_from_entrypoint
from prefect.settings import (
PREFECT_DEFAULT_WORK_POOL_NAME,
PREFECT_UI_URL,
)
from prefect.utilities.annotations import NotSet
from prefect.utilities.asyncutils import run_sync_in_worker_thread
from prefect.utilities.callables import parameter_schema
from prefect.utilities.callables import (
parameter_schema_from_entrypoint,
)
from prefect.utilities.collections import get_from_dict
from prefect.utilities.slugify import slugify
from prefect.utilities.templating import (
Expand Down Expand Up @@ -470,26 +471,20 @@ async def _run_single_deploy(
"You can also provide an entrypoint in a prefect.yaml file."
)
deploy_config["entrypoint"] = await prompt_entrypoint(app.console)
if deploy_config.get("flow_name") and deploy_config.get("entrypoint"):
raise ValueError(
"Received an entrypoint and a flow name for this deployment. Please provide"
" either an entrypoint or a flow name."
)

# entrypoint logic
if deploy_config.get("entrypoint"):
flow = await run_sync_in_worker_thread(
load_flow_from_entrypoint, deploy_config["entrypoint"]
)
deploy_config["flow_name"] = flow.name
deploy_config["flow_name"] = load_flow_argument_from_entrypoint(
deploy_config["entrypoint"], arg="name"
)

deployment_name = deploy_config.get("name")
if not deployment_name:
if not is_interactive():
raise ValueError("A deployment name must be provided.")
deploy_config["name"] = prompt("Deployment name", default="default")

deploy_config["parameter_openapi_schema"] = parameter_schema(flow)
deploy_config["parameter_openapi_schema"] = parameter_schema_from_entrypoint(
deploy_config["entrypoint"]
)

deploy_config["schedules"] = _construct_schedules(
deploy_config,
Expand Down Expand Up @@ -669,8 +664,9 @@ async def _run_single_deploy(
deploy_config["work_pool"]["job_variables"]["image"] = "{{ build-image.image }}"

if not deploy_config.get("description"):
deploy_config["description"] = flow.description

deploy_config["description"] = load_flow_argument_from_entrypoint(
deploy_config["entrypoint"], arg="description"
)
# save deploy_config before templating
deploy_config_before_templating = deepcopy(deploy_config)
## apply templating from build and push steps to the final deployment spec
Expand Down
60 changes: 59 additions & 1 deletion src/prefect/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
# This file requires type-checking with pyright because mypy does not yet support PEP612
# See https://github.com/python/mypy/issues/8645

import ast
import datetime
import importlib.util
import inspect
import json
import os
Expand Down Expand Up @@ -1709,7 +1711,9 @@ def load_flow_from_script(path: str, flow_name: str = None) -> Flow:
)


def load_flow_from_entrypoint(entrypoint: str) -> Flow:
def load_flow_from_entrypoint(
entrypoint: str,
) -> Flow:
"""
Extract a flow object from a script at an entrypoint by running all of the code in the file.

Expand Down Expand Up @@ -1943,3 +1947,57 @@ async def load_flow_from_flow_run(
flow = await run_sync_in_worker_thread(load_flow_from_entrypoint, str(import_path))

return flow


def load_flow_argument_from_entrypoint(
entrypoint: str, arg: str = "name"
) -> Optional[str]:
"""
Extract a flow argument from an entrypoint string.

Loads the source code of the entrypoint and extracts the flow argument from the
`flow` decorator.

Args:
entrypoint: a string in the format `<path_to_script>:<flow_func_name>` or a module path
to a flow function

Returns:
The flow argument value
"""
if ":" in entrypoint:
# split by the last colon once to handle Windows paths with drive letters i.e C:\path\to\file.py:do_stuff
path, func_name = entrypoint.rsplit(":", maxsplit=1)
source_code = Path(path).read_text()
else:
path, func_name = entrypoint.rsplit(".", maxsplit=1)
spec = importlib.util.find_spec(path)
if not spec or not spec.origin:
raise ValueError(f"Could not find module {path!r}")
source_code = Path(spec.origin).read_text()
parsed_code = ast.parse(source_code)
func_def = next(
(
node
for node in ast.walk(parsed_code)
if isinstance(node, ast.FunctionDef) and node.name == func_name
),
None,
)
if not func_def:
raise ValueError(f"Could not find flow {func_name!r} in {path!r}")
for decorator in func_def.decorator_list:
if (
isinstance(decorator, ast.Call)
and getattr(decorator.func, "id", "") == "flow"
):
for keyword in decorator.keywords:
if keyword.arg == arg:
return (
keyword.value.value
) # Return the string value of the argument

if arg == "name":
return func_name.replace(
"_", "-"
) # If no matching decorator or keyword argument is found
192 changes: 190 additions & 2 deletions src/prefect/utilities/callables.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@
Utilities for working with Python callables.
"""

import ast
import importlib.util
import inspect
import warnings
from functools import partial
from pathlib import Path
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple

import cloudpickle
Expand All @@ -26,9 +29,12 @@
ReservedArgumentError,
SignatureMismatchError,
)
from prefect.logging.loggers import disable_logger
from prefect.logging.loggers import disable_logger, get_logger
from prefect.utilities.annotations import allow_failure, quote, unmapped
from prefect.utilities.collections import isiterable
from prefect.utilities.importtools import safe_load_namespace

logger = get_logger(__name__)


def get_call_parameters(
Expand Down Expand Up @@ -321,9 +327,62 @@ def parameter_schema(fn: Callable) -> ParameterSchema:
# `eval_str` is not available in Python < 3.10
signature = inspect.signature(fn)

docstrings = parameter_docstrings(inspect.getdoc(fn))

return generate_parameter_schema(signature, docstrings)


def parameter_schema_from_entrypoint(entrypoint: str) -> ParameterSchema:
desertaxle marked this conversation as resolved.
Show resolved Hide resolved
"""
Generate a parameter schema from an entrypoint string.

Will load the source code of the function and extract the signature and docstring
to generate the schema.

Useful for generating a schema for a function when instantiating the function may
not be possible due to missing imports or other issues.

Args:
entrypoint: A string representing the entrypoint to a function. The string
should be in the format of `module.path.to.function:do_stuff`.

Returns:
ParameterSchema: The parameter schema for the function.
"""
if ":" in entrypoint:
# split by the last colon once to handle Windows paths with drive letters i.e C:\path\to\file.py:do_stuff
path, func_name = entrypoint.rsplit(":", maxsplit=1)
source_code = Path(path).read_text()
else:
path, func_name = entrypoint.rsplit(".", maxsplit=1)
spec = importlib.util.find_spec(path)
if not spec or not spec.origin:
raise ValueError(f"Could not find module {path!r}")
source_code = Path(spec.origin).read_text()
signature = _generate_signature_from_source(source_code, func_name)
docstring = _get_docstring_from_source(source_code, func_name)
return generate_parameter_schema(signature, parameter_docstrings(docstring))


def generate_parameter_schema(
desertaxle marked this conversation as resolved.
Show resolved Hide resolved
signature: inspect.Signature, docstrings: Dict[str, str]
) -> ParameterSchema:
"""
Generate a parameter schema from a function signature and docstrings.

To get a signature from a function, use `inspect.signature(fn)` or
`_generate_signature_from_source(source_code, func_name)`.

Args:
signature: The function signature.
docstrings: A dictionary mapping parameter names to docstrings.

Returns:
ParameterSchema: The parameter schema.
"""

model_fields = {}
aliases = {}
docstrings = parameter_docstrings(inspect.getdoc(fn))

if not has_v1_type_as_param(signature):
create_schema = create_v2_schema
Expand Down Expand Up @@ -369,6 +428,135 @@ def raise_for_reserved_arguments(fn: Callable, reserved_arguments: Iterable[str]
)


def _generate_signature_from_source(
source_code: str, func_name: str
) -> inspect.Signature:
"""
Extract the signature of a function from its source code.

Will ignore missing imports and exceptions while loading local class definitions.

Args:
source_code: The source code where the function named `func_name` is declared.
func_name: The name of the function.

Returns:
The signature of the function.
"""
# Load the namespace from the source code. Missing imports and exceptions while
# loading local class definitions are ignored.
namespace = safe_load_namespace(source_code)
# Parse the source code into an AST
parsed_code = ast.parse(source_code)

func_def = next(
(
node
for node in ast.walk(parsed_code)
if isinstance(node, ast.FunctionDef) and node.name == func_name
),
None,
)
if func_def is None:
raise ValueError(f"Function {func_name} not found in source code")
parameters = []

for arg in func_def.args.args:
name = arg.arg
annotation = arg.annotation
if annotation is not None:
try:
# Compile and evaluate the annotation
ann_code = compile(ast.Expression(annotation), "<string>", "eval")
annotation = eval(ann_code, namespace)
except Exception as e:
# Don't raise an error if the annotation evaluation fails. Set the
# annotation to `inspect.Parameter.empty` instead which is equivalent to
# not having an annotation.
logger.debug("Failed to evaluate annotation for %s: %s", name, e)
annotation = inspect.Parameter.empty
else:
annotation = inspect.Parameter.empty

param = inspect.Parameter(
name, inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=annotation
)
parameters.append(param)

defaults = [None] * (
len(func_def.args.args) - len(func_def.args.defaults)
) + func_def.args.defaults
for param, default in zip(parameters, defaults):
if default is not None:
try:
def_code = compile(ast.Expression(default), "<string>", "eval")
default = eval(def_code, namespace)
except Exception as e:
logger.debug(
"Failed to evaluate default value for %s: %s", param.name, e
)
default = None # Set to None if evaluation fails
parameters[parameters.index(param)] = param.replace(default=default)

if func_def.args.vararg:
parameters.append(
inspect.Parameter(
func_def.args.vararg.arg, inspect.Parameter.VAR_POSITIONAL
)
)
if func_def.args.kwarg:
parameters.append(
inspect.Parameter(func_def.args.kwarg.arg, inspect.Parameter.VAR_KEYWORD)
)

# Handle return annotation
return_annotation = func_def.returns
if return_annotation is not None:
try:
ret_ann_code = compile(
ast.Expression(return_annotation), "<string>", "eval"
)
return_annotation = eval(ret_ann_code, namespace)
except Exception as e:
logger.debug("Failed to evaluate return annotation: %s", e)
return_annotation = inspect.Signature.empty

return inspect.Signature(parameters, return_annotation=return_annotation)


def _get_docstring_from_source(source_code: str, func_name: str) -> Optional[str]:
"""
Extract the docstring of a function from its source code.

Args:
source_code (str): The source code of the function.
func_name (str): The name of the function.

Returns:
The docstring of the function. If the function has no docstring, returns None.
"""
parsed_code = ast.parse(source_code)

func_def = next(
(
node
for node in ast.walk(parsed_code)
if isinstance(node, ast.FunctionDef) and node.name == func_name
),
None,
)
if func_def is None:
raise ValueError(f"Function {func_name} not found in source code")

if (
func_def.body
and isinstance(func_def.body[0], ast.Expr)
and isinstance(func_def.body[0].value, ast.Constant)
):
return func_def.body[0].value.value
return None


def expand_mapping_parameters(
func: Callable, parameters: Dict[str, Any]
) -> List[Dict[str, Any]]:
Expand Down
Loading
Loading