Skip to content

Commit

Permalink
Fixes some server side tests that were failing due to ephemeral API c…
Browse files Browse the repository at this point in the history
…hanges
  • Loading branch information
desertaxle committed Jul 24, 2024
1 parent 9d1a9e5 commit 7fde075
Show file tree
Hide file tree
Showing 9 changed files with 74 additions and 50 deletions.
47 changes: 29 additions & 18 deletions src/prefect/server/api/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@

import asyncio
import mimetypes
import multiprocessing
import os
import shutil
import socket
import sqlite3
import subprocess
import time
from contextlib import asynccontextmanager
from functools import partial, wraps
Expand All @@ -21,7 +21,6 @@
import sqlalchemy as sa
import sqlalchemy.exc
import sqlalchemy.orm.exc
import uvicorn
from fastapi import APIRouter, Depends, FastAPI, Request, status
from fastapi.encoders import jsonable_encoder
from fastapi.exceptions import RequestValidationError
Expand Down Expand Up @@ -55,8 +54,10 @@
PREFECT_MEMO_STORE_PATH,
PREFECT_MEMOIZE_BLOCK_AUTO_REGISTRATION,
PREFECT_UI_SERVE_BASE,
get_current_settings,
)
from prefect.utilities.hashing import hash_objects
from prefect.utilities.processutils import get_sys_executable

TITLE = "Prefect Server"
API_TITLE = "Prefect Prefect REST API"
Expand Down Expand Up @@ -775,19 +776,6 @@ def find_available_port(self):
def address(self) -> str:
return f"http://127.0.0.1:{self.port}"

@staticmethod
def run_server(port):
app = create_app()
config = uvicorn.Config(
app=app,
host="127.0.0.1",
port=port,
log_level="error",
lifespan="on",
)

uvicorn.Server(config).run()

