diff --git a/src/zenml/models/v2/base/filter.py b/src/zenml/models/v2/base/filter.py index 1c4d2cccfb..dd2a074b35 100644 --- a/src/zenml/models/v2/base/filter.py +++ b/src/zenml/models/v2/base/filter.py @@ -113,7 +113,7 @@ def validate_operation(cls, value: Any) -> Any: def generate_query_conditions( self, table: Type[SQLModel], - ) -> Union["ColumnElement[bool]"]: + ) -> "ColumnElement[bool]": """Generate the query conditions for the database. This method converts the Filter class into an appropriate SQLModel @@ -291,11 +291,19 @@ def generate_query_conditions_from_column(self, column: Any) -> Any: import sqlalchemy from sqlalchemy_utils.functions import cast_if + from zenml.utils import uuid_utils + # For equality checks, compare the UUID directly if self.operation == GenericFilterOps.EQUALS: + if not uuid_utils.is_valid_uuid(self.value): + return False + return column == self.value if self.operation == GenericFilterOps.NOT_EQUALS: + if not uuid_utils.is_valid_uuid(self.value): + return True + return column != self.value # For all other operations, cast and handle the column as string @@ -703,16 +711,10 @@ def generate_name_or_id_query_conditions( conditions = [] - try: - filter_ = FilterGenerator(table).define_filter( - column="id", value=value, operator=operator - ) - conditions.append(filter_.generate_query_conditions(table=table)) - except ValueError: - # UUID filter with equal operators and no full UUID fail with - # a ValueError. In this case, we already know that the filter - # will not produce any result and can simply ignore it. - pass + filter_ = FilterGenerator(table).define_filter( + column="id", value=value, operator=operator + ) + conditions.append(filter_.generate_query_conditions(table=table)) filter_ = FilterGenerator(table).define_filter( column="name", value=value, operator=operator @@ -1101,18 +1103,8 @@ def _define_uuid_filter( A Filter object. Raises: - ValueError: If the value is not a valid UUID. + ValueError: If the value for a oneof filter is not a list. """ - # For equality checks, ensure that the value is a valid UUID. - if operator == GenericFilterOps.EQUALS and not isinstance(value, UUID): - try: - UUID(value) - except ValueError as e: - raise ValueError( - "Invalid value passed as UUID query parameter." - ) from e - - # For equality checks, ensure that the value is a valid UUID. if operator == GenericFilterOps.ONEOF and not isinstance(value, list): raise ValueError(ONEOF_ERROR) diff --git a/src/zenml/zen_server/utils.py b/src/zenml/zen_server/utils.py index ff96c7a640..86414385ff 100644 --- a/src/zenml/zen_server/utils.py +++ b/src/zenml/zen_server/utils.py @@ -421,6 +421,8 @@ def f(model: Model = Depends(make_dependable(Model))): """ from fastapi import Query + from zenml.zen_server.exceptions import error_detail + def init_cls_and_handle_errors(*args: Any, **kwargs: Any) -> BaseModel: from fastapi import HTTPException @@ -428,9 +430,8 @@ def init_cls_and_handle_errors(*args: Any, **kwargs: Any) -> BaseModel: inspect.signature(init_cls_and_handle_errors).bind(*args, **kwargs) return cls(*args, **kwargs) except ValidationError as e: - for error in e.errors(): - error["loc"] = tuple(["query"] + list(error["loc"])) - raise HTTPException(422, detail=e.errors()) + detail = error_detail(e, exception_type=ValueError) + raise HTTPException(422, detail=detail) params = {v.name: v for v in inspect.signature(cls).parameters.values()} query_params = getattr(cls, "API_MULTI_INPUT_PARAMS", []) diff --git a/tests/unit/models/test_filter_models.py b/tests/unit/models/test_filter_models.py index 46b711bb7c..c0d69ea4d2 100644 --- a/tests/unit/models/test_filter_models.py +++ b/tests/unit/models/test_filter_models.py @@ -235,21 +235,11 @@ def test_uuid_filter_model(): ) -def test_uuid_filter_model_fails_for_invalid_uuids_on_equality(): - """Test filtering for equality with invalid UUID fails.""" - with pytest.raises(ValueError): - uuid_value = "a92k34" - SomeFilterModel(uuid_field=f"{GenericFilterOps.EQUALS}:{uuid_value}") - - def test_uuid_filter_model_succeeds_for_invalid_uuid_on_non_equality(): """Test filtering with other UUID operations is possible with non-UUIDs.""" filter_value = "a92k34" for filter_op in UUIDFilter.ALLOWED_OPS: - if ( - filter_op == GenericFilterOps.EQUALS - or filter_op == GenericFilterOps.ONEOF - ): + if filter_op == GenericFilterOps.ONEOF: continue filter_model = SomeFilterModel( uuid_field=f"{filter_op}:{filter_value}"