diff --git a/services/web/server/src/simcore_service_webserver/projects/_nodes_handlers.py b/services/web/server/src/simcore_service_webserver/projects/_nodes_handlers.py index 67f3104a829..6a7109799e0 100644 --- a/services/web/server/src/simcore_service_webserver/projects/_nodes_handlers.py +++ b/services/web/server/src/simcore_service_webserver/projects/_nodes_handlers.py @@ -74,6 +74,7 @@ DefaultPricingUnitNotFoundError, NodeNotFoundError, ProjectInvalidRightsError, + ProjectNodeRequiredInputsNotSetError, ProjectNodeResourcesInsufficientRightsError, ProjectNodeResourcesInvalidError, ProjectNotFoundError, @@ -105,6 +106,8 @@ async def wrapper(request: web.Request) -> web.StreamResponse: raise web.HTTPConflict(reason=f"{exc}") from exc except ClustersKeeperNotAvailableError as exc: raise web.HTTPServiceUnavailable(reason=f"{exc}") from exc + except ProjectNodeRequiredInputsNotSetError as exc: + raise web.HTTPConflict(reason=f"{exc}") from exc return wrapper diff --git a/services/web/server/src/simcore_service_webserver/projects/exceptions.py b/services/web/server/src/simcore_service_webserver/projects/exceptions.py index e62c4ef78e2..ecd60a58c39 100644 --- a/services/web/server/src/simcore_service_webserver/projects/exceptions.py +++ b/services/web/server/src/simcore_service_webserver/projects/exceptions.py @@ -4,6 +4,7 @@ import redis.exceptions from models_library.projects import ProjectID +from models_library.projects_nodes_io import NodeID from models_library.users import UserID from ..errors import WebServerBaseError @@ -136,6 +137,51 @@ class ProjectNodeResourcesInsufficientRightsError(BaseProjectError): ... +class ProjectNodeRequiredInputsNotSetError(BaseProjectError): + ... + + +class ProjectNodeConnectionsMissingError(ProjectNodeRequiredInputsNotSetError): + msg_template = "Missing '{joined_unset_required_inputs}' connection(s) to '{node_with_required_inputs}'" + + def __init__( + self, + *, + unset_required_inputs: list[str], + node_with_required_inputs: NodeID, + **ctx, + ): + super().__init__( + joined_unset_required_inputs=", ".join(unset_required_inputs), + unset_required_inputs=unset_required_inputs, + node_with_required_inputs=node_with_required_inputs, + **ctx, + ) + self.unset_required_inputs = unset_required_inputs + self.node_with_required_inputs = node_with_required_inputs + + +class ProjectNodeOutputPortMissingValueError(ProjectNodeRequiredInputsNotSetError): + msg_template = "Missing: {joined_start_message}" + + def __init__( + self, + *, + unset_outputs_in_upstream: list[tuple[str, str]], + **ctx, + ): + start_messages = [ + f"'{input_key}' of '{service_name}'" + for input_key, service_name in unset_outputs_in_upstream + ] + super().__init__( + joined_start_message=", ".join(start_messages), + unset_outputs_in_upstream=unset_outputs_in_upstream, + **ctx, + ) + self.unset_outputs_in_upstream = unset_outputs_in_upstream + + class DefaultPricingUnitNotFoundError(BaseProjectError): msg_template = "Default pricing unit not found for node '{node_uuid}' in project '{project_uuid}'" diff --git a/services/web/server/src/simcore_service_webserver/projects/projects_api.py b/services/web/server/src/simcore_service_webserver/projects/projects_api.py index 5088775c7fb..28655495bfb 100644 --- a/services/web/server/src/simcore_service_webserver/projects/projects_api.py +++ b/services/web/server/src/simcore_service_webserver/projects/projects_api.py @@ -34,8 +34,8 @@ from models_library.errors import ErrorDict from models_library.products import ProductName from models_library.projects import Project, ProjectID, ProjectIDStr -from models_library.projects_nodes import Node -from models_library.projects_nodes_io import NodeID, NodeIDStr +from models_library.projects_nodes import Node, OutputsDict +from models_library.projects_nodes_io import NodeID, NodeIDStr, PortLink from models_library.projects_state import ( Owner, ProjectLocked, @@ -124,6 +124,9 @@ NodeNotFoundError, ProjectInvalidRightsError, ProjectLockError, + ProjectNodeConnectionsMissingError, + ProjectNodeOutputPortMissingValueError, + ProjectNodeRequiredInputsNotSetError, ProjectNodeResourcesInvalidError, ProjectOwnerNotFoundInTheProjectAccessRightsError, ProjectStartsTooManyDynamicNodesError, @@ -447,6 +450,56 @@ def _by_type_name(ec2: EC2InstanceTypeGet) -> bool: raise ClustersKeeperNotAvailableError from exc +async def _check_project_node_has_all_required_inputs( + db: ProjectDBAPI, user_id: UserID, project_uuid: ProjectID, node_id: NodeID +) -> None: + + project_dict, _ = await db.get_project(user_id, f"{project_uuid}") + + nodes_map: dict[NodeID, Node] = { + NodeID(k): Node(**v) for k, v in project_dict["workbench"].items() + } + node = nodes_map[node_id] + + unset_required_inputs: list[str] = [] + unset_outputs_in_upstream: list[tuple[str, str]] = [] + + def _check_required_input(required_input_key: str) -> None: + input_entry: PortLink | None = None + if node.inputs: + input_entry = node.inputs.get(required_input_key, None) + if input_entry is None: + # NOT linked to any node connect service or set value manually(whichever applies) + unset_required_inputs.append(required_input_key) + return + + source_node_id: NodeID = input_entry.node_uuid + source_output_key = input_entry.output + + source_node = nodes_map[source_node_id] + + output_entry: OutputsDict | None = None + if source_node.outputs: + output_entry = source_node.outputs.get(source_output_key, None) + if output_entry is None: + unset_outputs_in_upstream.append((source_output_key, source_node.label)) + + for required_input in node.inputs_required: + _check_required_input(required_input) + + node_with_required_inputs = node.label + if unset_required_inputs: + raise ProjectNodeConnectionsMissingError( + unset_required_inputs=unset_required_inputs, + node_with_required_inputs=node_with_required_inputs, + ) + + if unset_outputs_in_upstream: + raise ProjectNodeOutputPortMissingValueError( + unset_outputs_in_upstream=unset_outputs_in_upstream + ) + + async def _start_dynamic_service( request: web.Request, *, @@ -456,6 +509,7 @@ async def _start_dynamic_service( user_id: UserID, project_uuid: ProjectID, node_uuid: NodeID, + graceful_start: bool = False, ) -> None: if not _is_node_dynamic(service_key): return @@ -464,6 +518,20 @@ async def _start_dynamic_service( db: ProjectDBAPI = ProjectDBAPI.get_from_app_context(request.app) + try: + await _check_project_node_has_all_required_inputs( + db, user_id, project_uuid, node_uuid + ) + except ProjectNodeRequiredInputsNotSetError as e: + if graceful_start: + log.info( + "Did not start '%s' because of missing required inputs: %s", + node_uuid, + e, + ) + return + raise + save_state = False user_role: UserRole = await get_user_role(request.app, user_id) if user_role > UserRole.GUEST: @@ -1464,6 +1532,7 @@ async def run_project_dynamic_services( user_id=user_id, project_uuid=project["uuid"], node_uuid=NodeID(service_uuid), + graceful_start=True, ) for service_uuid, is_deprecated in zip( services_to_start_uuids, deprecated_services, strict=True diff --git a/services/web/server/tests/unit/with_dbs/03/test_project_db.py b/services/web/server/tests/unit/with_dbs/03/test_project_db.py index 8411f84ca23..ebf46bee580 100644 --- a/services/web/server/tests/unit/with_dbs/03/test_project_db.py +++ b/services/web/server/tests/unit/with_dbs/03/test_project_db.py @@ -33,9 +33,13 @@ from simcore_service_webserver.projects.db import ProjectAccessRights, ProjectDBAPI from simcore_service_webserver.projects.exceptions import ( NodeNotFoundError, + ProjectNodeRequiredInputsNotSetError, ProjectNotFoundError, ) from simcore_service_webserver.projects.models import ProjectDict +from simcore_service_webserver.projects.projects_api import ( + _check_project_node_has_all_required_inputs, +) from simcore_service_webserver.users.exceptions import UserNotFoundError from simcore_service_webserver.utils import to_datetime from sqlalchemy.engine.result import Row @@ -829,3 +833,116 @@ async def test_has_permission( await db_api.has_permission(second_user["id"], project_id, permission) is access_rights[permission] ), f"Found unexpected {permission=} for {access_rights=} of {user_role=} and {project_id=}" + + +def _fake_output_data() -> dict: + return { + "store": 0, + "path": "9f8207e6-144a-11ef-831f-0242ac140027/98b68cbe-9e22-4eb5-a91b-2708ad5317b7/outputs/output_2/output_2.zip", + "eTag": "ec3bc734d85359b660aab400147cd1ea", + } + + +def _fake_connect_to(output_number: int) -> dict: + return { + "nodeUuid": "98b68cbe-9e22-4eb5-a91b-2708ad5317b7", + "output": f"output_{output_number}", + } + + +@pytest.fixture +async def inserted_project( + logged_user: dict[str, Any], + insert_project_in_db: Callable[..., Awaitable[dict[str, Any]]], + fake_project: dict[str, Any], + downstream_inputs: dict, + downstream_required_inputs: list[str], + upstream_outputs: dict, +) -> dict: + fake_project["workbench"] = { + "98b68cbe-9e22-4eb5-a91b-2708ad5317b7": { + "key": "simcore/services/dynamic/jupyter-math", + "version": "2.0.10", + "label": "upstream", + "inputs": {}, + "inputsUnits": {}, + "inputNodes": [], + "thumbnail": "", + "outputs": upstream_outputs, + "runHash": "c6ae58f36a2e0f65f443441ecda023a451cb1b8051d01412d79aa03653e1a6b3", + }, + "324d6ef2-a82c-414d-9001-dc84da1cbea3": { + "key": "simcore/services/dynamic/jupyter-math", + "version": "2.0.10", + "label": "downstream", + "inputs": downstream_inputs, + "inputsUnits": {}, + "inputNodes": ["98b68cbe-9e22-4eb5-a91b-2708ad5317b7"], + "thumbnail": "", + "inputsRequired": downstream_required_inputs, + }, + } + + return await insert_project_in_db(fake_project, user_id=logged_user["id"]) + + +@pytest.mark.parametrize( + "downstream_inputs,downstream_required_inputs,upstream_outputs,expected_error", + [ + pytest.param( + {"input_1": _fake_connect_to(1)}, + ["input_1", "input_2"], + {}, + "Missing 'input_2' connection(s) to 'downstream'", + id="missing_connection_on_input_2", + ), + pytest.param( + {"input_1": _fake_connect_to(1), "input_2": _fake_connect_to(2)}, + ["input_1", "input_2"], + {"output_2": _fake_output_data()}, + "Missing: 'output_1' of 'upstream'", + id="output_1_has_not_file", + ), + ], +) +@pytest.mark.parametrize("user_role", [(UserRole.USER)]) +async def test_check_project_node_has_all_required_inputs_raises( + logged_user: dict[str, Any], + db_api: ProjectDBAPI, + inserted_project: dict, + expected_error: str, +): + + with pytest.raises(ProjectNodeRequiredInputsNotSetError) as exc: + await _check_project_node_has_all_required_inputs( + db_api, + user_id=logged_user["id"], + project_uuid=UUID(inserted_project["uuid"]), + node_id=UUID("324d6ef2-a82c-414d-9001-dc84da1cbea3"), + ) + assert f"{exc.value}" == expected_error + + +@pytest.mark.parametrize( + "downstream_inputs,downstream_required_inputs,upstream_outputs", + [ + pytest.param( + {"input_1": _fake_connect_to(1), "input_2": _fake_connect_to(2)}, + ["input_1", "input_2"], + {"output_1": _fake_output_data(), "output_2": _fake_output_data()}, + id="with_required_inputs_present", + ), + ], +) +@pytest.mark.parametrize("user_role", [(UserRole.USER)]) +async def test_check_project_node_has_all_required_inputs_ok( + logged_user: dict[str, Any], + db_api: ProjectDBAPI, + inserted_project: dict, +): + await _check_project_node_has_all_required_inputs( + db_api, + user_id=logged_user["id"], + project_uuid=UUID(inserted_project["uuid"]), + node_id=UUID("324d6ef2-a82c-414d-9001-dc84da1cbea3"), + )