From d34425f51c3648e8e0e693f879b26b6bcb8eaa3c Mon Sep 17 00:00:00 2001 From: Alex Streed Date: Mon, 13 May 2024 12:34:27 -0500 Subject: [PATCH] Working version of ast-based signature generation --- src/prefect/cli/deploy.py | 21 ++-- src/prefect/flows.py | 36 +++++- src/prefect/utilities/callables.py | 164 ++++++++++++++++++++++++++- src/prefect/utilities/importtools.py | 53 +++++++++ 4 files changed, 258 insertions(+), 16 deletions(-) diff --git a/src/prefect/cli/deploy.py b/src/prefect/cli/deploy.py index 3fff1bfed356f..551bb18637e8a 100644 --- a/src/prefect/cli/deploy.py +++ b/src/prefect/cli/deploy.py @@ -65,14 +65,15 @@ from prefect.deployments.steps.core import run_steps from prefect.events import DeploymentTriggerTypes, TriggerTypes from prefect.exceptions import ObjectNotFound, PrefectHTTPStatusError -from prefect.flows import load_flow_from_entrypoint +from prefect.flows import load_flow_name_from_entrypoint from prefect.settings import ( PREFECT_DEFAULT_WORK_POOL_NAME, PREFECT_UI_URL, ) from prefect.utilities.annotations import NotSet -from prefect.utilities.asyncutils import run_sync_in_worker_thread -from prefect.utilities.callables import parameter_schema +from prefect.utilities.callables import ( + parameter_schema_from_entrypoint, +) from prefect.utilities.collections import get_from_dict from prefect.utilities.slugify import slugify from prefect.utilities.templating import ( @@ -474,18 +475,10 @@ async def _run_single_deploy( "You can also provide an entrypoint in a prefect.yaml file." ) deploy_config["entrypoint"] = await prompt_entrypoint(app.console) - if deploy_config.get("flow_name") and deploy_config.get("entrypoint"): - raise ValueError( - "Received an entrypoint and a flow name for this deployment. Please provide" - " either an entrypoint or a flow name." - ) # entrypoint logic if deploy_config.get("entrypoint"): - flow = await run_sync_in_worker_thread( - load_flow_from_entrypoint, deploy_config["entrypoint"] - ) - deploy_config["flow_name"] = flow.name + flow = load_flow_name_from_entrypoint(deploy_config["entrypoint"]) deployment_name = deploy_config.get("name") if not deployment_name: @@ -493,7 +486,9 @@ async def _run_single_deploy( raise ValueError("A deployment name must be provided.") deploy_config["name"] = prompt("Deployment name", default="default") - deploy_config["parameter_openapi_schema"] = parameter_schema(flow) + deploy_config["parameter_openapi_schema"] = parameter_schema_from_entrypoint( + deploy_config["entrypoint"] + ) deploy_config["schedules"] = _construct_schedules( deploy_config, diff --git a/src/prefect/flows.py b/src/prefect/flows.py index 3a402d54b98ac..d348d2b307e0e 100644 --- a/src/prefect/flows.py +++ b/src/prefect/flows.py @@ -5,7 +5,9 @@ # This file requires type-checking with pyright because mypy does not yet support PEP612 # See https://github.com/python/mypy/issues/8645 +import ast import datetime +import importlib.util import inspect import os import tempfile @@ -1640,7 +1642,9 @@ def load_flow_from_script(path: str, flow_name: str = None) -> Flow: ) -def load_flow_from_entrypoint(entrypoint: str) -> Flow: +def load_flow_from_entrypoint( + entrypoint: str, +) -> Flow: """ Extract a flow object from a script at an entrypoint by running all of the code in the file. @@ -1793,3 +1797,33 @@ def my_other_flow(name): ) await runner.start() + + +def load_flow_name_from_entrypoint(entrypoint: str): + if ":" in entrypoint: + # split by the last colon once to handle Windows paths with drive letters i.e C:\path\to\file.py:do_stuff + path, func_name = entrypoint.rsplit(":", maxsplit=1) + source_code = Path(path).read_text() + else: + path, func_name = entrypoint.rsplit(".", maxsplit=1) + spec = importlib.util.find_spec(path) + if not spec or not spec.origin: + raise ValueError(f"Could not find module {path!r}") + source_code = Path(spec.origin).read_text() + parsed_code = ast.parse(source_code) + func_def = next( + node + for node in ast.walk(parsed_code) + if isinstance(node, ast.FunctionDef) and node.name == func_name + ) + for decorator in func_def.decorator_list: + if ( + isinstance(decorator, ast.Call) + and getattr(decorator.func, "id", "") == "flow" + ): + for keyword in decorator.keywords: + if keyword.arg == "name": + return keyword.value.s # Return the string value of the argument + return func_name.replace( + "_", "-" + ) # If no matching decorator or keyword argument is found diff --git a/src/prefect/utilities/callables.py b/src/prefect/utilities/callables.py index c087e0a2d3493..b7f1ea91527a5 100644 --- a/src/prefect/utilities/callables.py +++ b/src/prefect/utilities/callables.py @@ -2,14 +2,20 @@ Utilities for working with Python callables. """ +import ast +import importlib.util import inspect from functools import partial +from pathlib import Path from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple import cloudpickle from prefect._internal.pydantic import HAS_PYDANTIC_V2 from prefect._internal.pydantic.v1_schema import has_v1_type_as_param +from prefect.utilities.importtools import ( + safe_load_namespace, +) if HAS_PYDANTIC_V2: import pydantic.v1 as pydantic @@ -31,7 +37,9 @@ ReservedArgumentError, SignatureMismatchError, ) -from prefect.logging.loggers import disable_logger +from prefect.logging.loggers import disable_logger, get_logger + +logger = get_logger(__name__) def get_call_parameters( @@ -318,9 +326,32 @@ def parameter_schema(fn: Callable) -> ParameterSchema: # `eval_str` is not available in Python < 3.10 signature = inspect.signature(fn) + docstrings = parameter_docstrings(inspect.getdoc(fn)) + + return generate_parameter_schema(signature, docstrings) + + +def parameter_schema_from_entrypoint(entrypoint: str) -> ParameterSchema: + if ":" in entrypoint: + # split by the last colon once to handle Windows paths with drive letters i.e C:\path\to\file.py:do_stuff + path, func_name = entrypoint.rsplit(":", maxsplit=1) + source_code = Path(path).read_text() + else: + path, func_name = entrypoint.rsplit(".", maxsplit=1) + spec = importlib.util.find_spec(path) + if not spec or not spec.origin: + raise ValueError(f"Could not find module {path!r}") + source_code = Path(spec.origin).read_text() + signature = generate_signature_from_source(source_code, func_name) + docstring = get_docstring_from_source(source_code, func_name) + return generate_parameter_schema(signature, parameter_docstrings(docstring)) + + +def generate_parameter_schema( + signature: inspect.Signature, docstrings: Dict[str, str] +) -> ParameterSchema: model_fields = {} aliases = {} - docstrings = parameter_docstrings(inspect.getdoc(fn)) class ModelConfig: arbitrary_types_allowed = True @@ -362,3 +393,132 @@ def raise_for_reserved_arguments(fn: Callable, reserved_arguments: Iterable[str] raise ReservedArgumentError( f"{argument!r} is a reserved argument name and cannot be used." ) + + +def generate_signature_from_source( + source_code: str, func_name: str +) -> inspect.Signature: + """ + Extract the signature of a function from its source code. + + Will ignore missing imports and exceptions while loading local class definitions. + + Args: + source_code: The source code where the function named `func_name` is declared. + func_name: The name of the function. + + Returns: + The signature of the function. + """ + # Load the namespace from the source code. Missing imports and exceptions while + # loading local class definitions are ignored. + namespace = safe_load_namespace(source_code) + # Parse the source code into an AST + parsed_code = ast.parse(source_code) + + func_def = next( + ( + node + for node in ast.walk(parsed_code) + if isinstance(node, ast.FunctionDef) and node.name == func_name + ), + None, + ) + if func_def is None: + raise ValueError(f"Function {func_name} not found in source code") + parameters = [] + + for arg in func_def.args.args: + name = arg.arg + annotation = arg.annotation + if annotation is not None: + try: + # Compile and evaluate the annotation + ann_code = compile(ast.Expression(annotation), "", "eval") + annotation = eval(ann_code, namespace) + except Exception as e: + # Don't raise an error if the annotation evaluation fails. Set the + # annotation to `inspect.Parameter.empty` instead which is equivalent to + # not having an annotation. + logger.debug("Failed to evaluate annotation for %s: %s", name, e) + annotation = inspect.Parameter.empty + else: + annotation = inspect.Parameter.empty + + param = inspect.Parameter( + name, inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=annotation + ) + parameters.append(param) + + defaults = [None] * ( + len(func_def.args.args) - len(func_def.args.defaults) + ) + func_def.args.defaults + for param, default in zip(parameters, defaults): + if default is not None: + try: + def_code = compile(ast.Expression(default), "", "eval") + default = eval(def_code) + except Exception as e: + logger.debug( + "Failed to evaluate default value for %s: %s", param.name, e + ) + default = None # Set to None if evaluation fails + parameters[parameters.index(param)] = param.replace(default=default) + + if func_def.args.vararg: + parameters.append( + inspect.Parameter( + func_def.args.vararg.arg, inspect.Parameter.VAR_POSITIONAL + ) + ) + if func_def.args.kwarg: + parameters.append( + inspect.Parameter(func_def.args.kwarg.arg, inspect.Parameter.VAR_KEYWORD) + ) + + # Handle return annotation + return_annotation = func_def.returns + if return_annotation is not None: + try: + ret_ann_code = compile( + ast.Expression(return_annotation), "", "eval" + ) + return_annotation = eval(ret_ann_code, namespace) + except Exception as e: + logger.debug("Failed to evaluate return annotation: %s", e) + return_annotation = inspect.Signature.empty + + return inspect.Signature(parameters, return_annotation=return_annotation) + + +def get_docstring_from_source(source_code: str, func_name: str) -> Optional[str]: + """ + Extract the docstring of a function from its source code. + + Args: + source_code (str): The source code of the function. + func_name (str): The name of the function. + + Returns: + The docstring of the function. If the function has no docstring, returns None. + """ + parsed_code = ast.parse(source_code) + + func_def = next( + ( + node + for node in ast.walk(parsed_code) + if isinstance(node, ast.FunctionDef) and node.name == func_name + ), + None, + ) + if func_def is None: + raise ValueError(f"Function {func_name} not found in source code") + + if ( + func_def.body + and isinstance(func_def.body[0], ast.Expr) + and isinstance(func_def.body[0].value, (ast.Str, ast.Constant)) + ): + return func_def.body[0].value.s + return None diff --git a/src/prefect/utilities/importtools.py b/src/prefect/utilities/importtools.py index 0286a3c67fc12..257cb982e7bd3 100644 --- a/src/prefect/utilities/importtools.py +++ b/src/prefect/utilities/importtools.py @@ -1,3 +1,4 @@ +import ast import importlib import importlib.util import inspect @@ -14,8 +15,11 @@ import fsspec from prefect.exceptions import ScriptError +from prefect.logging.loggers import get_logger from prefect.utilities.filesystem import filename, is_local_path, tmpchdir +logger = get_logger(__name__) + def to_qualified_name(obj: Any) -> str: """ @@ -356,3 +360,52 @@ def exec_module(self, _: ModuleType) -> None: if self.callback is not None: self.callback(self.alias) sys.modules[self.alias] = root_module + + +def safe_load_namespace(source_code: str): + parsed_code = ast.parse(source_code) + + namespace = {} + + # Walk through the AST and find all import statements + for node in ast.walk(parsed_code): + if isinstance(node, ast.Import): + for alias in node.names: + module_name = alias.name + as_name = alias.asname if alias.asname else module_name + try: + # Attempt to import the module + namespace[as_name] = importlib.import_module(module_name) + logger.debug("Successfully imported %s", module_name) + except ImportError as e: + logger.debug(f"Failed to import {module_name}: {e}") + elif isinstance(node, ast.ImportFrom): + module_name = node.module + if module_name is None: + continue + try: + module = importlib.import_module(module_name) + for alias in node.names: + name = alias.name + asname = alias.asname if alias.asname else name + try: + # Get the specific attribute from the module + attribute = getattr(module, name) + namespace[asname] = attribute + except AttributeError as e: + logger.debug( + "Failed to retrieve %s from %s: %s", name, module_name, e + ) + except ImportError as e: + logger.debug("Failed to import from %s: %s", node.module, e) + + # Handle local class definitions + for node in ast.walk(parsed_code): + if isinstance(node, ast.ClassDef): + try: + # Compile and evaluate each class and function definition locally + code = compile(ast.Module(body=[node]), filename="", mode="exec") + exec(code, namespace) + except Exception as e: + logger.debug("Failed to compile class definition: %s", e) + return namespace