Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🐛storage: fixes query to file metadata #6184

Merged
82 changes: 56 additions & 26 deletions services/storage/src/simcore_service_storage/db_access_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@
from models_library.projects import ProjectID
from models_library.projects_nodes_io import StorageFileID
from models_library.users import GroupID, UserID
from simcore_postgres_database.models.project_to_groups import project_to_groups
from simcore_postgres_database.models.projects import projects
from simcore_postgres_database.storage_models import file_meta_data, user_to_groups

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -112,31 +114,58 @@ def _aggregate_access_rights(
return AccessRights.none()


def assemble_array_groups(user_group_ids: list[GroupID]) -> str:
return (
"array[]::text[]"
if len(user_group_ids) == 0
else f"""array[{', '.join(f"'{group_id}'" for group_id in user_group_ids)}]"""
)


access_rights_subquery = (
sanderegg marked this conversation as resolved.
Show resolved Hide resolved
sa.select(
project_to_groups.c.project_uuid,
sa.func.jsonb_object_agg(
project_to_groups.c.gid,
sa.func.jsonb_build_object(
"read",
project_to_groups.c.read,
"write",
project_to_groups.c.write,
"delete",
project_to_groups.c.delete,
),
).label("access_rights"),
).group_by(project_to_groups.c.project_uuid)
).subquery("access_rights_subquery")


async def list_projects_access_rights(
conn: SAConnection, user_id: UserID
) -> dict[ProjectID, AccessRights]:
"""
Returns access-rights of user (user_id) over all OWNED or SHARED projects
"""

user_group_ids: list[int] = await _get_user_groups_ids(conn, user_id)
user_group_ids: list[GroupID] = await _get_user_groups_ids(conn, user_id)

smt = sa.DDL(
f"""\
SELECT uuid, access_rights
FROM projects
WHERE (
prj_owner = {user_id}
OR jsonb_exists_any( access_rights, (
SELECT ARRAY( SELECT gid::TEXT FROM user_to_groups WHERE uid = {user_id} )
query = (
sa.select(
projects.c.uuid,
access_rights_subquery.c.access_rights,
)
.select_from(projects.join(access_rights_subquery, isouter=True))
.where(
(projects.c.prj_owner == user_id)
| sa.text(
f"jsonb_exists_any(access_rights_subquery.access_rights, {assemble_array_groups(user_group_ids)})"
)
)
)
"""
)

projects_access_rights = {}

async for row in conn.execute(smt):
async for row in conn.execute(query):
assert isinstance(row.access_rights, dict) # nosec
assert isinstance(row.uuid, str) # nosec

Expand All @@ -160,25 +189,26 @@ async def get_project_access_rights(
"""
Returns access-rights of user (user_id) over a project resource (project_id)
"""
user_group_ids: list[int] = await _get_user_groups_ids(conn, user_id)

stmt = sa.DDL(
f"""\
SELECT prj_owner, access_rights
FROM projects
WHERE (
( uuid = '{project_id}' ) AND (
prj_owner = {user_id}
OR jsonb_exists_any( access_rights, (
SELECT ARRAY( SELECT gid::TEXT FROM user_to_groups WHERE uid = {user_id} )
)
user_group_ids: list[GroupID] = await _get_user_groups_ids(conn, user_id)

query = (
sa.select(
projects.c.prj_owner,
access_rights_subquery.c.access_rights,
)
.select_from(projects.join(access_rights_subquery, isouter=True))
.where(
(projects.c.uuid == f"{project_id}")
& (
(projects.c.prj_owner == user_id)
| sa.text(
f"jsonb_exists_any(access_rights_subquery.access_rights, {assemble_array_groups(user_group_ids)})"
)
)
)
"""
)

result: ResultProxy = await conn.execute(stmt)
result: ResultProxy = await conn.execute(query)
row: RowProxy | None = await result.first()

if not row:
Expand Down
66 changes: 44 additions & 22 deletions services/storage/src/simcore_service_storage/db_file_meta_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,7 @@ async def get(conn: SAConnection, file_id: SimcoreS3FileID) -> FileMetaDataAtDB:
raise FileMetaDataNotFoundError(file_id=file_id)


async def list_filter_with_partial_file_id(
conn: SAConnection,
def _list_filter_with_partial_file_id_stmt(
*,
user_or_project_filter: UserOrProjectFilter,
file_id_prefix: str | None,
Expand All @@ -74,15 +73,21 @@ async def list_filter_with_partial_file_id(
only_files: bool,
limit: int | None = None,
offset: int | None = None,
) -> list[FileMetaDataAtDB]:
conditions = []

# user_or_project_filter
if user_id := user_or_project_filter.user_id:
conditions.append(file_meta_data.c.user_id == f"{user_id}")
elif project_ids := user_or_project_filter.project_ids:
# Check if project_ids is not empty and add condition
conditions.append(file_meta_data.c.project_id.in_(f"{_}" for _ in project_ids))
):
conditions: list = []

# Checks access rights (project can be owned or shared)
user_id = user_or_project_filter.user_id
if user_id is not None:
project_ids = user_or_project_filter.project_ids
conditions.append(
sa.or_(
file_meta_data.c.user_id == f"{user_id}",
file_meta_data.c.project_id.in_(f"{_}" for _ in project_ids)
if project_ids
else False,
)
)

# Optional filters
if file_id_prefix:
Expand All @@ -94,19 +99,36 @@ async def list_filter_with_partial_file_id(
if sha256_checksum:
conditions.append(file_meta_data.c.sha256_checksum == sha256_checksum)

where_clause = sa.and_(*conditions)

stmt = (
sa.select(file_meta_data).where(where_clause)
# sorted as oldest first
.order_by(file_meta_data.c.created_at.asc())
return (
sa.select(file_meta_data)
.where(sa.and_(*conditions))
.order_by(file_meta_data.c.created_at.asc()) # sorted as oldest first
.offset(offset)
.limit(limit)
)

# Apply limit and offset if specified
if limit is not None:
stmt = stmt.limit(limit)
if offset is not None:
stmt = stmt.offset(offset)

async def list_filter_with_partial_file_id(
conn: SAConnection,
*,
user_or_project_filter: UserOrProjectFilter,
file_id_prefix: str | None,
partial_file_id: str | None,
sha256_checksum: SHA256Str | None,
only_files: bool,
limit: int | None = None,
offset: int | None = None,
) -> list[FileMetaDataAtDB]:

stmt = _list_filter_with_partial_file_id_stmt(
user_or_project_filter=user_or_project_filter,
file_id_prefix=file_id_prefix,
partial_file_id=partial_file_id,
sha256_checksum=sha256_checksum,
only_files=only_files,
limit=limit,
offset=offset,
)

return [FileMetaDataAtDB.from_orm(row) async for row in await conn.execute(stmt)]

Expand Down
2 changes: 1 addition & 1 deletion services/storage/src/simcore_service_storage/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ def unquote(cls, v):


class UserOrProjectFilter(NamedTuple):
user_id: UserID | None
user_id: UserID | None # = None disables filter
pcrespov marked this conversation as resolved.
Show resolved Hide resolved
project_ids: list[ProjectID]


Expand Down
81 changes: 80 additions & 1 deletion services/storage/tests/fixtures/data_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@
from pydantic import ByteSize, parse_obj_as
from pytest_simcore.helpers.faker_factories import random_project, random_user
from servicelib.utils import limited_gather
from simcore_postgres_database.models.project_to_groups import project_to_groups
from simcore_postgres_database.storage_models import projects, users
from sqlalchemy.dialects.postgresql import insert as pg_insert

from ..helpers.utils import get_updated_project

Expand Down Expand Up @@ -51,7 +53,13 @@ async def _user_context(aiopg_engine: Engine, *, name: str) -> AsyncIterator[Use

@pytest.fixture
async def user_id(aiopg_engine: Engine) -> AsyncIterator[UserID]:
async with _user_context(aiopg_engine, name="test") as new_user_id:
async with _user_context(aiopg_engine, name="test-user") as new_user_id:
yield new_user_id


@pytest.fixture
async def other_user_id(aiopg_engine: Engine) -> AsyncIterator[UserID]:
async with _user_context(aiopg_engine, name="test-other-user") as new_user_id:
yield new_user_id


Expand Down Expand Up @@ -83,6 +91,52 @@ async def _creator(**kwargs) -> dict[str, Any]:
)


@pytest.fixture
async def create_project_access_rights(
aiopg_engine: Engine,
) -> AsyncIterator[Callable[[ProjectID, UserID, bool, bool, bool], Awaitable[None]]]:
_created = []

async def _creator(
project_id: ProjectID, user_id: UserID, read: bool, write: bool, delete: bool
) -> None:
async with aiopg_engine.acquire() as conn:
result = await conn.execute(
project_to_groups.insert()
.values(
project_uuid=f"{project_id}",
gid=sa.select(users.c.primary_gid)
.where(users.c.id == f"{user_id}")
.scalar_subquery(),
read=read,
write=write,
delete=delete,
)
.returning(sa.literal_column("*"))
)
row = await result.fetchone()
assert row
_created.append(
(row[project_to_groups.c.project_uuid], row[project_to_groups.c.gid])
)

yield _creator

# cleanup
async with aiopg_engine.acquire() as conn:
await conn.execute(
project_to_groups.delete().where(
sa.or_(
*(
(project_to_groups.c.project_uuid == pid)
& (project_to_groups.c.gid == gid)
for pid, gid in _created
)
)
)
)


@pytest.fixture
async def project_id(
create_project: Callable[[], Awaitable[dict[str, Any]]]
Expand Down Expand Up @@ -142,6 +196,31 @@ async def _() -> None:
.values(access_rights=access_rights)
)

# project_to_groups needs to be updated
for group_id, permissions in access_rights.items():
insert_stmt = pg_insert(project_to_groups).values(
project_uuid=f"{project_id}",
gid=int(group_id),
read=permissions["read"],
write=permissions["write"],
delete=permissions["delete"],
created=sa.func.now(),
modified=sa.func.now(),
)
on_update_stmt = insert_stmt.on_conflict_do_update(
index_elements=[
project_to_groups.c.project_uuid,
project_to_groups.c.gid,
],
set_={
"read": insert_stmt.excluded.read,
"write": insert_stmt.excluded.write,
"delete": insert_stmt.excluded.delete,
"modified": sa.func.now(),
},
)
await conn.execute(on_update_stmt)

return _


Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
# pylint:disable=unused-variable
# pylint:disable=unused-argument
# pylint:disable=redefined-outer-name
# pylint: disable=redefined-outer-name
# pylint: disable=unused-argument
# pylint: disable=unused-variable
# pylint: disable=too-many-arguments


from typing import Iterable

import pytest
from aiopg.sa.engine import Engine
from models_library.projects import ProjectID
from models_library.users import UserID
Expand All @@ -18,13 +16,6 @@
pytest_simcore_core_services_selection = ["postgres"]


@pytest.fixture
async def filemeta_id(
user_id: UserID, project_id: ProjectID, aiopg_engine: Engine
) -> Iterable[str]:
raise NotImplementedError()


async def test_access_rights_on_owned_project(
user_id: UserID, project_id: ProjectID, aiopg_engine: Engine
):
Expand Down
Loading
Loading