diff --git a/bitfield/models.py b/bitfield/models.py index 2ac13b4..19d0b08 100644 --- a/bitfield/models.py +++ b/bitfield/models.py @@ -6,7 +6,7 @@ from django.db.models.fields.subclassing import SubfieldBase except ImportError: # django 1.2 - from django.db.models.fields.subclassing import LegacyConnection as SubfieldBase + from django.db.models.fields.subclassing import LegacyConnection as SubfieldBase # NOQA from bitfield.forms import BitFormField from bitfield.query import BitQueryLookupWrapper @@ -96,7 +96,7 @@ def contribute_to_class(self, cls, name): class BitField(BigIntegerField): __metaclass__ = BitFieldMeta - def __init__(self, flags, *args, **kwargs): + def __init__(self, flags, default=None, *args, **kwargs): if isinstance(flags, dict): # Get only integer keys in correct range valid_keys = (k for k in flags.keys() if isinstance(k, int) and (0 <= k < MAX_FLAG_COUNT)) @@ -108,7 +108,13 @@ def __init__(self, flags, *args, **kwargs): if len(flags) > MAX_FLAG_COUNT: raise ValueError('Too many flags') - BigIntegerField.__init__(self, *args, **kwargs) + if isinstance(default, (list, tuple, set, frozenset)): + new_value = 0 + for flag in default: + new_value |= Bit(flags.index(flag)) + default = new_value + + BigIntegerField.__init__(self, default=default, *args, **kwargs) self.flags = flags def south_field_triple(self): diff --git a/bitfield/tests/models.py b/bitfield/tests/models.py index be6641a..be585e9 100644 --- a/bitfield/tests/models.py +++ b/bitfield/tests/models.py @@ -2,6 +2,7 @@ from bitfield import BitField, CompositeBitField + class BitFieldTestModel(models.Model): flags = BitField(flags=( 'FLAG_0', @@ -10,6 +11,7 @@ class BitFieldTestModel(models.Model): 'FLAG_3', ), default=3, db_column='another_name') + class CompositeBitFieldTestModel(models.Model): flags_1 = BitField(flags=( 'FLAG_0', @@ -27,4 +29,3 @@ class CompositeBitFieldTestModel(models.Model): 'flags_1', 'flags_2', )) - diff --git a/bitfield/tests/tests.py b/bitfield/tests/tests.py index 2da4314..5df60ea 100644 --- a/bitfield/tests/tests.py +++ b/bitfield/tests/tests.py @@ -1,6 +1,6 @@ import pickle -from django.db import connection +from django.db import connection, models from django.db.models import F from django.test import TestCase @@ -248,6 +248,17 @@ def test_dictionary_init(self): self.assertRaises(ValueError, BitField, flags={'wrongkey': 'wrongkey'}) self.assertRaises(ValueError, BitField, flags={'1': 'non_int_key'}) + def test_defaults_as_key_names(self): + class TestModel(models.Model): + flags = BitField(flags=( + 'FLAG_0', + 'FLAG_1', + 'FLAG_2', + 'FLAG_3', + ), default=('FLAG_1', 'FLAG_2')) + field = TestModel._meta.get_field('flags') + self.assertEquals(field.default, TestModel.flags.FLAG_1 | TestModel.flags.FLAG_2) + class BitFieldSerializationTest(TestCase): def test_can_unserialize_bithandler(self):