diff --git a/.github/workflows/benchmarks.yaml b/.github/workflows/benchmarks.yaml index 776ff6de57be..1ed7e38865b1 100644 --- a/.github/workflows/benchmarks.yaml +++ b/.github/workflows/benchmarks.yaml @@ -105,8 +105,8 @@ jobs: # characters with an underscore sanitized_uniquename="${uniquename//[^a-zA-Z0-9_\-]/_}" - PREFECT_API_URL="http://127.0.0.1:4200/api" - python benches \ + PREFECT_API_URL="http://127.0.0.1:4200/api" \ + python -m benches \ --ignore=benches/bench_import.py \ --timeout=180 \ --benchmark-save="${sanitized_uniquename}" \ diff --git a/benches/bench_flows.py b/benches/bench_flows.py index 6882489f15f8..9494fc44a9c9 100644 --- a/benches/bench_flows.py +++ b/benches/bench_flows.py @@ -55,7 +55,10 @@ def benchmark_flow(): for _ in range(num_tasks): test_task() - benchmark(benchmark_flow) + if num_tasks > 100: + benchmark.pedantic(benchmark_flow) + else: + benchmark(benchmark_flow) @pytest.mark.parametrize("num_tasks", [10, 50, 100, 250]) @@ -68,7 +71,10 @@ async def benchmark_flow(): for _ in range(num_tasks): tg.start_soon(test_task) - benchmark(anyio.run, benchmark_flow) + if num_tasks > 100: + benchmark.pedantic(anyio.run, (benchmark_flow,)) + else: + benchmark(anyio.run, benchmark_flow) @pytest.mark.parametrize("num_flows", [5, 10, 20]) diff --git a/docs/3.0rc/api-ref/rest-api/server/schema.json b/docs/3.0rc/api-ref/rest-api/server/schema.json index a71a100a1879..76b7c0b85eb2 100644 --- a/docs/3.0rc/api-ref/rest-api/server/schema.json +++ b/docs/3.0rc/api-ref/rest-api/server/schema.json @@ -22271,6 +22271,16 @@ "title": "Prefect Server Csrf Token Expiration", "default": "PT1H" }, + "PREFECT_SERVER_ALLOW_EPHEMERAL_MODE": { + "type": "boolean", + "title": "Prefect Server Allow Ephemeral Mode", + "default": false + }, + "PREFECT_SERVER_EPHEMERAL_STARTUP_TIMEOUT_SECONDS": { + "type": "integer", + "title": "Prefect Server Ephemeral Startup Timeout Seconds", + "default": 10 + }, "PREFECT_UI_ENABLED": { "type": "boolean", "title": "Prefect Ui Enabled", diff --git a/src/prefect/blocks/core.py b/src/prefect/blocks/core.py index 59097417e46a..4a4216e3950b 100644 --- a/src/prefect/blocks/core.py +++ b/src/prefect/blocks/core.py @@ -902,7 +902,9 @@ class Custom(Block): loaded_block.save("my-custom-message", overwrite=True) ``` """ - block_document, block_document_name = await cls._get_block_document(name) + block_document, block_document_name = await cls._get_block_document( + name, client=client + ) return cls._load_from_block_document(block_document, validate=validate) diff --git a/src/prefect/cli/profile.py b/src/prefect/cli/profile.py index d31534731d0f..fa01cd79d25a 100644 --- a/src/prefect/cli/profile.py +++ b/src/prefect/cli/profile.py @@ -18,6 +18,7 @@ from prefect.cli._utilities import exit_with_error, exit_with_success from prefect.cli.cloud import CloudUnauthorizedError, get_cloud_client from prefect.cli.root import app, is_interactive +from prefect.client.base import determine_server_type from prefect.client.orchestration import ServerType, get_client from prefect.context import use_profile from prefect.exceptions import ObjectNotFound @@ -138,6 +139,13 @@ async def use(name: str): " in ephemeral mode." ), ), + ConnectionStatus.UNCONFIGURED: ( + exit_with_error, + ( + f"Prefect server URL not configured using profile {name!r} - please" + " configure the server URL or enable ephemeral mode." + ), + ), ConnectionStatus.INVALID_API: ( exit_with_error, "Error connecting to Prefect API URL", @@ -350,6 +358,7 @@ class ConnectionStatus(AutoEnum): CLOUD_UNAUTHORIZED = AutoEnum.auto() SERVER_CONNECTED = AutoEnum.auto() SERVER_ERROR = AutoEnum.auto() + UNCONFIGURED = AutoEnum.auto() EPHEMERAL = AutoEnum.auto() INVALID_API = AutoEnum.auto() @@ -373,14 +382,16 @@ async def check_server_connection(): try: # inform the user if Prefect API endpoints exist, but there are # connection issues + server_type = determine_server_type() + if server_type == ServerType.EPHEMERAL: + return ConnectionStatus.EPHEMERAL + elif server_type == ServerType.UNCONFIGURED: + return ConnectionStatus.UNCONFIGURED client = get_client(httpx_settings=httpx_settings) async with client: connect_error = await client.api_healthcheck() if connect_error is not None: return ConnectionStatus.SERVER_ERROR - elif client.server_type == ServerType.EPHEMERAL: - # if the client is using an ephemeral Prefect app, inform the user - return ConnectionStatus.EPHEMERAL else: return ConnectionStatus.SERVER_CONNECTED except Exception: @@ -390,6 +401,13 @@ async def check_server_connection(): except TypeError: # if no Prefect API URL has been set, httpx will throw a TypeError try: + # try to connect with the client anyway, it will likely use an + # ephemeral Prefect instance + server_type = determine_server_type() + if server_type == ServerType.EPHEMERAL: + return ConnectionStatus.EPHEMERAL + elif server_type == ServerType.UNCONFIGURED: + return ConnectionStatus.UNCONFIGURED client = get_client(httpx_settings=httpx_settings) if client.server_type == ServerType.EPHEMERAL: return ConnectionStatus.EPHEMERAL diff --git a/src/prefect/cli/root.py b/src/prefect/cli/root.py index bb98e6c6e6ac..a6aa37ea7afd 100644 --- a/src/prefect/cli/root.py +++ b/src/prefect/cli/root.py @@ -17,6 +17,7 @@ import prefect.settings from prefect.cli._types import PrefectTyper, SettingsOption from prefect.cli._utilities import with_cli_exception_handling +from prefect.client.base import determine_server_type from prefect.client.constants import SERVER_API_VERSION from prefect.client.orchestration import ServerType from prefect.logging.configuration import setup_logging @@ -117,16 +118,7 @@ async def version( "OS/Arch": f"{sys.platform}/{platform.machine()}", "Profile": prefect.context.get_settings_context().profile.name, } - - server_type: str - - try: - # We do not context manage the client because when using an ephemeral app we do not - # want to create the database or run migrations - client = prefect.get_client() - server_type = client.server_type.value - except Exception: - server_type = "" + server_type = determine_server_type() version_info["Server type"] = server_type.lower() diff --git a/src/prefect/client/base.py b/src/prefect/client/base.py index ed9929ee42c2..b2e9c55eef7a 100644 --- a/src/prefect/client/base.py +++ b/src/prefect/client/base.py @@ -35,10 +35,14 @@ from prefect.exceptions import PrefectHTTPStatusError from prefect.logging import get_logger from prefect.settings import ( + PREFECT_API_URL, PREFECT_CLIENT_MAX_RETRIES, PREFECT_CLIENT_RETRY_EXTRA_CODES, PREFECT_CLIENT_RETRY_JITTER_FACTOR, + PREFECT_CLOUD_API_URL, + PREFECT_SERVER_ALLOW_EPHEMERAL_MODE, ) +from prefect.utilities.collections import AutoEnum from prefect.utilities.math import bounded_poisson_interval, clamped_poisson_interval # Datastores for lifespan management, keys should be a tuple of thread and app @@ -637,3 +641,33 @@ def __init__( ) pass + + +class ServerType(AutoEnum): + EPHEMERAL = AutoEnum.auto() + SERVER = AutoEnum.auto() + CLOUD = AutoEnum.auto() + UNCONFIGURED = AutoEnum.auto() + + +def determine_server_type() -> ServerType: + """ + Determine the server type based on the current settings. + + Returns: + - `ServerType.EPHEMERAL` if the ephemeral server is enabled + - `ServerType.SERVER` if a API URL is configured and it is not a cloud URL + - `ServerType.CLOUD` if an API URL is configured and it is a cloud URL + - `ServerType.UNCONFIGURED` if no API URL is configured and ephemeral mode is + not enabled + """ + api_url = PREFECT_API_URL.value() + if api_url is None: + if PREFECT_SERVER_ALLOW_EPHEMERAL_MODE.value(): + return ServerType.EPHEMERAL + else: + return ServerType.UNCONFIGURED + if api_url.startswith(PREFECT_CLOUD_API_URL.value()): + return ServerType.CLOUD + else: + return ServerType.SERVER diff --git a/src/prefect/client/orchestration.py b/src/prefect/client/orchestration.py index ec80086fcb5e..53a0a98f562a 100644 --- a/src/prefect/client/orchestration.py +++ b/src/prefect/client/orchestration.py @@ -132,9 +132,9 @@ PREFECT_API_URL, PREFECT_CLIENT_CSRF_SUPPORT_ENABLED, PREFECT_CLOUD_API_URL, + PREFECT_SERVER_ALLOW_EPHEMERAL_MODE, PREFECT_UNIT_TEST_MODE, ) -from prefect.utilities.collections import AutoEnum if TYPE_CHECKING: from prefect.flows import Flow as FlowObject @@ -145,6 +145,7 @@ PrefectHttpxAsyncClient, PrefectHttpxSyncClient, PrefectHttpxSyncEphemeralClient, + ServerType, app_lifespan_context, ) @@ -152,12 +153,6 @@ R = TypeVar("R") -class ServerType(AutoEnum): - EPHEMERAL = AutoEnum.auto() - SERVER = AutoEnum.auto() - CLOUD = AutoEnum.auto() - - @overload def get_client( httpx_settings: Optional[Dict[str, Any]] = None, sync_client: Literal[False] = False @@ -194,8 +189,6 @@ def get_client( """ import prefect.context - settings_ctx = prefect.context.get_settings_context() - # try to load clients from a client context, if possible # only load clients that match the provided config / loop try: @@ -217,24 +210,36 @@ def get_client( return client_ctx.client api = PREFECT_API_URL.value() + server_type = None - if not api: + if not api and PREFECT_SERVER_ALLOW_EPHEMERAL_MODE: # create an ephemeral API if none was provided - from prefect.server.api.server import create_app + from prefect.server.api.server import SubprocessASGIServer - api = create_app(settings_ctx.settings, ephemeral=True) + server = SubprocessASGIServer() + server.start() + assert server.server_process is not None, "Server process did not start" + + api = f"{server.address()}/api" + server_type = ServerType.EPHEMERAL + elif not api and not PREFECT_SERVER_ALLOW_EPHEMERAL_MODE: + raise ValueError( + "No Prefect API URL provided. Please set PREFECT_API_URL to the address of a running Prefect server." + ) if sync_client: return SyncPrefectClient( api, api_key=PREFECT_API_KEY.value(), httpx_settings=httpx_settings, + server_type=server_type, ) else: return PrefectClient( api, api_key=PREFECT_API_KEY.value(), httpx_settings=httpx_settings, + server_type=server_type, ) @@ -271,6 +276,7 @@ def __init__( api_key: Optional[str] = None, api_version: Optional[str] = None, httpx_settings: Optional[Dict[str, Any]] = None, + server_type: Optional[ServerType] = None, ) -> None: httpx_settings = httpx_settings.copy() if httpx_settings else {} httpx_settings.setdefault("headers", {}) @@ -333,11 +339,14 @@ def __init__( # client will use a standard HTTP/1.1 connection instead. httpx_settings.setdefault("http2", PREFECT_API_ENABLE_HTTP2.value()) - self.server_type = ( - ServerType.CLOUD - if api.startswith(PREFECT_CLOUD_API_URL.value()) - else ServerType.SERVER - ) + if server_type: + self.server_type = server_type + else: + self.server_type = ( + ServerType.CLOUD + if api.startswith(PREFECT_CLOUD_API_URL.value()) + else ServerType.SERVER + ) # Connect to an in-process application elif isinstance(api, ASGIApp): @@ -3386,6 +3395,7 @@ def __init__( api_key: Optional[str] = None, api_version: Optional[str] = None, httpx_settings: Optional[Dict[str, Any]] = None, + server_type: Optional[ServerType] = None, ) -> None: httpx_settings = httpx_settings.copy() if httpx_settings else {} httpx_settings.setdefault("headers", {}) @@ -3444,11 +3454,14 @@ def __init__( # client will use a standard HTTP/1.1 connection instead. httpx_settings.setdefault("http2", PREFECT_API_ENABLE_HTTP2.value()) - self.server_type = ( - ServerType.CLOUD - if api.startswith(PREFECT_CLOUD_API_URL.value()) - else ServerType.SERVER - ) + if server_type: + self.server_type = server_type + else: + self.server_type = ( + ServerType.CLOUD + if api.startswith(PREFECT_CLOUD_API_URL.value()) + else ServerType.SERVER + ) # Connect to an in-process application elif isinstance(api, ASGIApp): @@ -4062,6 +4075,33 @@ def read_deployment( raise return DeploymentResponse.model_validate(response.json()) + def read_deployment_by_name( + self, + name: str, + ) -> DeploymentResponse: + """ + Query the Prefect API for a deployment by name. + + Args: + name: A deployed flow's name: / + + Raises: + prefect.exceptions.ObjectNotFound: If request returns 404 + httpx.RequestError: If request fails + + Returns: + a Deployment model representation of the deployment + """ + try: + response = self._client.get(f"/deployments/name/{name}") + except httpx.HTTPStatusError as e: + if e.response.status_code == status.HTTP_404_NOT_FOUND: + raise prefect.exceptions.ObjectNotFound(http_exc=e) from e + else: + raise + + return DeploymentResponse.model_validate(response.json()) + def create_artifact( self, artifact: ArtifactCreate, diff --git a/src/prefect/events/worker.py b/src/prefect/events/worker.py index 2dc7849a1e5d..5668e18ffae5 100644 --- a/src/prefect/events/worker.py +++ b/src/prefect/events/worker.py @@ -17,7 +17,6 @@ EventsClient, NullEventsClient, PrefectCloudEventsClient, - PrefectEphemeralEventsClient, PrefectEventsClient, ) from .related import related_resources_from_run_context @@ -97,7 +96,15 @@ def instance( elif should_emit_events_to_running_server(): client_type = PrefectEventsClient elif should_emit_events_to_ephemeral_server(): - client_type = PrefectEphemeralEventsClient + # create an ephemeral API if none was provided + from prefect.server.api.server import SubprocessASGIServer + + server = SubprocessASGIServer() + server.start() + assert server.server_process is not None, "Server process did not start" + + client_kwargs = {"api_url": f"{server.address()}/api"} + client_type = PrefectEventsClient else: client_type = NullEventsClient diff --git a/src/prefect/profiles.toml b/src/prefect/profiles.toml index fa9f8930cc55..1662f5a108f2 100644 --- a/src/prefect/profiles.toml +++ b/src/prefect/profiles.toml @@ -4,7 +4,7 @@ active = "ephemeral" [profiles.ephemeral] -PREFECT_API_DATABASE_CONNECTION_URL = "sqlite+aiosqlite:///prefect.db" +PREFECT_SERVER_ALLOW_EPHEMERAL_MODE = "true" [profiles.local] # You will need to set these values appropriately for your local development environment diff --git a/src/prefect/server/api/server.py b/src/prefect/server/api/server.py index ab68c97dc6d2..f15ae87c4cc9 100644 --- a/src/prefect/server/api/server.py +++ b/src/prefect/server/api/server.py @@ -3,17 +3,24 @@ """ import asyncio +import atexit +import contextlib import mimetypes import os +import random import shutil +import socket import sqlite3 +import subprocess +import time from contextlib import asynccontextmanager from functools import partial, wraps from hashlib import sha256 -from typing import Any, Awaitable, Callable, Dict, List, Mapping, Optional, Tuple +from typing import Any, Awaitable, Callable, Dict, List, Mapping, Optional, Tuple, Union import anyio import asyncpg +import httpx import sqlalchemy as sa import sqlalchemy.exc import sqlalchemy.orm.exc @@ -49,9 +56,12 @@ PREFECT_DEBUG_MODE, PREFECT_MEMO_STORE_PATH, PREFECT_MEMOIZE_BLOCK_AUTO_REGISTRATION, + PREFECT_SERVER_EPHEMERAL_STARTUP_TIMEOUT_SECONDS, PREFECT_UI_SERVE_BASE, + get_current_settings, ) from prefect.utilities.hashing import hash_objects +from prefect.utilities.processutils import get_sys_executable TITLE = "Prefect Server" API_TITLE = "Prefect Prefect REST API" @@ -737,3 +747,136 @@ def openapi(): APP_CACHE[cache_key] = app return app + + +subprocess_server_logger = get_logger() + + +class SubprocessASGIServer: + _instances: Dict[Union[int, None], "SubprocessASGIServer"] = {} + _port_range = range(8000, 9000) + + def __new__(cls, port: Optional[int] = None, *args, **kwargs): + """ + Return an instance of the server associated with the provided port. + Prevents multiple instances from being created for the same port. + """ + if port not in cls._instances: + instance = super().__new__(cls) + cls._instances[port] = instance + return cls._instances[port] + + def __init__(self, port: Optional[int] = None): + # This ensures initialization happens only once + if not hasattr(self, "_initialized"): + if port is None: + port = self.find_available_port() + assert port is not None, "Port must be provided or available" + self.port: int = port + self.server_process = None + self.server = None + self.running = False + self._initialized = True + + def find_available_port(self): + max_attempts = 10 + for _ in range(max_attempts): + port = random.choice(self._port_range) + if self.is_port_available(port): + return port + time.sleep(random.uniform(0.1, 0.5)) # Random backoff + raise RuntimeError("Unable to find an available port after multiple attempts") + + @staticmethod + def is_port_available(port: int): + with contextlib.closing( + socket.socket(socket.AF_INET, socket.SOCK_STREAM) + ) as sock: + try: + sock.bind(("127.0.0.1", port)) + return True + except socket.error: + return False + + def address(self) -> str: + return f"http://127.0.0.1:{self.port}" + + def start(self): + """ + Start the server in a separate process. Safe to call multiple times; only starts + the server once. + """ + if not self.running: + subprocess_server_logger.info(f"Starting server on {self.address()}") + try: + self.running = True + server_env = {"PREFECT_UI_ENABLED": "0"} + self.server_process = subprocess.Popen( + args=[ + get_sys_executable(), + "-m", + "uvicorn", + "--app-dir", + # quote wrapping needed for windows paths with spaces + f'"{prefect.__module_path__.parent}"', + "--factory", + "prefect.server.api.server:create_app", + "--host", + "127.0.0.1", + "--port", + str(self.port), + "--log-level", + "error", + "--lifespan", + "on", + ], + env={ + **os.environ, + **server_env, + **get_current_settings().to_environment_variables( + exclude_unset=True + ), + }, + ) + atexit.register(self.stop) + + with httpx.Client() as client: + response = None + elapsed_time = 0 + while ( + elapsed_time + < PREFECT_SERVER_EPHEMERAL_STARTUP_TIMEOUT_SECONDS.value() + ): + try: + response = client.get(f"{self.address()}/api/health") + except httpx.ConnectError: + pass + else: + if response.status_code == 200: + break + time.sleep(0.1) + elapsed_time += 0.1 + if response: + response.raise_for_status() + if not response: + raise RuntimeError( + "Timed out while attempting to connect to hosted test Prefect API." + ) + except Exception: + self.running = False + raise + + def stop(self): + subprocess_server_logger.info(f"Stopping server on {self.address()}") + if self.server_process: + self.server_process.terminate() + try: + self.server_process.wait(timeout=5) + except subprocess.TimeoutExpired: + self.server_process.kill() + finally: + self.server_process = None + if self.port in self._instances: + del self._instances[self.port] + if self.running: + self.running = False diff --git a/src/prefect/settings.py b/src/prefect/settings.py index ba5b2ffae476..9b18cae66f62 100644 --- a/src/prefect/settings.py +++ b/src/prefect/settings.py @@ -1216,6 +1216,19 @@ def default_cloud_ui_url(settings, value): and usage patterns. """ +PREFECT_SERVER_ALLOW_EPHEMERAL_MODE = Setting(bool, default=False) +""" +Controls whether or not a subprocess server can be started when no API URL is provided. +""" + +PREFECT_SERVER_EPHEMERAL_STARTUP_TIMEOUT_SECONDS = Setting( + int, + default=10, +) +""" +The number of seconds to wait for an ephemeral server to respond on start up before erroring. +""" + PREFECT_UI_ENABLED = Setting( bool, default=True, diff --git a/src/prefect/task_engine.py b/src/prefect/task_engine.py index 0882ade66e24..b9abc33083a6 100644 --- a/src/prefect/task_engine.py +++ b/src/prefect/task_engine.py @@ -31,7 +31,7 @@ from typing_extensions import ParamSpec from prefect import Task -from prefect.client.orchestration import PrefectClient, SyncPrefectClient +from prefect.client.orchestration import PrefectClient, SyncPrefectClient, get_client from prefect.client.schemas import TaskRun from prefect.client.schemas.objects import State, TaskRunInput from prefect.concurrency.asyncio import concurrency as aconcurrency @@ -1192,8 +1192,8 @@ async def initialize_run( """ with hydrated_context(self.context): - async with AsyncClientContext.get_or_create() as client_ctx: - self._client = client_ctx.client + async with AsyncClientContext.get_or_create(): + self._client = get_client() self._is_started = True try: if PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_ORCHESTRATION: diff --git a/src/prefect/testing/fixtures.py b/src/prefect/testing/fixtures.py index 4c2c6a88f835..8326ccea08db 100644 --- a/src/prefect/testing/fixtures.py +++ b/src/prefect/testing/fixtures.py @@ -19,9 +19,11 @@ from prefect.events.clients import AssertingEventsClient from prefect.events.filters import EventFilter from prefect.events.worker import EventsWorker +from prefect.server.api.server import SubprocessASGIServer from prefect.server.events.pipeline import EventsPipeline from prefect.settings import ( PREFECT_API_URL, + PREFECT_SERVER_ALLOW_EPHEMERAL_MODE, PREFECT_SERVER_CSRF_PROTECTION_ENABLED, get_current_settings, temporary_settings, @@ -61,6 +63,7 @@ async def hosted_api_server(unused_tcp_port_factory): The API URL """ port = unused_tcp_port_factory() + print(f"Running hosted API server on port {port}") # Will connect to the same database as normal test clients async with open_process( @@ -124,7 +127,7 @@ async def hosted_api_server(unused_tcp_port_factory): pass -@pytest.fixture +@pytest.fixture(autouse=True) def use_hosted_api_server(hosted_api_server): """ Sets `PREFECT_API_URL` to the test session's hosted API endpoint. @@ -138,6 +141,34 @@ def use_hosted_api_server(hosted_api_server): yield hosted_api_server +@pytest.fixture +def disable_hosted_api_server(): + """ + Disables the hosted API server by setting `PREFECT_API_URL` to `None`. + """ + with temporary_settings( + { + PREFECT_API_URL: None, + } + ): + yield hosted_api_server + + +@pytest.fixture +def enable_ephemeral_server(disable_hosted_api_server): + """ + Enables the ephemeral server by setting `PREFECT_SERVER_ALLOW_EPHEMERAL_MODE` to `True`. + """ + with temporary_settings( + { + PREFECT_SERVER_ALLOW_EPHEMERAL_MODE: True, + } + ): + yield hosted_api_server + + SubprocessASGIServer().stop() + + @pytest.fixture def mock_anyio_sleep(monkeypatch): """ diff --git a/tests/blocks/test_core.py b/tests/blocks/test_core.py index 33f3c0837612..af350a4943da 100644 --- a/tests/blocks/test_core.py +++ b/tests/blocks/test_core.py @@ -884,8 +884,12 @@ class ParentBlock(Block): }, } - async def test_block_load(self, test_block, block_document): - my_block = await test_block.load(block_document.name) + async def test_block_load( + self, test_block, block_document, in_memory_prefect_client + ): + my_block = await test_block.load( + block_document.name, client=in_memory_prefect_client + ) assert my_block._block_document_name == block_document.name assert my_block._block_document_id == block_document.id @@ -894,14 +898,21 @@ async def test_block_load(self, test_block, block_document): assert my_block.foo == "bar" async def test_block_load_loads__collections( - self, test_block, block_document: BlockDocument, monkeypatch + self, + test_block, + block_document: BlockDocument, + monkeypatch, + in_memory_prefect_client, ): mock_load_prefect_collections = Mock() monkeypatch.setattr( prefect.plugins, "load_prefect_collections", mock_load_prefect_collections ) - await Block.load(block_document.block_type.slug + "/" + block_document.name) + await Block.load( + block_document.block_type.slug + "/" + block_document.name, + client=in_memory_prefect_client, + ) mock_load_prefect_collections.assert_called_once() async def test_load_from_block_base_class(self): @@ -914,7 +925,7 @@ class Custom(Block): loaded_block = await Block.load("custom/my-custom-block") assert loaded_block.message == "hello" - async def test_load_nested_block(self, session): + async def test_load_nested_block(self, session, in_memory_prefect_client): class B(Block): _block_schema_type = "abc" @@ -1011,7 +1022,9 @@ class E(Block): await session.commit() - block_instance = await E.load("outer-block-document") + block_instance = await E.load( + "outer-block-document", client=in_memory_prefect_client + ) assert isinstance(block_instance, E) assert isinstance(block_instance.c, C) assert isinstance(block_instance.d, D) @@ -1062,7 +1075,9 @@ def save_block_flow(): block = await Test.load("test") assert block.a == "foo" - async def test_save_protected_block_with_new_block_schema_version(self, session): + async def test_save_protected_block_with_new_block_schema_version( + self, session, prefect_client: PrefectClient + ): """ This testcase would fail when block protection was enabled for block type updates and block schema creation. @@ -1078,9 +1093,7 @@ async def test_save_protected_block_with_new_block_schema_version(self, session) block_document_id = await JSON(value={"the_answer": 42}).save("test") - block_document = await models.block_documents.read_block_document_by_id( - session=session, block_document_id=block_document_id - ) + block_document = await prefect_client.read_block_document(block_document_id) assert block_document.block_schema.version == mock_version @@ -1538,7 +1551,9 @@ async def test_save_nested_block_without_references(self, InnerBlock, OuterBlock assert loaded_outer_block.contents._block_document_id is None assert loaded_outer_block.contents._block_document_name is None - async def test_save_and_load_block_with_secrets_includes_secret_data(self, session): + async def test_save_and_load_block_with_secrets_includes_secret_data( + self, prefect_client: PrefectClient + ): class SecretBlockB(Block): w: SecretDict x: SecretStr @@ -1548,27 +1563,27 @@ class SecretBlockB(Block): block = SecretBlockB(w=dict(secret="value"), x="x", y=b"y", z="z") await block.save("secret-block") - # read from DB without secrets - db_block_without_secrets = ( - await models.block_documents.read_block_document_by_id( - session=session, - block_document_id=block._block_document_id, - ) + # read from API without secrets + api_block = await prefect_client.read_block_document( + block._block_document_id, include_secrets=False ) - assert db_block_without_secrets.data == { + assert api_block.data == { "w": {"secret": "********"}, "x": "********", "y": "********", "z": "z", } - # read from DB with secrets - db_block = await models.block_documents.read_block_document_by_id( - session=session, - block_document_id=block._block_document_id, - include_secrets=True, + # read from API with secrets + api_block = await prefect_client.read_block_document( + block._block_document_id, include_secrets=True ) - assert db_block.data == {"w": {"secret": "value"}, "x": "x", "y": "y", "z": "z"} + assert api_block.data == { + "w": {"secret": "value"}, + "x": "x", + "y": "y", + "z": "z", + } # load block with secrets api_block = await SecretBlockB.load("secret-block") @@ -1578,7 +1593,7 @@ class SecretBlockB(Block): assert api_block.z == "z" async def test_save_and_load_nested_block_with_secrets_hardcoded_child( - self, session + self, prefect_client: PrefectClient ): class Child(Block): a: SecretStr @@ -1593,14 +1608,11 @@ class Parent(Block): block = Parent(a="a", b="b", child=dict(a="a", b="b", c=dict(secret="value"))) await block.save("secret-block") - # read from DB without secrets - db_block_without_secrets = ( - await models.block_documents.read_block_document_by_id( - session=session, - block_document_id=block._block_document_id, - ) + # read from API without secrets + api_block = await prefect_client.read_block_document( + block._block_document_id, include_secrets=False ) - assert db_block_without_secrets.data == { + assert api_block.data == { "a": "********", "b": "b", "child": { @@ -1611,13 +1623,11 @@ class Parent(Block): }, } - # read from DB with secrets - db_block = await models.block_documents.read_block_document_by_id( - session=session, - block_document_id=block._block_document_id, - include_secrets=True, + # read from API with secrets + api_block = await prefect_client.read_block_document( + block._block_document_id, include_secrets=True ) - assert db_block.data == { + assert api_block.data == { "a": "a", "b": "b", "child": { @@ -1636,7 +1646,9 @@ class Parent(Block): assert api_block.child.b == "b" assert api_block.child.c.get_secret_value() == {"secret": "value"} - async def test_save_and_load_nested_block_with_secrets_saved_child(self, session): + async def test_save_and_load_nested_block_with_secrets_saved_child( + self, prefect_client: PrefectClient + ): class Child(Block): a: SecretStr b: str @@ -1652,14 +1664,11 @@ class Parent(Block): block = Parent(a="a", b="b", child=child) await block.save("parent-block") - # read from DB without secrets - db_block_without_secrets = ( - await models.block_documents.read_block_document_by_id( - session=session, - block_document_id=block._block_document_id, - ) + # read from API without secrets + api_block = await prefect_client.read_block_document( + block._block_document_id, include_secrets=False ) - assert db_block_without_secrets.data == { + assert api_block.data == { "a": "********", "b": "b", "child": { @@ -1669,13 +1678,11 @@ class Parent(Block): }, } - # read from DB with secrets - db_block = await models.block_documents.read_block_document_by_id( - session=session, - block_document_id=block._block_document_id, - include_secrets=True, + # read from API with secrets + api_block = await prefect_client.read_block_document( + block._block_document_id, include_secrets=True ) - assert db_block.data == { + assert api_block.data == { "a": "a", "b": "b", "child": {"a": "a", "b": "b", "c": {"secret": "value"}}, diff --git a/tests/blocks/test_notifications.py b/tests/blocks/test_notifications.py index e90bacaecbe1..b77e3180067f 100644 --- a/tests/blocks/test_notifications.py +++ b/tests/blocks/test_notifications.py @@ -422,6 +422,7 @@ class TestCustomWebhook: async def test_notify_async(self): with respx.mock as xmock: xmock.post("https://example.com/") + xmock.route(host="localhost").pass_through() custom_block = CustomWebhookNotificationBlock( name="test name", @@ -444,6 +445,7 @@ async def test_notify_async(self): def test_notify_sync(self): with respx.mock as xmock: xmock.post("https://example.com/") + xmock.route(host="localhost").pass_through() custom_block = CustomWebhookNotificationBlock( name="test name", @@ -452,11 +454,7 @@ def test_notify_sync(self): secrets={"token": "someSecretToken"}, ) - @flow - def test_flow(): - custom_block.notify("test", "subject") - - test_flow() + custom_block.notify("test", "subject") last_req = xmock.calls.last.request assert last_req.headers["user-agent"] == "Prefect Notifications" diff --git a/tests/cli/test_block.py b/tests/cli/test_block.py index 597e4f6a687a..00732b2a0c05 100644 --- a/tests/cli/test_block.py +++ b/tests/cli/test_block.py @@ -1,5 +1,3 @@ -import asyncio - import pytest from prefect.blocks import system @@ -47,7 +45,9 @@ def test_register_blocks_from_module_with_ui_url(): ) -def test_register_blocks_from_module_without_ui_url(): +def test_register_blocks_from_module_without_ui_url( + disable_hosted_api_server, enable_ephemeral_server +): with temporary_settings(set_defaults={PREFECT_UI_URL: None}): invoke_and_assert( ["block", "register", "-m", "prefect.blocks.core"], @@ -83,14 +83,15 @@ def test_register_blocks_from_invalid_module(): ) -def test_register_blocks_from_file(tmp_path, prefect_client: PrefectClient): +async def test_register_blocks_from_file(tmp_path, prefect_client: PrefectClient): test_file_path = tmp_path / "test.py" with open(test_file_path, "w") as f: f.write(TEST_BLOCK_CODE) with temporary_settings(set_defaults={PREFECT_UI_URL: "https://app.prefect.cloud"}): - invoke_and_assert( + await run_sync_in_worker_thread( + invoke_and_assert, ["block", "register", "-f", str(test_file_path)], expected_code=0, expected_output_contains=[ @@ -99,9 +100,7 @@ def test_register_blocks_from_file(tmp_path, prefect_client: PrefectClient): ], ) - block_type = asyncio.run( - prefect_client.read_block_type_by_slug(slug="testforfileregister") - ) + block_type = prefect_client.read_block_type_by_slug(slug="testforfileregister") assert block_type is not None @@ -297,14 +296,15 @@ def test_inspecting_a_block_type(tmp_path): ) -def test_deleting_a_block_type(tmp_path, prefect_client): +async def test_deleting_a_block_type(tmp_path, prefect_client): test_file_path = tmp_path / "test.py" with open(test_file_path, "w") as f: f.write(TEST_BLOCK_CODE) - invoke_and_assert( - ["block", "register", "-f", str(test_file_path)], + await run_sync_in_worker_thread( + invoke_and_assert, + command=["block", "register", "-f", str(test_file_path)], expected_code=0, expected_output_contains="Successfully registered 1 block", ) @@ -314,15 +314,16 @@ def test_deleting_a_block_type(tmp_path, prefect_client): "testforfileregister", ] - invoke_and_assert( - ["block", "type", "delete", "testforfileregister"], + await run_sync_in_worker_thread( + invoke_and_assert, + command=["block", "type", "delete", "testforfileregister"], expected_code=0, user_input="y", expected_output_contains=expected_output, ) with pytest.raises(ObjectNotFound): - asyncio.run(prefect_client.read_block_type_by_slug(slug="testforfileregister")) + await prefect_client.read_block_type_by_slug(slug="testforfileregister") def test_deleting_a_protected_block_type( diff --git a/tests/cli/test_config.py b/tests/cli/test_config.py index 9c4d3bb8a989..26fc8652286e 100644 --- a/tests/cli/test_config.py +++ b/tests/cli/test_config.py @@ -7,11 +7,11 @@ import prefect.settings from prefect.context import use_profile from prefect.settings import ( - PREFECT_API_DATABASE_CONNECTION_URL, PREFECT_API_DATABASE_TIMEOUT, PREFECT_API_KEY, PREFECT_LOGGING_TO_API_MAX_LOG_SIZE, PREFECT_PROFILES_PATH, + PREFECT_SERVER_ALLOW_EPHEMERAL_MODE, PREFECT_TEST_SETTING, SETTING_VARIABLES, Profile, @@ -68,7 +68,7 @@ def test_set_using_default_profile(): assert "ephemeral" in profiles assert profiles["ephemeral"].settings == { PREFECT_TEST_SETTING: "DEBUG", - PREFECT_API_DATABASE_CONNECTION_URL: "sqlite+aiosqlite:///prefect.db", + PREFECT_SERVER_ALLOW_EPHEMERAL_MODE: "true", } diff --git a/tests/cli/test_profile.py b/tests/cli/test_profile.py index 768bc07d34b6..98ddca138222 100644 --- a/tests/cli/test_profile.py +++ b/tests/cli/test_profile.py @@ -9,11 +9,11 @@ from prefect.context import use_profile from prefect.settings import ( DEFAULT_PROFILES_PATH, - PREFECT_API_DATABASE_CONNECTION_URL, PREFECT_API_KEY, PREFECT_API_URL, PREFECT_DEBUG_MODE, PREFECT_PROFILES_PATH, + PREFECT_SERVER_ALLOW_EPHEMERAL_MODE, Profile, ProfilesCollection, _read_profiles_from, @@ -70,7 +70,10 @@ def profiles(self): "PREFECT_API_URL": hosted_server_api_url, }, ), - Profile(name="ephemeral", settings={}), + Profile( + name="ephemeral", + settings={"PREFECT_SERVER_ALLOW_EPHEMERAL_MODE": True}, + ), ], active=None, ) @@ -665,7 +668,7 @@ def test_populate_defaults_migrates_default(self, temporary_profiles_path): assert new_profiles.active_name == "ephemeral" assert new_profiles["ephemeral"].settings == { PREFECT_API_KEY: "test_key", - PREFECT_API_DATABASE_CONNECTION_URL: "sqlite+aiosqlite:///prefect.db", + PREFECT_SERVER_ALLOW_EPHEMERAL_MODE: "true", } def test_show_profile_changes(self, capsys): diff --git a/tests/cli/test_version.py b/tests/cli/test_version.py index fcb2a57c0cad..5be9cbaf15b4 100644 --- a/tests/cli/test_version.py +++ b/tests/cli/test_version.py @@ -1,7 +1,8 @@ import platform import sqlite3 import sys -from unittest.mock import MagicMock, Mock +from textwrap import dedent +from unittest.mock import Mock import pendulum import pydantic @@ -9,13 +10,29 @@ import prefect from prefect.client.constants import SERVER_API_VERSION -from prefect.settings import PREFECT_API_URL, PREFECT_CLOUD_API_URL, temporary_settings +from prefect.settings import ( + PREFECT_API_URL, + PREFECT_CLOUD_API_URL, + PREFECT_SERVER_ALLOW_EPHEMERAL_MODE, + temporary_settings, +) from prefect.testing.cli import invoke_and_assert -def test_version_ephemeral_server_type(): +def test_version_ephemeral_server_type(disable_hosted_api_server): + with temporary_settings( + { + PREFECT_SERVER_ALLOW_EPHEMERAL_MODE: True, + } + ): + invoke_and_assert( + ["version"], expected_output_contains="Server type: ephemeral" + ) + + +def test_version_unconfigured_server_type(disable_hosted_api_server): invoke_and_assert( - ["version"], expected_output_contains="Server type: ephemeral" + ["version"], expected_output_contains="Server type: unconfigured" ) @@ -39,14 +56,7 @@ def test_version_cloud_server_type(): ) -def test_version_client_error_server_type(monkeypatch): - monkeypatch.setattr("prefect.get_client", MagicMock(side_effect=ValueError)) - invoke_and_assert( - ["version"], expected_output_contains="Server type: " - ) - - -def test_correct_output_ephemeral_sqlite(monkeypatch): +def test_correct_output_ephemeral_sqlite(monkeypatch, disable_hosted_api_server): version_info = prefect.__version_info__ built = pendulum.parse(prefect.__version_info__["date"]) profile = prefect.context.get_settings_context().profile @@ -55,25 +65,33 @@ def test_correct_output_ephemeral_sqlite(monkeypatch): dialect().name = "sqlite" monkeypatch.setattr("prefect.server.utilities.database.get_dialect", dialect) - invoke_and_assert( - ["version"], - expected_output=f"""Version: {prefect.__version__} -API version: {SERVER_API_VERSION} -Python version: {platform.python_version()} -Git commit: {version_info['full-revisionid'][:8]} -Built: {built.to_day_datetime_string()} -OS/Arch: {sys.platform}/{platform.machine()} -Profile: {profile.name} -Server type: ephemeral -Pydantic version: {pydantic.__version__} -Server: - Database: sqlite - SQLite version: {sqlite3.sqlite_version} -""", - ) + with temporary_settings( + { + PREFECT_SERVER_ALLOW_EPHEMERAL_MODE: True, + } + ): + invoke_and_assert( + ["version"], + expected_output=dedent( + f""" + Version: {prefect.__version__} + API version: {SERVER_API_VERSION} + Python version: {platform.python_version()} + Git commit: {version_info['full-revisionid'][:8]} + Built: {built.to_day_datetime_string()} + OS/Arch: {sys.platform}/{platform.machine()} + Profile: {profile.name} + Server type: ephemeral + Pydantic version: {pydantic.__version__} + Server: + Database: sqlite + SQLite version: {sqlite3.sqlite_version} + """, + ), + ) -def test_correct_output_ephemeral_postgres(monkeypatch): +def test_correct_output_ephemeral_postgres(monkeypatch, disable_hosted_api_server): version_info = prefect.__version_info__ built = pendulum.parse(prefect.__version_info__["date"]) profile = prefect.context.get_settings_context().profile @@ -82,21 +100,29 @@ def test_correct_output_ephemeral_postgres(monkeypatch): dialect().name = "postgres" monkeypatch.setattr("prefect.server.utilities.database.get_dialect", dialect) - invoke_and_assert( - ["version"], - expected_output=f"""Version: {prefect.__version__} -API version: {SERVER_API_VERSION} -Python version: {platform.python_version()} -Git commit: {version_info['full-revisionid'][:8]} -Built: {built.to_day_datetime_string()} -OS/Arch: {sys.platform}/{platform.machine()} -Profile: {profile.name} -Server type: ephemeral -Pydantic version: {pydantic.__version__} -Server: - Database: postgres -""", - ) + with temporary_settings( + { + PREFECT_SERVER_ALLOW_EPHEMERAL_MODE: True, + } + ): + invoke_and_assert( + ["version"], + expected_output=dedent( + f""" + Version: {prefect.__version__} + API version: {SERVER_API_VERSION} + Python version: {platform.python_version()} + Git commit: {version_info['full-revisionid'][:8]} + Built: {built.to_day_datetime_string()} + OS/Arch: {sys.platform}/{platform.machine()} + Profile: {profile.name} + Server type: ephemeral + Pydantic version: {pydantic.__version__} + Server: + Database: postgres + """, + ), + ) @pytest.mark.usefixtures("use_hosted_api_server") diff --git a/tests/cli/test_work_pool.py b/tests/cli/test_work_pool.py index 06ae98183f46..9af78d399819 100644 --- a/tests/cli/test_work_pool.py +++ b/tests/cli/test_work_pool.py @@ -60,7 +60,7 @@ def readchar(): class TestCreate: @pytest.mark.usefixtures("mock_collection_registry") - async def test_create_work_pool(self, prefect_client, mock_collection_registry): + async def test_create_work_pool(self, prefect_client): pool_name = "my-pool" res = await run_sync_in_worker_thread( invoke_and_assert, diff --git a/tests/client/test_base_client.py b/tests/client/test_base_client.py index f4d7467490bf..bd4d0006ef98 100644 --- a/tests/client/test_base_client.py +++ b/tests/client/test_base_client.py @@ -11,13 +11,21 @@ import prefect import prefect.client import prefect.client.constants -from prefect.client.base import PrefectHttpxAsyncClient, PrefectResponse +from prefect.client.base import ( + PrefectHttpxAsyncClient, + PrefectResponse, + ServerType, + determine_server_type, +) from prefect.client.schemas.objects import CsrfToken from prefect.exceptions import PrefectHTTPStatusError from prefect.settings import ( + PREFECT_API_URL, PREFECT_CLIENT_MAX_RETRIES, PREFECT_CLIENT_RETRY_EXTRA_CODES, PREFECT_CLIENT_RETRY_JITTER_FACTOR, + PREFECT_CLOUD_API_URL, + PREFECT_SERVER_ALLOW_EPHEMERAL_MODE, temporary_settings, ) from prefect.testing.utilities import AsyncMock @@ -739,3 +747,41 @@ async def test_passes_informative_user_agent( assert isinstance(request, httpx.Request) assert request.headers["User-Agent"] == "prefect/42.43.44 (API 45.46.47)" + + +class TestDetermineServerType: + @pytest.mark.parametrize( + "temp_settings, expected_type", + [ + ( + { + PREFECT_API_URL: "http://localhost:4200/api", + }, + ServerType.SERVER, + ), + ( + { + PREFECT_API_URL: None, + PREFECT_SERVER_ALLOW_EPHEMERAL_MODE: True, + }, + ServerType.EPHEMERAL, + ), + ( + { + PREFECT_API_URL: None, + PREFECT_SERVER_ALLOW_EPHEMERAL_MODE: False, + }, + ServerType.UNCONFIGURED, + ), + ( + { + PREFECT_CLOUD_API_URL: "https://api.prefect.cloud/api/", + PREFECT_API_URL: "https://api.prefect.cloud/api/accounts/foo/workspaces/bar", + }, + ServerType.CLOUD, + ), + ], + ) + def test_with_settings_variations(self, temp_settings, expected_type): + with temporary_settings(temp_settings): + assert determine_server_type() == expected_type diff --git a/tests/client/test_cloud_client.py b/tests/client/test_cloud_client.py index 549e9335af87..2a7841d52c46 100644 --- a/tests/client/test_cloud_client.py +++ b/tests/client/test_cloud_client.py @@ -64,7 +64,7 @@ async def test_cloud_client_follow_redirects(): assert client._client.follow_redirects is False -async def test_get_cloud_work_pool_types(mock_work_pool_types): +async def test_get_cloud_work_pool_types(): account_id = uuid.uuid4() workspace_id = uuid.uuid4() with temporary_settings( @@ -72,6 +72,23 @@ async def test_get_cloud_work_pool_types(mock_work_pool_types): PREFECT_API_URL: f"https://api.prefect.cloud/api/accounts/{account_id}/workspaces/{workspace_id}/" } ): - async with get_cloud_client() as client: - response = await client.read_worker_metadata() - assert response == mock_work_pool_types_response + with respx.mock( + assert_all_mocked=False, base_url=PREFECT_API_URL.value() + ) as respx_mock: + respx_mock.route( + M( + host="api.prefect.cloud", + path__regex=( + r"api/accounts/(.{36})/workspaces/(.{36})/collections/work_pool_types" + ), + ), + method="GET", + ).mock( + return_value=httpx.Response( + 200, + json=mock_work_pool_types_response, + ) + ) + async with get_cloud_client() as client: + response = await client.read_worker_metadata() + assert response == mock_work_pool_types_response diff --git a/tests/client/test_prefect_client.py b/tests/client/test_prefect_client.py index cf460ee538de..ffd3446d2c15 100644 --- a/tests/client/test_prefect_client.py +++ b/tests/client/test_prefect_client.py @@ -1,5 +1,6 @@ import json import os +import subprocess from contextlib import asynccontextmanager from datetime import timedelta from typing import Generator, List @@ -100,6 +101,28 @@ def test_get_client_cache_uses_profile_settings(self): assert isinstance(new_client, PrefectClient) assert new_client is not client + def test_get_client_starts_subprocess_server_when_enabled( + self, enable_ephemeral_server, monkeypatch + ): + popen_spy = MagicMock() + orig_popen = subprocess.Popen + + def popen_stub(*args, **kwargs): + popen_spy(*args, **kwargs) + return orig_popen(*args, **kwargs) + + monkeypatch.setattr("prefect.server.api.server.subprocess.Popen", popen_stub) + + get_client() + assert popen_spy.call_count == 1 + assert "prefect.server.api.server:create_app" in popen_spy.call_args[1]["args"] + + def test_get_client_rasises_error_when_no_api_url_and_no_ephemeral_mode( + self, disable_hosted_api_server + ): + with pytest.raises(ValueError, match="API URL"): + get_client() + class TestClientProxyAwareness: """Regression test for https://github.com/PrefectHQ/nebula/issues/2356, where @@ -1342,8 +1365,9 @@ async def test_prefect_api_tls_insecure_skip_verify_setting_set_to_true(monkeypa mock.assert_called_once_with( headers=ANY, verify=False, - transport=ANY, base_url=ANY, + limits=ANY, + http2=ANY, timeout=ANY, enable_csrf_support=ANY, ) @@ -1360,8 +1384,9 @@ async def test_prefect_api_tls_insecure_skip_verify_setting_set_to_false(monkeyp mock.assert_called_once_with( headers=ANY, verify=ANY, - transport=ANY, base_url=ANY, + limits=ANY, + http2=ANY, timeout=ANY, enable_csrf_support=ANY, ) @@ -1374,8 +1399,9 @@ async def test_prefect_api_tls_insecure_skip_verify_default_setting(monkeypatch) mock.assert_called_once_with( headers=ANY, verify=ANY, - transport=ANY, base_url=ANY, + limits=ANY, + http2=ANY, timeout=ANY, enable_csrf_support=ANY, ) @@ -1397,8 +1423,9 @@ async def test_prefect_api_ssl_cert_file_setting_explicitly_set(monkeypatch): mock.assert_called_once_with( headers=ANY, verify="my_cert.pem", - transport=ANY, base_url=ANY, + limits=ANY, + http2=ANY, timeout=ANY, enable_csrf_support=ANY, ) @@ -1420,8 +1447,9 @@ async def test_prefect_api_ssl_cert_file_default_setting(monkeypatch): mock.assert_called_once_with( headers=ANY, verify="my_cert.pem", - transport=ANY, base_url=ANY, + limits=ANY, + http2=ANY, timeout=ANY, enable_csrf_support=ANY, ) @@ -1443,8 +1471,9 @@ async def test_prefect_api_ssl_cert_file_default_setting_fallback(monkeypatch): mock.assert_called_once_with( headers=ANY, verify=certifi.where(), - transport=ANY, base_url=ANY, + limits=ANY, + http2=ANY, timeout=ANY, enable_csrf_support=ANY, ) @@ -1733,7 +1762,8 @@ async def test_delete_flow_run(prefect_client, flow_run): await prefect_client.delete_flow_run(flow_run.id) -def test_server_type_ephemeral(prefect_client): +def test_server_type_ephemeral(enable_ephemeral_server): + prefect_client = get_client() assert prefect_client.server_type == ServerType.EPHEMERAL @@ -2467,7 +2497,8 @@ async def test_delete_deployment_schedule_nonexistent( class TestPrefectClientCsrfSupport: - def test_enabled_ephemeral(self, prefect_client: PrefectClient): + def test_enabled_ephemeral(self, enable_ephemeral_server): + prefect_client = get_client() assert prefect_client.server_type == ServerType.EPHEMERAL assert prefect_client._client.enable_csrf_support diff --git a/tests/concurrency/test_context.py b/tests/concurrency/test_context.py index 41f76934d812..c1a0dc826278 100644 --- a/tests/concurrency/test_context.py +++ b/tests/concurrency/test_context.py @@ -3,7 +3,7 @@ import pytest -from prefect.client.orchestration import PrefectClient +from prefect.client.orchestration import PrefectClient, get_client from prefect.concurrency.asyncio import concurrency as aconcurrency from prefect.concurrency.context import ConcurrencyContext from prefect.concurrency.sync import concurrency @@ -41,10 +41,9 @@ async def test_concurrency_context_releases_slots_sync( ): def expensive_task(): with concurrency(concurrency_limit.name): + client = get_client() response = run_coro_as_sync( - prefect_client.read_global_concurrency_limit_by_name( - concurrency_limit.name - ) + client.read_global_concurrency_limit_by_name(concurrency_limit.name) ) assert response and response.active_slots == 1 diff --git a/tests/events/client/test_automations_server_compatibility.py b/tests/events/client/test_automations_server_compatibility.py index 12f5fc301bef..a65122b41a8f 100644 --- a/tests/events/client/test_automations_server_compatibility.py +++ b/tests/events/client/test_automations_server_compatibility.py @@ -139,18 +139,21 @@ def test_all_triggers_represented(): @pytest.mark.parametrize("trigger", EXAMPLE_TRIGGERS) -async def test_trigger_round_tripping(trigger: TriggerTypes): +async def test_trigger_round_tripping(trigger: TriggerTypes, in_memory_prefect_client): """Tests that any of the example client triggers can be round-tripped to the Prefect server""" - async with get_client() as client: - automation_id = await client.create_automation( - AutomationCore( - name="test", - trigger=trigger, - actions=[DoNothing()], - ) + # Using an in-memory client because the Pydantic model marshalling doesn't work + # with the hosted API server. It appears to chose the client-side model for EventTrigger + # instead of the server-side model. + # TODO: Fix the model resolution to work with the hosted API server + automation_id = await in_memory_prefect_client.create_automation( + AutomationCore( + name="test", + trigger=trigger, + actions=[DoNothing()], ) - automation = await client.read_automation(automation_id) + ) + automation = await in_memory_prefect_client.read_automation(automation_id) sent = trigger.model_dump() returned = automation.trigger.model_dump() diff --git a/tests/events/client/test_events_worker.py b/tests/events/client/test_events_worker.py index fe94be00e8a0..b338c49dd1dd 100644 --- a/tests/events/client/test_events_worker.py +++ b/tests/events/client/test_events_worker.py @@ -7,7 +7,6 @@ from prefect.events import Event from prefect.events.clients import ( AssertingEventsClient, - PrefectEphemeralEventsClient, PrefectEventsClient, ) from prefect.events.worker import EventsWorker @@ -34,12 +33,6 @@ def test_emits_event_via_client(asserting_events_worker: EventsWorker, event: Ev assert asserting_events_worker._client.events == [event] -def test_worker_instance_ephemeral_client_no_api_url(): - with temporary_settings(updates={PREFECT_API_URL: None}): - worker = EventsWorker.instance() - assert worker.client_type == PrefectEphemeralEventsClient - - def test_worker_instance_server_client_non_cloud_api_url(): with temporary_settings(updates={PREFECT_API_URL: "http://localhost:8080/api"}): worker = EventsWorker.instance() @@ -52,9 +45,13 @@ def test_worker_instance_client_non_cloud_api_url_events_enabled(): assert worker.client_type == PrefectEventsClient -def test_worker_instance_ephemeral_prefect_events_client(): +def test_worker_instance_ephemeral_prefect_events_client(enable_ephemeral_server): + """ + Getting an instance of the worker with ephemeral server mode enabled should + return a PrefectEventsClient pointing to the subprocess server. + """ worker = EventsWorker.instance() - assert worker.client_type == PrefectEphemeralEventsClient + assert worker.client_type == PrefectEventsClient async def test_includes_related_resources_from_run_context( diff --git a/tests/events/server/actions/test_calling_webhook.py b/tests/events/server/actions/test_calling_webhook.py index 3e992294fdc4..7559dd0395d4 100644 --- a/tests/events/server/actions/test_calling_webhook.py +++ b/tests/events/server/actions/test_calling_webhook.py @@ -93,9 +93,9 @@ async def take_a_picture_work_queue( @pytest.fixture -async def webhook_block_id() -> UUID: +async def webhook_block_id(in_memory_prefect_client) -> UUID: block = Webhook(method="POST", url="https://example.com", headers={"foo": "bar"}) - return await block.save(name="webhook-test") + return await block.save(name="webhook-test", client=in_memory_prefect_client) @pytest.fixture diff --git a/tests/fixtures/api.py b/tests/fixtures/api.py index fad347c04489..26adaf9e65d8 100644 --- a/tests/fixtures/api.py +++ b/tests/fixtures/api.py @@ -3,6 +3,7 @@ import httpx import pytest from fastapi import FastAPI +from fastapi.testclient import TestClient from httpx import ASGITransport, AsyncClient from prefect.server.api.server import create_app @@ -19,17 +20,27 @@ def app() -> FastAPI: @pytest.fixture -async def client(app: ASGIApp) -> AsyncGenerator[AsyncClient, Any]: +def test_client(app: FastAPI) -> TestClient: + return TestClient(app) + + +@pytest.fixture +async def client(app) -> AsyncGenerator[AsyncClient, Any]: """ Yield a test client for testing the api """ - transport = ASGITransport(app=app) async with httpx.AsyncClient( - transport=transport, base_url="https://test/api" + transport=ASGITransport(app=app), base_url="https://test/api" ) as async_client: yield async_client +@pytest.fixture +async def hosted_api_client(use_hosted_api_server) -> AsyncGenerator[AsyncClient, Any]: + async with httpx.AsyncClient(base_url=use_hosted_api_server) as async_client: + yield async_client + + @pytest.fixture async def client_with_unprotected_block_api( app: ASGIApp, diff --git a/tests/fixtures/client.py b/tests/fixtures/client.py index 5efeced99168..fde737b97e73 100644 --- a/tests/fixtures/client.py +++ b/tests/fixtures/client.py @@ -16,6 +16,20 @@ async def prefect_client( yield client +@pytest.fixture +async def in_memory_prefect_client(app) -> AsyncGenerator[PrefectClient, None]: + """ + Yield a test client that communicates with an in-memory server + """ + # This was created because we were getting test failures caused by the + # hosted API fixture using a different DB than the bare DB operations + # in tests. + # TODO: Figure out how to use the `prefect_client` fixture instead for + # tests/fixtures using this fixture. + async with PrefectClient(api=app) as client: + yield client + + @pytest.fixture def sync_prefect_client(test_database_connection_url): yield get_client(sync_client=True) diff --git a/tests/fixtures/collections_registry.py b/tests/fixtures/collections_registry.py index fba9a483af0a..3544234ed137 100644 --- a/tests/fixtures/collections_registry.py +++ b/tests/fixtures/collections_registry.py @@ -1,8 +1,11 @@ +from unittest.mock import ANY + import httpx import pytest import respx from prefect.server.api import collections +from prefect.settings import PREFECT_API_URL FAKE_DEFAULT_BASE_JOB_TEMPLATE = { "job_configuration": { @@ -427,8 +430,11 @@ def mock_collection_registry( with respx.mock( assert_all_mocked=False, assert_all_called=False, + base_url=PREFECT_API_URL.value(), ) as respx_mock: - respx_mock.get( - "https://raw.githubusercontent.com/PrefectHQ/prefect-collection-registry/main/views/aggregate-worker-metadata.json" - ).mock(return_value=httpx.Response(200, json=mock_body)) + respx_mock.get("/csrf-token", params={"client": ANY}).pass_through() + respx_mock.route(path__startswith="/work_pools/").pass_through() + respx_mock.get("/collections/views/aggregate-worker-metadata").mock( + return_value=httpx.Response(200, json=mock_body) + ) yield diff --git a/tests/fixtures/database.py b/tests/fixtures/database.py index 3812e8842583..47554b7a4939 100644 --- a/tests/fixtures/database.py +++ b/tests/fixtures/database.py @@ -1066,12 +1066,14 @@ async def notify(self, subject: str, body: str): @pytest.fixture -async def notifier_block(DebugPrintNotification: Type[NotificationBlock]): +async def notifier_block( + DebugPrintNotification: Type[NotificationBlock], in_memory_prefect_client +): # Ignore warnings from block reuse in fixture warnings.filterwarnings("ignore", category=UserWarning) block = DebugPrintNotification() - await block.save("debug-print-notification") + await block.save("debug-print-notification", client=in_memory_prefect_client) return block diff --git a/tests/runner/test_runner.py b/tests/runner/test_runner.py index 97d7c8c7f979..41e8269d392b 100644 --- a/tests/runner/test_runner.py +++ b/tests/runner/test_runner.py @@ -22,7 +22,7 @@ import prefect.runner from prefect import __version__, flow, serve, task -from prefect.client.orchestration import PrefectClient +from prefect.client.orchestration import PrefectClient, SyncPrefectClient from prefect.client.schemas.actions import DeploymentScheduleCreate from prefect.client.schemas.objects import StateType from prefect.client.schemas.schedules import CronSchedule, IntervalSchedule @@ -234,15 +234,15 @@ def type_container_input_flow(arg1: List[str]) -> str: def test_serve_can_create_multiple_deployments( self, - prefect_client: PrefectClient, + sync_prefect_client: SyncPrefectClient, ): deployment_1 = dummy_flow_1.to_deployment(__file__, interval=3600) deployment_2 = dummy_flow_2.to_deployment(__file__, cron="* * * * *") serve(deployment_1, deployment_2) - deployment = asyncio.run( - prefect_client.read_deployment_by_name(name="dummy-flow-1/test_runner") + deployment = sync_prefect_client.read_deployment_by_name( + name="dummy-flow-1/test_runner" ) assert deployment is not None @@ -250,8 +250,8 @@ def test_serve_can_create_multiple_deployments( seconds=3600 ) - deployment = asyncio.run( - prefect_client.read_deployment_by_name(name="dummy-flow-2/test_runner") + deployment = sync_prefect_client.read_deployment_by_name( + name="dummy-flow-2/test_runner" ) assert deployment is not None diff --git a/tests/runtime/test_flow_run.py b/tests/runtime/test_flow_run.py index 655d534921bc..2323bc8986d6 100644 --- a/tests/runtime/test_flow_run.py +++ b/tests/runtime/test_flow_run.py @@ -599,15 +599,19 @@ async def test_url_is_none_when_id_not_set(self, url_type): assert getattr(flow_run, url_type) is None @pytest.mark.parametrize( - "url_type, base_url_value", - [("api_url", PREFECT_API_URL.value()), ("ui_url", PREFECT_UI_URL.value())], + "url_type,", + ["api_url", "ui_url"], ) async def test_url_returns_correct_url_when_id_present( self, url_type, - base_url_value, ): test_id = "12345" + if url_type == "api_url": + base_url_value = PREFECT_API_URL.value() + elif url_type == "ui_url": + base_url_value = PREFECT_UI_URL.value() + expected_url = f"{base_url_value}/flow-runs/flow-run/{test_id}" with FlowRunContext.model_construct( @@ -618,20 +622,24 @@ async def test_url_returns_correct_url_when_id_present( assert not getattr(flow_run, url_type) @pytest.mark.parametrize( - "url_type, base_url_value", - [("api_url", PREFECT_API_URL.value()), ("ui_url", PREFECT_UI_URL.value())], + "url_type,", + ["api_url", "ui_url"], ) async def test_url_pulls_from_api_when_needed( self, monkeypatch, prefect_client, url_type, - base_url_value, ): run = await prefect_client.create_flow_run(flow=flow(lambda: None, name="test")) assert not getattr(flow_run, url_type) + if url_type == "api_url": + base_url_value = PREFECT_API_URL.value() + elif url_type == "ui_url": + base_url_value = PREFECT_UI_URL.value() + expected_url = f"{base_url_value}/flow-runs/flow-run/{str(run.id)}" monkeypatch.setenv(name="PREFECT__FLOW_RUN_ID", value=str(run.id)) diff --git a/tests/server/api/test_server.py b/tests/server/api/test_server.py index dd818620beb5..a5bd9b931426 100644 --- a/tests/server/api/test_server.py +++ b/tests/server/api/test_server.py @@ -1,8 +1,11 @@ +import contextlib +import socket import sqlite3 from unittest.mock import MagicMock, patch from uuid import uuid4 import asyncpg +import httpx import pytest import sqlalchemy as sa import toml @@ -10,9 +13,12 @@ from httpx import ASGITransport, AsyncClient from prefect.client.constants import SERVER_API_VERSION +from prefect.client.orchestration import get_client +from prefect.flows import flow from prefect.server.api.server import ( API_ROUTERS, SQLITE_LOCKED_MSG, + SubprocessASGIServer, _memoize_block_auto_registration, create_api_app, create_app, @@ -20,6 +26,7 @@ ) from prefect.settings import ( PREFECT_API_DATABASE_CONNECTION_URL, + PREFECT_API_URL, PREFECT_MEMO_STORE_PATH, PREFECT_MEMOIZE_BLOCK_AUTO_REGISTRATION, temporary_settings, @@ -405,3 +412,73 @@ async def test_changing_database_breaks_cache(self, enable_memoization): await _memoize_block_auto_registration(test_func)() assert test_func.call_count == 2 + + +class TestSubprocessASGIServer: + def test_singleton_on_port(self): + server_8000 = SubprocessASGIServer(port=8000) + assert server_8000 is SubprocessASGIServer(port=8000) + + server_random = SubprocessASGIServer() + assert server_random is SubprocessASGIServer() + + assert server_8000 is not server_random + + def test_find_available_port_returns_available_port(self): + server = SubprocessASGIServer() + port = server.find_available_port() + assert server.is_port_available(port) + assert 8000 <= port < 9000 + + def test_is_port_available_returns_true_for_available_port(self): + server = SubprocessASGIServer() + port = server.find_available_port() + assert server.is_port_available(port) + + def test_is_port_available_returns_false_for_unavailable_port(self): + server = SubprocessASGIServer() + with contextlib.closing( + socket.socket(socket.AF_INET, socket.SOCK_STREAM) + ) as sock: + sock.bind(("127.0.0.1", 12345)) + assert not server.is_port_available(12345) + + def test_start_is_idempotent(self, respx_mock, monkeypatch): + popen_mock = MagicMock() + monkeypatch.setattr("prefect.server.api.server.subprocess.Popen", popen_mock) + respx_mock.get("http://127.0.0.1:8000/api/health").respond(status_code=200) + server = SubprocessASGIServer(port=8000) + server.start() + server.start() + + assert popen_mock.call_count == 1 + + def test_address_returns_correct_address(self): + server = SubprocessASGIServer(port=8000) + assert server.address() == "http://127.0.0.1:8000" + + def test_start_and_stop_server(self): + server = SubprocessASGIServer() + server.start() + health_response = httpx.get(f"{server.address()}/api/health") + assert health_response.status_code == 200 + + server.stop() + with pytest.raises(httpx.RequestError): + httpx.get(f"{server.address()}/api/health") + + def test_run_a_flow_against_subprocess_server(self): + @flow + def f(): + return 42 + + server = SubprocessASGIServer() + server.start() + + with temporary_settings({PREFECT_API_URL: f"{server.address()}/api"}): + assert f() == 42 + + client = get_client(sync_client=True) + assert len(client.read_flow_runs()) == 1 + + server.stop() diff --git a/tests/server/models/test_block_documents.py b/tests/server/models/test_block_documents.py index 8175c7815481..1fffd3b4440d 100644 --- a/tests/server/models/test_block_documents.py +++ b/tests/server/models/test_block_documents.py @@ -1754,16 +1754,16 @@ async def test_updating_secret_block_document_with_obfuscated_result_is_ignored( # x was NOT overwritten assert block2.data["x"] != obfuscate_string(X) - async def test_block_with_list_of_secrets(self, session): + async def test_block_with_list_of_secrets(self, session, prefect_client): class ListSecretBlock(Block): x: List[SecretStr] # save the block orig_block = ListSecretBlock(x=["a", "b"]) - await orig_block.save(name="list-secret") + await orig_block.save(name="list-secret", client=prefect_client) # load the block - block = await ListSecretBlock.load("list-secret") + block = await ListSecretBlock.load("list-secret", client=prefect_client) assert block.x[0].get_secret_value() == "a" assert block.x[1].get_secret_value() == "b" diff --git a/tests/server/orchestration/api/test_block_documents.py b/tests/server/orchestration/api/test_block_documents.py index 2df059b5c45a..9b4b83a1e163 100644 --- a/tests/server/orchestration/api/test_block_documents.py +++ b/tests/server/orchestration/api/test_block_documents.py @@ -1416,7 +1416,7 @@ async def test_read_secret_block_documents_with_secrets( assert blocks[0].data["z"] == Z async def test_nested_block_secrets_are_obfuscated_when_all_blocks_are_saved( - self, client, session + self, hosted_api_client, session ): class ChildBlock(Block): x: SecretStr @@ -1435,7 +1435,9 @@ class ParentBlock(Block): block = ParentBlock(a=3, b="b", child=child) await block.save("nested-test") await session.commit() - response = await client.get(f"/block_documents/{block._block_document_id}") + response = await hosted_api_client.get( + f"/block_documents/{block._block_document_id}" + ) block = schemas.core.BlockDocument.model_validate(response.json()) assert block.data["a"] == 3 assert block.data["b"] == obfuscate_string("b") @@ -1443,7 +1445,7 @@ class ParentBlock(Block): assert block.data["child"]["y"] == Y assert block.data["child"]["z"] == {"secret": obfuscate_string(Z)} - async def test_nested_block_secrets_are_returned(self, client): + async def test_nested_block_secrets_are_returned(self, hosted_api_client): class ChildBlock(Block): x: SecretStr y: str @@ -1457,7 +1459,7 @@ class ParentBlock(Block): block = ParentBlock(a=3, b="b", child=ChildBlock(x=X, y=Y, z=dict(secret=Z))) await block.save("nested-test") - response = await client.get( + response = await hosted_api_client.get( f"/block_documents/{block._block_document_id}", params=dict(include_secrets=True), ) diff --git a/tests/server/orchestration/api/test_block_types.py b/tests/server/orchestration/api/test_block_types.py index 3d94b944c539..5d27535f67ec 100644 --- a/tests/server/orchestration/api/test_block_types.py +++ b/tests/server/orchestration/api/test_block_types.py @@ -491,13 +491,13 @@ async def test_install_system_block_types_multiple_times(self, client): await client.post("/block_types/install_system_block_types") await client.post("/block_types/install_system_block_types") - async def test_create_system_block_type(self, client, session): + async def test_create_system_block_type(self, hosted_api_client, session): # install system blocks - await client.post("/block_types/install_system_block_types") + await hosted_api_client.post("/block_types/install_system_block_types") # create a datetime block - datetime_block_type = await client.get("/block_types/slug/date-time") - datetime_block_schema = await client.post( + datetime_block_type = await hosted_api_client.get("/block_types/slug/date-time") + datetime_block_schema = await hosted_api_client.post( "/block_schemas/filter", json=dict( block_schemas=dict( @@ -507,7 +507,7 @@ async def test_create_system_block_type(self, client, session): ), ) block = prefect.blocks.system.DateTime(value="2022-01-01T00:00:00+00:00") - response = await client.post( + response = await hosted_api_client.post( "/block_documents/", json=block._to_block_document( name="my-test-date-time", diff --git a/tests/server/orchestration/api/test_deployments.py b/tests/server/orchestration/api/test_deployments.py index 82fafa1b5068..c02c2681176c 100644 --- a/tests/server/orchestration/api/test_deployments.py +++ b/tests/server/orchestration/api/test_deployments.py @@ -39,7 +39,7 @@ class TestCreateDeployment: async def test_create_oldstyle_deployment( self, session, - client, + hosted_api_client, flow, flow_function, storage_document_id, @@ -52,7 +52,7 @@ async def test_create_oldstyle_deployment( parameters={"foo": "bar"}, storage_document_id=storage_document_id, ).model_dump(mode="json") - response = await client.post("/deployments/", json=data) + response = await hosted_api_client.post("/deployments/", json=data) assert response.status_code == status.HTTP_201_CREATED assert response.json()["name"] == "My Deployment" assert response.json()["version"] == "mint" @@ -72,7 +72,7 @@ async def test_create_oldstyle_deployment( async def test_create_deployment( self, session, - client, + hosted_api_client, flow, flow_function, storage_document_id, @@ -88,7 +88,7 @@ async def test_create_deployment( job_variables={"cpu": 24}, storage_document_id=storage_document_id, ).model_dump(mode="json") - response = await client.post("/deployments/", json=data) + response = await hosted_api_client.post("/deployments/", json=data) assert response.status_code == status.HTTP_201_CREATED deployment_response = DeploymentResponse(**response.json()) @@ -318,7 +318,7 @@ async def test_default_work_queue_name_is_none(self, session, client, flow): async def test_create_deployment_respects_flow_id_name_uniqueness( self, session, - client, + hosted_api_client, flow, storage_document_id, ): @@ -328,7 +328,7 @@ async def test_create_deployment_respects_flow_id_name_uniqueness( paused=True, storage_document_id=storage_document_id, ).model_dump(mode="json") - response = await client.post("/deployments/", json=data) + response = await hosted_api_client.post("/deployments/", json=data) assert response.status_code == 201 assert response.json()["name"] == "My Deployment" deployment_id = response.json()["id"] @@ -340,7 +340,7 @@ async def test_create_deployment_respects_flow_id_name_uniqueness( paused=True, storage_document_id=storage_document_id, ).model_dump(mode="json") - response = await client.post("/deployments/", json=data) + response = await hosted_api_client.post("/deployments/", json=data) assert response.status_code == status.HTTP_200_OK assert response.json()["name"] == "My Deployment" assert response.json()["id"] == deployment_id @@ -355,7 +355,7 @@ async def test_create_deployment_respects_flow_id_name_uniqueness( paused=False, # CHANGED storage_document_id=storage_document_id, ).model_dump(mode="json") - response = await client.post("/deployments/", json=data) + response = await hosted_api_client.post("/deployments/", json=data) assert response.status_code == status.HTTP_200_OK assert response.json()["name"] == "My Deployment" assert response.json()["id"] == deployment_id diff --git a/tests/server/orchestration/api/test_infra_overrides.py b/tests/server/orchestration/api/test_infra_overrides.py index 3152fe56ec55..88a793ad05ab 100644 --- a/tests/server/orchestration/api/test_infra_overrides.py +++ b/tests/server/orchestration/api/test_infra_overrides.py @@ -400,7 +400,7 @@ async def test_creating_flow_run_with_missing_work_queue( async def test_base_job_template_default_references_to_blocks( self, session, - client, + hosted_api_client, k8s_credentials, ): # create a pool with a pool schema that has a default value referencing a block @@ -451,7 +451,7 @@ async def test_base_job_template_default_references_to_blocks( ) # create a flow run with no overrides - response = await client.post( + response = await hosted_api_client.post( f"/deployments/{deployment.id}/create_flow_run", json={} ) @@ -666,7 +666,7 @@ async def test_updating_flow_run_with_missing_work_queue( async def test_base_job_template_default_references_to_blocks( self, session, - client, + hosted_api_client, k8s_credentials, ): # create a pool with a pool schema that has a default value referencing a block @@ -718,7 +718,7 @@ async def test_base_job_template_default_references_to_blocks( # create a flow run with custom overrides updates = {"k8s_credentials": {"context_name": "foo", "config": {}}} - response = await client.post( + response = await hosted_api_client.post( f"/deployments/{deployment.id}/create_flow_run", json={"job_variables": updates}, ) @@ -727,7 +727,7 @@ async def test_base_job_template_default_references_to_blocks( # update the flow run to force it to refer to the default block's value flow_run_id = response.json()["id"] - response = await client.patch( + response = await hosted_api_client.patch( f"/flow_runs/{flow_run_id}", json={"job_variables": {}} ) assert response.status_code == 204, response.text diff --git a/tests/server/orchestration/api/test_task_run_subscriptions.py b/tests/server/orchestration/api/test_task_run_subscriptions.py index 1d92fd27ceed..08d99137259e 100644 --- a/tests/server/orchestration/api/test_task_run_subscriptions.py +++ b/tests/server/orchestration/api/test_task_run_subscriptions.py @@ -396,7 +396,7 @@ async def test_task_worker_basic_tracking( task_keys, expected_workers, client_id, - prefect_client, + test_client, ): for _ in range(num_connections): with authenticated_socket(app) as socket: @@ -404,7 +404,7 @@ async def test_task_worker_basic_tracking( {"type": "subscribe", "keys": task_keys, "client_id": client_id} ) - response = await prefect_client._client.post("/task_workers/filter") + response = test_client.post("api/task_workers/filter") assert response.status_code == 200 tracked_workers = response.json() assert len(tracked_workers) == expected_workers diff --git a/tests/server/orchestration/api/test_task_workers.py b/tests/server/orchestration/api/test_task_workers.py index 530d9f9fa498..f7ad7ce3b739 100644 --- a/tests/server/orchestration/api/test_task_workers.py +++ b/tests/server/orchestration/api/test_task_workers.py @@ -19,13 +19,13 @@ ], ) async def test_read_task_workers( - prefect_client, initial_workers, certain_tasks, expected_count + test_client, initial_workers, certain_tasks, expected_count ): for worker, tasks in initial_workers.items(): await observe_worker(tasks, worker) - response = await prefect_client._client.post( - "/task_workers/filter", + response = test_client.post( + "api/task_workers/filter", json={"task_worker_filter": {"task_keys": certain_tasks}} if certain_tasks else None, diff --git a/tests/server/services/test_flow_run_notifications.py b/tests/server/services/test_flow_run_notifications.py index 1572a2ee0ba0..b85b40b38edd 100644 --- a/tests/server/services/test_flow_run_notifications.py +++ b/tests/server/services/test_flow_run_notifications.py @@ -5,7 +5,7 @@ from prefect.server import models, schemas from prefect.server.services.flow_run_notifications import FlowRunNotifications -from prefect.settings import PREFECT_UI_URL, temporary_settings +from prefect.settings import PREFECT_API_URL, PREFECT_UI_URL, temporary_settings @pytest.fixture @@ -216,7 +216,7 @@ async def test_service_only_sends_notifications_for_matching_policy( @pytest.mark.parametrize( "provided_ui_url,expected_ui_url", [ - (None, "http://ephemeral-prefect/api"), + (None, "from-settings"), ("http://some-url", "http://some-url"), ], ) @@ -225,6 +225,8 @@ def test_get_ui_url_for_flow_run_id_with_ui_url( ): with temporary_settings({PREFECT_UI_URL: provided_ui_url}): url = FlowRunNotifications().get_ui_url_for_flow_run_id(flow_run_id=flow_run.id) + if expected_ui_url == "from-settings": + expected_ui_url = PREFECT_API_URL.value()[:-4] assert url == expected_ui_url + "/runs/flow-run/{flow_run_id}".format( flow_run_id=flow_run.id ) diff --git a/tests/test_background_tasks.py b/tests/test_background_tasks.py index 2bcfd4cb58f1..6e2d543f7d08 100644 --- a/tests/test_background_tasks.py +++ b/tests/test_background_tasks.py @@ -175,10 +175,18 @@ async def test_async_task_submission_creates_a_scheduled_task_run( async def test_scheduled_tasks_are_enqueued_server_side( - foo_task_with_result_storage: Task, prefect_client + foo_task_with_result_storage: Task, + in_memory_prefect_client: "PrefectClient", + monkeypatch, ): + # Need to mock `get_client` to return the in-memory client because we are directly inspecting + # changes in the server-side task queue. Ideally, we'd be able to inspect the task queue via + # the REST API for this test, but that's not currently possible. + # TODO: Add ways to inspect the task queue via the REST API + monkeypatch.setattr(prefect.tasks, "get_client", lambda: in_memory_prefect_client) + task_run_future = foo_task_with_result_storage.apply_async((42,)) - task_run = await prefect_client.read_task_run(task_run_future.task_run_id) + task_run = await in_memory_prefect_client.read_task_run(task_run_future.task_run_id) client_run: TaskRun = task_run assert client_run.state.is_scheduled() diff --git a/tests/test_context.py b/tests/test_context.py index 198654d9da18..da88fc170ea7 100644 --- a/tests/test_context.py +++ b/tests/test_context.py @@ -26,11 +26,11 @@ from prefect.results import ResultFactory from prefect.settings import ( DEFAULT_PROFILES_PATH, - PREFECT_API_DATABASE_CONNECTION_URL, PREFECT_API_KEY, PREFECT_API_URL, PREFECT_HOME, PREFECT_PROFILES_PATH, + PREFECT_SERVER_ALLOW_EPHEMERAL_MODE, Profile, ProfilesCollection, save_profiles, @@ -299,9 +299,7 @@ def test_root_settings_context_default(self, monkeypatch): use_profile.assert_called_once_with( Profile( name="ephemeral", - settings={ - PREFECT_API_DATABASE_CONNECTION_URL: "sqlite+aiosqlite:///prefect.db" - }, + settings={PREFECT_SERVER_ALLOW_EPHEMERAL_MODE: "true"}, source=DEFAULT_PROFILES_PATH, ), override_environment_variables=False, @@ -328,9 +326,7 @@ def test_root_settings_context_default_if_cli_args_do_not_match_format( use_profile.assert_called_once_with( Profile( name="ephemeral", - settings={ - PREFECT_API_DATABASE_CONNECTION_URL: "sqlite+aiosqlite:///prefect.db" - }, + settings={PREFECT_SERVER_ALLOW_EPHEMERAL_MODE: "true"}, source=DEFAULT_PROFILES_PATH, ), override_environment_variables=False, diff --git a/tests/test_flows.py b/tests/test_flows.py index b4bc0ccdde2f..c61ad244efa4 100644 --- a/tests/test_flows.py +++ b/tests/test_flows.py @@ -26,7 +26,7 @@ import prefect.exceptions from prefect import flow, runtime, tags, task from prefect.blocks.core import Block -from prefect.client.orchestration import PrefectClient, get_client +from prefect.client.orchestration import PrefectClient, SyncPrefectClient, get_client from prefect.client.schemas.schedules import ( CronSchedule, IntervalSchedule, @@ -3943,7 +3943,7 @@ def test_serve_prints_message(self, capsys): ) assert "$ prefect deployment run 'test-flow/test'" in captured.out - def test_serve_creates_deployment(self, prefect_client: PrefectClient): + def test_serve_creates_deployment(self, sync_prefect_client: SyncPrefectClient): self.flow.serve( name="test", tags=["price", "luggage"], @@ -3954,9 +3954,7 @@ def test_serve_creates_deployment(self, prefect_client: PrefectClient): paused=True, ) - deployment = asyncio.run( - prefect_client.read_deployment_by_name(name="test-flow/test") - ) + deployment = sync_prefect_client.read_deployment_by_name(name="test-flow/test") assert deployment is not None # Flow.serve should created deployments without a work queue or work pool @@ -3971,61 +3969,53 @@ def test_serve_creates_deployment(self, prefect_client: PrefectClient): assert deployment.paused assert not deployment.is_schedule_active - def test_serve_can_user_a_module_path_entrypoint(self, prefect_client): + def test_serve_can_user_a_module_path_entrypoint(self, sync_prefect_client): deployment = self.flow.serve( name="test", entrypoint_type=EntrypointType.MODULE_PATH ) - deployment = asyncio.run( - prefect_client.read_deployment_by_name(name="test-flow/test") - ) + deployment = sync_prefect_client.read_deployment_by_name(name="test-flow/test") assert deployment.entrypoint == f"{self.flow.__module__}.{self.flow.__name__}" - def test_serve_handles__file__(self, prefect_client: PrefectClient): + def test_serve_handles__file__(self, sync_prefect_client: SyncPrefectClient): self.flow.serve(__file__) - deployment = asyncio.run( - prefect_client.read_deployment_by_name(name="test-flow/test_flows") + deployment = sync_prefect_client.read_deployment_by_name( + name="test-flow/test_flows" ) assert deployment.name == "test_flows" def test_serve_creates_deployment_with_interval_schedule( - self, prefect_client: PrefectClient + self, sync_prefect_client: SyncPrefectClient ): self.flow.serve( "test", interval=3600, ) - deployment = asyncio.run( - prefect_client.read_deployment_by_name(name="test-flow/test") - ) + deployment = sync_prefect_client.read_deployment_by_name(name="test-flow/test") assert deployment is not None assert isinstance(deployment.schedule, IntervalSchedule) assert deployment.schedule.interval == datetime.timedelta(seconds=3600) def test_serve_creates_deployment_with_cron_schedule( - self, prefect_client: PrefectClient + self, sync_prefect_client: SyncPrefectClient ): self.flow.serve("test", cron="* * * * *") - deployment = asyncio.run( - prefect_client.read_deployment_by_name(name="test-flow/test") - ) + deployment = sync_prefect_client.read_deployment_by_name(name="test-flow/test") assert deployment is not None assert deployment.schedule == CronSchedule(cron="* * * * *") def test_serve_creates_deployment_with_rrule_schedule( - self, prefect_client: PrefectClient + self, sync_prefect_client: SyncPrefectClient ): self.flow.serve("test", rrule="FREQ=MINUTELY") - deployment = asyncio.run( - prefect_client.read_deployment_by_name(name="test-flow/test") - ) + deployment = sync_prefect_client.read_deployment_by_name(name="test-flow/test") assert deployment is not None assert deployment.schedule == RRuleSchedule(rrule="FREQ=MINUTELY") diff --git a/tests/test_settings.py b/tests/test_settings.py index 5e2eb0149da8..f47b3890cd31 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -23,6 +23,7 @@ PREFECT_LOGGING_LEVEL, PREFECT_LOGGING_SERVER_LEVEL, PREFECT_PROFILES_PATH, + PREFECT_SERVER_ALLOW_EPHEMERAL_MODE, PREFECT_SERVER_API_HOST, PREFECT_SERVER_API_PORT, PREFECT_TEST_MODE, @@ -581,7 +582,7 @@ def test_load_profiles_with_ephemeral(self, temporary_profiles_path): expected = { "ephemeral": { PREFECT_API_KEY: "foo", - PREFECT_API_DATABASE_CONNECTION_URL: "sqlite+aiosqlite:///prefect.db", # default value + PREFECT_SERVER_ALLOW_EPHEMERAL_MODE: "true", # default value }, "bar": {PREFECT_API_KEY: "bar"}, } @@ -592,9 +593,7 @@ def test_load_profiles_with_ephemeral(self, temporary_profiles_path): def test_load_profile_ephemeral(self): assert load_profile("ephemeral") == Profile( name="ephemeral", - settings={ - PREFECT_API_DATABASE_CONNECTION_URL: "sqlite+aiosqlite:///prefect.db" - }, + settings={PREFECT_SERVER_ALLOW_EPHEMERAL_MODE: "true"}, source=DEFAULT_PROFILES_PATH, ) diff --git a/tests/utilities/test_urls.py b/tests/utilities/test_urls.py index adeace677d3b..661eee3235cd 100644 --- a/tests/utilities/test_urls.py +++ b/tests/utilities/test_urls.py @@ -267,7 +267,7 @@ def test_url_for_missing_url(flow_run): ) -def test_url_for_with_default_base_url(flow_run): +def test_url_for_with_default_base_url(flow_run, enable_ephemeral_server): default_base_url = "https://default.prefect.io" expected_url = f"{default_base_url}/runs/flow-run/{flow_run.id}" assert ( @@ -280,7 +280,9 @@ def test_url_for_with_default_base_url(flow_run): ) -def test_url_for_with_default_base_url_with_path_fragment(flow_run): +def test_url_for_with_default_base_url_with_path_fragment( + flow_run, enable_ephemeral_server +): default_base_url = "https://default.prefect.io/api" expected_url = f"{default_base_url}/runs/flow-run/{flow_run.id}" assert ( @@ -293,7 +295,9 @@ def test_url_for_with_default_base_url_with_path_fragment(flow_run): ) -def test_url_for_with_default_base_url_with_path_fragment_and_slash(flow_run): +def test_url_for_with_default_base_url_with_path_fragment_and_slash( + flow_run, enable_ephemeral_server +): default_base_url = "https://default.prefect.io/api/" expected_url = f"{default_base_url}runs/flow-run/{flow_run.id}" assert ( diff --git a/tests/workers/test_base_worker.py b/tests/workers/test_base_worker.py index 47bf8cd4b024..97a1fc9407e6 100644 --- a/tests/workers/test_base_worker.py +++ b/tests/workers/test_base_worker.py @@ -413,9 +413,6 @@ def create_run_with_deployment(state): assert tracking_mock.call_count == 1 - # Multiple hits if worker's client is not being reused - assert caplog.text.count("Using ephemeral application") == 1 - async def test_base_worker_gets_job_configuration_when_syncing_with_backend_with_just_job_config( session, client diff --git a/tests/workers/test_utilities.py b/tests/workers/test_utilities.py index 52590bab3eb4..52b4f6d23506 100644 --- a/tests/workers/test_utilities.py +++ b/tests/workers/test_utilities.py @@ -55,13 +55,17 @@ def available(): @pytest.mark.usefixtures("mock_collection_registry_not_available") async def test_get_available_work_pool_types_without_collection_registry( - self, monkeypatch + self, monkeypatch, in_memory_prefect_client ): respx.routes def available(): return ["process"] + monkeypatch.setattr( + "prefect.client.collections.get_client", + lambda *args, **kwargs: in_memory_prefect_client, + ) monkeypatch.setattr(BaseWorker, "get_all_available_worker_types", available) work_pool_types = await get_available_work_pool_types()