Skip to content

Commit

Permalink
Fix pydantic related errors due to handling None values (#439)
Browse files Browse the repository at this point in the history
  • Loading branch information
uittenbroekrobbert authored Dec 18, 2024
2 parents a825831 + c0bfcfd commit 36398f2
Show file tree
Hide file tree
Showing 13 changed files with 229 additions and 204 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ repos:
- id: check-toml
- id: detect-private-key
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.7.4
rev: v0.8.3
hooks:
- id: ruff
args: [--config, pyproject.toml]
Expand Down
6 changes: 5 additions & 1 deletion amt/api/routes/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -695,7 +695,11 @@ async def get_assessment_card(
)

assessment_card_data = next(
(assessment for assessment in algorithm.system_card.assessments if assessment.name.lower() == assessment_card),
(
assessment
for assessment in algorithm.system_card.assessments
if assessment.name is not None and assessment.name.lower() == assessment_card
),
None,
)

Expand Down
2 changes: 1 addition & 1 deletion amt/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
from .task import Task
from .user import User

__all__ = ["Task", "User", "Algorithm", "Organization"]
__all__ = ["Algorithm", "Organization", "Task", "User"]
18 changes: 9 additions & 9 deletions amt/schema/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,19 @@ class AlgorithmBase(BaseModel):

class AlgorithmNew(AlgorithmBase):
instruments: list[str] | str = []
type: str = Field(default=None)
open_source: str = Field(default=None)
risk_group: str = Field(default=None)
conformity_assessment_body: str = Field(default=None)
systemic_risk: str = Field(default=None)
transparency_obligations: str = Field(default=None)
type: str | None = Field(default=None)
open_source: str | None = Field(default=None)
risk_group: str | None = Field(default=None)
conformity_assessment_body: str | None = Field(default=None)
systemic_risk: str | None = Field(default=None)
transparency_obligations: str | None = Field(default=None)
role: list[str] | str = []
template_id: str = Field(default=None)
template_id: str | None = Field(default=None)
organization_id: int = Field()

@field_validator("organization_id", mode="before")
@classmethod
def ensure_required(cls, v: int | str) -> int: # noqa
def ensure_required(cls, v: int | str) -> int:
if isinstance(v, str) and v == "": # this is always a string
# TODO (Robbert): the error message from pydantic becomes 'Value error,
# missing' which is why a custom message will be applied
Expand All @@ -31,5 +31,5 @@ def ensure_required(cls, v: int | str) -> int: # noqa

@field_validator("instruments", "role")
@classmethod
def ensure_list(cls, v: list[str] | str) -> list[str]: # noqa
def ensure_list(cls, v: list[str] | str) -> list[str]:
return v if isinstance(v, list) else [v]
16 changes: 8 additions & 8 deletions amt/schema/assessment_card.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,20 @@


class AssessmentAuthor(BaseModel):
name: str = Field(default=None)
name: str | None = Field(default=None)


class AssessmentContent(BaseModel):
question: str = Field(default=None)
urn: str = Field(default=None)
answer: str = Field(default=None)
remarks: str = Field(default=None)
question: str | None = Field(default=None)
urn: str | None = Field(default=None)
answer: str | None = Field(default=None)
remarks: str | None = Field(default=None)
authors: list[AssessmentAuthor] = Field(default=[])
timestamp: datetime | None = Field(default=None)


class AssessmentCard(BaseModel):
name: str = Field(default=None)
urn: str = Field(default=None)
date: datetime = Field(default=None)
name: str | None = Field(default=None)
urn: str | None = Field(default=None)
date: datetime | None = Field(default=None)
contents: list[AssessmentContent] = Field(default=[])
14 changes: 7 additions & 7 deletions amt/schema/system_card.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@


class Reference(BaseModel):
name: str = Field(default=None)
link: str = Field(default=None)
name: str | None = Field(default=None)
link: str | None = Field(default=None)


# TODO: consider reusing classes, Owner is now also defined for Models
class Owner(BaseModel):
organization: str = Field(default=None)
organization: str | None = Field(default=None)
oin: str | None = Field(default=None)

def __init__(self, organization: str, oin: str | None = None, **data) -> None: # pyright: ignore # noqa
Expand All @@ -28,12 +28,12 @@ def __init__(self, organization: str, oin: str | None = None, **data) -> None:

class SystemCard(BaseModel):
schema_version: str = Field(default="0.1a10")
name: str = Field(default=None)
ai_act_profile: AiActProfile = Field(default=None)
name: str | None = Field(default=None)
ai_act_profile: AiActProfile | None = Field(default=None)
provenance: dict[str, Any] = Field(default={})
description: str = Field(default=None)
description: str | None = Field(default=None)
labels: list[dict[str, Any]] = Field(default=[])
status: str = Field(default=None)
status: str | None = Field(default=None)
instruments: list[InstrumentBase] = Field(default=[])
requirements: list[RequirementTask] = Field(default=[])
measures: list[MeasureTask] = Field(default=[])
Expand Down
4 changes: 2 additions & 2 deletions amt/schema/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@


class MovedTask(BaseModel):
id: int = PydanticField(None, alias="taskId", strict=False)
status_id: int = PydanticField(None, alias="statusId", strict=False)
id: int | None = PydanticField(None, alias="taskId", strict=False)
status_id: int | None = PydanticField(None, alias="statusId", strict=False)
previous_sibling_id: int | None = PydanticField(None, alias="previousSiblingId", strict=False)
next_sibling_id: int | None = PydanticField(None, alias="nextSiblingId", strict=False)
6 changes: 4 additions & 2 deletions amt/services/instruments_and_requirements_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def get_first_lifecycle_idx(lifecycles: list[str]) -> int | None:
return 0


def get_instrument_result_from_system_card(urn: str, system_card: SystemCard) -> AssessmentCard | None:
def get_instrument_result_from_system_card(urn: str | None, system_card: SystemCard) -> AssessmentCard | None:
"""
Returns the results of the given instrument if it is found in the system card, otherwise None.
:param urn: the urn of the instrument
Expand All @@ -57,7 +57,9 @@ def get_instrument_result_from_system_card(urn: str, system_card: SystemCard) ->
return None


def get_task_timestamp_from_assessment_card(task_urn: str, assessment_card: AssessmentCard) -> datetime | None:
def get_task_timestamp_from_assessment_card(task_urn: str | None, assessment_card: AssessmentCard) -> datetime | None:
if task_urn is None:
raise ValueError("Can not get timestamp from assessment card if task_urn is none")
for content in assessment_card.contents:
if content.urn == task_urn and content.timestamp:
return content.timestamp
Expand Down
8 changes: 7 additions & 1 deletion amt/services/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,11 @@ async def create_instrument_tasks(self, tasks: Sequence[InstrumentTask], algorit
)

async def move_task(
self, task_id: int, status_id: int, previous_sibling_id: int | None = None, next_sibling_id: int | None = None
self,
task_id: int | None,
status_id: int | None,
previous_sibling_id: int | None = None,
next_sibling_id: int | None = None,
) -> Task:
"""
Updates the task with the given task_id
Expand All @@ -68,6 +72,8 @@ async def move_task(
:param next_sibling_id: the id of the next sibling of the task or None
:return: the updated task
"""
if task_id is None or status_id is None:
raise ValueError("task_id or status_id must not be None")
task = await self.repository.find_by_id(task_id)

if status_id == Status.DONE:
Expand Down
Loading

0 comments on commit 36398f2

Please sign in to comment.