Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for custom flow decorators to prefect deploy #14694

Merged
merged 11 commits into from
Jul 29, 2024
16 changes: 6 additions & 10 deletions src/prefect/cli/deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,14 @@
from prefect.deployments.steps.core import run_steps
from prefect.events import DeploymentTriggerTypes, TriggerTypes
from prefect.exceptions import ObjectNotFound, PrefectHTTPStatusError
from prefect.flows import load_flow_arguments_from_entrypoint
from prefect.flows import load_flow_from_entrypoint
from prefect.settings import (
PREFECT_DEFAULT_WORK_POOL_NAME,
PREFECT_UI_URL,
)
from prefect.utilities.annotations import NotSet
from prefect.utilities.callables import (
parameter_schema_from_entrypoint,
parameter_schema,
)
from prefect.utilities.collections import get_from_dict
from prefect.utilities.slugify import slugify
Expand Down Expand Up @@ -471,21 +471,17 @@ async def _run_single_deploy(
)
deploy_config["entrypoint"] = await prompt_entrypoint(app.console)

flow_decorator_arguments = load_flow_arguments_from_entrypoint(
deploy_config["entrypoint"], arguments={"name", "description"}
)
flow = load_flow_from_entrypoint(deploy_config["entrypoint"])

deploy_config["flow_name"] = flow_decorator_arguments["name"]
deploy_config["flow_name"] = flow.name

deployment_name = deploy_config.get("name")
if not deployment_name:
if not is_interactive():
raise ValueError("A deployment name must be provided.")
deploy_config["name"] = prompt("Deployment name", default="default")

deploy_config["parameter_openapi_schema"] = parameter_schema_from_entrypoint(
deploy_config["entrypoint"]
)
deploy_config["parameter_openapi_schema"] = parameter_schema(flow)

