diff --git a/changes/3090.feature.md b/changes/3090.feature.md new file mode 100644 index 0000000000..e725ca5ff6 --- /dev/null +++ b/changes/3090.feature.md @@ -0,0 +1 @@ +Implement CRUD API for managing Harbor per-project Quota. diff --git a/src/ai/backend/client/func/group.py b/src/ai/backend/client/func/group.py index 9590ee948d..c8fa49254c 100644 --- a/src/ai/backend/client/func/group.py +++ b/src/ai/backend/client/func/group.py @@ -3,6 +3,7 @@ from ai.backend.client.output.fields import group_fields from ai.backend.client.output.types import FieldSpec +from ai.backend.common.utils import b64encode from ...cli.types import Undefined, undefined from ..session import api_session @@ -311,3 +312,101 @@ async def remove_users( } data = await api_session.get().Admin._query(query, variables) return data["modify_group"] + + @api_function + @classmethod + async def get_container_registry_quota(cls, group_id: str) -> int: + """ + Get Quota Limit for the group's container registry. + Currently only HarborV2 registry is supported. + + You need an admin privilege for this operation. + """ + query = textwrap.dedent( + """\ + query($id: String!) { + group_node(id: $id) { + registry_quota + } + } + """ + ) + + variables = {"id": b64encode(f"group_node:{group_id}")} + data = await api_session.get().Admin._query(query, variables) + return data["group_node"]["registry_quota"] + + @api_function + @classmethod + async def create_container_registry_quota(cls, group_id: str, quota: int) -> dict: + """ + Create Quota Limit for the group's container registry. + Currently only HarborV2 registry is supported. + + You need an admin privilege for this operation. + """ + query = textwrap.dedent( + """\ + mutation($scope_id: ScopeField!, $quota: Int!) { + create_container_registry_quota( + scope_id: $scope_id, quota: $quota) { + ok msg + } + } + """ + ) + + scope_id = f"project:{group_id}" + variables = {"scope_id": scope_id, "quota": quota} + data = await api_session.get().Admin._query(query, variables) + return data["create_container_registry_quota"] + + @api_function + @classmethod + async def update_container_registry_quota(cls, group_id: str, quota: int) -> dict: + """ + Update Quota Limit for the group's container registry. + Currently only HarborV2 registry is supported. + + You need an admin privilege for this operation. + """ + query = textwrap.dedent( + """\ + mutation($scope_id: ScopeField!, $quota: Int!) { + update_container_registry_quota( + scope_id: $scope_id, quota: $quota) { + ok msg + } + } + """ + ) + + scope_id = f"project:{group_id}" + variables = {"scope_id": scope_id, "quota": quota} + data = await api_session.get().Admin._query(query, variables) + return data["update_container_registry_quota"] + + @api_function + @classmethod + async def delete_container_registry_quota(cls, group_id: str) -> dict: + """ + Delete Quota Limit for the group's container registry. + Currently only HarborV2 registry is supported. + + You need an admin privilege for this operation. + """ + query = textwrap.dedent( + """\ + mutation($scope_id: ScopeField!) { + delete_container_registry_quota( + scope_id: $scope_id) { + ok msg + } + } + """ + ) + + scope_id = f"project:{group_id}" + variables = {"scope_id": scope_id} + data = await api_session.get().Admin._query(query, variables) + return data["delete_container_registry_quota"] diff --git a/src/ai/backend/common/utils.py b/src/ai/backend/common/utils.py index 5d98b54d33..c669f58814 100644 --- a/src/ai/backend/common/utils.py +++ b/src/ai/backend/common/utils.py @@ -425,3 +425,12 @@ def join_non_empty(*args, sep): """ filtered_args = [arg for arg in args if arg] return sep.join(filtered_args) + + +def b64encode(s: str) -> str: + """ + base64 encoding method of graphql_relay. + Use it in components where the graphql_relay package is unavailable. + """ + b: bytes = s.encode("utf-8") if isinstance(s, str) else s + return base64.b64encode(b).decode("ascii") diff --git a/src/ai/backend/manager/api/exceptions.py b/src/ai/backend/manager/api/exceptions.py index 23f381d30b..8db4ee7db9 100644 --- a/src/ai/backend/manager/api/exceptions.py +++ b/src/ai/backend/manager/api/exceptions.py @@ -242,6 +242,10 @@ class EndpointTokenNotFound(ObjectNotFound): object_name = "endpoint_token" +class ContainerRegistryNotFound(ObjectNotFound): + object_name = "container_registry" + + class TooManySessionsMatched(BackendError, web.HTTPNotFound): error_type = "https://api.backend.ai/probs/too-many-sessions-matched" error_title = "Too many sessions matched." diff --git a/src/ai/backend/manager/api/group.py b/src/ai/backend/manager/api/group.py new file mode 100644 index 0000000000..5434d6a88d --- /dev/null +++ b/src/ai/backend/manager/api/group.py @@ -0,0 +1,123 @@ +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, Any, Iterable, Tuple + +import aiohttp_cors +import trafaret as t +from aiohttp import web + +from ai.backend.common import validators as tx +from ai.backend.logging import BraceStyleAdapter +from ai.backend.manager.models.gql_models.container_registry_utils import ( + HarborQuotaManager, +) +from ai.backend.manager.models.rbac import ProjectScope + +if TYPE_CHECKING: + from .context import RootContext + +from .auth import superadmin_required +from .manager import READ_ALLOWED, server_status_required +from .types import CORSOptions, WebMiddleware +from .utils import check_api_params + +log = BraceStyleAdapter(logging.getLogger(__spec__.name)) + + +@server_status_required(READ_ALLOWED) +@superadmin_required +@check_api_params( + t.Dict({ + tx.AliasedKey(["group_id", "group"]): t.String, + tx.AliasedKey(["quota"]): t.Int, + }) +) +async def update_registry_quota(request: web.Request, params: Any) -> web.Response: + log.info("UPDATE_REGISTRY_QUOTA (gr:{})", params["group_id"]) + root_ctx: RootContext = request.app["_root.context"] + group_id = params["group_id"] + scope_id = ProjectScope(project_id=group_id, domain_name=None) + quota = int(params["quota"]) + + async with root_ctx.db.begin_session() as db_sess: + manager = await HarborQuotaManager.new(db_sess, scope_id) + await manager.update(quota) + + return web.json_response({}) + + +@server_status_required(READ_ALLOWED) +@superadmin_required +@check_api_params( + t.Dict({ + tx.AliasedKey(["group_id", "group"]): t.String, + }) +) +async def delete_registry_quota(request: web.Request, params: Any) -> web.Response: + log.info("DELETE_REGISTRY_QUOTA (gr:{})", params["group_id"]) + root_ctx: RootContext = request.app["_root.context"] + group_id = params["group_id"] + scope_id = ProjectScope(project_id=group_id, domain_name=None) + + async with root_ctx.db.begin_session() as db_sess: + manager = await HarborQuotaManager.new(db_sess, scope_id) + await manager.delete() + + return web.json_response({}) + + +@server_status_required(READ_ALLOWED) +@superadmin_required +@check_api_params( + t.Dict({ + tx.AliasedKey(["group_id", "group"]): t.String, + tx.AliasedKey(["quota"]): t.Int, + }) +) +async def create_registry_quota(request: web.Request, params: Any) -> web.Response: + log.info("CREATE_REGISTRY_QUOTA (gr:{})", params["group_id"]) + root_ctx: RootContext = request.app["_root.context"] + group_id = params["group_id"] + scope_id = ProjectScope(project_id=group_id, domain_name=None) + quota = int(params["quota"]) + + async with root_ctx.db.begin_session() as db_sess: + manager = await HarborQuotaManager.new(db_sess, scope_id) + await manager.create(quota) + + return web.json_response({}) + + +@server_status_required(READ_ALLOWED) +@superadmin_required +@check_api_params( + t.Dict({ + tx.AliasedKey(["group_id", "group"]): t.String, + }) +) +async def read_registry_quota(request: web.Request, params: Any) -> web.Response: + log.info("READ_REGISTRY_QUOTA (gr:{})", params["group_id"]) + root_ctx: RootContext = request.app["_root.context"] + group_id = params["group_id"] + scope_id = ProjectScope(project_id=group_id, domain_name=None) + + async with root_ctx.db.begin_session() as db_sess: + manager = await HarborQuotaManager.new(db_sess, scope_id) + quota = await manager.read() + + return web.json_response({"result": quota}) + + +def create_app( + default_cors_options: CORSOptions, +) -> Tuple[web.Application, Iterable[WebMiddleware]]: + app = web.Application() + app["api_versions"] = (1, 2, 3, 4, 5) + app["prefix"] = "group" + cors = aiohttp_cors.setup(app, defaults=default_cors_options) + cors.add(app.router.add_route("POST", "/registry-quota", create_registry_quota)) + cors.add(app.router.add_route("GET", "/registry-quota", read_registry_quota)) + cors.add(app.router.add_route("PATCH", "/registry-quota", update_registry_quota)) + cors.add(app.router.add_route("DELETE", "/registry-quota", delete_registry_quota)) + return app, [] diff --git a/src/ai/backend/manager/api/schema.graphql b/src/ai/backend/manager/api/schema.graphql index 6caa04e4ea..dcfebc74c6 100644 --- a/src/ai/backend/manager/api/schema.graphql +++ b/src/ai/backend/manager/api/schema.graphql @@ -698,6 +698,9 @@ type GroupNode implements Node { """Added in 24.03.7.""" container_registry: JSONString scaling_groups: [String] + + """Added in 24.12.0.""" + registry_quota: BigInt user_nodes(filter: String, order: String, offset: Int, before: String, after: String, first: Int, last: Int): UserConnection } @@ -1844,6 +1847,15 @@ type Mutations { """Added in 24.12.0""" disassociate_container_registry_with_group(group_id: String!, registry_id: String!): DisassociateContainerRegistryWithGroup + + """Added in 24.12.0""" + create_container_registry_quota(quota: BigInt!, scope_id: ScopeField!): CreateContainerRegistryQuota + + """Added in 24.12.0""" + update_container_registry_quota(quota: BigInt!, scope_id: ScopeField!): UpdateContainerRegistryQuota + + """Added in 24.12.0""" + delete_container_registry_quota(scope_id: ScopeField!): DeleteContainerRegistryQuota create_container_registry(hostname: String!, props: CreateContainerRegistryInput!): CreateContainerRegistry modify_container_registry(hostname: String!, props: ModifyContainerRegistryInput!): ModifyContainerRegistry delete_container_registry(hostname: String!): DeleteContainerRegistry @@ -2570,6 +2582,24 @@ type DisassociateContainerRegistryWithGroup { msg: String } +"""Added in 24.12.0.""" +type CreateContainerRegistryQuota { + ok: Boolean + msg: String +} + +"""Added in 24.12.0.""" +type UpdateContainerRegistryQuota { + ok: Boolean + msg: String +} + +"""Added in 24.12.0.""" +type DeleteContainerRegistryQuota { + ok: Boolean + msg: String +} + type CreateContainerRegistry { container_registry: ContainerRegistry } diff --git a/src/ai/backend/manager/models/gql.py b/src/ai/backend/manager/models/gql.py index 078c169577..9facab1268 100644 --- a/src/ai/backend/manager/models/gql.py +++ b/src/ai/backend/manager/models/gql.py @@ -74,7 +74,10 @@ ) from .gql_models.container_registry import ( AssociateContainerRegistryWithGroup, + CreateContainerRegistryQuota, + DeleteContainerRegistryQuota, DisassociateContainerRegistryWithGroup, + UpdateContainerRegistryQuota, ) from .gql_models.domain import ( CreateDomainNode, @@ -349,6 +352,16 @@ class Mutations(graphene.ObjectType): description="Added in 24.12.0" ) + create_container_registry_quota = CreateContainerRegistryQuota.Field( + description="Added in 24.12.0" + ) + update_container_registry_quota = UpdateContainerRegistryQuota.Field( + description="Added in 24.12.0" + ) + delete_container_registry_quota = DeleteContainerRegistryQuota.Field( + description="Added in 24.12.0" + ) + # Legacy mutations create_container_registry = CreateContainerRegistry.Field() modify_container_registry = ModifyContainerRegistry.Field() diff --git a/src/ai/backend/manager/models/gql_models/container_registry.py b/src/ai/backend/manager/models/gql_models/container_registry.py index ae44bef0e2..72010c565f 100644 --- a/src/ai/backend/manager/models/gql_models/container_registry.py +++ b/src/ai/backend/manager/models/gql_models/container_registry.py @@ -11,8 +11,11 @@ from ..association_container_registries_groups import ( AssociationContainerRegistriesGroupsRow, ) -from ..base import simple_db_mutate +from ..base import BigInt, simple_db_mutate +from ..rbac import ScopeType from ..user import UserRole +from .container_registry_utils import HarborQuotaManager +from .fields import ScopeField log = BraceStyleAdapter(logging.getLogger(__spec__.name)) # type: ignore @@ -70,3 +73,97 @@ async def mutate( .where(AssociationContainerRegistriesGroupsRow.group_id == group_id) ) return await simple_db_mutate(cls, info.context, delete_query) + + +class CreateContainerRegistryQuota(graphene.Mutation): + """Added in 24.12.0.""" + + allowed_roles = ( + UserRole.SUPERADMIN, + UserRole.ADMIN, + ) + + class Arguments: + scope_id = ScopeField(required=True) + quota = BigInt(required=True) + + ok = graphene.Boolean() + msg = graphene.String() + + @classmethod + async def mutate( + cls, + root, + info: graphene.ResolveInfo, + scope_id: ScopeType, + quota: int | float, + ) -> Self: + async with info.context.db.begin_session() as db_sess: + try: + manager = await HarborQuotaManager.new(db_sess, scope_id) + await manager.create(int(quota)) + return cls(ok=True, msg="success") + except Exception as e: + return cls(ok=False, msg=str(e)) + + +class UpdateContainerRegistryQuota(graphene.Mutation): + """Added in 24.12.0.""" + + allowed_roles = ( + UserRole.SUPERADMIN, + UserRole.ADMIN, + ) + + class Arguments: + scope_id = ScopeField(required=True) + quota = BigInt(required=True) + + ok = graphene.Boolean() + msg = graphene.String() + + @classmethod + async def mutate( + cls, + root, + info: graphene.ResolveInfo, + scope_id: ScopeType, + quota: int | float, + ) -> Self: + async with info.context.db.begin_session() as db_sess: + try: + manager = await HarborQuotaManager.new(db_sess, scope_id) + await manager.update(int(quota)) + return cls(ok=True, msg="success") + except Exception as e: + return cls(ok=False, msg=str(e)) + + +class DeleteContainerRegistryQuota(graphene.Mutation): + """Added in 24.12.0.""" + + allowed_roles = ( + UserRole.SUPERADMIN, + UserRole.ADMIN, + ) + + class Arguments: + scope_id = ScopeField(required=True) + + ok = graphene.Boolean() + msg = graphene.String() + + @classmethod + async def mutate( + cls, + root, + info: graphene.ResolveInfo, + scope_id: ScopeType, + ) -> Self: + async with info.context.db.begin_session() as db_sess: + try: + manager = await HarborQuotaManager.new(db_sess, scope_id) + await manager.delete() + return cls(ok=True, msg="success") + except Exception as e: + return cls(ok=False, msg=str(e)) diff --git a/src/ai/backend/manager/models/gql_models/container_registry_utils.py b/src/ai/backend/manager/models/gql_models/container_registry_utils.py new file mode 100644 index 0000000000..af0e9546a6 --- /dev/null +++ b/src/ai/backend/manager/models/gql_models/container_registry_utils.py @@ -0,0 +1,251 @@ +from __future__ import annotations + +import logging +import uuid +from typing import Any, TypedDict + +import aiohttp +import aiohttp.client_exceptions +import sqlalchemy as sa +import yarl +from sqlalchemy.ext.asyncio import AsyncSession as SASession +from sqlalchemy.orm import load_only + +from ai.backend.common.types import aobject +from ai.backend.logging import BraceStyleAdapter +from ai.backend.manager.api.exceptions import ( + ContainerRegistryNotFound, + GenericBadRequest, + InternalServerError, + NotImplementedAPI, + ObjectNotFound, +) + +from ...container_registry import ContainerRegistryRow +from ..association_container_registries_groups import ( + AssociationContainerRegistriesGroupsRow, +) +from ..group import GroupRow +from ..rbac import ProjectScope, ScopeType + +log = BraceStyleAdapter(logging.getLogger(__spec__.name)) # type: ignore + + +class HarborQuotaInfo(TypedDict): + previous_quota: int + quota_id: int + + +class HarborQuotaManager(aobject): + """ + Utility class for HarborV2 per-project Quota CRUD API. + """ + + db_sess: SASession + scope_id: ScopeType + group_row: GroupRow + registry: ContainerRegistryRow + project: str + project_id: uuid.UUID + + def __init__(self, db_sess: SASession, scope_id: ScopeType): + if not isinstance(scope_id, ProjectScope): + raise NotImplementedAPI("Quota mutation currently supports only the project scope.") + + self.db_sess = db_sess + self.scope_id = scope_id + + async def __ainit__(self) -> None: + assert isinstance(self.scope_id, ProjectScope) + + project_id = self.scope_id.project_id + group_query = ( + sa.select(GroupRow) + .where(GroupRow.id == project_id) + .options(load_only(GroupRow.container_registry)) + ) + result = await self.db_sess.execute(group_query) + group_row = result.scalar_one_or_none() + + if not HarborQuotaManager._is_valid_group_row(group_row): + raise ContainerRegistryNotFound( + f"Container registry info does not exist or is invalid in the group. (gr: {project_id})" + ) + + registry_name, project = ( + group_row.container_registry["registry"], + group_row.container_registry["project"], + ) + + registry_query = sa.select(ContainerRegistryRow).where( + (ContainerRegistryRow.registry_name == registry_name) + & (ContainerRegistryRow.project == project) + ) + + result = await self.db_sess.execute(registry_query) + registry = result.scalars().one_or_none() + + if not registry: + raise ContainerRegistryNotFound( + f"Specified container registry row not found. (cr: {registry_name}, gr: {project})" + ) + + self.group_row = group_row + self.registry = registry + self.project = project + self.project_id = project_id + + @classmethod + def _is_valid_group_row(cls, group_row: GroupRow) -> bool: + return ( + group_row + and group_row.container_registry + and "registry" in group_row.container_registry + and "project" in group_row.container_registry + ) + + async def _get_harbor_project_id( + self, sess: aiohttp.ClientSession, rqst_args: dict[str, Any] + ) -> str: + get_project_id_api = ( + yarl.URL(self.registry.url) / "api" / "v2.0" / "projects" / self.project + ) + + async with sess.get(get_project_id_api, allow_redirects=False, **rqst_args) as resp: + if resp.status != 200: + raise InternalServerError(f"Failed to get harbor project_id! response: {resp}") + + res = await resp.json() + harbor_project_id = res["project_id"] + return harbor_project_id + + async def _get_quota_info( + self, sess: aiohttp.ClientSession, rqst_args: dict[str, Any] + ) -> HarborQuotaInfo: + harbor_project_id = await self._get_harbor_project_id(sess, rqst_args) + get_quota_id_api = (yarl.URL(self.registry.url) / "api" / "v2.0" / "quotas").with_query({ + "reference": "project", + "reference_id": harbor_project_id, + }) + + async with sess.get(get_quota_id_api, allow_redirects=False, **rqst_args) as resp: + if resp.status != 200: + raise InternalServerError(f"Failed to get quota info! response: {resp}") + + res = await resp.json() + if not res: + raise ObjectNotFound(object_name="quota entity") + if len(res) > 1: + raise InternalServerError( + f"Multiple quota entities found. (project_id: {harbor_project_id})" + ) + + previous_quota = res[0]["hard"]["storage"] + quota_id = res[0]["id"] + + return HarborQuotaInfo(previous_quota=previous_quota, quota_id=quota_id) + + async def read(self) -> int: + if not self.registry.is_global: + get_assoc_query = sa.select( + sa.exists() + .where(AssociationContainerRegistriesGroupsRow.registry_id == self.registry.id) + .where(AssociationContainerRegistriesGroupsRow.group_id == self.group_row.row_id) + ) + assoc_exist = (await self.db_sess.execute(get_assoc_query)).scalar() + + if not assoc_exist: + raise ValueError("The group is not associated with the container registry.") + + ssl_verify = self.registry.ssl_verify + connector = aiohttp.TCPConnector(ssl=ssl_verify) + async with aiohttp.ClientSession(connector=connector) as sess: + rqst_args: dict[str, Any] = {} + rqst_args["auth"] = aiohttp.BasicAuth( + self.registry.username, + self.registry.password, + ) + + previous_quota = (await self._get_quota_info(sess, rqst_args))["previous_quota"] + if previous_quota == -1: + raise ObjectNotFound(object_name="quota entity") + + return previous_quota + + async def create(self, quota: int) -> None: + ssl_verify = self.registry.ssl_verify + connector = aiohttp.TCPConnector(ssl=ssl_verify) + async with aiohttp.ClientSession(connector=connector) as sess: + rqst_args: dict[str, Any] = {} + rqst_args["auth"] = aiohttp.BasicAuth( + self.registry.username, + self.registry.password, + ) + + quota_info = await self._get_quota_info(sess, rqst_args) + previous_quota, quota_id = quota_info["previous_quota"], quota_info["quota_id"] + + if previous_quota > 0: + raise GenericBadRequest(f"Quota limit already exists. (gr: {self.project_id})") + + put_quota_api = yarl.URL(self.registry.url) / "api" / "v2.0" / "quotas" / str(quota_id) + payload = {"hard": {"storage": quota}} + + async with sess.put( + put_quota_api, json=payload, allow_redirects=False, **rqst_args + ) as resp: + if resp.status != 200: + log.error(f"Failed to create quota! response: {resp}") + raise InternalServerError(f"Failed to create quota! response: {resp}") + + async def update(self, quota: int) -> None: + ssl_verify = self.registry.ssl_verify + connector = aiohttp.TCPConnector(ssl=ssl_verify) + async with aiohttp.ClientSession(connector=connector) as sess: + rqst_args: dict[str, Any] = {} + rqst_args["auth"] = aiohttp.BasicAuth( + self.registry.username, + self.registry.password, + ) + + quota_info = await self._get_quota_info(sess, rqst_args) + previous_quota, quota_id = quota_info["previous_quota"], quota_info["quota_id"] + + if previous_quota == -1: + raise ObjectNotFound(object_name="quota entity") + + put_quota_api = yarl.URL(self.registry.url) / "api" / "v2.0" / "quotas" / str(quota_id) + payload = {"hard": {"storage": quota}} + + async with sess.put( + put_quota_api, json=payload, allow_redirects=False, **rqst_args + ) as resp: + if resp.status != 200: + log.error(f"Failed to update quota! response: {resp}") + raise InternalServerError(f"Failed to update quota! response: {resp}") + + async def delete(self) -> None: + ssl_verify = self.registry.ssl_verify + connector = aiohttp.TCPConnector(ssl=ssl_verify) + async with aiohttp.ClientSession(connector=connector) as sess: + rqst_args: dict[str, Any] = {} + rqst_args["auth"] = aiohttp.BasicAuth( + self.registry.username, + self.registry.password, + ) + + quota_info = await self._get_quota_info(sess, rqst_args) + previous_quota, quota_id = quota_info["previous_quota"], quota_info["quota_id"] + + if previous_quota == -1: + raise ObjectNotFound(object_name="quota entity") + + put_quota_api = yarl.URL(self.registry.url) / "api" / "v2.0" / "quotas" / str(quota_id) + payload = {"hard": {"storage": -1}} # setting quota to -1 means delete + + async with sess.put( + put_quota_api, json=payload, allow_redirects=False, **rqst_args + ) as resp: + if resp.status != 200: + log.error(f"Failed to delete quota! response: {resp}") + raise InternalServerError(f"Failed to delete quota! response: {resp}") diff --git a/src/ai/backend/manager/models/gql_models/group.py b/src/ai/backend/manager/models/gql_models/group.py index 4bd98da161..afc58eb4ac 100644 --- a/src/ai/backend/manager/models/gql_models/group.py +++ b/src/ai/backend/manager/models/gql_models/group.py @@ -13,6 +13,7 @@ from graphene.types.datetime import DateTime as GQLDateTime from ..base import ( + BigInt, FilterExprArg, OrderExprArg, PaginatedConnectionField, @@ -26,6 +27,10 @@ from ..group import AssocGroupUserRow, GroupRow, ProjectType from ..minilang.ordering import OrderSpecItem, QueryOrderParser from ..minilang.queryfilter import FieldSpecItem, QueryFilterParser +from ..rbac import ProjectScope +from .container_registry_utils import ( + HarborQuotaManager, +) from .user import UserConnection, UserNode if TYPE_CHECKING: @@ -112,6 +117,8 @@ class Meta: lambda: graphene.String, ) + registry_quota = BigInt(description="Added in 24.12.0.") + user_nodes = PaginatedConnectionField( UserConnection, ) @@ -204,6 +211,13 @@ async def resolve_user_nodes( total_cnt = await db_session.scalar(cnt_query) return ConnectionResolverResult(result, cursor, pagination_order, page_size, total_cnt) + async def resolve_registry_quota(self, info: graphene.ResolveInfo) -> int: + graph_ctx = info.context + async with graph_ctx.db.begin_session() as db_sess: + scope_id = ProjectScope(project_id=self.id, domain_name=None) + manager = await HarborQuotaManager.new(db_sess, scope_id) + return await manager.read() + @classmethod async def get_node(cls, info: graphene.ResolveInfo, id) -> Self: graph_ctx: GraphQueryContext = info.context diff --git a/src/ai/backend/manager/server.py b/src/ai/backend/manager/server.py index fc85bc9c95..635b7831dd 100644 --- a/src/ai/backend/manager/server.py +++ b/src/ai/backend/manager/server.py @@ -187,6 +187,7 @@ ".image", ".userconfig", ".domainconfig", + ".group", ".groupconfig", ".logs", ] diff --git a/tests/manager/conftest.py b/tests/manager/conftest.py index 92d39c7ce8..b027e4b8fb 100644 --- a/tests/manager/conftest.py +++ b/tests/manager/conftest.py @@ -28,6 +28,7 @@ from unittest.mock import AsyncMock, MagicMock from urllib.parse import quote_plus as urlquote +import aiofiles.os import aiohttp import asyncpg import pytest @@ -419,7 +420,12 @@ async def database_engine(local_config, database): @pytest.fixture() -def database_fixture(local_config, test_db, database) -> Iterator[None]: +def extra_fixtures(): + return {} + + +@pytest.fixture() +def database_fixture(local_config, test_db, database, extra_fixtures) -> Iterator[None]: """ Populate the example data as fixtures to the database and delete them after use. @@ -430,12 +436,20 @@ def database_fixture(local_config, test_db, database) -> Iterator[None]: db_url = f"postgresql+asyncpg://{db_user}:{urlquote(db_pass)}@{db_addr}/{test_db}" build_root = Path(os.environ["BACKEND_BUILD_ROOT"]) + + extra_fixture_file = tempfile.NamedTemporaryFile(delete=False) + extra_fixture_file_path = Path(extra_fixture_file.name) + + with open(extra_fixture_file_path, "w") as f: + json.dump(extra_fixtures, f) + fixture_paths = [ build_root / "fixtures" / "manager" / "example-users.json", build_root / "fixtures" / "manager" / "example-keypairs.json", build_root / "fixtures" / "manager" / "example-set-user-main-access-keys.json", build_root / "fixtures" / "manager" / "example-resource-presets.json", build_root / "fixtures" / "manager" / "example-container-registries-harbor.json", + extra_fixture_file_path, ] async def init_fixture() -> None: @@ -460,6 +474,9 @@ async def init_fixture() -> None: yield async def clean_fixture() -> None: + if extra_fixture_file_path.exists(): + await aiofiles.os.remove(extra_fixture_file_path) + engine: SAEngine = sa.ext.asyncio.create_async_engine( db_url, connect_args=pgsql_connect_opts, diff --git a/tests/manager/models/test_container_registries.py b/tests/manager/models/test_container_registries.py index 69abc7efcd..e4b49d7dba 100644 --- a/tests/manager/models/test_container_registries.py +++ b/tests/manager/models/test_container_registries.py @@ -1,10 +1,16 @@ import pytest +from aioresponses import aioresponses from graphene import Schema from graphene.test import Client +from ai.backend.common.utils import b64encode +from ai.backend.manager.api.context import RootContext from ai.backend.manager.defs import PASSWORD_PLACEHOLDER from ai.backend.manager.models.gql import GraphQueryContext, Mutations, Queries from ai.backend.manager.models.utils import ExtendedAsyncSAEngine +from ai.backend.manager.server import ( + database_ctx, +) CONTAINER_REGISTRY_FIELDS = """ hostname @@ -250,3 +256,93 @@ async def test_delete_container_registry(client: Client, database_engine: Extend response = await client.execute_async(query, variables=variables, context_value=context) assert response["data"] is None + + +FIXTURES_FOR_HARBOR_CRUD_TEST = [ + { + "container_registries": [ + { + "id": "00000000-0000-0000-0000-000000000000", + "type": "harbor2", + "url": "http://mock_registry", + "registry_name": "mock_registry", + "project": "mock_project", + "username": "mock_user", + "password": "mock_password", + "ssl_verify": False, + "is_global": True, + } + ], + "groups": [ + { + "id": "00000000-0000-0000-0000-000000000000", + "name": "mock_group", + "description": "", + "is_active": True, + "domain_name": "default", + "resource_policy": "default", + "total_resource_slots": {}, + "allowed_vfolder_hosts": {}, + "container_registry": { + "registry": "mock_registry", + "project": "mock_project", + }, + "type": "general", + } + ], + }, +] + + +@pytest.mark.asyncio +@pytest.mark.parametrize("extra_fixtures", FIXTURES_FOR_HARBOR_CRUD_TEST) +async def test_harbor_read_project_quota( + client: Client, + database_fixture, + create_app_and_client, +): + test_app, _ = await create_app_and_client( + [ + database_ctx, + ], + [], + ) + + root_ctx: RootContext = test_app["_root.context"] + context = get_graphquery_context(root_ctx.db) + + # Arbitrary values for mocking Harbor API responses + HARBOR_PROJECT_ID = "123" + HARBOR_QUOTA_ID = 456 + HARBOR_QUOTA_VALUE = 1024 + + with aioresponses() as mocked: + # Mock the get project ID API call + get_project_id_url = "http://mock_registry/api/v2.0/projects/mock_project" + mocked.get(get_project_id_url, status=200, payload={"project_id": HARBOR_PROJECT_ID}) + + # Mock the get quota info API call + get_quota_url = f"http://mock_registry/api/v2.0/quotas?reference=project&reference_id={HARBOR_PROJECT_ID}" + mocked.get( + get_quota_url, + status=200, + payload=[{"id": HARBOR_QUOTA_ID, "hard": {"storage": HARBOR_QUOTA_VALUE}}], + ) + + groupnode_query = """ + query ($id: String!) { + group_node(id: $id) { + registry_quota + } + } + """ + + group_id = "00000000-0000-0000-0000-000000000000" + variables = { + "id": b64encode(f"group_node:{group_id}"), + } + + response = await client.execute_async( + groupnode_query, variables=variables, context_value=context + ) + assert response["data"]["group_node"]["registry_quota"] == HARBOR_QUOTA_VALUE