def start(self):
"""
Start the server in a separate process. Safe to call multiple times; only starts
Expand All @@ -797,10 +785,33 @@ def start(self):
get_logger().info(f"Starting server on {self.address()}")
try:
self.running = True
self.server_process = multiprocessing.Process(
target=self.run_server, kwargs={"port": self.port}, daemon=True
self.server_process = subprocess.Popen(
args=[
get_sys_executable(),
"-m",
"uvicorn",
"--app-dir",
# quote wrapping needed for windows paths with spaces
f'"{prefect.__module_path__.parent}"',
"--factory",
"prefect.server.api.server:create_app",
"--host",
"127.0.0.1",
"--port",
str(self.port),
"--log-level",
"error",
"--lifespan",
"on",
],
env={
**os.environ,
**get_current_settings().to_environment_variables(
exclude_unset=True
),
},
)
self.server_process.start()

with httpx.Client() as client:
response = None
elapsed_time = 0
Expand Down
17 changes: 14 additions & 3 deletions tests/fixtures/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import httpx
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from httpx import ASGITransport, AsyncClient

from prefect.server.api.server import create_app
Expand All @@ -19,17 +20,27 @@ def app() -> FastAPI:


@pytest.fixture
async def client(app: ASGIApp) -> AsyncGenerator[AsyncClient, Any]:
def test_client(app: FastAPI) -> TestClient:
return TestClient(app)


@pytest.fixture
async def client(app) -> AsyncGenerator[AsyncClient, Any]:
"""
Yield a test client for testing the api
"""
transport = ASGITransport(app=app)
async with httpx.AsyncClient(
transport=transport, base_url="https://test/api"
transport=ASGITransport(app=app), base_url="https://test/api"
) as async_client:
yield async_client


@pytest.fixture
async def hosted_api_client(use_hosted_api_server) -> AsyncGenerator[AsyncClient, Any]:
async with httpx.AsyncClient(base_url=use_hosted_api_server) as async_client:
yield async_client


@pytest.fixture
async def client_with_unprotected_block_api(
app: ASGIApp,
Expand Down
4 changes: 2 additions & 2 deletions tests/fixtures/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@

@pytest.fixture
async def prefect_client(
test_database_connection_url: str,
test_database_connection_url: str, use_hosted_api_server
) -> AsyncGenerator[PrefectClient, None]:
async with get_client() as client:
yield client


@pytest.fixture
def sync_prefect_client(test_database_connection_url):
def sync_prefect_client(test_database_connection_url, use_hosted_api_server):
yield get_client(sync_client=True)


Expand Down
10 changes: 6 additions & 4 deletions tests/server/orchestration/api/test_block_documents.py
Original file line number Diff line number Diff line change
Expand Up @@ -1416,7 +1416,7 @@ async def test_read_secret_block_documents_with_secrets(
assert blocks[0].data["z"] == Z

async def test_nested_block_secrets_are_obfuscated_when_all_blocks_are_saved(
self, client, session
self, hosted_api_client, session
):
class ChildBlock(Block):
x: SecretStr
Expand All @@ -1435,15 +1435,17 @@ class ParentBlock(Block):
block = ParentBlock(a=3, b="b", child=child)
await block.save("nested-test")
await session.commit()
response = await client.get(f"/block_documents/{block._block_document_id}")
response = await hosted_api_client.get(
f"/block_documents/{block._block_document_id}"
)
block = schemas.core.BlockDocument.model_validate(response.json())
assert block.data["a"] == 3
assert block.data["b"] == obfuscate_string("b")
assert block.data["child"]["x"] == obfuscate_string(X)
assert block.data["child"]["y"] == Y
assert block.data["child"]["z"] == {"secret": obfuscate_string(Z)}

async def test_nested_block_secrets_are_returned(self, client):
async def test_nested_block_secrets_are_returned(self, hosted_api_client):
class ChildBlock(Block):
x: SecretStr
y: str
Expand All @@ -1457,7 +1459,7 @@ class ParentBlock(Block):
block = ParentBlock(a=3, b="b", child=ChildBlock(x=X, y=Y, z=dict(secret=Z)))
await block.save("nested-test")

response = await client.get(
response = await hosted_api_client.get(
f"/block_documents/{block._block_document_id}",
params=dict(include_secrets=True),
)
Expand Down
10 changes: 5 additions & 5 deletions tests/server/orchestration/api/test_block_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,13 +491,13 @@ async def test_install_system_block_types_multiple_times(self, client):
await client.post("/block_types/install_system_block_types")
await client.post("/block_types/install_system_block_types")

async def test_create_system_block_type(self, client, session):
async def test_create_system_block_type(self, hosted_api_client, session):
# install system blocks
await client.post("/block_types/install_system_block_types")
await hosted_api_client.post("/block_types/install_system_block_types")

# create a datetime block
datetime_block_type = await client.get("/block_types/slug/date-time")
datetime_block_schema = await client.post(
datetime_block_type = await hosted_api_client.get("/block_types/slug/date-time")
datetime_block_schema = await hosted_api_client.post(
"/block_schemas/filter",
json=dict(
block_schemas=dict(
Expand All @@ -507,7 +507,7 @@ async def test_create_system_block_type(self, client, session):
),
)
block = prefect.blocks.system.DateTime(value="2022-01-01T00:00:00+00:00")
response = await client.post(
response = await hosted_api_client.post(
"/block_documents/",
json=block._to_block_document(
name="my-test-date-time",
Expand Down
16 changes: 8 additions & 8 deletions tests/server/orchestration/api/test_deployments.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class TestCreateDeployment:
async def test_create_oldstyle_deployment(
self,
session,
client,
hosted_api_client,
flow,
flow_function,
storage_document_id,
Expand All @@ -52,7 +52,7 @@ async def test_create_oldstyle_deployment(
parameters={"foo": "bar"},
storage_document_id=storage_document_id,
).model_dump(mode="json")
response = await client.post("/deployments/", json=data)
response = await hosted_api_client.post("/deployments/", json=data)
assert response.status_code == status.HTTP_201_CREATED
assert response.json()["name"] == "My Deployment"
assert response.json()["version"] == "mint"
Expand All @@ -72,7 +72,7 @@ async def test_create_oldstyle_deployment(
async def test_create_deployment(
self,
session,
client,
hosted_api_client,
flow,
flow_function,
storage_document_id,
Expand All @@ -88,7 +88,7 @@ async def test_create_deployment(
job_variables={"cpu": 24},
storage_document_id=storage_document_id,
).model_dump(mode="json")
response = await client.post("/deployments/", json=data)
response = await hosted_api_client.post("/deployments/", json=data)
assert response.status_code == status.HTTP_201_CREATED

deployment_response = DeploymentResponse(**response.json())
Expand Down Expand Up @@ -318,7 +318,7 @@ async def test_default_work_queue_name_is_none(self, session, client, flow):
async def test_create_deployment_respects_flow_id_name_uniqueness(
self,
session,
client,
hosted_api_client,
flow,
storage_document_id,
):
Expand All @@ -328,7 +328,7 @@ async def test_create_deployment_respects_flow_id_name_uniqueness(
paused=True,
storage_document_id=storage_document_id,
).model_dump(mode="json")
response = await client.post("/deployments/", json=data)
response = await hosted_api_client.post("/deployments/", json=data)
assert response.status_code == 201
assert response.json()["name"] == "My Deployment"
deployment_id = response.json()["id"]
Expand All @@ -340,7 +340,7 @@ async def test_create_deployment_respects_flow_id_name_uniqueness(
paused=True,
storage_document_id=storage_document_id,
).model_dump(mode="json")
response = await client.post("/deployments/", json=data)
response = await hosted_api_client.post("/deployments/", json=data)
assert response.status_code == status.HTTP_200_OK
assert response.json()["name"] == "My Deployment"
assert response.json()["id"] == deployment_id
Expand All @@ -355,7 +355,7 @@ async def test_create_deployment_respects_flow_id_name_uniqueness(
paused=False, # CHANGED
storage_document_id=storage_document_id,
).model_dump(mode="json")
response = await client.post("/deployments/", json=data)
response = await hosted_api_client.post("/deployments/", json=data)
assert response.status_code == status.HTTP_200_OK
assert response.json()["name"] == "My Deployment"
assert response.json()["id"] == deployment_id
Expand Down
10 changes: 5 additions & 5 deletions tests/server/orchestration/api/test_infra_overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,7 @@ async def test_creating_flow_run_with_missing_work_queue(
async def test_base_job_template_default_references_to_blocks(
self,
session,
client,
hosted_api_client,
k8s_credentials,
):
# create a pool with a pool schema that has a default value referencing a block
Expand Down Expand Up @@ -451,7 +451,7 @@ async def test_base_job_template_default_references_to_blocks(
)

# create a flow run with no overrides
response = await client.post(
response = await hosted_api_client.post(
f"/deployments/{deployment.id}/create_flow_run", json={}
)

Expand Down Expand Up @@ -666,7 +666,7 @@ async def test_updating_flow_run_with_missing_work_queue(
async def test_base_job_template_default_references_to_blocks(
self,
session,
client,
hosted_api_client,
k8s_credentials,
):
# create a pool with a pool schema that has a default value referencing a block
Expand Down Expand Up @@ -718,7 +718,7 @@ async def test_base_job_template_default_references_to_blocks(

# create a flow run with custom overrides
updates = {"k8s_credentials": {"context_name": "foo", "config": {}}}
response = await client.post(
response = await hosted_api_client.post(
f"/deployments/{deployment.id}/create_flow_run",
json={"job_variables": updates},
)
Expand All @@ -727,7 +727,7 @@ async def test_base_job_template_default_references_to_blocks(

# update the flow run to force it to refer to the default block's value
flow_run_id = response.json()["id"]
response = await client.patch(
response = await hosted_api_client.patch(
f"/flow_runs/{flow_run_id}", json={"job_variables": {}}
)
assert response.status_code == 204, response.text
Expand Down
4 changes: 2 additions & 2 deletions tests/server/orchestration/api/test_task_run_subscriptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,15 +396,15 @@ async def test_task_worker_basic_tracking(
task_keys,
expected_workers,
client_id,
prefect_client,
test_client,
):
for _ in range(num_connections):
with authenticated_socket(app) as socket:
socket.send_json(
{"type": "subscribe", "keys": task_keys, "client_id": client_id}
)

response = await prefect_client._client.post("/task_workers/filter")
response = test_client.post("api/task_workers/filter")
assert response.status_code == 200
tracked_workers = response.json()
assert len(tracked_workers) == expected_workers
Expand Down
6 changes: 3 additions & 3 deletions tests/server/orchestration/api/test_task_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@
],
)
async def test_read_task_workers(
prefect_client, initial_workers, certain_tasks, expected_count
test_client, initial_workers, certain_tasks, expected_count
):
for worker, tasks in initial_workers.items():
await observe_worker(tasks, worker)

response = await prefect_client._client.post(
"/task_workers/filter",
response = test_client.post(
"api/task_workers/filter",
json={"task_worker_filter": {"task_keys": certain_tasks}}
if certain_tasks
else None,
Expand Down

0 comments on commit 7fde075

Please sign in to comment.