Skip to content

Commit

Permalink
Add support for custom flow decorators to prefect deploy (#14782)
Browse files Browse the repository at this point in the history
  • Loading branch information
desertaxle authored Aug 1, 2024
1 parent 067cbc6 commit 15274df
Show file tree
Hide file tree
Showing 10 changed files with 682 additions and 84 deletions.
17 changes: 7 additions & 10 deletions src/prefect/cli/deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,14 @@
from prefect.deployments.steps.core import run_steps
from prefect.events import DeploymentTriggerTypes, TriggerTypes
from prefect.exceptions import ObjectNotFound, PrefectHTTPStatusError
from prefect.flows import load_flow_arguments_from_entrypoint
from prefect.flows import load_flow_from_entrypoint
from prefect.settings import (
PREFECT_DEFAULT_WORK_POOL_NAME,
PREFECT_UI_URL,
)
from prefect.utilities.annotations import NotSet
from prefect.utilities.callables import (
parameter_schema_from_entrypoint,
parameter_schema,
)
from prefect.utilities.collections import get_from_dict
from prefect.utilities.slugify import slugify
Expand Down Expand Up @@ -481,20 +481,17 @@ async def _run_single_deploy(
)
deploy_config["entrypoint"] = await prompt_entrypoint(app.console)

flow_decorator_arguments = load_flow_arguments_from_entrypoint(
deploy_config["entrypoint"]
)
flow = load_flow_from_entrypoint(deploy_config["entrypoint"])

deploy_config["flow_name"] = flow.name

deploy_config["flow_name"] = flow_decorator_arguments["name"]
deployment_name = deploy_config.get("name")
if not deployment_name:
if not is_interactive():
raise ValueError("A deployment name must be provided.")
deploy_config["name"] = prompt("Deployment name", default="default")

deploy_config["parameter_openapi_schema"] = parameter_schema_from_entrypoint(
deploy_config["entrypoint"]
)
deploy_config["parameter_openapi_schema"] = parameter_schema(flow)

deploy_config["schedules"] = _construct_schedules(
deploy_config,
Expand Down Expand Up @@ -675,7 +672,7 @@ async def _run_single_deploy(
deploy_config["work_pool"]["job_variables"]["image"] = "{{ build-image.image }}"

if not deploy_config.get("description"):
deploy_config["description"] = flow_decorator_arguments.get("description")
deploy_config["description"] = flow.description

# save deploy_config before templating
deploy_config_before_templating = deepcopy(deploy_config)
Expand Down
159 changes: 151 additions & 8 deletions src/prefect/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -1676,6 +1676,7 @@ def load_flow_from_entrypoint(
if ":" in entrypoint:
# split by the last colon once to handle Windows paths with drive letters i.e C:\path\to\file.py:do_stuff
path, func_name = entrypoint.rsplit(":", maxsplit=1)

else:
path, func_name = entrypoint.rsplit(".", maxsplit=1)
try:
Expand All @@ -1684,15 +1685,13 @@ def load_flow_from_entrypoint(
raise MissingFlowError(
f"Flow function with name {func_name!r} not found in {path!r}. "
) from exc
except ScriptError as exc:
except ScriptError:
# If the flow has dependencies that are not installed in the current
# environment, fallback to loading the flow via AST parsing. The
# drawback of this approach is that we're unable to actually load the
# function, so we create a placeholder flow that will re-raise this
# exception when called.

# environment, fallback to loading the flow via AST parsing.
if use_placeholder_flow:
flow = load_placeholder_flow(entrypoint=entrypoint, raises=exc)
flow = safe_load_flow_from_entrypoint(entrypoint)
if flow is None:
raise
else:
raise

Expand Down Expand Up @@ -1855,6 +1854,147 @@ async def async_placeholder_flow(*args, **kwargs):
return Flow(**arguments)


def safe_load_flow_from_entrypoint(entrypoint: str) -> Optional[Flow]:
"""
Load a flow from an entrypoint and return None if an exception is raised.
Args:
entrypoint: a string in the format `<path_to_script>:<flow_func_name>`
or a module path to a flow function
"""
func_def, source_code = _entrypoint_definition_and_source(entrypoint)
path = None
if ":" in entrypoint:
path = entrypoint.rsplit(":")[0]
namespace = safe_load_namespace(source_code, filepath=path)
if func_def.name in namespace:
return namespace[func_def.name]
else:
# If the function is not in the namespace, if may be due to missing dependencies
# for the function. We will attempt to compile each annotation and default value
# and remove them from the function definition to see if the function can be
# compiled without them.

return _sanitize_and_load_flow(func_def, namespace)


def _sanitize_and_load_flow(
func_def: Union[ast.FunctionDef, ast.AsyncFunctionDef], namespace: Dict[str, Any]
) -> Optional[Flow]:
"""
Attempt to load a flow from the function definition after sanitizing the annotations
and defaults that can't be compiled.
Args:
func_def: the function definition
namespace: the namespace to load the function into
Returns:
The loaded function or None if the function can't be loaded
after sanitizing the annotations and defaults.
"""
args = func_def.args.posonlyargs + func_def.args.args + func_def.args.kwonlyargs
if func_def.args.vararg:
args.append(func_def.args.vararg)
if func_def.args.kwarg:
args.append(func_def.args.kwarg)
# Remove annotations that can't be compiled
for arg in args:
if arg.annotation is not None:
try:
code = compile(
ast.Expression(arg.annotation),
filename="<ast>",
mode="eval",
)
exec(code, namespace)
except Exception as e:
logger.debug(
"Failed to evaluate annotation for argument %s due to the following error. Ignoring annotation.",
arg.arg,
exc_info=e,
)
arg.annotation = None

# Remove defaults that can't be compiled
new_defaults = []
for default in func_def.args.defaults:
try:
code = compile(ast.Expression(default), "<ast>", "eval")
exec(code, namespace)
new_defaults.append(default)
except Exception as e:
logger.debug(
"Failed to evaluate default value %s due to the following error. Ignoring default.",
default,
exc_info=e,
)
new_defaults.append(
ast.Constant(
value=None, lineno=default.lineno, col_offset=default.col_offset
)
)
func_def.args.defaults = new_defaults

# Remove kw_defaults that can't be compiled
new_kw_defaults = []
for default in func_def.args.kw_defaults:
if default is not None:
try:
code = compile(ast.Expression(default), "<ast>", "eval")
exec(code, namespace)
new_kw_defaults.append(default)
except Exception as e:
logger.debug(
"Failed to evaluate default value %s due to the following error. Ignoring default.",
default,
exc_info=e,
)
new_kw_defaults.append(
ast.Constant(
value=None,
lineno=default.lineno,
col_offset=default.col_offset,
)
)
else:
new_kw_defaults.append(
ast.Constant(
value=None,
lineno=func_def.lineno,
col_offset=func_def.col_offset,
)
)
func_def.args.kw_defaults = new_kw_defaults

if func_def.returns is not None:
try:
code = compile(
ast.Expression(func_def.returns), filename="<ast>", mode="eval"
)
exec(code, namespace)
except Exception as e:
logger.debug(
"Failed to evaluate return annotation due to the following error. Ignoring annotation.",
exc_info=e,
)
func_def.returns = None

# Attempt to compile the function without annotations and defaults that
# can't be compiled
try:
code = compile(
ast.Module(body=[func_def], type_ignores=[]),
filename="<ast>",
mode="exec",
)
exec(code, namespace)
except Exception as e:
logger.debug("Failed to compile: %s", e)
else:
return namespace.get(func_def.name)


def load_flow_arguments_from_entrypoint(
entrypoint: str, arguments: Optional[Union[List[str], Set[str]]] = None
) -> Dict[str, Any]:
Expand All @@ -1870,6 +2010,9 @@ def load_flow_arguments_from_entrypoint(
"""

func_def, source_code = _entrypoint_definition_and_source(entrypoint)
path = None
if ":" in entrypoint:
path = entrypoint.rsplit(":")[0]

if arguments is None:
# If no arguments are provided default to known arguments that are of
Expand Down Expand Up @@ -1905,7 +2048,7 @@ def load_flow_arguments_from_entrypoint(

# if the arg value is not a raw str (i.e. a variable or expression),
# then attempt to evaluate it
namespace = safe_load_namespace(source_code)
namespace = safe_load_namespace(source_code, filepath=path)
literal_arg_value = ast.get_source_segment(source_code, keyword.value)
cleaned_value = (
literal_arg_value.replace("\n", "") if literal_arg_value else ""
Expand Down
8 changes: 5 additions & 3 deletions src/prefect/utilities/callables.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,17 +346,19 @@ def parameter_schema_from_entrypoint(entrypoint: str) -> ParameterSchema:
Returns:
ParameterSchema: The parameter schema for the function.
"""
filepath = None
if ":" in entrypoint:
# split by the last colon once to handle Windows paths with drive letters i.e C:\path\to\file.py:do_stuff
path, func_name = entrypoint.rsplit(":", maxsplit=1)
source_code = Path(path).read_text()
filepath = path
else:
path, func_name = entrypoint.rsplit(".", maxsplit=1)
spec = importlib.util.find_spec(path)
if not spec or not spec.origin:
raise ValueError(f"Could not find module {path!r}")
source_code = Path(spec.origin).read_text()
signature = _generate_signature_from_source(source_code, func_name)
signature = _generate_signature_from_source(source_code, func_name, filepath)
docstring = _get_docstring_from_source(source_code, func_name)
return generate_parameter_schema(signature, parameter_docstrings(docstring))

Expand Down Expand Up @@ -424,7 +426,7 @@ def raise_for_reserved_arguments(fn: Callable, reserved_arguments: Iterable[str]


def _generate_signature_from_source(
source_code: str, func_name: str
source_code: str, func_name: str, filepath: Optional[str] = None
) -> inspect.Signature:
"""
Extract the signature of a function from its source code.
Expand All @@ -440,7 +442,7 @@ def _generate_signature_from_source(
"""
# Load the namespace from the source code. Missing imports and exceptions while
# loading local class definitions are ignored.
namespace = safe_load_namespace(source_code)
namespace = safe_load_namespace(source_code, filepath=filepath)
# Parse the source code into an AST
parsed_code = ast.parse(source_code)

Expand Down
Loading

0 comments on commit 15274df

Please sign in to comment.