deploy_config["schedules"] = _construct_schedules(
deploy_config,
Expand Down Expand Up @@ -654,7 +650,7 @@ async def _run_single_deploy(
deploy_config["work_pool"]["job_variables"]["image"] = "{{ build-image.image }}"

if not deploy_config.get("description"):
deploy_config["description"] = flow_decorator_arguments.get("description")
deploy_config["description"] = flow.description
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm loving this, much cleaner.


# save deploy_config before templating
deploy_config_before_templating = deepcopy(deploy_config)
Expand Down
157 changes: 150 additions & 7 deletions src/prefect/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -1734,14 +1734,13 @@ def load_flow_from_entrypoint(
raise MissingFlowError(
f"Flow function with name {func_name!r} not found in {path!r}. "
) from exc
except ScriptError as exc:
except ScriptError:
# If the flow has dependencies that are not installed in the current
# environment, fallback to loading the flow via AST parsing. The
# drawback of this approach is that we're unable to actually load the
# function, so we create a placeholder flow that will re-raise this
# exception when called.
# environment, fallback to loading the flow via AST parsing.
if use_placeholder_flow:
flow = load_placeholder_flow(entrypoint=entrypoint, raises=exc)
flow = safe_load_flow_from_entrypoint(entrypoint)
if flow is None:
raise
else:
raise

Expand Down Expand Up @@ -1976,6 +1975,147 @@ async def async_placeholder_flow(*args, **kwargs):
return Flow(**arguments)


def safe_load_flow_from_entrypoint(entrypoint: str) -> Optional[Flow]:
"""
Load a flow from an entrypoint and return None if an exception is raised.

Args:
entrypoint: a string in the format `<path_to_script>:<flow_func_name>`
or a module path to a flow function
"""
func_def, source_code = _entrypoint_definition_and_source(entrypoint)
path = None
if ":" in entrypoint:
path = entrypoint.rsplit(":")[0]
namespace = safe_load_namespace(source_code, filepath=path)
if func_def.name in namespace:
return namespace[func_def.name]
else:
# If the function is not in the namespace, if may be due to missing dependencies
# for the function. We will attempt to compile each annotation and default value
# and remove them from the function definition to see if the function can be
# compiled without them.

return _sanitize_and_load_flow(func_def, namespace)


def _sanitize_and_load_flow(
func_def: Union[ast.FunctionDef, ast.AsyncFunctionDef], namespace: Dict[str, Any]
) -> Optional[Flow]:
"""
Attempt to load a flow from the function definition after sanitizing the annotations
and defaults that can't be compiled.

Args:
func_def: the function definition
namespace: the namespace to load the function into

Returns:
The loaded function or None if the function can't be loaded
after sanitizing the annotations and defaults.
"""
args = func_def.args.posonlyargs + func_def.args.args + func_def.args.kwonlyargs
if func_def.args.vararg:
args.append(func_def.args.vararg)
if func_def.args.kwarg:
args.append(func_def.args.kwarg)
# Remove annotations that can't be compiled
for arg in args:
if arg.annotation is not None:
try:
code = compile(
ast.Expression(arg.annotation),
filename="<ast>",
mode="eval",
)
exec(code, namespace)
except Exception as e:
logger.debug(
"Failed to evaluate annotation for argument %s due to the following error. Ignoring annotation.",
arg.arg,
exc_info=e,
)
arg.annotation = None

# Remove defaults that can't be compiled
new_defaults = []
for default in func_def.args.defaults:
try:
code = compile(ast.Expression(default), "<ast>", "eval")
exec(code, namespace)
new_defaults.append(default)
except Exception as e:
logger.debug(
"Failed to evaluate default value %s due to the following error. Ignoring default.",
default,
exc_info=e,
)
new_defaults.append(
ast.Constant(
value=None, lineno=default.lineno, col_offset=default.col_offset
)
)
func_def.args.defaults = new_defaults

# Remove kw_defaults that can't be compiled
new_kw_defaults = []
for default in func_def.args.kw_defaults:
if default is not None:
try:
code = compile(ast.Expression(default), "<ast>", "eval")
exec(code, namespace)
new_kw_defaults.append(default)
except Exception as e:
logger.debug(
"Failed to evaluate default value %s due to the following error. Ignoring default.",
default,
exc_info=e,
)
new_kw_defaults.append(
ast.Constant(
value=None,
lineno=default.lineno,
col_offset=default.col_offset,
)
)
else:
new_kw_defaults.append(
ast.Constant(
value=None,
lineno=func_def.lineno,
col_offset=func_def.col_offset,
)
)
func_def.args.kw_defaults = new_kw_defaults

if func_def.returns is not None:
try:
code = compile(
ast.Expression(func_def.returns), filename="<ast>", mode="eval"
)
exec(code, namespace)
except Exception as e:
logger.debug(
"Failed to evaluate return annotation due to the following error. Ignoring annotation.",
exc_info=e,
)
func_def.returns = None

# Attempt to compile the function without annotations and defaults that
# can't be compiled
try:
code = compile(
ast.Module(body=[func_def], type_ignores=[]),
filename="<ast>",
mode="exec",
)
exec(code, namespace)
except Exception as e:
logger.debug("Failed to compile: %s", e)
else:
return namespace.get(func_def.name)


def load_flow_arguments_from_entrypoint(
entrypoint: str, arguments: Optional[Union[List[str], Set[str]]] = None
) -> dict[str, Any]:
Expand All @@ -1991,6 +2131,9 @@ def load_flow_arguments_from_entrypoint(
"""

func_def, source_code = _entrypoint_definition_and_source(entrypoint)
path = None
if ":" in entrypoint:
path = entrypoint.rsplit(":")[0]

if arguments is None:
# If no arguments are provided default to known arguments that are of
Expand Down Expand Up @@ -2026,7 +2169,7 @@ def load_flow_arguments_from_entrypoint(

# if the arg value is not a raw str (i.e. a variable or expression),
# then attempt to evaluate it
namespace = safe_load_namespace(source_code)
namespace = safe_load_namespace(source_code, filepath=path)
literal_arg_value = ast.get_source_segment(source_code, keyword.value)
cleaned_value = (
literal_arg_value.replace("\n", "") if literal_arg_value else ""
Expand Down
8 changes: 5 additions & 3 deletions src/prefect/utilities/callables.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,17 +364,19 @@ def parameter_schema_from_entrypoint(entrypoint: str) -> ParameterSchema:
Returns:
ParameterSchema: The parameter schema for the function.
"""
filepath = None
if ":" in entrypoint:
# split by the last colon once to handle Windows paths with drive letters i.e C:\path\to\file.py:do_stuff
path, func_name = entrypoint.rsplit(":", maxsplit=1)
source_code = Path(path).read_text()
filepath = path
else:
path, func_name = entrypoint.rsplit(".", maxsplit=1)
spec = importlib.util.find_spec(path)
if not spec or not spec.origin:
raise ValueError(f"Could not find module {path!r}")
source_code = Path(spec.origin).read_text()
signature = _generate_signature_from_source(source_code, func_name)
signature = _generate_signature_from_source(source_code, func_name, filepath)
docstring = _get_docstring_from_source(source_code, func_name)
return generate_parameter_schema(signature, parameter_docstrings(docstring))

Expand Down Expand Up @@ -444,7 +446,7 @@ def raise_for_reserved_arguments(fn: Callable, reserved_arguments: Iterable[str]


def _generate_signature_from_source(
source_code: str, func_name: str
source_code: str, func_name: str, filepath: Optional[str] = None
) -> inspect.Signature:
"""
Extract the signature of a function from its source code.
Expand All @@ -460,7 +462,7 @@ def _generate_signature_from_source(
"""
# Load the namespace from the source code. Missing imports and exceptions while
# loading local class definitions are ignored.
namespace = safe_load_namespace(source_code)
namespace = safe_load_namespace(source_code, filepath=filepath)
# Parse the source code into an AST
parsed_code = ast.parse(source_code)

Expand Down
Loading