Skip to content

Commit

Permalink
Working version of ast-based signature generation
Browse files Browse the repository at this point in the history
  • Loading branch information
desertaxle committed May 13, 2024
1 parent 0401b48 commit d34425f
Show file tree
Hide file tree
Showing 4 changed files with 258 additions and 16 deletions.
21 changes: 8 additions & 13 deletions src/prefect/cli/deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,15 @@
from prefect.deployments.steps.core import run_steps
from prefect.events import DeploymentTriggerTypes, TriggerTypes
from prefect.exceptions import ObjectNotFound, PrefectHTTPStatusError
from prefect.flows import load_flow_from_entrypoint
from prefect.flows import load_flow_name_from_entrypoint
from prefect.settings import (
PREFECT_DEFAULT_WORK_POOL_NAME,
PREFECT_UI_URL,
)
from prefect.utilities.annotations import NotSet
from prefect.utilities.asyncutils import run_sync_in_worker_thread
from prefect.utilities.callables import parameter_schema
from prefect.utilities.callables import (
parameter_schema_from_entrypoint,
)
from prefect.utilities.collections import get_from_dict
from prefect.utilities.slugify import slugify
from prefect.utilities.templating import (
Expand Down Expand Up @@ -474,26 +475,20 @@ async def _run_single_deploy(
"You can also provide an entrypoint in a prefect.yaml file."
)
deploy_config["entrypoint"] = await prompt_entrypoint(app.console)
if deploy_config.get("flow_name") and deploy_config.get("entrypoint"):
raise ValueError(
"Received an entrypoint and a flow name for this deployment. Please provide"
" either an entrypoint or a flow name."
)

# entrypoint logic
if deploy_config.get("entrypoint"):
flow = await run_sync_in_worker_thread(
load_flow_from_entrypoint, deploy_config["entrypoint"]
)
deploy_config["flow_name"] = flow.name
flow = load_flow_name_from_entrypoint(deploy_config["entrypoint"])

deployment_name = deploy_config.get("name")
if not deployment_name:
if not is_interactive():
raise ValueError("A deployment name must be provided.")
deploy_config["name"] = prompt("Deployment name", default="default")

deploy_config["parameter_openapi_schema"] = parameter_schema(flow)
deploy_config["parameter_openapi_schema"] = parameter_schema_from_entrypoint(
deploy_config["entrypoint"]
)

deploy_config["schedules"] = _construct_schedules(
deploy_config,
Expand Down
36 changes: 35 additions & 1 deletion src/prefect/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
# This file requires type-checking with pyright because mypy does not yet support PEP612
# See https://github.com/python/mypy/issues/8645

import ast
import datetime
import importlib.util
import inspect
import os
import tempfile
Expand Down Expand Up @@ -1640,7 +1642,9 @@ def load_flow_from_script(path: str, flow_name: str = None) -> Flow:
)


def load_flow_from_entrypoint(entrypoint: str) -> Flow:
def load_flow_from_entrypoint(
entrypoint: str,
) -> Flow:
"""
Extract a flow object from a script at an entrypoint by running all of the code in the file.
Expand Down Expand Up @@ -1793,3 +1797,33 @@ def my_other_flow(name):
)

await runner.start()


def load_flow_name_from_entrypoint(entrypoint: str):
if ":" in entrypoint:
# split by the last colon once to handle Windows paths with drive letters i.e C:\path\to\file.py:do_stuff
path, func_name = entrypoint.rsplit(":", maxsplit=1)
source_code = Path(path).read_text()
else:
path, func_name = entrypoint.rsplit(".", maxsplit=1)
spec = importlib.util.find_spec(path)
if not spec or not spec.origin:
raise ValueError(f"Could not find module {path!r}")
source_code = Path(spec.origin).read_text()
parsed_code = ast.parse(source_code)
func_def = next(
node
for node in ast.walk(parsed_code)
if isinstance(node, ast.FunctionDef) and node.name == func_name
)
for decorator in func_def.decorator_list:
if (
isinstance(decorator, ast.Call)
and getattr(decorator.func, "id", "") == "flow"
):
for keyword in decorator.keywords:
if keyword.arg == "name":
return keyword.value.s # Return the string value of the argument
return func_name.replace(
"_", "-"
) # If no matching decorator or keyword argument is found
164 changes: 162 additions & 2 deletions src/prefect/utilities/callables.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,20 @@
Utilities for working with Python callables.
"""

import ast
import importlib.util
import inspect
from functools import partial
from pathlib import Path
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple

import cloudpickle

from prefect._internal.pydantic import HAS_PYDANTIC_V2
from prefect._internal.pydantic.v1_schema import has_v1_type_as_param
from prefect.utilities.importtools import (
safe_load_namespace,
)

if HAS_PYDANTIC_V2:
import pydantic.v1 as pydantic
Expand All @@ -31,7 +37,9 @@
ReservedArgumentError,
SignatureMismatchError,
)
from prefect.logging.loggers import disable_logger
from prefect.logging.loggers import disable_logger, get_logger

