Skip to content

Commit

Permalink
♻️Pydantic V2 and SQLAlchemy warning fixes (ITISFoundation#6877)
Browse files Browse the repository at this point in the history
  • Loading branch information
sanderegg authored Dec 2, 2024
1 parent 931595e commit dc35757
Show file tree
Hide file tree
Showing 9 changed files with 43 additions and 55 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -381,15 +381,6 @@ def LOG_LEVEL(self) -> LogLevel: # noqa: N802
def _valid_log_level(cls, value: str) -> str:
return cls.validate_log_level(value)

@field_validator("SERVICE_TRACKING_HEARTBEAT", mode="before")
@classmethod
def _validate_interval(
cls, value: str | datetime.timedelta
) -> int | datetime.timedelta:
if isinstance(value, str):
return int(value)
return value


def get_application_settings(app: FastAPI) -> ApplicationSettings:
return cast(ApplicationSettings, app.state.settings)
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import asyncio
import dataclasses
import datetime
from collections.abc import Awaitable, Callable
from typing import Final
from unittest.mock import MagicMock
Expand Down Expand Up @@ -36,7 +37,9 @@ def wallet_id(faker: Faker, request: pytest.FixtureRequest) -> WalletID | None:
return faker.pyint(min_value=1) if request.param == "with_wallet" else None


_FAST_TIME_BEFORE_TERMINATION_SECONDS: Final[int] = 10
_FAST_TIME_BEFORE_TERMINATION_SECONDS: Final[datetime.timedelta] = datetime.timedelta(
seconds=10
)


@pytest.fixture
Expand Down Expand Up @@ -149,7 +152,7 @@ async def test_cluster_management_core_properly_removes_unused_instances(
mocked_dask_ping_scheduler.is_scheduler_busy.reset_mock()

# running the cluster management task after the heartbeat came in shall not remove anything
await asyncio.sleep(_FAST_TIME_BEFORE_TERMINATION_SECONDS + 1)
await asyncio.sleep(_FAST_TIME_BEFORE_TERMINATION_SECONDS.total_seconds() + 1)
await cluster_heartbeat(initialized_app, user_id=user_id, wallet_id=wallet_id)
await check_clusters(initialized_app)
await _assert_cluster_exist_and_state(
Expand All @@ -161,7 +164,7 @@ async def test_cluster_management_core_properly_removes_unused_instances(
mocked_dask_ping_scheduler.is_scheduler_busy.reset_mock()

# after waiting the termination time, running the task shall remove the cluster
await asyncio.sleep(_FAST_TIME_BEFORE_TERMINATION_SECONDS + 1)
await asyncio.sleep(_FAST_TIME_BEFORE_TERMINATION_SECONDS.total_seconds() + 1)
await check_clusters(initialized_app)
await _assert_cluster_exist_and_state(
ec2_client, instances=created_clusters, state="terminated"
Expand Down Expand Up @@ -201,7 +204,7 @@ async def test_cluster_management_core_properly_removes_workers_on_shutdown(
ec2_client, instance_ids=worker_instance_ids, state="running"
)
# after waiting the termination time, running the task shall remove the cluster
await asyncio.sleep(_FAST_TIME_BEFORE_TERMINATION_SECONDS + 1)
await asyncio.sleep(_FAST_TIME_BEFORE_TERMINATION_SECONDS.total_seconds() + 1)
await check_clusters(initialized_app)
await _assert_cluster_exist_and_state(
ec2_client, instances=created_clusters, state="terminated"
Expand Down Expand Up @@ -314,7 +317,7 @@ async def test_cluster_management_core_removes_broken_clusters_after_some_delay(
mocked_dask_ping_scheduler.is_scheduler_busy.reset_mock()

# waiting for the termination time will now terminate the cluster
await asyncio.sleep(_FAST_TIME_BEFORE_TERMINATION_SECONDS + 1)
await asyncio.sleep(_FAST_TIME_BEFORE_TERMINATION_SECONDS.total_seconds() + 1)
await check_clusters(initialized_app)
await _assert_cluster_exist_and_state(
ec2_client, instances=created_clusters, state="terminated"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -540,11 +540,7 @@ async def apply(
project_id: ProjectID,
iteration: Iteration,
) -> None:
"""schedules a pipeline for a given user, project and iteration.
Arguments:
wake_up_callback -- a callback function that is called in a separate thread everytime a pipeline node is completed
"""
"""apply the scheduling of a pipeline for a given user, project and iteration."""
with log_context(
_logger,
level=logging.INFO,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,9 @@ async def list(
return [
CompRunsAtDB.model_validate(row)
async for row in conn.execute(
sa.select(comp_runs).where(sa.and_(*conditions))
sa.select(comp_runs).where(
sa.and_(True, *conditions) # noqa: FBT003
)
)
]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,7 @@ async def _get_service_details(
node.version,
product_name,
)
obj: ServiceMetaDataPublished = ServiceMetaDataPublished.model_construct(
**service_details
)
obj: ServiceMetaDataPublished = ServiceMetaDataPublished(**service_details)
return obj


Expand All @@ -105,7 +103,7 @@ def _compute_node_requirements(
node_defined_resources[resource_name] = node_defined_resources.get(
resource_name, 0
) + min(resource_value.limit, resource_value.reservation)
return NodeRequirements.model_validate(node_defined_resources)
return NodeRequirements(**node_defined_resources)


def _compute_node_boot_mode(node_resources: ServiceResourcesDict) -> BootMode:
Expand Down Expand Up @@ -146,12 +144,12 @@ async def _get_node_infos(
None,
)

result: tuple[ServiceMetaDataPublished, ServiceExtras, SimcoreServiceLabels] = (
await asyncio.gather(
_get_service_details(catalog_client, user_id, product_name, node),
director_client.get_service_extras(node.key, node.version),
director_client.get_service_labels(node),
)
result: tuple[
ServiceMetaDataPublished, ServiceExtras, SimcoreServiceLabels
] = await asyncio.gather(
_get_service_details(catalog_client, user_id, product_name, node),
director_client.get_service_extras(node.key, node.version),
director_client.get_service_labels(node),
)
return result

Expand Down Expand Up @@ -189,7 +187,7 @@ async def _generate_task_image(
data.update(envs=_compute_node_envs(node_labels))
if node_extras and node_extras.container_spec:
data.update(command=node_extras.container_spec.command)
return Image.model_validate(data)
return Image(**data)


async def _get_pricing_and_hardware_infos(
Expand Down Expand Up @@ -247,9 +245,9 @@ async def _get_pricing_and_hardware_infos(
return pricing_info, hardware_info


_RAM_SAFE_MARGIN_RATIO: Final[float] = (
0.1 # NOTE: machines always have less available RAM than advertised
)
_RAM_SAFE_MARGIN_RATIO: Final[
float
] = 0.1 # NOTE: machines always have less available RAM than advertised
_CPUS_SAFE_MARGIN: Final[float] = 0.1


Expand All @@ -267,11 +265,11 @@ async def _update_project_node_resources_from_hardware_info(
if not hardware_info.aws_ec2_instances:
return
try:
unordered_list_ec2_instance_types: list[EC2InstanceTypeGet] = (
await get_instance_type_details(
rabbitmq_rpc_client,
instance_type_names=set(hardware_info.aws_ec2_instances),
)
unordered_list_ec2_instance_types: list[
EC2InstanceTypeGet
] = await get_instance_type_details(
rabbitmq_rpc_client,
instance_type_names=set(hardware_info.aws_ec2_instances),
)

assert unordered_list_ec2_instance_types # nosec
Expand Down Expand Up @@ -347,7 +345,7 @@ async def generate_tasks_list_from_project(
list_comp_tasks = []

unique_service_key_versions: set[ServiceKeyVersion] = {
ServiceKeyVersion.model_construct(
ServiceKeyVersion(
key=node.key, version=node.version
) # the service key version is frozen
for node in project.workbench.values()
Expand All @@ -366,9 +364,7 @@ async def generate_tasks_list_from_project(

for internal_id, node_id in enumerate(project.workbench, 1):
node: Node = project.workbench[node_id]
node_key_version = ServiceKeyVersion.model_construct(
key=node.key, version=node.version
)
node_key_version = ServiceKeyVersion(key=node.key, version=node.version)
node_details, node_extras, node_labels = key_version_to_node_infos.get(
node_key_version,
(None, None, None),
Expand Down Expand Up @@ -434,8 +430,8 @@ async def generate_tasks_list_from_project(
task_db = CompTaskAtDB(
project_id=project.uuid,
node_id=NodeID(node_id),
schema=NodeSchema.model_validate(
node_details.model_dump(
schema=NodeSchema(
**node_details.model_dump(
exclude_unset=True, by_alias=True, include={"inputs", "outputs"}
)
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ def minimal_configuration(
redis_service: RedisSettings,
monkeypatch: pytest.MonkeyPatch,
faker: Faker,
with_disabled_auto_scheduling: mock.Mock,
with_disabled_scheduler_publisher: mock.Mock,
):
monkeypatch.setenv("DIRECTOR_V2_DYNAMIC_SIDECAR_ENABLED", "false")
monkeypatch.setenv("COMPUTATIONAL_BACKEND_DASK_CLIENT_ENABLED", "1")
Expand Down Expand Up @@ -588,11 +590,7 @@ async def test_create_computation_with_wallet(

@pytest.mark.parametrize(
"default_pricing_plan",
[
PricingPlanGet.model_construct(
**PricingPlanGet.model_config["json_schema_extra"]["examples"][0]
)
],
[PricingPlanGet(**PricingPlanGet.model_config["json_schema_extra"]["examples"][0])],
)
async def test_create_computation_with_wallet_with_invalid_pricing_unit_name_raises_422(
minimal_configuration: None,
Expand Down Expand Up @@ -631,7 +629,7 @@ async def test_create_computation_with_wallet_with_invalid_pricing_unit_name_rai
@pytest.mark.parametrize(
"default_pricing_plan",
[
PricingPlanGet.model_construct(
PricingPlanGet(
**PricingPlanGet.model_config["json_schema_extra"]["examples"][0] # type: ignore
)
],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1606,7 +1606,9 @@ async def _return_random_task_result(job_id) -> TaskOutputData:
@pytest.fixture
def with_fast_service_heartbeat_s(monkeypatch: pytest.MonkeyPatch) -> int:
seconds = 1
monkeypatch.setenv("SERVICE_TRACKING_HEARTBEAT", f"{seconds}")
monkeypatch.setenv(
"SERVICE_TRACKING_HEARTBEAT", f"{datetime.timedelta(seconds=seconds)}"
)
return seconds


Expand Down
6 changes: 3 additions & 3 deletions services/director-v2/tests/unit/with_dbs/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,9 +278,9 @@ async def _(user: dict[str, Any], **cluster_kwargs) -> Cluster:
.where(clusters.c.id == created_cluster.id)
):
access_rights_in_db[row.gid] = {
"read": row[cluster_to_groups.c.read],
"write": row[cluster_to_groups.c.write],
"delete": row[cluster_to_groups.c.delete],
"read": row.read,
"write": row.write,
"delete": row.delete,
}

return Cluster(
Expand Down

0 comments on commit dc35757

Please sign in to comment.