Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix request model validation #3245

Merged
merged 6 commits into from
Dec 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 14 additions & 22 deletions src/zenml/models/v2/base/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def validate_operation(cls, value: Any) -> Any:
def generate_query_conditions(
self,
table: Type[SQLModel],
) -> Union["ColumnElement[bool]"]:
) -> "ColumnElement[bool]":
"""Generate the query conditions for the database.

This method converts the Filter class into an appropriate SQLModel
Expand Down Expand Up @@ -291,11 +291,19 @@ def generate_query_conditions_from_column(self, column: Any) -> Any:
import sqlalchemy
from sqlalchemy_utils.functions import cast_if

from zenml.utils import uuid_utils

# For equality checks, compare the UUID directly
if self.operation == GenericFilterOps.EQUALS:
if not uuid_utils.is_valid_uuid(self.value):
return False

return column == self.value

if self.operation == GenericFilterOps.NOT_EQUALS:
if not uuid_utils.is_valid_uuid(self.value):
return True

return column != self.value

# For all other operations, cast and handle the column as string
Expand Down Expand Up @@ -703,16 +711,10 @@ def generate_name_or_id_query_conditions(

conditions = []

try:
filter_ = FilterGenerator(table).define_filter(
column="id", value=value, operator=operator
)
conditions.append(filter_.generate_query_conditions(table=table))
except ValueError:
# UUID filter with equal operators and no full UUID fail with
# a ValueError. In this case, we already know that the filter
# will not produce any result and can simply ignore it.
pass
filter_ = FilterGenerator(table).define_filter(
column="id", value=value, operator=operator
)
conditions.append(filter_.generate_query_conditions(table=table))

filter_ = FilterGenerator(table).define_filter(
column="name", value=value, operator=operator
Expand Down Expand Up @@ -1101,18 +1103,8 @@ def _define_uuid_filter(
A Filter object.

Raises:
ValueError: If the value is not a valid UUID.
ValueError: If the value for a oneof filter is not a list.
"""
# For equality checks, ensure that the value is a valid UUID.
if operator == GenericFilterOps.EQUALS and not isinstance(value, UUID):
try:
UUID(value)
except ValueError as e:
raise ValueError(
"Invalid value passed as UUID query parameter."
) from e

# For equality checks, ensure that the value is a valid UUID.
if operator == GenericFilterOps.ONEOF and not isinstance(value, list):
raise ValueError(ONEOF_ERROR)

Expand Down
7 changes: 4 additions & 3 deletions src/zenml/zen_server/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,16 +421,17 @@ def f(model: Model = Depends(make_dependable(Model))):
"""
from fastapi import Query

from zenml.zen_server.exceptions import error_detail

def init_cls_and_handle_errors(*args: Any, **kwargs: Any) -> BaseModel:
from fastapi import HTTPException

try:
inspect.signature(init_cls_and_handle_errors).bind(*args, **kwargs)
return cls(*args, **kwargs)
except ValidationError as e:
for error in e.errors():
error["loc"] = tuple(["query"] + list(error["loc"]))
raise HTTPException(422, detail=e.errors())
detail = error_detail(e, exception_type=ValueError)
raise HTTPException(422, detail=detail)

params = {v.name: v for v in inspect.signature(cls).parameters.values()}
query_params = getattr(cls, "API_MULTI_INPUT_PARAMS", [])
Expand Down
12 changes: 1 addition & 11 deletions tests/unit/models/test_filter_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,21 +235,11 @@ def test_uuid_filter_model():
)


def test_uuid_filter_model_fails_for_invalid_uuids_on_equality():
"""Test filtering for equality with invalid UUID fails."""
with pytest.raises(ValueError):
uuid_value = "a92k34"
SomeFilterModel(uuid_field=f"{GenericFilterOps.EQUALS}:{uuid_value}")


def test_uuid_filter_model_succeeds_for_invalid_uuid_on_non_equality():
"""Test filtering with other UUID operations is possible with non-UUIDs."""
filter_value = "a92k34"
for filter_op in UUIDFilter.ALLOWED_OPS:
if (
filter_op == GenericFilterOps.EQUALS
or filter_op == GenericFilterOps.ONEOF
):
if filter_op == GenericFilterOps.ONEOF:
continue
filter_model = SomeFilterModel(
uuid_field=f"{filter_op}:{filter_value}"
Expand Down
Loading