logger = get_logger(__name__)


def get_call_parameters(
Expand Down Expand Up @@ -318,9 +326,32 @@ def parameter_schema(fn: Callable) -> ParameterSchema:
# `eval_str` is not available in Python < 3.10
signature = inspect.signature(fn)

docstrings = parameter_docstrings(inspect.getdoc(fn))

return generate_parameter_schema(signature, docstrings)


def parameter_schema_from_entrypoint(entrypoint: str) -> ParameterSchema:
if ":" in entrypoint:
# split by the last colon once to handle Windows paths with drive letters i.e C:\path\to\file.py:do_stuff
path, func_name = entrypoint.rsplit(":", maxsplit=1)
source_code = Path(path).read_text()
else:
path, func_name = entrypoint.rsplit(".", maxsplit=1)
spec = importlib.util.find_spec(path)
if not spec or not spec.origin:
raise ValueError(f"Could not find module {path!r}")
source_code = Path(spec.origin).read_text()
signature = generate_signature_from_source(source_code, func_name)
docstring = get_docstring_from_source(source_code, func_name)
return generate_parameter_schema(signature, parameter_docstrings(docstring))


def generate_parameter_schema(
signature: inspect.Signature, docstrings: Dict[str, str]
) -> ParameterSchema:
model_fields = {}
aliases = {}
docstrings = parameter_docstrings(inspect.getdoc(fn))

class ModelConfig:
arbitrary_types_allowed = True
Expand Down Expand Up @@ -362,3 +393,132 @@ def raise_for_reserved_arguments(fn: Callable, reserved_arguments: Iterable[str]
raise ReservedArgumentError(
f"{argument!r} is a reserved argument name and cannot be used."
)


def generate_signature_from_source(
source_code: str, func_name: str
) -> inspect.Signature:
"""
Extract the signature of a function from its source code.
Will ignore missing imports and exceptions while loading local class definitions.
Args:
source_code: The source code where the function named `func_name` is declared.
func_name: The name of the function.
Returns:
The signature of the function.
"""
# Load the namespace from the source code. Missing imports and exceptions while
# loading local class definitions are ignored.
namespace = safe_load_namespace(source_code)
# Parse the source code into an AST
parsed_code = ast.parse(source_code)

func_def = next(
(
node
for node in ast.walk(parsed_code)
if isinstance(node, ast.FunctionDef) and node.name == func_name
),
None,
)
if func_def is None:
raise ValueError(f"Function {func_name} not found in source code")
parameters = []

for arg in func_def.args.args:
name = arg.arg
annotation = arg.annotation
if annotation is not None:
try:
# Compile and evaluate the annotation
ann_code = compile(ast.Expression(annotation), "<string>", "eval")
annotation = eval(ann_code, namespace)
except Exception as e:
# Don't raise an error if the annotation evaluation fails. Set the
# annotation to `inspect.Parameter.empty` instead which is equivalent to
# not having an annotation.
logger.debug("Failed to evaluate annotation for %s: %s", name, e)
annotation = inspect.Parameter.empty
else:
annotation = inspect.Parameter.empty

param = inspect.Parameter(
name, inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=annotation
)
parameters.append(param)

defaults = [None] * (
len(func_def.args.args) - len(func_def.args.defaults)
) + func_def.args.defaults
for param, default in zip(parameters, defaults):
if default is not None:
try:
def_code = compile(ast.Expression(default), "<string>", "eval")
default = eval(def_code)
except Exception as e:
logger.debug(
"Failed to evaluate default value for %s: %s", param.name, e
)
default = None # Set to None if evaluation fails
parameters[parameters.index(param)] = param.replace(default=default)

if func_def.args.vararg:
parameters.append(
inspect.Parameter(
func_def.args.vararg.arg, inspect.Parameter.VAR_POSITIONAL
)
)
if func_def.args.kwarg:
parameters.append(
inspect.Parameter(func_def.args.kwarg.arg, inspect.Parameter.VAR_KEYWORD)
)

# Handle return annotation
return_annotation = func_def.returns
if return_annotation is not None:
try:
ret_ann_code = compile(
ast.Expression(return_annotation), "<string>", "eval"
)
return_annotation = eval(ret_ann_code, namespace)
except Exception as e:
logger.debug("Failed to evaluate return annotation: %s", e)
return_annotation = inspect.Signature.empty

return inspect.Signature(parameters, return_annotation=return_annotation)


def get_docstring_from_source(source_code: str, func_name: str) -> Optional[str]:
"""
Extract the docstring of a function from its source code.
Args:
source_code (str): The source code of the function.
func_name (str): The name of the function.
Returns:
The docstring of the function. If the function has no docstring, returns None.
"""
parsed_code = ast.parse(source_code)

func_def = next(
(
node
for node in ast.walk(parsed_code)
if isinstance(node, ast.FunctionDef) and node.name == func_name
),
None,
)
if func_def is None:
raise ValueError(f"Function {func_name} not found in source code")

if (
func_def.body
and isinstance(func_def.body[0], ast.Expr)
and isinstance(func_def.body[0].value, (ast.Str, ast.Constant))
):
return func_def.body[0].value.s
return None
53 changes: 53 additions & 0 deletions src/prefect/utilities/importtools.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import ast
import importlib
import importlib.util
import inspect
Expand All @@ -14,8 +15,11 @@
import fsspec

from prefect.exceptions import ScriptError
from prefect.logging.loggers import get_logger
from prefect.utilities.filesystem import filename, is_local_path, tmpchdir

logger = get_logger(__name__)


def to_qualified_name(obj: Any) -> str:
"""
Expand Down Expand Up @@ -356,3 +360,52 @@ def exec_module(self, _: ModuleType) -> None:
if self.callback is not None:
self.callback(self.alias)
sys.modules[self.alias] = root_module


def safe_load_namespace(source_code: str):
parsed_code = ast.parse(source_code)

namespace = {}

# Walk through the AST and find all import statements
for node in ast.walk(parsed_code):
if isinstance(node, ast.Import):
for alias in node.names:
module_name = alias.name
as_name = alias.asname if alias.asname else module_name
try:
# Attempt to import the module
namespace[as_name] = importlib.import_module(module_name)
logger.debug("Successfully imported %s", module_name)
except ImportError as e:
logger.debug(f"Failed to import {module_name}: {e}")
elif isinstance(node, ast.ImportFrom):
module_name = node.module
if module_name is None:
continue
try:
module = importlib.import_module(module_name)
for alias in node.names:
name = alias.name
asname = alias.asname if alias.asname else name
try:
# Get the specific attribute from the module
attribute = getattr(module, name)
namespace[asname] = attribute
except AttributeError as e:
logger.debug(
"Failed to retrieve %s from %s: %s", name, module_name, e
)
except ImportError as e:
logger.debug("Failed to import from %s: %s", node.module, e)

# Handle local class definitions
for node in ast.walk(parsed_code):
if isinstance(node, ast.ClassDef):
try:
# Compile and evaluate each class and function definition locally
code = compile(ast.Module(body=[node]), filename="<ast>", mode="exec")
exec(code, namespace)
except Exception as e:
logger.debug("Failed to compile class definition: %s", e)
return namespace

0 comments on commit d34425f

Please sign in to comment.