Skip to content

Commit

Permalink
✨ dynamic-services will fail if they have any required input that is …
Browse files Browse the repository at this point in the history
…not set (ITISFoundation#5845)

Co-authored-by: Andrei Neagu <[email protected]>
  • Loading branch information
GitHK and Andrei Neagu authored Jun 3, 2024
1 parent 00b0944 commit 0538066
Show file tree
Hide file tree
Showing 4 changed files with 237 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
DefaultPricingUnitNotFoundError,
NodeNotFoundError,
ProjectInvalidRightsError,
ProjectNodeRequiredInputsNotSetError,
ProjectNodeResourcesInsufficientRightsError,
ProjectNodeResourcesInvalidError,
ProjectNotFoundError,
Expand Down Expand Up @@ -105,6 +106,8 @@ async def wrapper(request: web.Request) -> web.StreamResponse:
raise web.HTTPConflict(reason=f"{exc}") from exc
except ClustersKeeperNotAvailableError as exc:
raise web.HTTPServiceUnavailable(reason=f"{exc}") from exc
except ProjectNodeRequiredInputsNotSetError as exc:
raise web.HTTPConflict(reason=f"{exc}") from exc

return wrapper

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import redis.exceptions
from models_library.projects import ProjectID
from models_library.projects_nodes_io import NodeID
from models_library.users import UserID

from ..errors import WebServerBaseError
Expand Down Expand Up @@ -136,6 +137,51 @@ class ProjectNodeResourcesInsufficientRightsError(BaseProjectError):
...


class ProjectNodeRequiredInputsNotSetError(BaseProjectError):
...


class ProjectNodeConnectionsMissingError(ProjectNodeRequiredInputsNotSetError):
msg_template = "Missing '{joined_unset_required_inputs}' connection(s) to '{node_with_required_inputs}'"

def __init__(
self,
*,
unset_required_inputs: list[str],
node_with_required_inputs: NodeID,
**ctx,
):
super().__init__(
joined_unset_required_inputs=", ".join(unset_required_inputs),
unset_required_inputs=unset_required_inputs,
node_with_required_inputs=node_with_required_inputs,
**ctx,
)
self.unset_required_inputs = unset_required_inputs
self.node_with_required_inputs = node_with_required_inputs


class ProjectNodeOutputPortMissingValueError(ProjectNodeRequiredInputsNotSetError):
msg_template = "Missing: {joined_start_message}"

def __init__(
self,
*,
unset_outputs_in_upstream: list[tuple[str, str]],
**ctx,
):
start_messages = [
f"'{input_key}' of '{service_name}'"
for input_key, service_name in unset_outputs_in_upstream
]
super().__init__(
joined_start_message=", ".join(start_messages),
unset_outputs_in_upstream=unset_outputs_in_upstream,
**ctx,
)
self.unset_outputs_in_upstream = unset_outputs_in_upstream


class DefaultPricingUnitNotFoundError(BaseProjectError):
msg_template = "Default pricing unit not found for node '{node_uuid}' in project '{project_uuid}'"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@
from models_library.errors import ErrorDict
from models_library.products import ProductName
from models_library.projects import Project, ProjectID, ProjectIDStr
from models_library.projects_nodes import Node
from models_library.projects_nodes_io import NodeID, NodeIDStr
from models_library.projects_nodes import Node, OutputsDict
from models_library.projects_nodes_io import NodeID, NodeIDStr, PortLink
from models_library.projects_state import (
Owner,
ProjectLocked,
Expand Down Expand Up @@ -124,6 +124,9 @@
NodeNotFoundError,
ProjectInvalidRightsError,
ProjectLockError,
ProjectNodeConnectionsMissingError,
ProjectNodeOutputPortMissingValueError,
ProjectNodeRequiredInputsNotSetError,
ProjectNodeResourcesInvalidError,
ProjectOwnerNotFoundInTheProjectAccessRightsError,
ProjectStartsTooManyDynamicNodesError,
Expand Down Expand Up @@ -447,6 +450,56 @@ def _by_type_name(ec2: EC2InstanceTypeGet) -> bool:
raise ClustersKeeperNotAvailableError from exc


async def _check_project_node_has_all_required_inputs(
db: ProjectDBAPI, user_id: UserID, project_uuid: ProjectID, node_id: NodeID
) -> None:

project_dict, _ = await db.get_project(user_id, f"{project_uuid}")

nodes_map: dict[NodeID, Node] = {
NodeID(k): Node(**v) for k, v in project_dict["workbench"].items()
}
node = nodes_map[node_id]

unset_required_inputs: list[str] = []
unset_outputs_in_upstream: list[tuple[str, str]] = []

def _check_required_input(required_input_key: str) -> None:
input_entry: PortLink | None = None
if node.inputs:
input_entry = node.inputs.get(required_input_key, None)
if input_entry is None:
# NOT linked to any node connect service or set value manually(whichever applies)
unset_required_inputs.append(required_input_key)
return

source_node_id: NodeID = input_entry.node_uuid
source_output_key = input_entry.output

source_node = nodes_map[source_node_id]

output_entry: OutputsDict | None = None
if source_node.outputs:
output_entry = source_node.outputs.get(source_output_key, None)
if output_entry is None:
unset_outputs_in_upstream.append((source_output_key, source_node.label))

for required_input in node.inputs_required:
_check_required_input(required_input)

node_with_required_inputs = node.label
if unset_required_inputs:
raise ProjectNodeConnectionsMissingError(
unset_required_inputs=unset_required_inputs,
node_with_required_inputs=node_with_required_inputs,
)

if unset_outputs_in_upstream:
raise ProjectNodeOutputPortMissingValueError(
unset_outputs_in_upstream=unset_outputs_in_upstream
)


async def _start_dynamic_service(
request: web.Request,
*,
Expand All @@ -456,6 +509,7 @@ async def _start_dynamic_service(
user_id: UserID,
project_uuid: ProjectID,
node_uuid: NodeID,
graceful_start: bool = False,
) -> None:
if not _is_node_dynamic(service_key):
return
Expand All @@ -464,6 +518,20 @@ async def _start_dynamic_service(

db: ProjectDBAPI = ProjectDBAPI.get_from_app_context(request.app)

try:
await _check_project_node_has_all_required_inputs(
db, user_id, project_uuid, node_uuid
)
except ProjectNodeRequiredInputsNotSetError as e:
if graceful_start:
log.info(
"Did not start '%s' because of missing required inputs: %s",
node_uuid,
e,
)
return
raise

save_state = False
user_role: UserRole = await get_user_role(request.app, user_id)
if user_role > UserRole.GUEST:
Expand Down Expand Up @@ -1464,6 +1532,7 @@ async def run_project_dynamic_services(
user_id=user_id,
project_uuid=project["uuid"],
node_uuid=NodeID(service_uuid),
graceful_start=True,
)
for service_uuid, is_deprecated in zip(
services_to_start_uuids, deprecated_services, strict=True
Expand Down
117 changes: 117 additions & 0 deletions services/web/server/tests/unit/with_dbs/03/test_project_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,13 @@
from simcore_service_webserver.projects.db import ProjectAccessRights, ProjectDBAPI
from simcore_service_webserver.projects.exceptions import (
NodeNotFoundError,
ProjectNodeRequiredInputsNotSetError,
ProjectNotFoundError,
)
from simcore_service_webserver.projects.models import ProjectDict
from simcore_service_webserver.projects.projects_api import (
_check_project_node_has_all_required_inputs,
)
from simcore_service_webserver.users.exceptions import UserNotFoundError
from simcore_service_webserver.utils import to_datetime
from sqlalchemy.engine.result import Row
Expand Down Expand Up @@ -829,3 +833,116 @@ async def test_has_permission(
await db_api.has_permission(second_user["id"], project_id, permission)
is access_rights[permission]
), f"Found unexpected {permission=} for {access_rights=} of {user_role=} and {project_id=}"


def _fake_output_data() -> dict:
return {
"store": 0,
"path": "9f8207e6-144a-11ef-831f-0242ac140027/98b68cbe-9e22-4eb5-a91b-2708ad5317b7/outputs/output_2/output_2.zip",
"eTag": "ec3bc734d85359b660aab400147cd1ea",
}


def _fake_connect_to(output_number: int) -> dict:
return {
"nodeUuid": "98b68cbe-9e22-4eb5-a91b-2708ad5317b7",
"output": f"output_{output_number}",
}


@pytest.fixture
async def inserted_project(
logged_user: dict[str, Any],
insert_project_in_db: Callable[..., Awaitable[dict[str, Any]]],
fake_project: dict[str, Any],
downstream_inputs: dict,
downstream_required_inputs: list[str],
upstream_outputs: dict,
) -> dict:
fake_project["workbench"] = {
"98b68cbe-9e22-4eb5-a91b-2708ad5317b7": {
"key": "simcore/services/dynamic/jupyter-math",
"version": "2.0.10",
"label": "upstream",
"inputs": {},
"inputsUnits": {},
"inputNodes": [],
"thumbnail": "",
"outputs": upstream_outputs,
"runHash": "c6ae58f36a2e0f65f443441ecda023a451cb1b8051d01412d79aa03653e1a6b3",
},
"324d6ef2-a82c-414d-9001-dc84da1cbea3": {
"key": "simcore/services/dynamic/jupyter-math",
"version": "2.0.10",
"label": "downstream",
"inputs": downstream_inputs,
"inputsUnits": {},
"inputNodes": ["98b68cbe-9e22-4eb5-a91b-2708ad5317b7"],
"thumbnail": "",
"inputsRequired": downstream_required_inputs,
},
}

return await insert_project_in_db(fake_project, user_id=logged_user["id"])


@pytest.mark.parametrize(
"downstream_inputs,downstream_required_inputs,upstream_outputs,expected_error",
[
pytest.param(
{"input_1": _fake_connect_to(1)},
["input_1", "input_2"],
{},
"Missing 'input_2' connection(s) to 'downstream'",
id="missing_connection_on_input_2",
),
pytest.param(
{"input_1": _fake_connect_to(1), "input_2": _fake_connect_to(2)},
["input_1", "input_2"],
{"output_2": _fake_output_data()},
"Missing: 'output_1' of 'upstream'",
id="output_1_has_not_file",
),
],
)
@pytest.mark.parametrize("user_role", [(UserRole.USER)])
async def test_check_project_node_has_all_required_inputs_raises(
logged_user: dict[str, Any],
db_api: ProjectDBAPI,
inserted_project: dict,
expected_error: str,
):

with pytest.raises(ProjectNodeRequiredInputsNotSetError) as exc:
await _check_project_node_has_all_required_inputs(
db_api,
user_id=logged_user["id"],
project_uuid=UUID(inserted_project["uuid"]),
node_id=UUID("324d6ef2-a82c-414d-9001-dc84da1cbea3"),
)
assert f"{exc.value}" == expected_error


@pytest.mark.parametrize(
"downstream_inputs,downstream_required_inputs,upstream_outputs",
[
pytest.param(
{"input_1": _fake_connect_to(1), "input_2": _fake_connect_to(2)},
["input_1", "input_2"],
{"output_1": _fake_output_data(), "output_2": _fake_output_data()},
id="with_required_inputs_present",
),
],
)
@pytest.mark.parametrize("user_role", [(UserRole.USER)])
async def test_check_project_node_has_all_required_inputs_ok(
logged_user: dict[str, Any],
db_api: ProjectDBAPI,
inserted_project: dict,
):
await _check_project_node_has_all_required_inputs(
db_api,
user_id=logged_user["id"],
project_uuid=UUID(inserted_project["uuid"]),
node_id=UUID("324d6ef2-a82c-414d-9001-dc84da1cbea3"),
)

0 comments on commit 0538066

Please sign in to comment.