diff --git a/bitfield/query.py b/bitfield/query.py index 1898507..98b423e 100644 --- a/bitfield/query.py +++ b/bitfield/query.py @@ -5,22 +5,28 @@ class BitQueryLookupWrapper(Exact): # NOQA - def process_lhs(self, qn, connection, lhs=None): - lhs_sql, params = super(BitQueryLookupWrapper, self).process_lhs( - qn, connection, lhs) - if self.rhs: - lhs_sql = lhs_sql + ' & %s' - else: - lhs_sql = lhs_sql + ' | %s' - params.extend(self.get_db_prep_lookup(self.rhs, connection)[1]) - return lhs_sql, params + def process_lhs(self, compiler, connection, lhs=None): + lhs_sql, lhs_params = super(BitQueryLookupWrapper, self).process_lhs( + compiler, connection, lhs) + + if not isinstance(self.rhs, (BitHandler, Bit)): + return lhs_sql, lhs_params + + op = ' & ' if self.rhs else ' | ' + rhs_sql, rhs_params = self.process_rhs(compiler, connection) + params = list(lhs_params) + params.extend(rhs_params) + + return op.join((lhs_sql, rhs_sql)), params def get_db_prep_lookup(self, value, connection, prepared=False): v = value.mask if isinstance(value, (BitHandler, Bit)) else value return super(BitQueryLookupWrapper, self).get_db_prep_lookup(v, connection) def get_prep_lookup(self): - return self.rhs + if isinstance(self.rhs, (BitHandler, Bit)): + return self.rhs # resolve at later stage, in get_db_prep_lookup + return super(BitQueryLookupWrapper, self).get_prep_lookup() class BitQuerySaveWrapper(BitQueryLookupWrapper): diff --git a/bitfield/tests/tests.py b/bitfield/tests/tests.py index edc2e22..6b4d25c 100644 --- a/bitfield/tests/tests.py +++ b/bitfield/tests/tests.py @@ -178,6 +178,28 @@ def test_select(self): self.assertFalse(BitFieldTestModel.objects.exclude(flags=BitFieldTestModel.flags.FLAG_0).exists()) self.assertFalse(BitFieldTestModel.objects.exclude(flags=BitFieldTestModel.flags.FLAG_1).exists()) + def test_select_complex_expression(self): + BitFieldTestModel.objects.create(flags=3) + self.assertTrue(BitFieldTestModel.objects.filter(flags=F('flags').bitor(BitFieldTestModel.flags.FLAG_1)).exists()) + self.assertTrue(BitFieldTestModel.objects.filter(flags=F('flags').bitor(BitFieldTestModel.flags.FLAG_0)).exists()) + self.assertTrue(BitFieldTestModel.objects.filter(flags=F('flags').bitor(BitFieldTestModel.flags.FLAG_0).bitor(BitFieldTestModel.flags.FLAG_1)).exists()) + self.assertTrue(BitFieldTestModel.objects.filter(flags=F('flags').bitand(BitFieldTestModel.flags.FLAG_0 | BitFieldTestModel.flags.FLAG_1)).exists()) + self.assertTrue(BitFieldTestModel.objects.filter(flags=F('flags').bitand(15)).exists()) + self.assertTrue(BitFieldTestModel.objects.exclude(flags=F('flags').bitand(BitFieldTestModel.flags.FLAG_2)).exists()) + self.assertTrue(BitFieldTestModel.objects.exclude(flags=F('flags').bitand(BitFieldTestModel.flags.FLAG_3)).exists()) + self.assertTrue(BitFieldTestModel.objects.exclude(flags=F('flags').bitand(BitFieldTestModel.flags.FLAG_2 | BitFieldTestModel.flags.FLAG_3)).exists()) + self.assertTrue(BitFieldTestModel.objects.exclude(flags=F('flags').bitand(12)).exists()) + + self.assertFalse(BitFieldTestModel.objects.exclude(flags=F('flags').bitor(BitFieldTestModel.flags.FLAG_1)).exists()) + self.assertFalse(BitFieldTestModel.objects.exclude(flags=F('flags').bitor(BitFieldTestModel.flags.FLAG_0)).exists()) + self.assertFalse(BitFieldTestModel.objects.exclude(flags=F('flags').bitor(BitFieldTestModel.flags.FLAG_0).bitor(BitFieldTestModel.flags.FLAG_1)).exists()) + self.assertFalse(BitFieldTestModel.objects.exclude(flags=F('flags').bitand(BitFieldTestModel.flags.FLAG_0 | BitFieldTestModel.flags.FLAG_1)).exists()) + self.assertFalse(BitFieldTestModel.objects.exclude(flags=F('flags').bitand(15)).exists()) + self.assertFalse(BitFieldTestModel.objects.filter(flags=F('flags').bitand(BitFieldTestModel.flags.FLAG_2)).exists()) + self.assertFalse(BitFieldTestModel.objects.filter(flags=F('flags').bitand(BitFieldTestModel.flags.FLAG_3)).exists()) + self.assertFalse(BitFieldTestModel.objects.filter(flags=F('flags').bitand(BitFieldTestModel.flags.FLAG_2 | BitFieldTestModel.flags.FLAG_3)).exists()) + self.assertFalse(BitFieldTestModel.objects.filter(flags=F('flags').bitand(12)).exists()) + def test_update(self): instance = BitFieldTestModel.objects.create(flags=0) self.assertFalse(instance.flags.FLAG_0)