diff --git a/services/storage/src/simcore_service_storage/db_access_layer.py b/services/storage/src/simcore_service_storage/db_access_layer.py index 7f3270a3587..aadaf9a87b9 100644 --- a/services/storage/src/simcore_service_storage/db_access_layer.py +++ b/services/storage/src/simcore_service_storage/db_access_layer.py @@ -45,6 +45,8 @@ from models_library.projects import ProjectID from models_library.projects_nodes_io import StorageFileID from models_library.users import GroupID, UserID +from simcore_postgres_database.models.project_to_groups import project_to_groups +from simcore_postgres_database.models.projects import projects from simcore_postgres_database.storage_models import file_meta_data, user_to_groups logger = logging.getLogger(__name__) @@ -112,6 +114,32 @@ def _aggregate_access_rights( return AccessRights.none() +def assemble_array_groups(user_group_ids: list[GroupID]) -> str: + return ( + "array[]::text[]" + if len(user_group_ids) == 0 + else f"""array[{', '.join(f"'{group_id}'" for group_id in user_group_ids)}]""" + ) + + +access_rights_subquery = ( + sa.select( + project_to_groups.c.project_uuid, + sa.func.jsonb_object_agg( + project_to_groups.c.gid, + sa.func.jsonb_build_object( + "read", + project_to_groups.c.read, + "write", + project_to_groups.c.write, + "delete", + project_to_groups.c.delete, + ), + ).label("access_rights"), + ).group_by(project_to_groups.c.project_uuid) +).subquery("access_rights_subquery") + + async def list_projects_access_rights( conn: SAConnection, user_id: UserID ) -> dict[ProjectID, AccessRights]: @@ -119,24 +147,25 @@ async def list_projects_access_rights( Returns access-rights of user (user_id) over all OWNED or SHARED projects """ - user_group_ids: list[int] = await _get_user_groups_ids(conn, user_id) + user_group_ids: list[GroupID] = await _get_user_groups_ids(conn, user_id) - smt = sa.DDL( - f"""\ - SELECT uuid, access_rights - FROM projects - WHERE ( - prj_owner = {user_id} - OR jsonb_exists_any( access_rights, ( - SELECT ARRAY( SELECT gid::TEXT FROM user_to_groups WHERE uid = {user_id} ) + query = ( + sa.select( + projects.c.uuid, + access_rights_subquery.c.access_rights, + ) + .select_from(projects.join(access_rights_subquery, isouter=True)) + .where( + (projects.c.prj_owner == user_id) + | sa.text( + f"jsonb_exists_any(access_rights_subquery.access_rights, {assemble_array_groups(user_group_ids)})" ) ) ) - """ - ) + projects_access_rights = {} - async for row in conn.execute(smt): + async for row in conn.execute(query): assert isinstance(row.access_rights, dict) # nosec assert isinstance(row.uuid, str) # nosec @@ -160,25 +189,26 @@ async def get_project_access_rights( """ Returns access-rights of user (user_id) over a project resource (project_id) """ - user_group_ids: list[int] = await _get_user_groups_ids(conn, user_id) - - stmt = sa.DDL( - f"""\ - SELECT prj_owner, access_rights - FROM projects - WHERE ( - ( uuid = '{project_id}' ) AND ( - prj_owner = {user_id} - OR jsonb_exists_any( access_rights, ( - SELECT ARRAY( SELECT gid::TEXT FROM user_to_groups WHERE uid = {user_id} ) - ) + user_group_ids: list[GroupID] = await _get_user_groups_ids(conn, user_id) + + query = ( + sa.select( + projects.c.prj_owner, + access_rights_subquery.c.access_rights, + ) + .select_from(projects.join(access_rights_subquery, isouter=True)) + .where( + (projects.c.uuid == f"{project_id}") + & ( + (projects.c.prj_owner == user_id) + | sa.text( + f"jsonb_exists_any(access_rights_subquery.access_rights, {assemble_array_groups(user_group_ids)})" ) ) ) - """ ) - result: ResultProxy = await conn.execute(stmt) + result: ResultProxy = await conn.execute(query) row: RowProxy | None = await result.first() if not row: diff --git a/services/storage/src/simcore_service_storage/db_file_meta_data.py b/services/storage/src/simcore_service_storage/db_file_meta_data.py index 35512a82db7..21c519e5483 100644 --- a/services/storage/src/simcore_service_storage/db_file_meta_data.py +++ b/services/storage/src/simcore_service_storage/db_file_meta_data.py @@ -64,8 +64,7 @@ async def get(conn: SAConnection, file_id: SimcoreS3FileID) -> FileMetaDataAtDB: raise FileMetaDataNotFoundError(file_id=file_id) -async def list_filter_with_partial_file_id( - conn: SAConnection, +def _list_filter_with_partial_file_id_stmt( *, user_or_project_filter: UserOrProjectFilter, file_id_prefix: str | None, @@ -74,15 +73,21 @@ async def list_filter_with_partial_file_id( only_files: bool, limit: int | None = None, offset: int | None = None, -) -> list[FileMetaDataAtDB]: - conditions = [] - - # user_or_project_filter - if user_id := user_or_project_filter.user_id: - conditions.append(file_meta_data.c.user_id == f"{user_id}") - elif project_ids := user_or_project_filter.project_ids: - # Check if project_ids is not empty and add condition - conditions.append(file_meta_data.c.project_id.in_(f"{_}" for _ in project_ids)) +): + conditions: list = [] + + # Checks access rights (project can be owned or shared) + user_id = user_or_project_filter.user_id + if user_id is not None: + project_ids = user_or_project_filter.project_ids + conditions.append( + sa.or_( + file_meta_data.c.user_id == f"{user_id}", + file_meta_data.c.project_id.in_(f"{_}" for _ in project_ids) + if project_ids + else False, + ) + ) # Optional filters if file_id_prefix: @@ -94,19 +99,36 @@ async def list_filter_with_partial_file_id( if sha256_checksum: conditions.append(file_meta_data.c.sha256_checksum == sha256_checksum) - where_clause = sa.and_(*conditions) - - stmt = ( - sa.select(file_meta_data).where(where_clause) - # sorted as oldest first - .order_by(file_meta_data.c.created_at.asc()) + return ( + sa.select(file_meta_data) + .where(sa.and_(*conditions)) + .order_by(file_meta_data.c.created_at.asc()) # sorted as oldest first + .offset(offset) + .limit(limit) ) - # Apply limit and offset if specified - if limit is not None: - stmt = stmt.limit(limit) - if offset is not None: - stmt = stmt.offset(offset) + +async def list_filter_with_partial_file_id( + conn: SAConnection, + *, + user_or_project_filter: UserOrProjectFilter, + file_id_prefix: str | None, + partial_file_id: str | None, + sha256_checksum: SHA256Str | None, + only_files: bool, + limit: int | None = None, + offset: int | None = None, +) -> list[FileMetaDataAtDB]: + + stmt = _list_filter_with_partial_file_id_stmt( + user_or_project_filter=user_or_project_filter, + file_id_prefix=file_id_prefix, + partial_file_id=partial_file_id, + sha256_checksum=sha256_checksum, + only_files=only_files, + limit=limit, + offset=offset, + ) return [FileMetaDataAtDB.from_orm(row) async for row in await conn.execute(stmt)] diff --git a/services/storage/src/simcore_service_storage/models.py b/services/storage/src/simcore_service_storage/models.py index 3b84e693043..d05099edd06 100644 --- a/services/storage/src/simcore_service_storage/models.py +++ b/services/storage/src/simcore_service_storage/models.py @@ -259,7 +259,7 @@ def unquote(cls, v): class UserOrProjectFilter(NamedTuple): - user_id: UserID | None + user_id: UserID | None # = None disables filter project_ids: list[ProjectID] diff --git a/services/storage/tests/fixtures/data_models.py b/services/storage/tests/fixtures/data_models.py index 738cd20fbcc..ae5816a427f 100644 --- a/services/storage/tests/fixtures/data_models.py +++ b/services/storage/tests/fixtures/data_models.py @@ -21,7 +21,9 @@ from pydantic import ByteSize, parse_obj_as from pytest_simcore.helpers.faker_factories import random_project, random_user from servicelib.utils import limited_gather +from simcore_postgres_database.models.project_to_groups import project_to_groups from simcore_postgres_database.storage_models import projects, users +from sqlalchemy.dialects.postgresql import insert as pg_insert from ..helpers.utils import get_updated_project @@ -51,7 +53,13 @@ async def _user_context(aiopg_engine: Engine, *, name: str) -> AsyncIterator[Use @pytest.fixture async def user_id(aiopg_engine: Engine) -> AsyncIterator[UserID]: - async with _user_context(aiopg_engine, name="test") as new_user_id: + async with _user_context(aiopg_engine, name="test-user") as new_user_id: + yield new_user_id + + +@pytest.fixture +async def other_user_id(aiopg_engine: Engine) -> AsyncIterator[UserID]: + async with _user_context(aiopg_engine, name="test-other-user") as new_user_id: yield new_user_id @@ -83,6 +91,52 @@ async def _creator(**kwargs) -> dict[str, Any]: ) +@pytest.fixture +async def create_project_access_rights( + aiopg_engine: Engine, +) -> AsyncIterator[Callable[[ProjectID, UserID, bool, bool, bool], Awaitable[None]]]: + _created = [] + + async def _creator( + project_id: ProjectID, user_id: UserID, read: bool, write: bool, delete: bool + ) -> None: + async with aiopg_engine.acquire() as conn: + result = await conn.execute( + project_to_groups.insert() + .values( + project_uuid=f"{project_id}", + gid=sa.select(users.c.primary_gid) + .where(users.c.id == f"{user_id}") + .scalar_subquery(), + read=read, + write=write, + delete=delete, + ) + .returning(sa.literal_column("*")) + ) + row = await result.fetchone() + assert row + _created.append( + (row[project_to_groups.c.project_uuid], row[project_to_groups.c.gid]) + ) + + yield _creator + + # cleanup + async with aiopg_engine.acquire() as conn: + await conn.execute( + project_to_groups.delete().where( + sa.or_( + *( + (project_to_groups.c.project_uuid == pid) + & (project_to_groups.c.gid == gid) + for pid, gid in _created + ) + ) + ) + ) + + @pytest.fixture async def project_id( create_project: Callable[[], Awaitable[dict[str, Any]]] @@ -142,6 +196,31 @@ async def _() -> None: .values(access_rights=access_rights) ) + # project_to_groups needs to be updated + for group_id, permissions in access_rights.items(): + insert_stmt = pg_insert(project_to_groups).values( + project_uuid=f"{project_id}", + gid=int(group_id), + read=permissions["read"], + write=permissions["write"], + delete=permissions["delete"], + created=sa.func.now(), + modified=sa.func.now(), + ) + on_update_stmt = insert_stmt.on_conflict_do_update( + index_elements=[ + project_to_groups.c.project_uuid, + project_to_groups.c.gid, + ], + set_={ + "read": insert_stmt.excluded.read, + "write": insert_stmt.excluded.write, + "delete": insert_stmt.excluded.delete, + "modified": sa.func.now(), + }, + ) + await conn.execute(on_update_stmt) + return _ diff --git a/services/storage/tests/unit/test_access_layer.py b/services/storage/tests/unit/test_db_access_layer.py similarity index 74% rename from services/storage/tests/unit/test_access_layer.py rename to services/storage/tests/unit/test_db_access_layer.py index 3f2b22e396f..f357b2909fc 100644 --- a/services/storage/tests/unit/test_access_layer.py +++ b/services/storage/tests/unit/test_db_access_layer.py @@ -1,11 +1,9 @@ -# pylint:disable=unused-variable -# pylint:disable=unused-argument -# pylint:disable=redefined-outer-name +# pylint: disable=redefined-outer-name +# pylint: disable=unused-argument +# pylint: disable=unused-variable +# pylint: disable=too-many-arguments -from typing import Iterable - -import pytest from aiopg.sa.engine import Engine from models_library.projects import ProjectID from models_library.users import UserID @@ -18,13 +16,6 @@ pytest_simcore_core_services_selection = ["postgres"] -@pytest.fixture -async def filemeta_id( - user_id: UserID, project_id: ProjectID, aiopg_engine: Engine -) -> Iterable[str]: - raise NotImplementedError() - - async def test_access_rights_on_owned_project( user_id: UserID, project_id: ProjectID, aiopg_engine: Engine ): diff --git a/services/storage/tests/unit/test_db_file_meta_data.py b/services/storage/tests/unit/test_db_file_meta_data.py new file mode 100644 index 00000000000..da94c6a5eeb --- /dev/null +++ b/services/storage/tests/unit/test_db_file_meta_data.py @@ -0,0 +1,94 @@ +# pylint: disable=redefined-outer-name +# pylint: disable=unused-argument +# pylint: disable=unused-variable +# pylint: disable=too-many-arguments + + +from uuid import uuid4 + +from faker import Faker +from simcore_postgres_database.utils import as_postgres_sql_query_str +from simcore_service_storage.db_file_meta_data import ( + _list_filter_with_partial_file_id_stmt, +) +from simcore_service_storage.models import UserOrProjectFilter + + +def test_building_sql_statements(faker: Faker): + def _check(func_smt, **kwargs): + print() + print(f"{func_smt.__name__:*^100}") + stmt = func_smt(**kwargs) + print() + print(as_postgres_sql_query_str(stmt)) + print() + + _check( + _list_filter_with_partial_file_id_stmt, + user_or_project_filter=UserOrProjectFilter( + user_id=None, project_ids=[uuid4() for _ in range(2)] + ), + file_id_prefix=None, + partial_file_id=None, + sha256_checksum=None, + only_files=True, + ) + # WHERE file_meta_data.is_directory IS false ORDER BY file_meta_data.created_at ASC + + _check( + _list_filter_with_partial_file_id_stmt, + user_or_project_filter=UserOrProjectFilter(user_id=42, project_ids=[]), + file_id_prefix=None, + partial_file_id=None, + sha256_checksum=None, + only_files=True, + ) + # WHERE file_meta_data.user_id = '42' AND file_meta_data.is_directory IS false ORDER BY file_meta_data.created_at ASC + + _check( + _list_filter_with_partial_file_id_stmt, + user_or_project_filter=UserOrProjectFilter( + user_id=42, project_ids=[uuid4() for _ in range(2)] + ), + file_id_prefix=None, + partial_file_id=None, + sha256_checksum=None, + only_files=True, + ) + # WHERE (file_meta_data.user_id = '42' OR file_meta_data.project_id IN ('18d5'..., )) AND file_meta_data.is_directory IS false ORDER BY file_meta_data.created_at ASC + + _check( + _list_filter_with_partial_file_id_stmt, + user_or_project_filter=UserOrProjectFilter( + user_id=42, project_ids=[uuid4() for _ in range(2)] + ), + file_id_prefix=None, + partial_file_id=None, + sha256_checksum=None, + only_files=True, + limit=10, + offset=1, + ) + # (file_meta_data.user_id = '42' OR file_meta_data.project_id IN ('3cd9704db' ...)) AND file_meta_data.is_directory IS false ORDER BY file_meta_data.created_at ASC LIMIT 10 OFFSET 1 + + # As used in SimcoreS3DataManager.list_files + _check( + _list_filter_with_partial_file_id_stmt, + user_or_project_filter=UserOrProjectFilter(user_id=42, project_ids=[]), + file_id_prefix=None, + partial_file_id="{project_id}/", + sha256_checksum=None, + only_files=False, + ) + + # As used in SimcoreS3DataManager.search_owned_files + _check( + _list_filter_with_partial_file_id_stmt, + user_or_project_filter=UserOrProjectFilter(user_id=42, project_ids=[]), + file_id_prefix="api/", + partial_file_id=None, + sha256_checksum=faker.sha256(), + only_files=True, + limit=10, + offset=0, + ) diff --git a/services/storage/tests/unit/test_handlers_files_metadata.py b/services/storage/tests/unit/test_handlers_files_metadata.py index 55cf83dbcac..dd0c8138ebb 100644 --- a/services/storage/tests/unit/test_handlers_files_metadata.py +++ b/services/storage/tests/unit/test_handlers_files_metadata.py @@ -5,8 +5,10 @@ import urllib.parse from collections.abc import Awaitable, Callable +from copy import deepcopy from pathlib import Path from random import choice +from typing import Protocol import pytest from aiohttp.test_utils import TestClient @@ -22,10 +24,24 @@ pytest_simcore_ops_services_selection = ["adminer"] +class CreateProjectAccessRightsCallable(Protocol): + async def __call__( + self, + project_id: ProjectID, + user_id: UserID, + read: bool, + write: bool, + delete: bool, + ) -> None: + ... + + async def test_get_files_metadata( upload_file: Callable[[ByteSize, str], Awaitable[tuple[Path, SimcoreS3FileID]]], + create_project_access_rights: CreateProjectAccessRightsCallable, client: TestClient, user_id: UserID, + other_user_id: UserID, location_id: int, project_id: ProjectID, faker: Faker, @@ -48,29 +64,52 @@ async def test_get_files_metadata( # now add some stuff there NUM_FILES = 10 file_size = parse_obj_as(ByteSize, "15Mib") - files_owned_by_us = [] - for _ in range(NUM_FILES): - files_owned_by_us.append(await upload_file(file_size, faker.file_name())) + files_owned_by_us = [ + await upload_file(file_size, faker.file_name()) for _ in range(NUM_FILES) + ] + assert files_owned_by_us + # we should find these files now response = await client.get(f"{url}") data, error = await assert_status(response, status.HTTP_200_OK) assert not error list_fmds = parse_obj_as(list[FileMetaDataGet], data) assert len(list_fmds) == NUM_FILES + + # checks project_id filter! + await create_project_access_rights( + project_id=project_id, + user_id=other_user_id, + read=True, + write=True, + delete=True, + ) + response = await client.get( + f"{url.update_query(project_id=str(project_id), user_id=other_user_id)}" + ) + previous_data = deepcopy(data) + data, error = await assert_status(response, status.HTTP_200_OK) + assert not error + list_fmds = parse_obj_as(list[FileMetaDataGet], data) + assert len(list_fmds) == (NUM_FILES) + assert previous_data == data + # create some more files but with a base common name NUM_FILES = 10 file_size = parse_obj_as(ByteSize, "15Mib") - files_with_common_name = [] - for _ in range(NUM_FILES): - files_with_common_name.append( - await upload_file(file_size, f"common_name-{faker.file_name()}") - ) + files_with_common_name = [ + await upload_file(file_size, f"common_name-{faker.file_name()}") + for _ in range(NUM_FILES) + ] + assert files_with_common_name + # we should find these files now response = await client.get(f"{url}") data, error = await assert_status(response, status.HTTP_200_OK) assert not error list_fmds = parse_obj_as(list[FileMetaDataGet], data) assert len(list_fmds) == (2 * NUM_FILES) + # we can filter them now response = await client.get(f"{url.update_query(uuid_filter='common_name')}") data, error = await assert_status(response, status.HTTP_200_OK)