Skip to content

Commit

Permalink
backport #16149 (#16317)
Browse files Browse the repository at this point in the history
  • Loading branch information
zzstoatzz authored Dec 11, 2024
1 parent b21233f commit 20b5334
Show file tree
Hide file tree
Showing 15 changed files with 1,087 additions and 946 deletions.
33 changes: 17 additions & 16 deletions flows/client_context_lifespan.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,20 @@
from packaging.version import Version

import prefect

# Only run these tests if the version is at least 2.13.0
if Version(prefect.__version__) < Version("2.13.0"):
raise NotImplementedError()

import asyncio
import random
import threading
from contextlib import asynccontextmanager
from typing import Callable
from unittest.mock import MagicMock

import anyio
from prefect._vendor.fastapi import FastAPI
from fastapi import FastAPI

import prefect
import prefect.context
import prefect.exceptions
from prefect.client.orchestration import PrefectClient


def make_lifespan(startup, shutdown) -> callable:
def make_lifespan(startup, shutdown) -> Callable:
async def lifespan(app):
try:
startup()
Expand All @@ -32,6 +26,7 @@ async def lifespan(app):


def client_context_lifespan_is_robust_to_threaded_concurrency():
print("testing that client context lifespan is robust to threaded concurrency")
startup, shutdown = MagicMock(), MagicMock()
app = FastAPI(lifespan=make_lifespan(startup, shutdown))

Expand Down Expand Up @@ -61,6 +56,7 @@ async def enter_client(context):


async def client_context_lifespan_is_robust_to_high_async_concurrency():
print("testing that client context lifespan is robust to high async concurrency")
startup, shutdown = MagicMock(), MagicMock()
app = FastAPI(lifespan=make_lifespan(startup, shutdown))

Expand All @@ -70,7 +66,7 @@ async def enter_client():
async with PrefectClient(app):
await anyio.sleep(random.random())

with anyio.fail_after(15):
with anyio.fail_after(30):
async with anyio.create_task_group() as tg:
for _ in range(1000):
tg.start_soon(enter_client)
Expand All @@ -80,6 +76,7 @@ async def enter_client():


async def client_context_lifespan_is_robust_to_mixed_concurrency():
print("testing that client context lifespan is robust to mixed concurrency")
startup, shutdown = MagicMock(), MagicMock()
app = FastAPI(lifespan=make_lifespan(startup, shutdown))

Expand All @@ -91,10 +88,14 @@ async def enter_client():

async def enter_client_many_times(context):
# We must re-enter the profile context in the new thread
with context:
async with anyio.create_task_group() as tg:
for _ in range(100):
tg.start_soon(enter_client)
try:
with context:
async with anyio.create_task_group() as tg:
for _ in range(10):
tg.start_soon(enter_client)
except Exception as e:
print(f"Error entering client many times {e}")
raise e

