From de7ef82089a05904de1f570ec955a1d84dce0465 Mon Sep 17 00:00:00 2001 From: Alexander Streed Date: Wed, 29 May 2024 09:42:08 -0500 Subject: [PATCH] Allow flow parameter schema generation when dependencies are missing (#13315) --- src/prefect/cli/deploy.py | 30 +- src/prefect/flows.py | 60 ++- src/prefect/utilities/callables.py | 192 +++++++- src/prefect/utilities/importtools.py | 71 +++ tests/test_flows.py | 103 +++- tests/utilities/test_callables.py | 670 +++++++++++++++++++++++++++ tests/utilities/test_importtools.py | 78 ++++ 7 files changed, 1183 insertions(+), 21 deletions(-) diff --git a/src/prefect/cli/deploy.py b/src/prefect/cli/deploy.py index 3fff1bfed356..aaa9ff1f6816 100644 --- a/src/prefect/cli/deploy.py +++ b/src/prefect/cli/deploy.py @@ -65,14 +65,15 @@ from prefect.deployments.steps.core import run_steps from prefect.events import DeploymentTriggerTypes, TriggerTypes from prefect.exceptions import ObjectNotFound, PrefectHTTPStatusError -from prefect.flows import load_flow_from_entrypoint +from prefect.flows import load_flow_argument_from_entrypoint from prefect.settings import ( PREFECT_DEFAULT_WORK_POOL_NAME, PREFECT_UI_URL, ) from prefect.utilities.annotations import NotSet -from prefect.utilities.asyncutils import run_sync_in_worker_thread -from prefect.utilities.callables import parameter_schema +from prefect.utilities.callables import ( + parameter_schema_from_entrypoint, +) from prefect.utilities.collections import get_from_dict from prefect.utilities.slugify import slugify from prefect.utilities.templating import ( @@ -474,18 +475,10 @@ async def _run_single_deploy( "You can also provide an entrypoint in a prefect.yaml file." ) deploy_config["entrypoint"] = await prompt_entrypoint(app.console) - if deploy_config.get("flow_name") and deploy_config.get("entrypoint"): - raise ValueError( - "Received an entrypoint and a flow name for this deployment. Please provide" - " either an entrypoint or a flow name." - ) - # entrypoint logic - if deploy_config.get("entrypoint"): - flow = await run_sync_in_worker_thread( - load_flow_from_entrypoint, deploy_config["entrypoint"] - ) - deploy_config["flow_name"] = flow.name + deploy_config["flow_name"] = load_flow_argument_from_entrypoint( + deploy_config["entrypoint"], arg="name" + ) deployment_name = deploy_config.get("name") if not deployment_name: @@ -493,7 +486,9 @@ async def _run_single_deploy( raise ValueError("A deployment name must be provided.") deploy_config["name"] = prompt("Deployment name", default="default") - deploy_config["parameter_openapi_schema"] = parameter_schema(flow) + deploy_config["parameter_openapi_schema"] = parameter_schema_from_entrypoint( + deploy_config["entrypoint"] + ) deploy_config["schedules"] = _construct_schedules( deploy_config, @@ -674,8 +669,9 @@ async def _run_single_deploy( deploy_config["work_pool"]["job_variables"]["image"] = "{{ build-image.image }}" if not deploy_config.get("description"): - deploy_config["description"] = flow.description - + deploy_config["description"] = load_flow_argument_from_entrypoint( + deploy_config["entrypoint"], arg="description" + ) # save deploy_config before templating deploy_config_before_templating = deepcopy(deploy_config) ## apply templating from build and push steps to the final deployment spec diff --git a/src/prefect/flows.py b/src/prefect/flows.py index 3a402d54b98a..be16a626864b 100644 --- a/src/prefect/flows.py +++ b/src/prefect/flows.py @@ -5,7 +5,9 @@ # This file requires type-checking with pyright because mypy does not yet support PEP612 # See https://github.com/python/mypy/issues/8645 +import ast import datetime +import importlib.util import inspect import os import tempfile @@ -1640,7 +1642,9 @@ def load_flow_from_script(path: str, flow_name: str = None) -> Flow: ) -def load_flow_from_entrypoint(entrypoint: str) -> Flow: +def load_flow_from_entrypoint( + entrypoint: str, +) -> Flow: """ Extract a flow object from a script at an entrypoint by running all of the code in the file. @@ -1793,3 +1797,57 @@ def my_other_flow(name): ) await runner.start() + + +def load_flow_argument_from_entrypoint( + entrypoint: str, arg: str = "name" +) -> Optional[str]: + """ + Extract a flow argument from an entrypoint string. + + Loads the source code of the entrypoint and extracts the flow argument from the + `flow` decorator. + + Args: + entrypoint: a string in the format `:` or a module path + to a flow function + + Returns: + The flow argument value + """ + if ":" in entrypoint: + # split by the last colon once to handle Windows paths with drive letters i.e C:\path\to\file.py:do_stuff + path, func_name = entrypoint.rsplit(":", maxsplit=1) + source_code = Path(path).read_text() + else: + path, func_name = entrypoint.rsplit(".", maxsplit=1) + spec = importlib.util.find_spec(path) + if not spec or not spec.origin: + raise ValueError(f"Could not find module {path!r}") + source_code = Path(spec.origin).read_text() + parsed_code = ast.parse(source_code) + func_def = next( + ( + node + for node in ast.walk(parsed_code) + if isinstance(node, ast.FunctionDef) and node.name == func_name + ), + None, + ) + if not func_def: + raise ValueError(f"Could not find flow {func_name!r} in {path!r}") + for decorator in func_def.decorator_list: + if ( + isinstance(decorator, ast.Call) + and getattr(decorator.func, "id", "") == "flow" + ): + for keyword in decorator.keywords: + if keyword.arg == arg: + return ( + keyword.value.value + ) # Return the string value of the argument + + if arg == "name": + return func_name.replace( + "_", "-" + ) # If no matching decorator or keyword argument is found diff --git a/src/prefect/utilities/callables.py b/src/prefect/utilities/callables.py index c087e0a2d349..9f60bb1b1b73 100644 --- a/src/prefect/utilities/callables.py +++ b/src/prefect/utilities/callables.py @@ -2,8 +2,11 @@ Utilities for working with Python callables. """ +import ast +import importlib.util import inspect from functools import partial +from pathlib import Path from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple import cloudpickle @@ -31,7 +34,10 @@ ReservedArgumentError, SignatureMismatchError, ) -from prefect.logging.loggers import disable_logger +from prefect.logging.loggers import disable_logger, get_logger +from prefect.utilities.importtools import safe_load_namespace + +logger = get_logger(__name__) def get_call_parameters( @@ -318,9 +324,62 @@ def parameter_schema(fn: Callable) -> ParameterSchema: # `eval_str` is not available in Python < 3.10 signature = inspect.signature(fn) + docstrings = parameter_docstrings(inspect.getdoc(fn)) + + return generate_parameter_schema(signature, docstrings) + + +def parameter_schema_from_entrypoint(entrypoint: str) -> ParameterSchema: + """ + Generate a parameter schema from an entrypoint string. + + Will load the source code of the function and extract the signature and docstring + to generate the schema. + + Useful for generating a schema for a function when instantiating the function may + not be possible due to missing imports or other issues. + + Args: + entrypoint: A string representing the entrypoint to a function. The string + should be in the format of `module.path.to.function:do_stuff`. + + Returns: + ParameterSchema: The parameter schema for the function. + """ + if ":" in entrypoint: + # split by the last colon once to handle Windows paths with drive letters i.e C:\path\to\file.py:do_stuff + path, func_name = entrypoint.rsplit(":", maxsplit=1) + source_code = Path(path).read_text() + else: + path, func_name = entrypoint.rsplit(".", maxsplit=1) + spec = importlib.util.find_spec(path) + if not spec or not spec.origin: + raise ValueError(f"Could not find module {path!r}") + source_code = Path(spec.origin).read_text() + signature = _generate_signature_from_source(source_code, func_name) + docstring = _get_docstring_from_source(source_code, func_name) + return generate_parameter_schema(signature, parameter_docstrings(docstring)) + + +def generate_parameter_schema( + signature: inspect.Signature, docstrings: Dict[str, str] +) -> ParameterSchema: + """ + Generate a parameter schema from a function signature and docstrings. + + To get a signature from a function, use `inspect.signature(fn)` or + `_generate_signature_from_source(source_code, func_name)`. + + Args: + signature: The function signature. + docstrings: A dictionary mapping parameter names to docstrings. + + Returns: + ParameterSchema: The parameter schema. + """ + model_fields = {} aliases = {} - docstrings = parameter_docstrings(inspect.getdoc(fn)) class ModelConfig: arbitrary_types_allowed = True @@ -362,3 +421,132 @@ def raise_for_reserved_arguments(fn: Callable, reserved_arguments: Iterable[str] raise ReservedArgumentError( f"{argument!r} is a reserved argument name and cannot be used." ) + + +def _generate_signature_from_source( + source_code: str, func_name: str +) -> inspect.Signature: + """ + Extract the signature of a function from its source code. + + Will ignore missing imports and exceptions while loading local class definitions. + + Args: + source_code: The source code where the function named `func_name` is declared. + func_name: The name of the function. + + Returns: + The signature of the function. + """ + # Load the namespace from the source code. Missing imports and exceptions while + # loading local class definitions are ignored. + namespace = safe_load_namespace(source_code) + # Parse the source code into an AST + parsed_code = ast.parse(source_code) + + func_def = next( + ( + node + for node in ast.walk(parsed_code) + if isinstance(node, ast.FunctionDef) and node.name == func_name + ), + None, + ) + if func_def is None: + raise ValueError(f"Function {func_name} not found in source code") + parameters = [] + + for arg in func_def.args.args: + name = arg.arg + annotation = arg.annotation + if annotation is not None: + try: + # Compile and evaluate the annotation + ann_code = compile(ast.Expression(annotation), "", "eval") + annotation = eval(ann_code, namespace) + except Exception as e: + # Don't raise an error if the annotation evaluation fails. Set the + # annotation to `inspect.Parameter.empty` instead which is equivalent to + # not having an annotation. + logger.debug("Failed to evaluate annotation for %s: %s", name, e) + annotation = inspect.Parameter.empty + else: + annotation = inspect.Parameter.empty + + param = inspect.Parameter( + name, inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=annotation + ) + parameters.append(param) + + defaults = [None] * ( + len(func_def.args.args) - len(func_def.args.defaults) + ) + func_def.args.defaults + for param, default in zip(parameters, defaults): + if default is not None: + try: + def_code = compile(ast.Expression(default), "", "eval") + default = eval(def_code, namespace) + except Exception as e: + logger.debug( + "Failed to evaluate default value for %s: %s", param.name, e + ) + default = None # Set to None if evaluation fails + parameters[parameters.index(param)] = param.replace(default=default) + + if func_def.args.vararg: + parameters.append( + inspect.Parameter( + func_def.args.vararg.arg, inspect.Parameter.VAR_POSITIONAL + ) + ) + if func_def.args.kwarg: + parameters.append( + inspect.Parameter(func_def.args.kwarg.arg, inspect.Parameter.VAR_KEYWORD) + ) + + # Handle return annotation + return_annotation = func_def.returns + if return_annotation is not None: + try: + ret_ann_code = compile( + ast.Expression(return_annotation), "", "eval" + ) + return_annotation = eval(ret_ann_code, namespace) + except Exception as e: + logger.debug("Failed to evaluate return annotation: %s", e) + return_annotation = inspect.Signature.empty + + return inspect.Signature(parameters, return_annotation=return_annotation) + + +def _get_docstring_from_source(source_code: str, func_name: str) -> Optional[str]: + """ + Extract the docstring of a function from its source code. + + Args: + source_code (str): The source code of the function. + func_name (str): The name of the function. + + Returns: + The docstring of the function. If the function has no docstring, returns None. + """ + parsed_code = ast.parse(source_code) + + func_def = next( + ( + node + for node in ast.walk(parsed_code) + if isinstance(node, ast.FunctionDef) and node.name == func_name + ), + None, + ) + if func_def is None: + raise ValueError(f"Function {func_name} not found in source code") + + if ( + func_def.body + and isinstance(func_def.body[0], ast.Expr) + and isinstance(func_def.body[0].value, ast.Constant) + ): + return func_def.body[0].value.value + return None diff --git a/src/prefect/utilities/importtools.py b/src/prefect/utilities/importtools.py index 0286a3c67fc1..262b58b3fcb3 100644 --- a/src/prefect/utilities/importtools.py +++ b/src/prefect/utilities/importtools.py @@ -1,3 +1,4 @@ +import ast import importlib import importlib.util import inspect @@ -14,8 +15,11 @@ import fsspec from prefect.exceptions import ScriptError +from prefect.logging.loggers import get_logger from prefect.utilities.filesystem import filename, is_local_path, tmpchdir +logger = get_logger(__name__) + def to_qualified_name(obj: Any) -> str: """ @@ -356,3 +360,70 @@ def exec_module(self, _: ModuleType) -> None: if self.callback is not None: self.callback(self.alias) sys.modules[self.alias] = root_module + + +def safe_load_namespace(source_code: str): + """ + Safely load a namespace from source code. + + This function will attempt to import all modules and classes defined in the source + code. If an import fails, the error is caught and the import is skipped. This function + will also attempt to compile and evaluate class and function definitions locally. + + Args: + source_code: The source code to load + + Returns: + The namespace loaded from the source code. Can be used when evaluating source + code. + """ + parsed_code = ast.parse(source_code) + + namespace = {} + + # Walk through the AST and find all import statements + for node in ast.walk(parsed_code): + if isinstance(node, ast.Import): + for alias in node.names: + module_name = alias.name + as_name = alias.asname if alias.asname else module_name + try: + # Attempt to import the module + namespace[as_name] = importlib.import_module(module_name) + logger.debug("Successfully imported %s", module_name) + except ImportError as e: + logger.debug(f"Failed to import {module_name}: {e}") + elif isinstance(node, ast.ImportFrom): + module_name = node.module + if module_name is None: + continue + try: + module = importlib.import_module(module_name) + for alias in node.names: + name = alias.name + asname = alias.asname if alias.asname else name + try: + # Get the specific attribute from the module + attribute = getattr(module, name) + namespace[asname] = attribute + except AttributeError as e: + logger.debug( + "Failed to retrieve %s from %s: %s", name, module_name, e + ) + except ImportError as e: + logger.debug("Failed to import from %s: %s", node.module, e) + + # Handle local class definitions + for node in ast.walk(parsed_code): + if isinstance(node, (ast.ClassDef, ast.FunctionDef)): + try: + # Compile and execute each class and function definition locally + code = compile( + ast.Module(body=[node], type_ignores=[]), + filename="", + mode="exec", + ) + exec(code, namespace) + except Exception as e: + logger.debug("Failed to compile class definition: %s", e) + return namespace diff --git a/tests/test_flows.py b/tests/test_flows.py index c701ba76ec0a..916d7ba6a25c 100644 --- a/tests/test_flows.py +++ b/tests/test_flows.py @@ -47,7 +47,11 @@ ReservedArgumentError, ) from prefect.filesystems import LocalFileSystem -from prefect.flows import Flow, load_flow_from_entrypoint +from prefect.flows import ( + Flow, + load_flow_argument_from_entrypoint, + load_flow_from_entrypoint, +) from prefect.runtime import flow_run as flow_run_ctx from prefect.server.schemas.core import TaskRunResult from prefect.server.schemas.filters import FlowFilter, FlowRunFilter @@ -3882,3 +3886,100 @@ async def test_suppress_console_output( ) assert not capsys.readouterr().out + + +class TestLoadFlowArgumentFromEntrypoint: + def test_load_flow_name_from_entrypoint(self, tmp_path: Path): + flow_source = dedent( + """ + + from prefect import flow + + @flow(name="My custom name") + def flow_function(name: str) -> str: + return name + """ + ) + + tmp_path.joinpath("flow.py").write_text(flow_source) + + entrypoint = f"{tmp_path.joinpath('flow.py')}:flow_function" + + result = load_flow_argument_from_entrypoint(entrypoint, "name") + + assert result == "My custom name" + + def test_load_flow_name_from_entrypoint_no_name(self, tmp_path: Path): + flow_source = dedent( + """ + + from prefect import flow + + @flow + def flow_function(name: str) -> str: + return name + """ + ) + + tmp_path.joinpath("flow.py").write_text(flow_source) + + entrypoint = f"{tmp_path.joinpath('flow.py')}:flow_function" + + result = load_flow_argument_from_entrypoint(entrypoint, "name") + + assert result == "flow-function" + + def test_load_flow_description_from_entrypoint(self, tmp_path: Path): + flow_source = dedent( + """ + + from prefect import flow + + @flow(description="My custom description") + def flow_function(name: str) -> str: + return name + """ + ) + + tmp_path.joinpath("flow.py").write_text(flow_source) + + entrypoint = f"{tmp_path.joinpath('flow.py')}:flow_function" + + result = load_flow_argument_from_entrypoint(entrypoint, "description") + + assert result == "My custom description" + + def test_load_flow_description_from_entrypoint_no_description(self, tmp_path: Path): + flow_source = dedent( + """ + + from prefect import flow + + @flow + def flow_function(name: str) -> str: + return name + """ + ) + + tmp_path.joinpath("flow.py").write_text(flow_source) + + entrypoint = f"{tmp_path.joinpath('flow.py')}:flow_function" + + result = load_flow_argument_from_entrypoint(entrypoint, "description") + + assert result is None + + def test_load_no_flow(self, tmp_path: Path): + flow_source = dedent( + """ + + from prefect import flow + """ + ) + + tmp_path.joinpath("flow.py").write_text(flow_source) + + entrypoint = f"{tmp_path.joinpath('flow.py')}:flow_function" + + with pytest.raises(ValueError, match="Could not find flow"): + load_flow_argument_from_entrypoint(entrypoint, "name") diff --git a/tests/utilities/test_callables.py b/tests/utilities/test_callables.py index 11e906135097..185ce859a49d 100644 --- a/tests/utilities/test_callables.py +++ b/tests/utilities/test_callables.py @@ -1,5 +1,7 @@ import datetime from enum import Enum +from pathlib import Path +from textwrap import dedent from typing import Any, Dict, List, Tuple, Union import pendulum @@ -919,3 +921,671 @@ def foo(a, b): with pytest.raises(ValueError): callables.collapse_variadic_parameters(foo, parameters) + + +class TestEntrypointToSchema: + def test_function_not_found(self, tmp_path: Path): + source_code = dedent( + """ + def f(): + pass + """ + ) + tmp_path.joinpath("test.py").write_text(source_code) + + with pytest.raises(ValueError): + callables.parameter_schema_from_entrypoint(f"{tmp_path}/test.py:g") + + def test_simple_function_with_no_arguments(self, tmp_path: Path): + source_code = dedent( + """ + def f(): + pass + """ + ) + tmp_path.joinpath("test.py").write_text(source_code) + + schema = callables.parameter_schema_from_entrypoint(f"{tmp_path}/test.py:f") + assert schema.model_dump_for_openapi() == { + "properties": {}, + "title": "Parameters", + "type": "object", + "required": [], + "definitions": {}, + } + + def test_function_with_pydantic_base_model_collisions(self, tmp_path: Path): + source_code = dedent( + """ + def f( + json, + copy, + parse_obj, + parse_raw, + parse_file, + from_orm, + schema, + schema_json, + construct, + validate, + foo, + ): + pass + """ + ) + tmp_path.joinpath("test.py").write_text(source_code) + schema = callables.parameter_schema_from_entrypoint(f"{tmp_path}/test.py:f") + assert schema.model_dump_for_openapi() == { + "title": "Parameters", + "type": "object", + "properties": { + "foo": {"title": "foo", "position": 10}, + "json": {"title": "json", "position": 0}, + "copy": {"title": "copy", "position": 1}, + "parse_obj": {"title": "parse_obj", "position": 2}, + "parse_raw": {"title": "parse_raw", "position": 3}, + "parse_file": {"title": "parse_file", "position": 4}, + "from_orm": {"title": "from_orm", "position": 5}, + "schema": {"title": "schema", "position": 6}, + "schema_json": {"title": "schema_json", "position": 7}, + "construct": {"title": "construct", "position": 8}, + "validate": {"title": "validate", "position": 9}, + }, + "required": [ + "json", + "copy", + "parse_obj", + "parse_raw", + "parse_file", + "from_orm", + "schema", + "schema_json", + "construct", + "validate", + "foo", + ], + "definitions": {}, + } + + def test_function_with_one_required_argument(self, tmp_path: Path): + source_code = dedent( + """ + def f(x): + pass + """ + ) + tmp_path.joinpath("test.py").write_text(source_code) + schema = callables.parameter_schema_from_entrypoint(f"{tmp_path}/test.py:f") + assert schema.model_dump_for_openapi() == { + "title": "Parameters", + "type": "object", + "properties": {"x": {"title": "x", "position": 0}}, + "required": ["x"], + "definitions": {}, + } + + def test_function_with_one_optional_argument(self, tmp_path: Path): + source_code = dedent( + """ + def f(x=42): + pass + """ + ) + tmp_path.joinpath("test.py").write_text(source_code) + schema = callables.parameter_schema_from_entrypoint(f"{tmp_path}/test.py:f") + assert schema.model_dump_for_openapi() == { + "title": "Parameters", + "type": "object", + "properties": {"x": {"title": "x", "default": 42, "position": 0}}, + "required": [], + "definitions": {}, + } + + def test_function_with_one_optional_annotated_argument(self, tmp_path: Path): + source_code = dedent( + """ + def f(x: int = 42): + pass + """ + ) + tmp_path.joinpath("test.py").write_text(source_code) + schema = callables.parameter_schema_from_entrypoint(f"{tmp_path}/test.py:f") + assert schema.model_dump_for_openapi() == { + "title": "Parameters", + "type": "object", + "properties": { + "x": {"title": "x", "default": 42, "type": "integer", "position": 0} + }, + "definitions": {}, + "required": [], + } + + def test_function_with_two_arguments(self, tmp_path: Path): + source_code = dedent( + """ + def f(x: int, y: float = 5.0): + pass + """ + ) + tmp_path.joinpath("test.py").write_text(source_code) + schema = callables.parameter_schema_from_entrypoint(f"{tmp_path}/test.py:f") + assert schema.model_dump_for_openapi() == { + "title": "Parameters", + "type": "object", + "properties": { + "x": {"title": "x", "type": "integer", "position": 0}, + "y": {"title": "y", "default": 5.0, "type": "number", "position": 1}, + }, + "required": ["x"], + "definitions": {}, + } + + def test_function_with_datetime_arguments(self, tmp_path: Path): + source_code = dedent( + """ + import pendulum + import datetime + + def f( + x: datetime.datetime, + y: pendulum.DateTime = pendulum.datetime(2025, 1, 1), + z: datetime.timedelta = datetime.timedelta(seconds=5), + ): + pass + """ + ) + tmp_path.joinpath("test.py").write_text(source_code) + schema = callables.parameter_schema_from_entrypoint(f"{tmp_path}/test.py:f") + expected_schema = { + "title": "Parameters", + "type": "object", + "properties": { + "x": { + "format": "date-time", + "position": 0, + "title": "x", + "type": "string", + }, + "y": { + "default": "2025-01-01T00:00:00Z", + "format": "date-time", + "position": 1, + "title": "y", + "type": "string", + }, + "z": { + "default": "PT5S", + "format": "duration", + "position": 2, + "title": "z", + "type": "string", + }, + }, + "required": ["x"], + "definitions": {}, + } + assert schema.model_dump_for_openapi() == expected_schema + + def test_function_with_enum_argument(self, tmp_path: Path): + class Color(Enum): + RED = "RED" + GREEN = "GREEN" + BLUE = "BLUE" + + source_code = dedent( + """ + from enum import Enum + + class Color(Enum): + RED = "RED" + GREEN = "GREEN" + BLUE = "BLUE" + + def f(x: Color = Color.RED): + pass + """ + ) + tmp_path.joinpath("test.py").write_text(source_code) + schema = callables.parameter_schema_from_entrypoint(f"{tmp_path}/test.py:f") + + expected_schema = { + "title": "Parameters", + "type": "object", + "properties": { + "x": { + "allOf": [{"$ref": "#/definitions/Color"}], + "default": "RED", + "position": 0, + "title": "x", + } + }, + "definitions": { + "Color": { + "enum": ["RED", "GREEN", "BLUE"], + "title": "Color", + "type": "string", + } + }, + "required": [], + } + assert schema.model_dump_for_openapi() == expected_schema + + def test_function_with_generic_arguments(self, tmp_path: Path): + source_code = dedent( + """ + from typing import List, Dict, Any, Tuple, Union + + def f( + a: List[str], + b: Dict[str, Any], + c: Any, + d: Tuple[int, float], + e: Union[str, bytes, int], + ): + pass + """ + ) + tmp_path.joinpath("test.py").write_text(source_code) + schema = callables.parameter_schema_from_entrypoint(f"{tmp_path}/test.py:f") + + expected_schema = { + "title": "Parameters", + "type": "object", + "properties": { + "a": { + "items": {"type": "string"}, + "position": 0, + "title": "a", + "type": "array", + }, + "b": {"position": 1, "title": "b", "type": "object"}, + "c": {"position": 2, "title": "c"}, + "d": { + "maxItems": 2, + "minItems": 2, + "position": 3, + "prefixItems": [{"type": "integer"}, {"type": "number"}], + "title": "d", + "type": "array", + }, + "e": { + "anyOf": [ + {"type": "string"}, + {"format": "binary", "type": "string"}, + {"type": "integer"}, + ], + "position": 4, + "title": "e", + }, + }, + "required": ["a", "b", "c", "d", "e"], + "definitions": {}, + } + + assert schema.model_dump_for_openapi() == expected_schema + + def test_function_with_user_defined_type(self, tmp_path: Path): + source_code = dedent( + """ + class Foo: + y: int + + def f(x: Foo): + pass + """ + ) + + tmp_path.joinpath("test.py").write_text(source_code) + schema = callables.parameter_schema_from_entrypoint(f"{tmp_path}/test.py:f") + assert schema.model_dump_for_openapi() == { + "title": "Parameters", + "type": "object", + "properties": {"x": {"title": "x", "position": 0}}, + "required": ["x"], + "definitions": {}, + } + + def test_function_with_user_defined_pydantic_model(self, tmp_path: Path): + source_code = dedent( + """ + import pydantic + + class Foo(pydantic.BaseModel): + y: int + z: str + + def f(x: Foo): + pass + """ + ) + + tmp_path.joinpath("test.py").write_text(source_code) + schema = callables.parameter_schema_from_entrypoint(f"{tmp_path}/test.py:f") + assert schema.model_dump_for_openapi() == { + "definitions": { + "Foo": { + "properties": { + "y": {"title": "Y", "type": "integer"}, + "z": {"title": "Z", "type": "string"}, + }, + "required": ["y", "z"], + "title": "Foo", + "type": "object", + } + }, + "properties": { + "x": { + "allOf": [{"$ref": "#/definitions/Foo"}], + "title": "x", + "position": 0, + } + }, + "required": ["x"], + "title": "Parameters", + "type": "object", + } + + def test_function_with_pydantic_model_default_across_v1_and_v2( + self, tmp_path: Path + ): + source_code = dedent( + """ + import pydantic + + class Foo(pydantic.BaseModel): + bar: str + + def f(foo: Foo = Foo(bar="baz")): + pass + """ + ) + + tmp_path.joinpath("test.py").write_text(source_code) + schema = callables.parameter_schema_from_entrypoint(f"{tmp_path}/test.py:f") + assert schema.model_dump_for_openapi() == { + "title": "Parameters", + "type": "object", + "properties": { + "foo": { + "allOf": [{"$ref": "#/definitions/Foo"}], + "default": {"bar": "baz"}, + "position": 0, + "title": "foo", + } + }, + "definitions": { + "Foo": { + "properties": {"bar": {"title": "Bar", "type": "string"}}, + "required": ["bar"], + "title": "Foo", + "type": "object", + } + }, + "required": [], + } + + def test_function_with_complex_args_across_v1_and_v2(self, tmp_path: Path): + source_code = dedent( + """ + import pydantic + import pendulum + import datetime + from enum import Enum + from typing import List + + class Foo(pydantic.BaseModel): + bar: str + + class Color(Enum): + RED = "RED" + GREEN = "GREEN" + BLUE = "BLUE" + + def f( + a: int, + s: List[None], + m: Foo, + i: int = 0, + x: float = 1.0, + model: Foo = Foo(bar="bar"), + pdt: pendulum.DateTime = pendulum.datetime(2025, 1, 1), + pdate: pendulum.Date = pendulum.date(2025, 1, 1), + pduration: pendulum.Duration = pendulum.duration(seconds=5), + c: Color = Color.BLUE, + ): + pass + """ + ) + + datetime_schema = { + "title": "pdt", + "default": "2025-01-01T00:00:00+00:00", + "position": 6, + "type": "string", + "format": "date-time", + } + duration_schema = { + "title": "pduration", + "default": 5.0, + "position": 8, + "type": "number", + "format": "time-delta", + } + enum_schema = { + "enum": ["RED", "GREEN", "BLUE"], + "title": "Color", + "type": "string", + "description": "An enumeration.", + } + + # these overrides represent changes in how pydantic generates schemas in v2 + datetime_schema["default"] = "2025-01-01T00:00:00Z" + duration_schema["default"] = "PT5S" + duration_schema["type"] = "string" + duration_schema["format"] = "duration" + enum_schema.pop("description") + + schema = tmp_path.joinpath("test.py").write_text(source_code) + schema = callables.parameter_schema_from_entrypoint(f"{tmp_path}/test.py:f") + + assert schema.model_dump_for_openapi() == { + "title": "Parameters", + "type": "object", + "properties": { + "a": {"position": 0, "title": "a", "type": "integer"}, + "s": { + "items": {"type": "null"}, + "position": 1, + "title": "s", + "type": "array", + }, + "m": { + "allOf": [{"$ref": "#/definitions/Foo"}], + "position": 2, + "title": "m", + }, + "i": {"default": 0, "position": 3, "title": "i", "type": "integer"}, + "x": {"default": 1.0, "position": 4, "title": "x", "type": "number"}, + "model": { + "allOf": [{"$ref": "#/definitions/Foo"}], + "default": {"bar": "bar"}, + "position": 5, + "title": "model", + }, + "pdt": datetime_schema, + "pdate": { + "title": "pdate", + "default": "2025-01-01", + "position": 7, + "type": "string", + "format": "date", + }, + "pduration": duration_schema, + "c": { + "title": "c", + "default": "BLUE", + "position": 9, + "allOf": [{"$ref": "#/definitions/Color"}], + }, + }, + "required": ["a", "s", "m"], + "definitions": { + "Foo": { + "properties": {"bar": {"title": "Bar", "type": "string"}}, + "required": ["bar"], + "title": "Foo", + "type": "object", + }, + "Color": enum_schema, + }, + } + + def test_function_with_secretstr(self, tmp_path: Path): + source_code = dedent( + """ + from pydantic import SecretStr + + def f(x: SecretStr): + pass + """ + ) + tmp_path.joinpath("test.py").write_text(source_code) + schema = callables.parameter_schema_from_entrypoint(f"{tmp_path}/test.py:f") + assert schema.model_dump_for_openapi() == { + "title": "Parameters", + "type": "object", + "properties": { + "x": { + "title": "x", + "position": 0, + "format": "password", + "type": "string", + "writeOnly": True, + }, + }, + "required": ["x"], + "definitions": {}, + } + + def test_function_with_v1_secretstr_from_compat_module(self, tmp_path: Path): + source_code = dedent( + """ + import pydantic.v1 as pydantic + + def f(x: pydantic.SecretStr): + pass + """ + ) + tmp_path.joinpath("test.py").write_text(source_code) + schema = callables.parameter_schema_from_entrypoint(f"{tmp_path}/test.py:f") + assert schema.model_dump_for_openapi() == { + "title": "Parameters", + "type": "object", + "properties": { + "x": { + "title": "x", + "position": 0, + }, + }, + "required": ["x"], + "definitions": {}, + } + + def test_flow_with_args_docstring(self, tmp_path: Path): + source_code = dedent( + ''' + def f(x): + """Function f. + + Args: + x: required argument x + """ + ''' + ) + tmp_path.joinpath("test.py").write_text(source_code) + schema = callables.parameter_schema_from_entrypoint(f"{tmp_path}/test.py:f") + assert schema.model_dump_for_openapi() == { + "title": "Parameters", + "type": "object", + "properties": { + "x": {"title": "x", "description": "required argument x", "position": 0} + }, + "required": ["x"], + "definitions": {}, + } + + def test_flow_without_args_docstring(self, tmp_path: Path): + source_code = dedent( + ''' + def f(x): + """Function f.""" + ''' + ) + tmp_path.joinpath("test.py").write_text(source_code) + schema = callables.parameter_schema_from_entrypoint(f"{tmp_path}/test.py:f") + assert schema.model_dump_for_openapi() == { + "title": "Parameters", + "type": "object", + "properties": {"x": {"title": "x", "position": 0}}, + "required": ["x"], + "definitions": {}, + } + + def test_flow_with_complex_args_docstring(self, tmp_path: Path): + source_code = dedent( + ''' + def f(x, y): + """Function f. + + Second line of docstring. + + Args: + x: required argument x + y (str): required typed argument y + with second line + + Returns: + None: nothing + """ + ''' + ) + tmp_path.joinpath("test.py").write_text(source_code) + schema = callables.parameter_schema_from_entrypoint(f"{tmp_path}/test.py:f") + assert schema.model_dump_for_openapi() == { + "title": "Parameters", + "type": "object", + "properties": { + "x": { + "title": "x", + "description": "required argument x", + "position": 0, + }, + "y": { + "title": "y", + "description": "required typed argument y\nwith second line", + "position": 1, + }, + }, + "required": ["x", "y"], + "definitions": {}, + } + + def test_does_not_raise_when_missing_dependencies(self, tmp_path: Path): + source_code = dedent( + """ + import bipitty_boopity + + def f(x): + pass + """ + ) + tmp_path.joinpath("test.py").write_text(source_code) + schema = callables.parameter_schema_from_entrypoint(f"{tmp_path}/test.py:f") + + assert schema.model_dump_for_openapi() == { + "title": "Parameters", + "type": "object", + "properties": {"x": {"title": "x", "position": 0}}, + "required": ["x"], + "definitions": {}, + } diff --git a/tests/utilities/test_importtools.py b/tests/utilities/test_importtools.py index 7284fd1fafce..c4756743769b 100644 --- a/tests/utilities/test_importtools.py +++ b/tests/utilities/test_importtools.py @@ -2,6 +2,7 @@ import runpy import sys from pathlib import Path +from textwrap import dedent from types import ModuleType from unittest.mock import MagicMock from uuid import uuid4 @@ -17,6 +18,7 @@ from_qualified_name, import_object, lazy_import, + safe_load_namespace, to_qualified_name, ) @@ -238,3 +240,79 @@ def test_import_object_from_module_with_relative_imports_expected_failures( # Python would raise the same error with pytest.raises((ValueError, ImportError)): runpy.run_module(import_path) + + +def test_safe_load_namespace(): + source_code = dedent( + """ + import math + from datetime import datetime + from pydantic import BaseModel + + class MyModel(BaseModel): + x: int + + def my_fn(): + return 42 + + x = 10 + y = math.sqrt(x) + now = datetime.now() + """ + ) + + namespace = safe_load_namespace(source_code) + + # module-level imports should be present + assert "math" in namespace + assert "datetime" in namespace + assert "BaseModel" in namespace + # module-level variables should not be present + assert "x" not in namespace + assert "y" not in namespace + assert "now" not in namespace + # module-level classes should be present + assert "MyModel" in namespace + # module-level functions should be present + assert "my_fn" in namespace + + assert namespace["MyModel"].__name__ == "MyModel" + + +def test_safe_load_namespace_ignores_import_errors(): + source_code = dedent( + """ + import flibbidy + + from pydantic import BaseModel + + class MyModel(BaseModel): + x: int + """ + ) + + # should not raise an ImportError + namespace = safe_load_namespace(source_code) + + assert "flibbidy" not in namespace + # other imports and classes should be present + assert "BaseModel" in namespace + assert "MyModel" in namespace + assert namespace["MyModel"].__name__ == "MyModel" + + +def test_safe_load_namespace_ignore_class_declaration_errors(): + source_code = dedent( + """ + from fake_pandas import DataFrame + + class CoolDataFrame(DataFrame): + pass + """ + ) + + # should not raise any errors + namespace = safe_load_namespace(source_code) + + assert "DataFrame" not in namespace + assert "CoolDataFrame" not in namespace