From 7fde075dc68fe9157c975e481d4492424c07f4d0 Mon Sep 17 00:00:00 2001 From: Alex Streed Date: Wed, 24 Jul 2024 16:31:44 -0500 Subject: [PATCH] Fixes some server side tests that were failing due to ephemeral API changes --- src/prefect/server/api/server.py | 47 ++++++++++++------- tests/fixtures/api.py | 17 +++++-- tests/fixtures/client.py | 4 +- .../orchestration/api/test_block_documents.py | 10 ++-- .../orchestration/api/test_block_types.py | 10 ++-- .../orchestration/api/test_deployments.py | 16 +++---- .../orchestration/api/test_infra_overrides.py | 10 ++-- .../api/test_task_run_subscriptions.py | 4 +- .../orchestration/api/test_task_workers.py | 6 +-- 9 files changed, 74 insertions(+), 50 deletions(-) diff --git a/src/prefect/server/api/server.py b/src/prefect/server/api/server.py index 7674bf3abcd6..1edc38523d31 100644 --- a/src/prefect/server/api/server.py +++ b/src/prefect/server/api/server.py @@ -4,11 +4,11 @@ import asyncio import mimetypes -import multiprocessing import os import shutil import socket import sqlite3 +import subprocess import time from contextlib import asynccontextmanager from functools import partial, wraps @@ -21,7 +21,6 @@ import sqlalchemy as sa import sqlalchemy.exc import sqlalchemy.orm.exc -import uvicorn from fastapi import APIRouter, Depends, FastAPI, Request, status from fastapi.encoders import jsonable_encoder from fastapi.exceptions import RequestValidationError @@ -55,8 +54,10 @@ PREFECT_MEMO_STORE_PATH, PREFECT_MEMOIZE_BLOCK_AUTO_REGISTRATION, PREFECT_UI_SERVE_BASE, + get_current_settings, ) from prefect.utilities.hashing import hash_objects +from prefect.utilities.processutils import get_sys_executable TITLE = "Prefect Server" API_TITLE = "Prefect Prefect REST API" @@ -775,19 +776,6 @@ def find_available_port(self): def address(self) -> str: return f"http://127.0.0.1:{self.port}" - @staticmethod - def run_server(port): - app = create_app() - config = uvicorn.Config( - app=app, - host="127.0.0.1", - port=port, - log_level="error", - lifespan="on", - ) - - uvicorn.Server(config).run() - def start(self): """ Start the server in a separate process. Safe to call multiple times; only starts @@ -797,10 +785,33 @@ def start(self): get_logger().info(f"Starting server on {self.address()}") try: self.running = True - self.server_process = multiprocessing.Process( - target=self.run_server, kwargs={"port": self.port}, daemon=True + self.server_process = subprocess.Popen( + args=[ + get_sys_executable(), + "-m", + "uvicorn", + "--app-dir", + # quote wrapping needed for windows paths with spaces + f'"{prefect.__module_path__.parent}"', + "--factory", + "prefect.server.api.server:create_app", + "--host", + "127.0.0.1", + "--port", + str(self.port), + "--log-level", + "error", + "--lifespan", + "on", + ], + env={ + **os.environ, + **get_current_settings().to_environment_variables( + exclude_unset=True + ), + }, ) - self.server_process.start() + with httpx.Client() as client: response = None elapsed_time = 0 diff --git a/tests/fixtures/api.py b/tests/fixtures/api.py index fad347c04489..26adaf9e65d8 100644 --- a/tests/fixtures/api.py +++ b/tests/fixtures/api.py @@ -3,6 +3,7 @@ import httpx import pytest from fastapi import FastAPI +from fastapi.testclient import TestClient from httpx import ASGITransport, AsyncClient from prefect.server.api.server import create_app @@ -19,17 +20,27 @@ def app() -> FastAPI: @pytest.fixture -async def client(app: ASGIApp) -> AsyncGenerator[AsyncClient, Any]: +def test_client(app: FastAPI) -> TestClient: + return TestClient(app) + + +@pytest.fixture +async def client(app) -> AsyncGenerator[AsyncClient, Any]: """ Yield a test client for testing the api """ - transport = ASGITransport(app=app) async with httpx.AsyncClient( - transport=transport, base_url="https://test/api" + transport=ASGITransport(app=app), base_url="https://test/api" ) as async_client: yield async_client +@pytest.fixture +async def hosted_api_client(use_hosted_api_server) -> AsyncGenerator[AsyncClient, Any]: + async with httpx.AsyncClient(base_url=use_hosted_api_server) as async_client: + yield async_client + + @pytest.fixture async def client_with_unprotected_block_api( app: ASGIApp, diff --git a/tests/fixtures/client.py b/tests/fixtures/client.py index 5efeced99168..e8b9801c42ae 100644 --- a/tests/fixtures/client.py +++ b/tests/fixtures/client.py @@ -10,14 +10,14 @@ @pytest.fixture async def prefect_client( - test_database_connection_url: str, + test_database_connection_url: str, use_hosted_api_server ) -> AsyncGenerator[PrefectClient, None]: async with get_client() as client: yield client @pytest.fixture -def sync_prefect_client(test_database_connection_url): +def sync_prefect_client(test_database_connection_url, use_hosted_api_server): yield get_client(sync_client=True) diff --git a/tests/server/orchestration/api/test_block_documents.py b/tests/server/orchestration/api/test_block_documents.py index 2df059b5c45a..9b4b83a1e163 100644 --- a/tests/server/orchestration/api/test_block_documents.py +++ b/tests/server/orchestration/api/test_block_documents.py @@ -1416,7 +1416,7 @@ async def test_read_secret_block_documents_with_secrets( assert blocks[0].data["z"] == Z async def test_nested_block_secrets_are_obfuscated_when_all_blocks_are_saved( - self, client, session + self, hosted_api_client, session ): class ChildBlock(Block): x: SecretStr @@ -1435,7 +1435,9 @@ class ParentBlock(Block): block = ParentBlock(a=3, b="b", child=child) await block.save("nested-test") await session.commit() - response = await client.get(f"/block_documents/{block._block_document_id}") + response = await hosted_api_client.get( + f"/block_documents/{block._block_document_id}" + ) block = schemas.core.BlockDocument.model_validate(response.json()) assert block.data["a"] == 3 assert block.data["b"] == obfuscate_string("b") @@ -1443,7 +1445,7 @@ class ParentBlock(Block): assert block.data["child"]["y"] == Y assert block.data["child"]["z"] == {"secret": obfuscate_string(Z)} - async def test_nested_block_secrets_are_returned(self, client): + async def test_nested_block_secrets_are_returned(self, hosted_api_client): class ChildBlock(Block): x: SecretStr y: str @@ -1457,7 +1459,7 @@ class ParentBlock(Block): block = ParentBlock(a=3, b="b", child=ChildBlock(x=X, y=Y, z=dict(secret=Z))) await block.save("nested-test") - response = await client.get( + response = await hosted_api_client.get( f"/block_documents/{block._block_document_id}", params=dict(include_secrets=True), ) diff --git a/tests/server/orchestration/api/test_block_types.py b/tests/server/orchestration/api/test_block_types.py index 3d94b944c539..5d27535f67ec 100644 --- a/tests/server/orchestration/api/test_block_types.py +++ b/tests/server/orchestration/api/test_block_types.py @@ -491,13 +491,13 @@ async def test_install_system_block_types_multiple_times(self, client): await client.post("/block_types/install_system_block_types") await client.post("/block_types/install_system_block_types") - async def test_create_system_block_type(self, client, session): + async def test_create_system_block_type(self, hosted_api_client, session): # install system blocks - await client.post("/block_types/install_system_block_types") + await hosted_api_client.post("/block_types/install_system_block_types") # create a datetime block - datetime_block_type = await client.get("/block_types/slug/date-time") - datetime_block_schema = await client.post( + datetime_block_type = await hosted_api_client.get("/block_types/slug/date-time") + datetime_block_schema = await hosted_api_client.post( "/block_schemas/filter", json=dict( block_schemas=dict( @@ -507,7 +507,7 @@ async def test_create_system_block_type(self, client, session): ), ) block = prefect.blocks.system.DateTime(value="2022-01-01T00:00:00+00:00") - response = await client.post( + response = await hosted_api_client.post( "/block_documents/", json=block._to_block_document( name="my-test-date-time", diff --git a/tests/server/orchestration/api/test_deployments.py b/tests/server/orchestration/api/test_deployments.py index 854ed60faa14..5f5b10cc40ef 100644 --- a/tests/server/orchestration/api/test_deployments.py +++ b/tests/server/orchestration/api/test_deployments.py @@ -39,7 +39,7 @@ class TestCreateDeployment: async def test_create_oldstyle_deployment( self, session, - client, + hosted_api_client, flow, flow_function, storage_document_id, @@ -52,7 +52,7 @@ async def test_create_oldstyle_deployment( parameters={"foo": "bar"}, storage_document_id=storage_document_id, ).model_dump(mode="json") - response = await client.post("/deployments/", json=data) + response = await hosted_api_client.post("/deployments/", json=data) assert response.status_code == status.HTTP_201_CREATED assert response.json()["name"] == "My Deployment" assert response.json()["version"] == "mint" @@ -72,7 +72,7 @@ async def test_create_oldstyle_deployment( async def test_create_deployment( self, session, - client, + hosted_api_client, flow, flow_function, storage_document_id, @@ -88,7 +88,7 @@ async def test_create_deployment( job_variables={"cpu": 24}, storage_document_id=storage_document_id, ).model_dump(mode="json") - response = await client.post("/deployments/", json=data) + response = await hosted_api_client.post("/deployments/", json=data) assert response.status_code == status.HTTP_201_CREATED deployment_response = DeploymentResponse(**response.json()) @@ -318,7 +318,7 @@ async def test_default_work_queue_name_is_none(self, session, client, flow): async def test_create_deployment_respects_flow_id_name_uniqueness( self, session, - client, + hosted_api_client, flow, storage_document_id, ): @@ -328,7 +328,7 @@ async def test_create_deployment_respects_flow_id_name_uniqueness( paused=True, storage_document_id=storage_document_id, ).model_dump(mode="json") - response = await client.post("/deployments/", json=data) + response = await hosted_api_client.post("/deployments/", json=data) assert response.status_code == 201 assert response.json()["name"] == "My Deployment" deployment_id = response.json()["id"] @@ -340,7 +340,7 @@ async def test_create_deployment_respects_flow_id_name_uniqueness( paused=True, storage_document_id=storage_document_id, ).model_dump(mode="json") - response = await client.post("/deployments/", json=data) + response = await hosted_api_client.post("/deployments/", json=data) assert response.status_code == status.HTTP_200_OK assert response.json()["name"] == "My Deployment" assert response.json()["id"] == deployment_id @@ -355,7 +355,7 @@ async def test_create_deployment_respects_flow_id_name_uniqueness( paused=False, # CHANGED storage_document_id=storage_document_id, ).model_dump(mode="json") - response = await client.post("/deployments/", json=data) + response = await hosted_api_client.post("/deployments/", json=data) assert response.status_code == status.HTTP_200_OK assert response.json()["name"] == "My Deployment" assert response.json()["id"] == deployment_id diff --git a/tests/server/orchestration/api/test_infra_overrides.py b/tests/server/orchestration/api/test_infra_overrides.py index 3152fe56ec55..88a793ad05ab 100644 --- a/tests/server/orchestration/api/test_infra_overrides.py +++ b/tests/server/orchestration/api/test_infra_overrides.py @@ -400,7 +400,7 @@ async def test_creating_flow_run_with_missing_work_queue( async def test_base_job_template_default_references_to_blocks( self, session, - client, + hosted_api_client, k8s_credentials, ): # create a pool with a pool schema that has a default value referencing a block @@ -451,7 +451,7 @@ async def test_base_job_template_default_references_to_blocks( ) # create a flow run with no overrides - response = await client.post( + response = await hosted_api_client.post( f"/deployments/{deployment.id}/create_flow_run", json={} ) @@ -666,7 +666,7 @@ async def test_updating_flow_run_with_missing_work_queue( async def test_base_job_template_default_references_to_blocks( self, session, - client, + hosted_api_client, k8s_credentials, ): # create a pool with a pool schema that has a default value referencing a block @@ -718,7 +718,7 @@ async def test_base_job_template_default_references_to_blocks( # create a flow run with custom overrides updates = {"k8s_credentials": {"context_name": "foo", "config": {}}} - response = await client.post( + response = await hosted_api_client.post( f"/deployments/{deployment.id}/create_flow_run", json={"job_variables": updates}, ) @@ -727,7 +727,7 @@ async def test_base_job_template_default_references_to_blocks( # update the flow run to force it to refer to the default block's value flow_run_id = response.json()["id"] - response = await client.patch( + response = await hosted_api_client.patch( f"/flow_runs/{flow_run_id}", json={"job_variables": {}} ) assert response.status_code == 204, response.text diff --git a/tests/server/orchestration/api/test_task_run_subscriptions.py b/tests/server/orchestration/api/test_task_run_subscriptions.py index 1d92fd27ceed..08d99137259e 100644 --- a/tests/server/orchestration/api/test_task_run_subscriptions.py +++ b/tests/server/orchestration/api/test_task_run_subscriptions.py @@ -396,7 +396,7 @@ async def test_task_worker_basic_tracking( task_keys, expected_workers, client_id, - prefect_client, + test_client, ): for _ in range(num_connections): with authenticated_socket(app) as socket: @@ -404,7 +404,7 @@ async def test_task_worker_basic_tracking( {"type": "subscribe", "keys": task_keys, "client_id": client_id} ) - response = await prefect_client._client.post("/task_workers/filter") + response = test_client.post("api/task_workers/filter") assert response.status_code == 200 tracked_workers = response.json() assert len(tracked_workers) == expected_workers diff --git a/tests/server/orchestration/api/test_task_workers.py b/tests/server/orchestration/api/test_task_workers.py index 530d9f9fa498..f7ad7ce3b739 100644 --- a/tests/server/orchestration/api/test_task_workers.py +++ b/tests/server/orchestration/api/test_task_workers.py @@ -19,13 +19,13 @@ ], ) async def test_read_task_workers( - prefect_client, initial_workers, certain_tasks, expected_count + test_client, initial_workers, certain_tasks, expected_count ): for worker, tasks in initial_workers.items(): await observe_worker(tasks, worker) - response = await prefect_client._client.post( - "/task_workers/filter", + response = test_client.post( + "api/task_workers/filter", json={"task_worker_filter": {"task_keys": certain_tasks}} if certain_tasks else None,