Skip to content

Commit

Permalink
api.main: use node endpoints for all type of Node subtypes
Browse files Browse the repository at this point in the history
Implement a mechanism for dynamic polymorphism on Node objects by using
explicit pydantic object validation depending on the node kind and
storing all Node objects as plain nodes in the same DB collection
regardless of their type.

Signed-off-by: Ricardo Cañuelo <[email protected]>
  • Loading branch information
Ricardo Cañuelo committed Nov 3, 2023
1 parent 679f689 commit 8952151
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 54 deletions.
81 changes: 36 additions & 45 deletions api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import os
import re
from typing import List, Union
from typing import List
from fastapi import (
Depends,
FastAPI,
Expand All @@ -34,12 +34,11 @@
from .models import (
Node,
Hierarchy,
Regression,
User,
UserGroup,
UserProfile,
Password,
get_model_from_kind
parse_node_obj,
)
from .paginator_models import PageModel
from .pubsub import PubSub, Subscription
Expand Down Expand Up @@ -457,13 +456,12 @@ async def translate_null_query_params(query_params: dict):
return translated


@app.get('/node/{node_id}', response_model=Union[Regression, Node],
@app.get('/node/{node_id}', response_model=Node,
response_model_by_alias=False)
async def get_node(node_id: str, kind: str = "node"):
async def get_node(node_id: str):
"""Get node information from the provided node id"""
try:
model = get_model_from_kind(kind)
return await db.find_by_id(model, node_id)
return await db.find_by_id(Node, node_id)
except KeyError as error:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
Expand All @@ -487,7 +485,7 @@ def serialize_paginated_data(model, data: list):


@app.get('/nodes', response_model=PageModel)
async def get_nodes(request: Request, kind: str = "node"):
async def get_nodes(request: Request):
"""Get all the nodes if no request parameters have passed.
Get all the matching nodes otherwise, within the pagination limit."""
query_params = dict(request.query_params)
Expand All @@ -499,7 +497,9 @@ async def get_nodes(request: Request, kind: str = "node"):
query_params = await translate_null_query_params(query_params)

try:
model = get_model_from_kind(kind)
# Query using the base Node model, regardless of the specific
# node type
model = Node
translated_params = model.translate_fields(query_params)
paginated_resp = await db.find_by_attributes(model, translated_params)
paginated_resp.items = serialize_paginated_data(
Expand All @@ -515,15 +515,17 @@ async def get_nodes(request: Request, kind: str = "node"):


@app.get('/count', response_model=int)
async def get_nodes_count(request: Request, kind: str = "node"):
async def get_nodes_count(request: Request):
"""Get the count of all the nodes if no request parameters have passed.
Get the count of all the matching nodes otherwise."""
query_params = dict(request.query_params)

query_params = await translate_null_query_params(query_params)

try:
model = get_model_from_kind(kind)
# Query using the base Node model, regardless of the specific
# node type
model = Node
translated_params = model.translate_fields(query_params)
return await db.count(model, translated_params)
except KeyError as error:
Expand All @@ -545,6 +547,10 @@ async def _verify_user_group_existence(user_groups: List[str]):
@app.post('/node', response_model=Node, response_model_by_alias=False)
async def post_node(node: Node, current_user: str = Depends(get_user)):
"""Create a new node"""
# Explicit pydantic model validation
parse_node_obj(node)

# [TODO] Implement sanity checks depending on the node kind
if node.parent:
parent = await db.find_by_id(Node, node.parent)
if not parent:
Expand All @@ -555,6 +561,10 @@ async def post_node(node: Node, current_user: str = Depends(get_user)):

await _verify_user_group_existence(node.user_groups)
node.owner = current_user.profile.username

# The node is handled as a generic Node by the DB, regardless of its
# specific kind. The concrete Node submodel (Kbuild, Checkout, etc.)
# is only used for data format validation
obj = await db.create(node)
data = _get_node_event_data('created', obj)
await pubsub.publish_cloudevent('node', data)
Expand All @@ -572,19 +582,28 @@ async def put_node(node_id: str, node: Node,
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Node not found with id: {node.id}"
)
is_valid, message = node_from_id.validate_node_state_transition(

# Sanity checks
# Note: do not update node ownership fields, don't update 'state'
# until we've checked the state transition is valid.
update_data = node.dict(exclude={'user', 'user_groups', 'state'})
new_node_def = node_from_id.copy(update=update_data)
# 1- Parse and validate node to specific subtype
specialized_node = parse_node_obj(new_node_def)

# 2 - State transition checks
is_valid, message = specialized_node.validate_node_state_transition(
node.state)
if not is_valid:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=message
)
# Now we can update the state
new_node_def.state = node.state

# Do not update node ownership fields
update_data = node.dict(exclude={'user', 'user_groups'})
node = node_from_id.copy(update=update_data)

obj = await db.update(node)
# Update node in the DB
obj = await db.update(new_node_def)
data = _get_node_event_data('updated', obj)
await pubsub.publish_cloudevent('node', data)
return obj
Expand Down Expand Up @@ -653,34 +672,6 @@ async def publish(raw: dict, channel: str, user: User = Depends(get_user)):
await pubsub.publish_cloudevent(channel, data, attributes)


# -----------------------------------------------------------------------------
# Regression

@app.post('/regression', response_model=Regression,
response_model_by_alias=False)
async def post_regression(regression: Regression,
token: str = Depends(get_user)):
"""Create a new regression"""
obj = await db.create(regression)
operation = 'created'
await pubsub.publish_cloudevent('regression', {'op': operation,
'id': str(obj.id)})
return obj


@app.put('/regression/{regression_id}', response_model=Regression,
response_model_by_alias=False)
async def put_regression(regression_id: str, regression: Regression,
token: str = Depends(get_user)):
"""Update an already added regression"""
regression.id = ObjectId(regression_id)
obj = await db.update(regression)
operation = 'updated'
await pubsub.publish_cloudevent('regression', {'op': operation,
'id': str(obj.id)})
return obj


app = VersionedFastAPI(
app,
version_format='{major}',
Expand Down
23 changes: 14 additions & 9 deletions api/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
"""KernelCI API model definitions"""

from datetime import datetime, timedelta
from typing import Any, Optional, Dict, List
from typing import Any, Optional, Dict, List, ClassVar
import enum
from bson import ObjectId
from pydantic import (
Expand Down Expand Up @@ -214,6 +214,7 @@ def get_timeout(self):

class Node(DatabaseModel):
"""KernelCI primitive object to model a node in a hierarchy"""
class_kind: ClassVar[str] = 'node'
kind: str = Field(
default='node',
description="Type of the object"
Expand Down Expand Up @@ -384,6 +385,7 @@ class Config:

class Checkout(Node):
"""API model for checkout nodes"""
class_kind: ClassVar[str] = 'checkout'
kind: str = Field(
default='checkout',
description='Type of the object',
Expand Down Expand Up @@ -420,6 +422,7 @@ class Config:

class Kbuild(Node):
"""API model for kbuild (kernel builds) nodes"""
class_kind: ClassVar[str] = 'kbuild'
kind: str = Field(
default='kbuild',
description='Type of the object',
Expand Down Expand Up @@ -451,6 +454,7 @@ class Config:

class Test(Node):
"""API model for test nodes"""
class_kind: ClassVar[str] = 'test'
kind: str = Field(
default='test',
description='Type of the object',
Expand All @@ -477,7 +481,7 @@ class Config:

class Regression(Node):
"""API model for regression tracking"""

class_kind: ClassVar[str] = 'regression'
kind: str = Field(
default='regression',
description='Type of the object',
Expand All @@ -503,10 +507,11 @@ class Regression(Node):
]


def get_model_from_kind(kind: str):
"""Get model from kind parameter"""
models = {
"node": Node,
"regression": Regression
}
return models[kind]
def parse_node_obj(node: Node):
"""Parses a generic Node object using the appropriate Node submodel
depending on its 'kind'.
"""
for c in type(node).__subclasses__():

Check warning on line 514 in api/models.py

View workflow job for this annotation

GitHub Actions / Lint

Variable name "c" doesn't conform to snake_case naming style
if node.kind == c.class_kind:
return c.parse_obj(node)
raise ValueError(f"Unsupported node kind: {node.kind}")

0 comments on commit 8952151

Please sign in to comment.