From ef7130ba43ea8729857b32984be6a2bf0ec04239 Mon Sep 17 00:00:00 2001 From: ort Date: Mon, 6 Apr 2020 09:21:04 +0900 Subject: [PATCH] query: process RHS according to its type + tests `Field.get_db_prep_lookup` merely "prepares" the right hand side operand, and returns a (sql, params) pair. For a plain value `params` will simply be the value itself; compound expressions with `as_sql` method, however, are never designed to work with `get_db_prep_lookup` and instead needs their `as_sql` method called directly to retrieve the pair. The error message has been changed from time to time across Django versions. Related issues include: * #89 "can't adapt type 'CombinedExpression'" in admin filter Since Django 1.10 removed `Field.get_db_prep_lookup`, `Lookup.get_db_prep_lookup` implementation is simply as follows. ```python def get_db_prep_lookup(self, value, connection): return ('%s', [value]) ``` `CombinedExpression` is an expression type returned by combining exprssions with operators, such as `__add__` or `bitor`. They are never meant to be passed to the SQL engine directly; instead, it has an `as_sql` method that returns what we want. `process_rhs` first applies any bilateral transforms on both sides, and if it finds either `as_sql` or `_as_sql` method on the RHS, calls it and wraps the SQL in the pair with parentheses; otherwise, it diverts to the usual `get_db_prep_lookup` as above. * #64 'CombinedExpression' object is not iterable when using an admin filter on Django 1.8 Same as above, except `Lookup.get_db_prep_lookup` is as follows. ```python def get_db_prep_lookup(self, value, connection): return ( '%s', self.lhs.output_field.get_db_prep_lookup( self.lookup_name, value, connection, prepared=True)) ``` Following is the relevant part of `Field.get_db_prep_lookup`. ```python def get_db_prep_lookup(self, lookup_type, value, connection, prepared=False): """ Returns field's value prepared for database lookup. """ if not prepared: value = self.get_prep_lookup(lookup_type, value) prepared = True if hasattr(value, 'get_compiler'): value = value.get_compiler(connection=connection) if hasattr(value, 'as_sql') or hasattr(value, '_as_sql'): # If the value has a relabeled_clone method it means the # value will be handled later on. if hasattr(value, 'relabeled_clone'): return value if hasattr(value, 'as_sql'): sql, params = value.as_sql() else: sql, params = value._as_sql(connection=connection) return QueryWrapper(('(%s)' % sql), params) if lookup_type in ('month', 'day', 'week_day', 'hour', 'minute', 'second', 'search', 'regex', 'iregex', 'contains', 'icontains', 'iexact', 'startswith', 'endswith', 'istartswith', 'iendswith'): return [value] ``` Normally it is supposed to return the parameters for the SQL "prepared statement" as a `list`; however, if the value has either `as_sql` or `_as_sql`, it either returns the value directly if it also has `relabeled_clone` (which is the case for expressions) or wraps it in `QueryWrapper`. This is not a desired behavior, and using them would result in `TypeError` saying the object was not iterable since they were not lists in the first place. * #61 'SQLEvaluator' object is not iterable when using admin filter on Django 1.7 Same as above, except `SQLEvaluator` emerges from the SQL compiler wrapping expressions for internal use. - Update expressions: django.db.models.sql.compiler.SQLUpdateCompiler.as_sql: ... if hasattr(val, 'evaluate'): val = SQLEvaluator(val, self.query, allow_joins=False) ... - filter() resolution: django.db.models.sql.query.Query.add_filter: ... elif isinstance(value, ExpressionNode): # If value is a query expression, evaluate it value = SQLEvaluator(value, self, reuse=can_reuse) having_clause = value.contains_aggregate ... This commit also causes "regression" for the following issue; its legitimacy, however, is ambiguous considering that the lookup essentially "hijacks" the original definition of 'exact'. #26 admin BitFieldListFilter not working as expected It should be noted, however, that the original behavior is still possible to achieve by wrapping the `int` in a `BitHandler`. Fixing this correctly would involve treating `int` as `Bit` as before while showing a deprecation warning. Meanwhile `filter(flags=Value(4))` will find a record of which `flags` is exactly `4` across all versions. --- bitfield/query.py | 26 ++++++++++++++++---------- bitfield/tests/tests.py | 22 ++++++++++++++++++++++ 2 files changed, 38 insertions(+), 10 deletions(-) diff --git a/bitfield/query.py b/bitfield/query.py index c3a753d..0e97b17 100644 --- a/bitfield/query.py +++ b/bitfield/query.py @@ -5,22 +5,28 @@ class BitQueryLookupWrapper(Exact): # NOQA - def process_lhs(self, qn, connection, lhs=None): - lhs_sql, params = super(BitQueryLookupWrapper, self).process_lhs( - qn, connection, lhs) - if self.rhs: - lhs_sql = lhs_sql + ' & %s' - else: - lhs_sql = lhs_sql + ' | %s' - params.extend(self.get_db_prep_lookup(self.rhs, connection)[1]) - return lhs_sql, params + def process_lhs(self, compiler, connection, lhs=None): + lhs_sql, lhs_params = super(BitQueryLookupWrapper, self).process_lhs( + compiler, connection, lhs) + + if not isinstance(self.rhs, (BitHandler, Bit)): + return lhs_sql, lhs_params + + op = ' & ' if self.rhs else ' | ' + rhs_sql, rhs_params = self.process_rhs(compiler, connection) + params = list(lhs_params) + params.extend(rhs_params) + + return op.join((lhs_sql, rhs_sql)), params def get_db_prep_lookup(self, value, connection): v = value.mask if isinstance(value, (BitHandler, Bit)) else value return super(BitQueryLookupWrapper, self).get_db_prep_lookup(v, connection) def get_prep_lookup(self): - return self.rhs + if isinstance(self.rhs, (BitHandler, Bit)): + return self.rhs # resolve at later stage, in get_db_prep_lookup + return super(BitQueryLookupWrapper, self).get_prep_lookup() class BitQuerySaveWrapper(BitQueryLookupWrapper): diff --git a/bitfield/tests/tests.py b/bitfield/tests/tests.py index 4d03062..cbedc33 100644 --- a/bitfield/tests/tests.py +++ b/bitfield/tests/tests.py @@ -177,6 +177,28 @@ def test_select(self): self.assertFalse(BitFieldTestModel.objects.exclude(flags=BitFieldTestModel.flags.FLAG_0).exists()) self.assertFalse(BitFieldTestModel.objects.exclude(flags=BitFieldTestModel.flags.FLAG_1).exists()) + def test_select_complex_expression(self): + BitFieldTestModel.objects.create(flags=3) + self.assertTrue(BitFieldTestModel.objects.filter(flags=F('flags').bitor(BitFieldTestModel.flags.FLAG_1)).exists()) + self.assertTrue(BitFieldTestModel.objects.filter(flags=F('flags').bitor(BitFieldTestModel.flags.FLAG_0)).exists()) + self.assertTrue(BitFieldTestModel.objects.filter(flags=F('flags').bitor(BitFieldTestModel.flags.FLAG_0).bitor(BitFieldTestModel.flags.FLAG_1)).exists()) + self.assertTrue(BitFieldTestModel.objects.filter(flags=F('flags').bitand(BitFieldTestModel.flags.FLAG_0 | BitFieldTestModel.flags.FLAG_1)).exists()) + self.assertTrue(BitFieldTestModel.objects.filter(flags=F('flags').bitand(15)).exists()) + self.assertTrue(BitFieldTestModel.objects.exclude(flags=F('flags').bitand(BitFieldTestModel.flags.FLAG_2)).exists()) + self.assertTrue(BitFieldTestModel.objects.exclude(flags=F('flags').bitand(BitFieldTestModel.flags.FLAG_3)).exists()) + self.assertTrue(BitFieldTestModel.objects.exclude(flags=F('flags').bitand(BitFieldTestModel.flags.FLAG_2 | BitFieldTestModel.flags.FLAG_3)).exists()) + self.assertTrue(BitFieldTestModel.objects.exclude(flags=F('flags').bitand(12)).exists()) + + self.assertFalse(BitFieldTestModel.objects.exclude(flags=F('flags').bitor(BitFieldTestModel.flags.FLAG_1)).exists()) + self.assertFalse(BitFieldTestModel.objects.exclude(flags=F('flags').bitor(BitFieldTestModel.flags.FLAG_0)).exists()) + self.assertFalse(BitFieldTestModel.objects.exclude(flags=F('flags').bitor(BitFieldTestModel.flags.FLAG_0).bitor(BitFieldTestModel.flags.FLAG_1)).exists()) + self.assertFalse(BitFieldTestModel.objects.exclude(flags=F('flags').bitand(BitFieldTestModel.flags.FLAG_0 | BitFieldTestModel.flags.FLAG_1)).exists()) + self.assertFalse(BitFieldTestModel.objects.exclude(flags=F('flags').bitand(15)).exists()) + self.assertFalse(BitFieldTestModel.objects.filter(flags=F('flags').bitand(BitFieldTestModel.flags.FLAG_2)).exists()) + self.assertFalse(BitFieldTestModel.objects.filter(flags=F('flags').bitand(BitFieldTestModel.flags.FLAG_3)).exists()) + self.assertFalse(BitFieldTestModel.objects.filter(flags=F('flags').bitand(BitFieldTestModel.flags.FLAG_2 | BitFieldTestModel.flags.FLAG_3)).exists()) + self.assertFalse(BitFieldTestModel.objects.filter(flags=F('flags').bitand(12)).exists()) + def test_update(self): instance = BitFieldTestModel.objects.create(flags=0) self.assertFalse(instance.flags.FLAG_0)