Skip to content

Commit

Permalink
Fix request model validation (#3245)
Browse files Browse the repository at this point in the history
* Fix request model validation error

* Don't throw value error for invalid uuids

* Docstring

* Fix tests

* Handle invalid uuid values
  • Loading branch information
schustmi authored Dec 13, 2024
1 parent d6fae4e commit 384cb8b
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 36 deletions.
36 changes: 14 additions & 22 deletions src/zenml/models/v2/base/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def validate_operation(cls, value: Any) -> Any:
def generate_query_conditions(
self,
table: Type[SQLModel],
) -> Union["ColumnElement[bool]"]:
) -> "ColumnElement[bool]":
"""Generate the query conditions for the database.
This method converts the Filter class into an appropriate SQLModel
Expand Down Expand Up @@ -291,11 +291,19 @@ def generate_query_conditions_from_column(self, column: Any) -> Any:
import sqlalchemy
from sqlalchemy_utils.functions import cast_if

from zenml.utils import uuid_utils

# For equality checks, compare the UUID directly
if self.operation == GenericFilterOps.EQUALS:
if not uuid_utils.is_valid_uuid(self.value):
return False

return column == self.value

if self.operation == GenericFilterOps.NOT_EQUALS:
if not uuid_utils.is_valid_uuid(self.value):
return True

return column != self.value

# For all other operations, cast and handle the column as string
Expand Down Expand Up @@ -702,16 +710,10 @@ def generate_name_or_id_query_conditions(

conditions = []

try:
filter_ = FilterGenerator(table).define_filter(
column="id", value=value, operator=operator
)
conditions.append(filter_.generate_query_conditions(table=table))
except ValueError:
# UUID filter with equal operators and no full UUID fail with
# a ValueError. In this case, we already know that the filter
# will not produce any result and can simply ignore it.
pass
filter_ = FilterGenerator(table).define_filter(
column="id", value=value, operator=operator
)
conditions.append(filter_.generate_query_conditions(table=table))

filter_ = FilterGenerator(table).define_filter(
column="name", value=value, operator=operator
Expand Down Expand Up @@ -1105,18 +1107,8 @@ def _define_uuid_filter(
A Filter object.
Raises:
ValueError: If the value is not a valid UUID.
ValueError: If the value for a oneof filter is not a list.
"""
# For equality checks, ensure that the value is a valid UUID.
if operator == GenericFilterOps.EQUALS and not isinstance(value, UUID):
try:
UUID(value)
except ValueError as e:
raise ValueError(
"Invalid value passed as UUID query parameter."
) from e

# For equality checks, ensure that the value is a valid UUID.
if operator == GenericFilterOps.ONEOF and not isinstance(value, list):
raise ValueError(ONEOF_ERROR)

Expand Down
7 changes: 4 additions & 3 deletions src/zenml/zen_server/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,16 +421,17 @@ def f(model: Model = Depends(make_dependable(Model))):
"""
from fastapi import Query

from zenml.zen_server.exceptions import error_detail

def init_cls_and_handle_errors(*args: Any, **kwargs: Any) -> BaseModel:
from fastapi import HTTPException

try:
inspect.signature(init_cls_and_handle_errors).bind(*args, **kwargs)
return cls(*args, **kwargs)
except ValidationError as e:
for error in e.errors():
error["loc"] = tuple(["query"] + list(error["loc"]))
raise HTTPException(422, detail=e.errors())
detail = error_detail(e, exception_type=ValueError)
raise HTTPException(422, detail=detail)

params = {v.name: v for v in inspect.signature(cls).parameters.values()}
query_params = getattr(cls, "API_MULTI_INPUT_PARAMS", [])
Expand Down
12 changes: 1 addition & 11 deletions tests/unit/models/test_filter_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,21 +235,11 @@ def test_uuid_filter_model():
)


def test_uuid_filter_model_fails_for_invalid_uuids_on_equality():
"""Test filtering for equality with invalid UUID fails."""
with pytest.raises(ValueError):
uuid_value = "a92k34"
SomeFilterModel(uuid_field=f"{GenericFilterOps.EQUALS}:{uuid_value}")


def test_uuid_filter_model_succeeds_for_invalid_uuid_on_non_equality():
"""Test filtering with other UUID operations is possible with non-UUIDs."""
filter_value = "a92k34"
for filter_op in UUIDFilter.ALLOWED_OPS:
if (
filter_op == GenericFilterOps.EQUALS
or filter_op == GenericFilterOps.ONEOF
):
if filter_op == GenericFilterOps.ONEOF:
continue
filter_model = SomeFilterModel(
uuid_field=f"{filter_op}:{filter_value}"
Expand Down

0 comments on commit 384cb8b

Please sign in to comment.