threads = [
threading.Thread(
Expand All @@ -104,7 +105,7 @@ async def enter_client_many_times(context):
prefect.context.SettingsContext.get().copy(),
),
)
for _ in range(100)
for _ in range(10)
]
for thread in threads:
thread.start()
Expand Down
44 changes: 24 additions & 20 deletions scripts/run-integration-flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,13 @@
Example:
PREFECT_API_URL="http://localhost:4200" ./scripts/run-integration-flows.py
PREFECT_API_URL="http://localhost:4200/api" ./scripts/run-integration-flows.py
"""

import os
import runpy
import subprocess
import sys
from pathlib import Path
from typing import Union
from typing import List, Union

import prefect
from prefect import __version__
Expand All @@ -31,25 +30,30 @@
)


def run_script(script_path: str):
print(f" {script_path} ".center(90, "-"), flush=True)
try:
result = subprocess.run(
["python", script_path], capture_output=True, text=True, check=True
)
return result.stdout, result.stderr, None
except subprocess.CalledProcessError as e:
return e.stdout, e.stderr, e


def run_flows(search_path: Union[str, Path]):
count = 0
print(f"Running integration tests with client version: {__version__}")
server_version = os.environ.get("TEST_SERVER_VERSION")
if server_version:
print(f"and server version: {server_version}")

for file in sorted(Path(search_path).glob("**/*.py")):
print(f" {file.relative_to(search_path)} ".center(90, "-"), flush=True)
scripts = sorted(Path(search_path).glob("**/*.py"))
errors: List[Exception] = []
for script in scripts:
print(f"Running {script}")
try:
runpy.run_path(file, run_name="__main__")
except NotImplementedError:
print(f"Skipping {file}: not supported by this version of Prefect")
print("".center(90, "-") + "\n", flush=True)
count += 1

if not count:
print(f"No Python files found at {search_path}")
exit(1)
run_script(str(script))
except Exception as e:
print(f"Error running {script}: {e}")
errors.append(e)

assert not errors, "Errors occurred while running flows"


if __name__ == "__main__":
Expand Down
19 changes: 17 additions & 2 deletions src/prefect/client/orchestration.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import datetime
import ssl
import warnings
from contextlib import AsyncExitStack
from typing import (
Expand Down Expand Up @@ -275,12 +276,18 @@ def __init__(
httpx_settings.setdefault("headers", {})

if PREFECT_API_TLS_INSECURE_SKIP_VERIFY:
httpx_settings.setdefault("verify", False)
# Create an unverified context for insecure connections
ctx = ssl.create_default_context()
ctx.check_hostname = False
ctx.verify_mode = ssl.CERT_NONE
httpx_settings.setdefault("verify", ctx)
else:
cert_file = PREFECT_API_SSL_CERT_FILE.value()
if not cert_file:
cert_file = certifi.where()
httpx_settings.setdefault("verify", cert_file)
# Create a verified context with the certificate file
ctx = ssl.create_default_context(cafile=cert_file)
httpx_settings.setdefault("verify", ctx)

if api_version is None:
api_version = SERVER_API_VERSION
Expand Down Expand Up @@ -3438,11 +3445,19 @@ def __init__(

if PREFECT_API_TLS_INSECURE_SKIP_VERIFY:
httpx_settings.setdefault("verify", False)
# Create an unverified context for insecure connections
ctx = ssl.create_default_context()
ctx.check_hostname = False
ctx.verify_mode = ssl.CERT_NONE
httpx_settings.setdefault("verify", ctx)
else:
cert_file = PREFECT_API_SSL_CERT_FILE.value()
if not cert_file:
cert_file = certifi.where()
httpx_settings.setdefault("verify", cert_file)
# Create a verified context with the certificate file
ctx = ssl.create_default_context(cafile=cert_file)
httpx_settings.setdefault("verify", ctx)

if api_version is None:
api_version = SERVER_API_VERSION
Expand Down
29 changes: 14 additions & 15 deletions tests/blocks/test_notifications.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,7 +555,7 @@ def test_invalid_to_phone_numbers_raises_warning(self, caplog):

class TestCustomWebhook:
async def test_notify_async(self):
with respx.mock as xmock:
with respx.mock(using="httpx") as xmock:
xmock.post("https://example.com/")

custom_block = CustomWebhookNotificationBlock(
Expand All @@ -570,14 +570,14 @@ async def test_notify_async(self):
assert last_req.headers["user-agent"] == "Prefect Notifications"
assert (
last_req.content
== b'{"msg": "subject\\ntest", "token": "someSecretToken"}'
== b'{"msg":"subject\\ntest","token":"someSecretToken"}'
)
assert last_req.extensions == {
"timeout": {"connect": 10, "pool": 10, "read": 10, "write": 10}
}

def test_notify_sync(self):
with respx.mock as xmock:
with respx.mock(using="httpx") as xmock:
xmock.post("https://example.com/")

custom_block = CustomWebhookNotificationBlock(
Expand All @@ -592,14 +592,14 @@ def test_notify_sync(self):
assert last_req.headers["user-agent"] == "Prefect Notifications"
assert (
last_req.content
== b'{"msg": "subject\\ntest", "token": "someSecretToken"}'
== b'{"msg":"subject\\ntest","token":"someSecretToken"}'
)
assert last_req.extensions == {
"timeout": {"connect": 10, "pool": 10, "read": 10, "write": 10}
}

def test_user_agent_override(self):
with respx.mock as xmock:
with respx.mock(using="httpx") as xmock:
xmock.post("https://example.com/")

custom_block = CustomWebhookNotificationBlock(
Expand All @@ -615,14 +615,14 @@ def test_user_agent_override(self):
assert last_req.headers["user-agent"] == "CustomUA"
assert (
last_req.content
== b'{"msg": "subject\\ntest", "token": "someSecretToken"}'
== b'{"msg":"subject\\ntest","token":"someSecretToken"}'
)
assert last_req.extensions == {
"timeout": {"connect": 10, "pool": 10, "read": 10, "write": 10}
}

def test_timeout_override(self):
with respx.mock as xmock:
with respx.mock(using="httpx") as xmock:
xmock.post("https://example.com/")

custom_block = CustomWebhookNotificationBlock(
Expand All @@ -637,14 +637,14 @@ def test_timeout_override(self):
last_req = xmock.calls.last.request
assert (
last_req.content
== b'{"msg": "subject\\ntest", "token": "someSecretToken"}'
== b'{"msg":"subject\\ntest","token":"someSecretToken"}'
)
assert last_req.extensions == {
"timeout": {"connect": 30, "pool": 30, "read": 30, "write": 30}
}

def test_request_cookie(self):
with respx.mock as xmock:
with respx.mock(using="httpx") as xmock:
xmock.post("https://example.com/")

custom_block = CustomWebhookNotificationBlock(
Expand All @@ -661,14 +661,14 @@ def test_request_cookie(self):
assert last_req.headers["cookie"] == "key=secretCookieValue"
assert (
last_req.content
== b'{"msg": "subject\\ntest", "token": "someSecretToken"}'
== b'{"msg":"subject\\ntest","token":"someSecretToken"}'
)
assert last_req.extensions == {
"timeout": {"connect": 30, "pool": 30, "read": 30, "write": 30}
}

def test_subst_nested_list(self):
with respx.mock as xmock:
with respx.mock(using="httpx") as xmock:
xmock.post("https://example.com/")

custom_block = CustomWebhookNotificationBlock(
Expand All @@ -685,14 +685,14 @@ def test_subst_nested_list(self):
assert last_req.headers["user-agent"] == "Prefect Notifications"
assert (
last_req.content
== b'{"data": {"sub1": [{"in-list": "test", "name": "test name"}]}}'
== b'{"data":{"sub1":[{"in-list":"test","name":"test name"}]}}'
)
assert last_req.extensions == {
"timeout": {"connect": 10, "pool": 10, "read": 10, "write": 10}
}

def test_subst_none(self):
with respx.mock as xmock:
with respx.mock(using="httpx") as xmock:
xmock.post("https://example.com/")

custom_block = CustomWebhookNotificationBlock(
Expand All @@ -707,8 +707,7 @@ def test_subst_none(self):
last_req = xmock.calls.last.request
assert last_req.headers["user-agent"] == "Prefect Notifications"
assert (
last_req.content
== b'{"msg": "null\\ntest", "token": "someSecretToken"}'
last_req.content == b'{"msg":"null\\ntest","token":"someSecretToken"}'
)
assert last_req.extensions == {
"timeout": {"connect": 10, "pool": 10, "read": 10, "write": 10}
Expand Down
Loading

0 comments on commit 20b5334

Please sign in to comment.