From 2682d584c1eaa71c771e1e1806b0e4d6c6a47f57 Mon Sep 17 00:00:00 2001 From: Alex Streed Date: Mon, 29 Jul 2024 14:34:33 -0500 Subject: [PATCH] Fix more tests --- src/prefect/task_engine.py | 6 ++--- tests/test_flows.py | 38 ++++++++++++------------------- tests/workers/test_base_worker.py | 3 --- tests/workers/test_utilities.py | 6 ++++- 4 files changed, 22 insertions(+), 31 deletions(-) diff --git a/src/prefect/task_engine.py b/src/prefect/task_engine.py index a75fee665e25..edcd1b8feeaf 100644 --- a/src/prefect/task_engine.py +++ b/src/prefect/task_engine.py @@ -31,7 +31,7 @@ from typing_extensions import ParamSpec from prefect import Task -from prefect.client.orchestration import PrefectClient, SyncPrefectClient +from prefect.client.orchestration import PrefectClient, SyncPrefectClient, get_client from prefect.client.schemas import TaskRun from prefect.client.schemas.objects import State, TaskRunInput from prefect.concurrency.asyncio import concurrency as aconcurrency @@ -1168,8 +1168,8 @@ async def initialize_run( """ with hydrated_context(self.context): - async with AsyncClientContext.get_or_create() as client_ctx: - self._client = client_ctx.client + async with AsyncClientContext.get_or_create(): + self._client = get_client() self._is_started = True try: if PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_ORCHESTRATION: diff --git a/tests/test_flows.py b/tests/test_flows.py index b4bc0ccdde2f..c61ad244efa4 100644 --- a/tests/test_flows.py +++ b/tests/test_flows.py @@ -26,7 +26,7 @@ import prefect.exceptions from prefect import flow, runtime, tags, task from prefect.blocks.core import Block -from prefect.client.orchestration import PrefectClient, get_client +from prefect.client.orchestration import PrefectClient, SyncPrefectClient, get_client from prefect.client.schemas.schedules import ( CronSchedule, IntervalSchedule, @@ -3943,7 +3943,7 @@ def test_serve_prints_message(self, capsys): ) assert "$ prefect deployment run 'test-flow/test'" in captured.out - def test_serve_creates_deployment(self, prefect_client: PrefectClient): + def test_serve_creates_deployment(self, sync_prefect_client: SyncPrefectClient): self.flow.serve( name="test", tags=["price", "luggage"], @@ -3954,9 +3954,7 @@ def test_serve_creates_deployment(self, prefect_client: PrefectClient): paused=True, ) - deployment = asyncio.run( - prefect_client.read_deployment_by_name(name="test-flow/test") - ) + deployment = sync_prefect_client.read_deployment_by_name(name="test-flow/test") assert deployment is not None # Flow.serve should created deployments without a work queue or work pool @@ -3971,61 +3969,53 @@ def test_serve_creates_deployment(self, prefect_client: PrefectClient): assert deployment.paused assert not deployment.is_schedule_active - def test_serve_can_user_a_module_path_entrypoint(self, prefect_client): + def test_serve_can_user_a_module_path_entrypoint(self, sync_prefect_client): deployment = self.flow.serve( name="test", entrypoint_type=EntrypointType.MODULE_PATH ) - deployment = asyncio.run( - prefect_client.read_deployment_by_name(name="test-flow/test") - ) + deployment = sync_prefect_client.read_deployment_by_name(name="test-flow/test") assert deployment.entrypoint == f"{self.flow.__module__}.{self.flow.__name__}" - def test_serve_handles__file__(self, prefect_client: PrefectClient): + def test_serve_handles__file__(self, sync_prefect_client: SyncPrefectClient): self.flow.serve(__file__) - deployment = asyncio.run( - prefect_client.read_deployment_by_name(name="test-flow/test_flows") + deployment = sync_prefect_client.read_deployment_by_name( + name="test-flow/test_flows" ) assert deployment.name == "test_flows" def test_serve_creates_deployment_with_interval_schedule( - self, prefect_client: PrefectClient + self, sync_prefect_client: SyncPrefectClient ): self.flow.serve( "test", interval=3600, ) - deployment = asyncio.run( - prefect_client.read_deployment_by_name(name="test-flow/test") - ) + deployment = sync_prefect_client.read_deployment_by_name(name="test-flow/test") assert deployment is not None assert isinstance(deployment.schedule, IntervalSchedule) assert deployment.schedule.interval == datetime.timedelta(seconds=3600) def test_serve_creates_deployment_with_cron_schedule( - self, prefect_client: PrefectClient + self, sync_prefect_client: SyncPrefectClient ): self.flow.serve("test", cron="* * * * *") - deployment = asyncio.run( - prefect_client.read_deployment_by_name(name="test-flow/test") - ) + deployment = sync_prefect_client.read_deployment_by_name(name="test-flow/test") assert deployment is not None assert deployment.schedule == CronSchedule(cron="* * * * *") def test_serve_creates_deployment_with_rrule_schedule( - self, prefect_client: PrefectClient + self, sync_prefect_client: SyncPrefectClient ): self.flow.serve("test", rrule="FREQ=MINUTELY") - deployment = asyncio.run( - prefect_client.read_deployment_by_name(name="test-flow/test") - ) + deployment = sync_prefect_client.read_deployment_by_name(name="test-flow/test") assert deployment is not None assert deployment.schedule == RRuleSchedule(rrule="FREQ=MINUTELY") diff --git a/tests/workers/test_base_worker.py b/tests/workers/test_base_worker.py index 47bf8cd4b024..97a1fc9407e6 100644 --- a/tests/workers/test_base_worker.py +++ b/tests/workers/test_base_worker.py @@ -413,9 +413,6 @@ def create_run_with_deployment(state): assert tracking_mock.call_count == 1 - # Multiple hits if worker's client is not being reused - assert caplog.text.count("Using ephemeral application") == 1 - async def test_base_worker_gets_job_configuration_when_syncing_with_backend_with_just_job_config( session, client diff --git a/tests/workers/test_utilities.py b/tests/workers/test_utilities.py index 52590bab3eb4..52b4f6d23506 100644 --- a/tests/workers/test_utilities.py +++ b/tests/workers/test_utilities.py @@ -55,13 +55,17 @@ def available(): @pytest.mark.usefixtures("mock_collection_registry_not_available") async def test_get_available_work_pool_types_without_collection_registry( - self, monkeypatch + self, monkeypatch, in_memory_prefect_client ): respx.routes def available(): return ["process"] + monkeypatch.setattr( + "prefect.client.collections.get_client", + lambda *args, **kwargs: in_memory_prefect_client, + ) monkeypatch.setattr(BaseWorker, "get_all_available_worker_types", available) work_pool_types = await get_available_work_pool_types()