Skip to content

Commit

Permalink
🐛storage: fixes query to file metadata (#6184)
Browse files Browse the repository at this point in the history
Co-authored-by: matusdrobuliak66 <[email protected]>
  • Loading branch information
pcrespov and matusdrobuliak66 authored Aug 15, 2024
1 parent 5cd7d13 commit 4341733
Show file tree
Hide file tree
Showing 7 changed files with 326 additions and 71 deletions.
82 changes: 56 additions & 26 deletions services/storage/src/simcore_service_storage/db_access_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@
from models_library.projects import ProjectID
from models_library.projects_nodes_io import StorageFileID
from models_library.users import GroupID, UserID
from simcore_postgres_database.models.project_to_groups import project_to_groups
from simcore_postgres_database.models.projects import projects
from simcore_postgres_database.storage_models import file_meta_data, user_to_groups

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -112,31 +114,58 @@ def _aggregate_access_rights(
return AccessRights.none()


def assemble_array_groups(user_group_ids: list[GroupID]) -> str:
return (
"array[]::text[]"
if len(user_group_ids) == 0
else f"""array[{', '.join(f"'{group_id}'" for group_id in user_group_ids)}]"""
)


access_rights_subquery = (
sa.select(
project_to_groups.c.project_uuid,
sa.func.jsonb_object_agg(
project_to_groups.c.gid,
sa.func.jsonb_build_object(
"read",
project_to_groups.c.read,
"write",
project_to_groups.c.write,
"delete",
project_to_groups.c.delete,
),
).label("access_rights"),
).group_by(project_to_groups.c.project_uuid)
).subquery("access_rights_subquery")


async def list_projects_access_rights(
conn: SAConnection, user_id: UserID
) -> dict[ProjectID, AccessRights]:
"""
Returns access-rights of user (user_id) over all OWNED or SHARED projects
"""

user_group_ids: list[int] = await _get_user_groups_ids(conn, user_id)
user_group_ids: list[GroupID] = await _get_user_groups_ids(conn, user_id)

smt = sa.DDL(
f"""\
SELECT uuid, access_rights
FROM projects
WHERE (
prj_owner = {user_id}
OR jsonb_exists_any( access_rights, (
SELECT ARRAY( SELECT gid::TEXT FROM user_to_groups WHERE uid = {user_id} )
query = (
sa.select(
projects.c.uuid,
access_rights_subquery.c.access_rights,
)
.select_from(projects.join(access_rights_subquery, isouter=True))
.where(
(projects.c.prj_owner == user_id)
| sa.text(
f"jsonb_exists_any(access_rights_subquery.access_rights, {assemble_array_groups(user_group_ids)})"
)
)
)
"""
)

projects_access_rights = {}

async for row in conn.execute(smt):
async for row in conn.execute(query):
assert isinstance(row.access_rights, dict) # nosec
assert isinstance(row.uuid, str) # nosec

Expand All @@ -160,25 +189,26 @@ async def get_project_access_rights(
"""
Returns access-rights of user (user_id) over a project resource (project_id)
"""
user_group_ids: list[int] = await _get_user_groups_ids(conn, user_id)

stmt = sa.DDL(
f"""\
SELECT prj_owner, access_rights
FROM projects
WHERE (
( uuid = '{project_id}' ) AND (
prj_owner = {user_id}
OR jsonb_exists_any( access_rights, (
SELECT ARRAY( SELECT gid::TEXT FROM user_to_groups WHERE uid = {user_id} )
)
user_group_ids: list[GroupID] = await _get_user_groups_ids(conn, user_id)

query = (
sa.select(
projects.c.prj_owner,
access_rights_subquery.c.access_rights,
)
.select_from(projects.join(access_rights_subquery, isouter=True))
.where(
(projects.c.uuid == f"{project_id}")
& (
(projects.c.prj_owner == user_id)
| sa.text(
f"jsonb_exists_any(access_rights_subquery.access_rights, {assemble_array_groups(user_group_ids)})"
)
)
)
"""
)

result: ResultProxy = await conn.execute(stmt)
result: ResultProxy = await conn.execute(query)
row: RowProxy | None = await result.first()

if not row:
Expand Down
66 changes: 44 additions & 22 deletions services/storage/src/simcore_service_storage/db_file_meta_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,7 @@ async def get(conn: SAConnection, file_id: SimcoreS3FileID) -> FileMetaDataAtDB:
raise FileMetaDataNotFoundError(file_id=file_id)


async def list_filter_with_partial_file_id(
conn: SAConnection,
def _list_filter_with_partial_file_id_stmt(
*,
user_or_project_filter: UserOrProjectFilter,
file_id_prefix: str | None,
Expand All @@ -74,15 +73,21 @@ async def list_filter_with_partial_file_id(
only_files: bool,
limit: int | None = None,
offset: int | None = None,
) -> list[FileMetaDataAtDB]:
conditions = []

# user_or_project_filter
if user_id := user_or_project_filter.user_id:
conditions.append(file_meta_data.c.user_id == f"{user_id}")
elif project_ids := user_or_project_filter.project_ids:
# Check if project_ids is not empty and add condition
conditions.append(file_meta_data.c.project_id.in_(f"{_}" for _ in project_ids))
):
conditions: list = []

# Checks access rights (project can be owned or shared)
user_id = user_or_project_filter.user_id
if user_id is not None:
project_ids = user_or_project_filter.project_ids
conditions.append(
sa.or_(
file_meta_data.c.user_id == f"{user_id}",
file_meta_data.c.project_id.in_(f"{_}" for _ in project_ids)
if project_ids
else False,
)
)

# Optional filters
if file_id_prefix:
Expand All @@ -94,19 +99,36 @@ async def list_filter_with_partial_file_id(
if sha256_checksum:
conditions.append(file_meta_data.c.sha256_checksum == sha256_checksum)

where_clause = sa.and_(*conditions)

stmt = (
sa.select(file_meta_data).where(where_clause)
# sorted as oldest first
.order_by(file_meta_data.c.created_at.asc())
return (
sa.select(file_meta_data)
.where(sa.and_(*conditions))
.order_by(file_meta_data.c.created_at.asc()) # sorted as oldest first
.offset(offset)
.limit(limit)
)

# Apply limit and offset if specified
if limit is not None:
stmt = stmt.limit(limit)
if offset is not None:
stmt = stmt.offset(offset)

async def list_filter_with_partial_file_id(
conn: SAConnection,
*,
user_or_project_filter: UserOrProjectFilter,
file_id_prefix: str | None,
partial_file_id: str | None,
sha256_checksum: SHA256Str | None,
only_files: bool,
limit: int | None = None,
offset: int | None = None,
) -> list[FileMetaDataAtDB]:

stmt = _list_filter_with_partial_file_id_stmt(
user_or_project_filter=user_or_project_filter,
file_id_prefix=file_id_prefix,
partial_file_id=partial_file_id,
sha256_checksum=sha256_checksum,
only_files=only_files,
limit=limit,
offset=offset,
)

return [FileMetaDataAtDB.from_orm(row) async for row in await conn.execute(stmt)]

Expand Down
2 changes: 1 addition & 1 deletion services/storage/src/simcore_service_storage/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ def unquote(cls, v):


class UserOrProjectFilter(NamedTuple):
user_id: UserID | None
user_id: UserID | None # = None disables filter
project_ids: list[ProjectID]


Expand Down
81 changes: 80 additions & 1 deletion services/storage/tests/fixtures/data_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@
from pydantic import ByteSize, parse_obj_as
from pytest_simcore.helpers.faker_factories import random_project, random_user
from servicelib.utils import limited_gather
from simcore_postgres_database.models.project_to_groups import project_to_groups
from simcore_postgres_database.storage_models import projects, users
from sqlalchemy.dialects.postgresql import insert as pg_insert

from ..helpers.utils import get_updated_project

Expand Down Expand Up @@ -51,7 +53,13 @@ async def _user_context(aiopg_engine: Engine, *, name: str) -> AsyncIterator[Use

@pytest.fixture
async def user_id(aiopg_engine: Engine) -> AsyncIterator[UserID]:
async with _user_context(aiopg_engine, name="test") as new_user_id:
async with _user_context(aiopg_engine, name="test-user") as new_user_id:
yield new_user_id


@pytest.fixture
async def other_user_id(aiopg_engine: Engine) -> AsyncIterator[UserID]:
async with _user_context(aiopg_engine, name="test-other-user") as new_user_id:
yield new_user_id


Expand Down Expand Up @@ -83,6 +91,52 @@ async def _creator(**kwargs) -> dict[str, Any]:
)


@pytest.fixture
async def create_project_access_rights(
aiopg_engine: Engine,
) -> AsyncIterator[Callable[[ProjectID, UserID, bool, bool, bool], Awaitable[None]]]:
_created = []

async def _creator(
project_id: ProjectID, user_id: UserID, read: bool, write: bool, delete: bool
) -> None:
async with aiopg_engine.acquire() as conn:
result = await conn.execute(
project_to_groups.insert()
.values(
project_uuid=f"{project_id}",
gid=sa.select(users.c.primary_gid)
.where(users.c.id == f"{user_id}")
.scalar_subquery(),
read=read,
write=write,
delete=delete,
)
.returning(sa.literal_column("*"))
)
row = await result.fetchone()
assert row
_created.append(
(row[project_to_groups.c.project_uuid], row[project_to_groups.c.gid])
)

yield _creator

# cleanup
async with aiopg_engine.acquire() as conn:
await conn.execute(
project_to_groups.delete().where(
sa.or_(
*(
(project_to_groups.c.project_uuid == pid)
& (project_to_groups.c.gid == gid)
for pid, gid in _created
)
)
)
)


@pytest.fixture
async def project_id(
create_project: Callable[[], Awaitable[dict[str, Any]]]
Expand Down Expand Up @@ -142,6 +196,31 @@ async def _() -> None:
.values(access_rights=access_rights)
)

# project_to_groups needs to be updated
for group_id, permissions in access_rights.items():
insert_stmt = pg_insert(project_to_groups).values(
project_uuid=f"{project_id}",
gid=int(group_id),
read=permissions["read"],
write=permissions["write"],
delete=permissions["delete"],
created=sa.func.now(),
modified=sa.func.now(),
)
on_update_stmt = insert_stmt.on_conflict_do_update(
index_elements=[
project_to_groups.c.project_uuid,
project_to_groups.c.gid,
],
set_={
"read": insert_stmt.excluded.read,
"write": insert_stmt.excluded.write,
"delete": insert_stmt.excluded.delete,
"modified": sa.func.now(),
},
)
await conn.execute(on_update_stmt)

return _


Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
# pylint:disable=unused-variable
# pylint:disable=unused-argument
# pylint:disable=redefined-outer-name
# pylint: disable=redefined-outer-name
# pylint: disable=unused-argument
# pylint: disable=unused-variable
# pylint: disable=too-many-arguments


from typing import Iterable

import pytest
from aiopg.sa.engine import Engine
from models_library.projects import ProjectID
from models_library.users import UserID
Expand All @@ -18,13 +16,6 @@
pytest_simcore_core_services_selection = ["postgres"]


@pytest.fixture
async def filemeta_id(
user_id: UserID, project_id: ProjectID, aiopg_engine: Engine
) -> Iterable[str]:
raise NotImplementedError()


async def test_access_rights_on_owned_project(
user_id: UserID, project_id: ProjectID, aiopg_engine: Engine
):
Expand Down
Loading

0 comments on commit 4341733

Please sign in to comment.