diff --git a/services/catalog/src/simcore_service_catalog/db/repositories/services.py b/services/catalog/src/simcore_service_catalog/db/repositories/services.py index 93f53d6bacb..afa9db295e4 100644 --- a/services/catalog/src/simcore_service_catalog/db/repositories/services.py +++ b/services/catalog/src/simcore_service_catalog/db/repositories/services.py @@ -44,7 +44,7 @@ _logger = logging.getLogger(__name__) -def is_newer( +def _is_newer( old: ServiceSpecificationsAtDB | None, new: ServiceSpecificationsAtDB, ) -> bool: @@ -54,7 +54,7 @@ def is_newer( ) -def merge_specs( +def _merge_specs( everyone_spec: ServiceSpecificationsAtDB | None, team_specs: dict[GroupID, ServiceSpecificationsAtDB], user_spec: ServiceSpecificationsAtDB | None, @@ -176,6 +176,7 @@ async def get_service( write_access: bool | None = None, product_name: str | None = None, ) -> ServiceMetaDataAtDB | None: + query = sa.select(services_meta_data).where( (services_meta_data.c.key == key) & (services_meta_data.c.version == version) @@ -242,9 +243,7 @@ async def create_or_update_service( await conn.execute(insert_stmt) return created_service - async def update_service( - self, patched_service: ServiceMetaDataAtDB - ) -> ServiceMetaDataAtDB: + async def update_service(self, patched_service: ServiceMetaDataAtDB) -> None: stmt_update = ( services_meta_data.update() @@ -259,14 +258,9 @@ async def update_service( exclude={"key", "version"}, ) ) - .returning(literal_column("*")) ) - async with self.db_engine.begin() as conn: - result = await conn.execute(stmt_update) - row = result.first() - assert row # nosec - return ServiceMetaDataAtDB.from_orm(row) + await conn.execute(stmt_update) async def can_get_service( self, @@ -290,6 +284,27 @@ async def can_get_service( ) return bool(result.scalar()) + async def can_update_service( + self, + # access-rights + product_name: ProductName, + user_id: UserID, + # get args + key: ServiceKey, + version: ServiceVersion, + ) -> bool: + async with self.db_engine.begin() as conn: + result = await conn.execute( + can_get_service_stmt( + product_name=product_name, + user_id=user_id, + access_rights=AccessRightsClauses.can_edit, + service_key=key, + service_version=version, + ) + ) + return bool(result.scalar()) + async def get_service_with_history( self, # access-rights @@ -576,16 +591,16 @@ async def get_service_specifications( continue # filter by group type group = gid_to_group_map[row.gid] - if (group.group_type == GroupTypeInModel.STANDARD) and is_newer( + if (group.group_type == GroupTypeInModel.STANDARD) and _is_newer( teams_specs.get(db_service_spec.gid), db_service_spec, ): teams_specs[db_service_spec.gid] = db_service_spec - elif (group.group_type == GroupTypeInModel.EVERYONE) and is_newer( + elif (group.group_type == GroupTypeInModel.EVERYONE) and _is_newer( everyone_specs, db_service_spec ): everyone_specs = db_service_spec - elif (group.group_type == GroupTypeInModel.PRIMARY) and is_newer( + elif (group.group_type == GroupTypeInModel.PRIMARY) and _is_newer( primary_specs, db_service_spec ): primary_specs = db_service_spec @@ -597,7 +612,7 @@ async def get_service_specifications( f"{exc}", ) - if merged_specifications := merge_specs( + if merged_specifications := _merge_specs( everyone_specs, teams_specs, primary_specs ): return ServiceSpecifications.parse_obj(merged_specifications) diff --git a/services/catalog/src/simcore_service_catalog/services/services_api.py b/services/catalog/src/simcore_service_catalog/services/services_api.py index 1747f80624f..af4721b07b9 100644 --- a/services/catalog/src/simcore_service_catalog/services/services_api.py +++ b/services/catalog/src/simcore_service_catalog/services/services_api.py @@ -172,6 +172,7 @@ async def get_service( version=service_version, ) if not service: + # no service found provided `access_rights` raise CatalogForbiddenError( name=f"{service_key}:{service_version}", service_key=service_key, @@ -216,9 +217,11 @@ async def update_service( product_name=product_name, ) - if not await repo.get_service_access_rights( + access_rights = await repo.get_service_access_rights( key=service_key, version=service_version, product_name=product_name - ): + ) + + if not access_rights: raise CatalogItemNotFoundError( name=f"{service_key}:{service_version}", service_key=service_key, @@ -227,13 +230,11 @@ async def update_service( product_name=product_name, ) - # Updates service_meta_data - if not await repo.update_service( - ServiceMetaDataAtDB( - key=service_key, - version=service_version, - **update.dict(exclude_unset=True), - ) + if not await repo.can_update_service( + product_name=product_name, + user_id=user_id, + key=service_key, + version=service_version, ): raise CatalogForbiddenError( name=f"{service_key}:{service_version}", @@ -243,14 +244,20 @@ async def update_service( product_name=product_name, ) + # Updates service_meta_data + await repo.update_service( + ServiceMetaDataAtDB( + key=service_key, + version=service_version, + **update.dict(exclude_unset=True), + ) + ) + # Updates service_access_rights (they can be added/removed/modified) if update.access_rights: # before - current_access_rights = await repo.get_service_access_rights( - service_key, service_version, product_name=product_name - ) - before_gids = [r.gid for r in current_access_rights] + previous_gids = [r.gid for r in access_rights] # new new_access_rights = [ @@ -267,17 +274,17 @@ async def update_service( await repo.upsert_service_access_rights(new_access_rights) # then delete the ones that were removed - remove_gids = [gid for gid in before_gids if gid not in update.access_rights] - delete_access_rights = [ + removed_access_rights = [ ServiceAccessRightsAtDB( key=service_key, version=service_version, gid=gid, product_name=product_name, ) - for gid in remove_gids + for gid in previous_gids + if gid not in update.access_rights ] - await repo.delete_service_access_rights(delete_access_rights) + await repo.delete_service_access_rights(removed_access_rights) return await get_service( repo=repo, diff --git a/services/catalog/tests/unit/with_dbs/conftest.py b/services/catalog/tests/unit/with_dbs/conftest.py index deb49f2c200..1681622314f 100644 --- a/services/catalog/tests/unit/with_dbs/conftest.py +++ b/services/catalog/tests/unit/with_dbs/conftest.py @@ -368,8 +368,8 @@ async def create_fake_service_data( fake_service, *fake_access_rights = create_fake_service_data( "simcore/services/dynamic/jupyterlab", "0.0.1", - team_access=None, - everyone_access=None, + team_access="xw", + everyone_access="x", product=target_product, ), diff --git a/services/catalog/tests/unit/with_dbs/test_api_rpc.py b/services/catalog/tests/unit/with_dbs/test_api_rpc.py index 289443fa82b..499683ecdff 100644 --- a/services/catalog/tests/unit/with_dbs/test_api_rpc.py +++ b/services/catalog/tests/unit/with_dbs/test_api_rpc.py @@ -4,26 +4,35 @@ # pylint: disable=too-many-arguments -from collections.abc import Callable +from collections.abc import AsyncIterator, Callable from typing import Any import pytest +from faker import Faker from fastapi import FastAPI from models_library.products import ProductName from models_library.rest_pagination import MAXIMUM_NUMBER_OF_ITEMS_PER_PAGE +from models_library.services_types import ServiceKey, ServiceVersion from models_library.users import UserID from pydantic import ValidationError +from pytest_simcore.helpers.faker_factories import random_user from pytest_simcore.helpers.monkeypatch_envs import setenvs_from_dict +from pytest_simcore.helpers.postgres_tools import insert_and_get_row_lifespan from pytest_simcore.helpers.typing_env import EnvVarsDict from respx.router import MockRouter from servicelib.rabbitmq import RabbitMQRPCClient -from servicelib.rabbitmq.rpc_interfaces.catalog.errors import CatalogItemNotFoundError +from servicelib.rabbitmq.rpc_interfaces.catalog.errors import ( + CatalogForbiddenError, + CatalogItemNotFoundError, +) from servicelib.rabbitmq.rpc_interfaces.catalog.services import ( check_for_service, get_service, list_services_paginated, update_service, ) +from simcore_postgres_database.models.users import users +from sqlalchemy.ext.asyncio import AsyncEngine pytest_simcore_core_services_selection = [ "rabbit", @@ -224,3 +233,163 @@ async def test_rpc_check_for_service( service_key="simcore/services/dynamic/unknown", service_version="1.0.0", ) + + +@pytest.fixture +async def other_user( + user_id: UserID, + sqlalchemy_async_engine: AsyncEngine, + faker: Faker, +) -> AsyncIterator[dict[str, Any]]: + + _user = random_user(fake=faker, id=user_id + 1) + async with insert_and_get_row_lifespan( + sqlalchemy_async_engine, + table=users, + values=_user, + pk_col=users.c.id, + pk_value=_user["id"], + ) as row: + yield row + + +async def test_rpc_get_service_access_rights( + background_sync_task_mocked: None, + mocked_director_service_api: MockRouter, + rpc_client: RabbitMQRPCClient, + product_name: ProductName, + user: dict[str, Any], + user_id: UserID, + other_user: dict[str, Any], + app: FastAPI, + create_fake_service_data: Callable, + target_product: ProductName, +): + assert app + assert user["id"] == user_id + + # user_id owns a service (created in background_sync_task_mocked) + service_key = ServiceKey("simcore/services/comp/test-api-rpc-service-0") + service_version = ServiceVersion("0.0.0") + + service = await get_service( + rpc_client, + product_name=product_name, + user_id=user_id, + service_key=service_key, + service_version=service_version, + ) + assert service + assert service.access_rights + assert service.access_rights[user["primary_gid"]].write + assert service.access_rights[user["primary_gid"]].execute + + assert other_user["primary_gid"] not in service.access_rights + + # other_user does not have EXECUTE access ----------------- + with pytest.raises(CatalogForbiddenError, match=service_key): + await get_service( + rpc_client, + product_name=product_name, + user_id=other_user["id"], + service_key=service_key, + service_version=service_version, + ) + + # other_user does not have WRITE access + with pytest.raises(CatalogForbiddenError, match=service_key): + await update_service( + rpc_client, + product_name=product_name, + user_id=other_user["id"], + service_key=service_key, + service_version=service_version, + update={ + "name": "foo", + "description": "bar", + }, + ) + + # user_id gives "x access" to other_user ------------ + assert service.access_rights is not None + await update_service( + rpc_client, + product_name=product_name, + user_id=user_id, + service_key=service_key, + service_version=service_version, + update={ + "access_rights": { + **service.access_rights, + other_user["primary_gid"]: { + "execute": True, + "write": False, + }, + } + }, + ) + + # other user can now GET but NOT UPDATE + await get_service( + rpc_client, + product_name=product_name, + user_id=other_user["id"], + service_key=service_key, + service_version=service_version, + ) + + with pytest.raises(CatalogForbiddenError, match=service_key): + await update_service( + rpc_client, + product_name=product_name, + user_id=other_user["id"], + service_key=service_key, + service_version=service_version, + update={ + "name": "foo", + "description": "bar", + }, + ) + + # user_id gives "xw access" to other_user ------------------ + assert service.access_rights is not None + await update_service( + rpc_client, + product_name=product_name, + user_id=user_id, + service_key=service_key, + service_version=service_version, + update={ + "access_rights": { + **service.access_rights, + other_user["primary_gid"]: { + "execute": True, + "write": True, + }, + } + }, + ) + + # other_user can now update and get + await update_service( + rpc_client, + product_name=product_name, + user_id=other_user["id"], + service_key=service_key, + service_version=service_version, + update={ + "name": "foo", + "description": "bar", + }, + ) + updated_service = await get_service( + rpc_client, + product_name=product_name, + user_id=other_user["id"], + service_key=service_key, + service_version=service_version, + ) + assert updated_service.dict(include={"name", "description"}) == { + "name": "foo", + "description": "bar", + } diff --git a/services/catalog/tests/unit/with_dbs/test_db_repositories.py b/services/catalog/tests/unit/with_dbs/test_db_repositories.py index 86315691c99..3438492f740 100644 --- a/services/catalog/tests/unit/with_dbs/test_db_repositories.py +++ b/services/catalog/tests/unit/with_dbs/test_db_repositories.py @@ -381,11 +381,12 @@ async def test_get_and_update_service_meta_data( assert got.key == service_key assert got.version == service_version - updated = await services_repo.update_service( + await services_repo.update_service( ServiceMetaDataAtDB.construct( key=service_key, version=service_version, name="foo" - ) + ), ) + updated = await services_repo.get_service(service_key, service_version) assert got.copy(update={"name": "foo"}) == updated