Skip to content

Commit

Permalink
Implement unique logic for in_interval(), add tests (swisscom#118)
Browse files Browse the repository at this point in the history
  • Loading branch information
brki committed Aug 30, 2016
1 parent d034838 commit 5b9b816
Show file tree
Hide file tree
Showing 2 changed files with 207 additions and 18 deletions.
123 changes: 105 additions & 18 deletions versions/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,11 @@

if VERSION[:2] >= (1, 8):
from django.db.models.sql.datastructures import Join
from django.db.models.expressions import RawSQL
if VERSION[:2] >= (1, 7):
from django.apps.registry import apps
from django.core.exceptions import SuspiciousOperation, ObjectDoesNotExist
from django.db import transaction
from django.db import transaction, connections
from django.db.models.base import Model
from django.db.models import Q
from django.db.models.constants import LOOKUP_SEP
Expand All @@ -52,6 +53,10 @@ def get_utc_now():
return datetime.datetime.utcnow().replace(tzinfo=utc)


def db_vendor(connection_string):
return connections[connection_string].vendor


class SimpleEqualityMixin(object):
def __eq__(self, other):
return type(self) == type(other) and self.__dict__ == other.__dict__
Expand All @@ -67,10 +72,16 @@ class QueryTimeInterval(SimpleEqualityMixin):
"""If true, only include last version, otherwise include all versions in interval"""

def __init__(self, start_time, end_time, unique=False):
if not (isinstance(start_time, datetime) and isinstance(end_time, datetime)):
if not (isinstance(start_time, datetime.datetime) and isinstance(end_time, datetime.datetime)):
raise ValueError("start_time and end_time parameters need to be datetime objects")
if not start_time <= datetime:
raise ValueError("start_time must not be later than end_time")
if end_time < start_time:
raise ValueError("start_time ({}) must not be later than end_time ({})".format(
start_time, end_time
))
if unique:
assert (VERSION[:2] >= (1, 8)), \
"Unique option for interval queries is only available when using Django >= 1.8"

self.start_time = start_time
self.end_time = end_time
self.unique = unique
Expand Down Expand Up @@ -147,6 +158,19 @@ def as_of(self, time=None):
"""
return self.get_queryset().as_of(time)

def in_interval(self, start, end, unique=False):
"""
Filters Versionables present in a given interval.
:param datetime start: The timestamp (including timezone info) representing the beginning of the interval.
:param datetime end: The timestamp (including timezone info) representing the end of the interval.
:param bool unique: If true, find only the latest matching version, otherwise find all matching versions.
:return: A QuerySet containing the base for a timestamped query.
"""
interval = QueryTimeInterval(start_time=start, end_time=end, unique=unique)
querytime = QueryTime(type=QueryTime.TYPE_INTERVAL, active=True, interval=interval)
return self.get_queryset().with_querytime(querytime)

def with_querytime(self, querytime):
"""
Filters Versionables with the given QueryTime
Expand Down Expand Up @@ -437,7 +461,7 @@ class VersionedExtraWhere(ExtraWhere):
"""
A specific implementation of ExtraWhere;
Before as_sql can be called on an object, ensure that calls to
- set_as_of and
- set_querytime and
- set_joined_alias
have been done
"""
Expand Down Expand Up @@ -530,11 +554,12 @@ def get_compiler(self, *args, **kwargs):
(e.g. by adding a filter to the queryset) does not allow the caching of related
object to work (they are attached to a queryset; filter() returns a new queryset).
"""
if self.querytime.active and (not hasattr(self, '_querytime_filter_added') or not self._querytime_filter_added):
if self.querytime.active and not getattr(self, '_querytime_filter_added', False):
using = kwargs.get('using', 'default')
if self.querytime.type == QueryTime.TYPE_POINT_IN_TIME:
self.add_point_in_time_filter()
elif self.querytime.type == QueryTime.TYPE_INTERVAL:
self.add_interval_filter()
self.add_interval_filter(using)
else:
raise RuntimeError("Unrecognized QueryTime.type")

Expand All @@ -553,19 +578,81 @@ def add_point_in_time_filter(self):
& Q(version_start_date__lte=time)
)

def add_interval_filter(self):
def add_interval_filter(self, using):
"""
Adds a filter to select only those versions of objects that occur in the interval defined
by self.querytime.interval.
:param str using: name of DB connection as defined in the settings (for example 'default').
:return:
"""
interval = self.querytime.interval
if interval.unique:
self.add_interval_unique_filter(interval, using)
else:
starts_before_end = Q(version_start_date__lt=interval.end_time)
ends_after_start = Q(version_end_date__isnull=True) | Q(version_end_date__gte=interval.start_time)
self.add_q(starts_before_end & ends_after_start)

current_version_intersects = Q(version_end_date__isnull=True, version_start_date__lte=interval.end_time)
encompasses_interval = Q(version_start_date__lt=interval.start_time, version_end_date__gt=interval.end_time)
intersects_interval = Q(
version_start_date__gte=interval.start_time,
version_start_date__lte=interval.end_time) | Q(
version_end_date__gte=interval.start_time,
version_end_date__lte=interval.end_time
)
terminated_version_matches = Q(version_end_date__isnull=False) & (encompasses_interval | intersects_interval)
self.add_q(current_version_intersects | terminated_version_matches)
def add_interval_unique_filter(self, interval, using):
"""
Adds a filter to select only the most recent version of the objects in the interval.
:param QueryTimeInterval interval: interval
:param str using: name of DB connection as defined in the settings (for example 'default').
:return:
"""
table_name = self.get_meta().db_table

if db_vendor(using) == 'postgresql':
# Use postgresql-specific syntax, that should be faster than generic sql.
raw_sql = """
SELECT t.id FROM (
SELECT id, rank() OVER
(PARTITION BY identity
ORDER BY version_end_date IS NULL DESC, version_end_date DESC
) AS rank
FROM {table}
WHERE version_start_date < %s
AND (version_end_date IS NULL OR version_end_date >= %s)
) t
WHERE t.rank = 1
""".format(table=table_name)
latest_restriction = Q(id__in=RawSQL(raw_sql, [interval.end_time, interval.start_time]))
else:
# Note this makes a separate query to first get the ids of the latest versions.
# For some reason it didn't work on sqlite when using a subquery to provide the
# list of ids.
# It also suffers from a year 10000 bug.
far_future = datetime.datetime(year=9999, month=12, day=31, tzinfo=utc)
raw_sql = """
SELECT vtc1.id
FROM {table} vtc1
LEFT OUTER JOIN {table} vtc2
ON (
vtc2.version_start_date < %s
AND (vtc2.version_end_date IS NULL OR vtc2.version_end_date >= %s)
AND vtc1.identity = vtc2.identity
AND (CASE WHEN vtc1.version_end_date IS NULL THEN %s ELSE vtc1.version_end_date END) <
(CASE WHEN vtc2.version_end_date IS NULL THEN %s ELSE vtc2.version_end_date END)
)
WHERE vtc2.id IS NULL
AND vtc1.version_start_date < %s
AND (vtc1.version_end_date IS NULL OR vtc1.version_end_date >= %s)
""".format(table=table_name)

params = [
interval.end_time, interval.start_time, # t2 interval restrictions
far_future, far_future, # far future date replaces null version_end_date for comparison
interval.end_time, interval.start_time # t1 interval restrictions
]

with connections[using].cursor() as c:
c.execute(raw_sql, params)
id_list = [row[0] for row in c.fetchall()]
if id_list:
latest_restriction = Q(id__in=id_list)
self.add_q(latest_restriction)

def build_filter(self, filter_expr, **kwargs):
"""
Expand Down
102 changes: 102 additions & 0 deletions versions_tests/tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2546,3 +2546,105 @@ def test_detach_with_relations(self):
t = Team.objects.current.get(pk=t_pk)
self.assertEqual({p, p2}, set(t.player_set.all()))
self.assertEqual([], list(t2.player_set.all()))

class IntervalQueryingTest(TestCase):

def setUp(self):
self.encompasses_interval_non_terminated = City.objects.create(name="Rome")
self.encompasses_interval_terminated = City.objects.create(name="Beichuan")
self.doomed_city = City.objects.create(name="Carthage")
self.doomed_city = self.doomed_city.clone()
self.doomed_city.delete()
self.overlaps_start = City.objects.create(name="Palenque")
sleep(0.001)
self.t0 = get_utc_now()
# Big interval starts here

sleep(0.001)
self.only_exists_in_interval = City.objects.create(name="Indecisiveness")
self.t1 = get_utc_now()

sleep(0.001)
self.only_exists_in_interval = self.only_exists_in_interval.clone()
self.only_exists_in_interval.name = "Indecisive"
self.only_exists_in_interval.save()
self.t2 = get_utc_now()

sleep(0.001)
self.only_exists_in_interval = self.only_exists_in_interval.clone()
self.only_exists_in_interval.name = "Umm"
self.only_exists_in_interval.save()
self.t3 = get_utc_now()

sleep(0.001)
self.overlaps_start.delete()
self.overlaps_end_non_terminated = City.objects.create(name="San Francisco")
self.only_exists_in_interval = self.only_exists_in_interval.clone()
self.only_exists_in_interval.name = "Anyone here?"
self.only_exists_in_interval.save()
self.t4 = get_utc_now()

sleep(0.001)
self.overlaps_end_terminated = City.objects.create(name="Bodie")
self.only_exists_in_interval.delete()
self.t5 = get_utc_now()
# Big interval ends here

sleep(0.001)
bodie_v2 = self.overlaps_end_terminated.clone()
bodie_v2.name += '-v2'
bodie_v2.save()
bodie_v2.delete()
self.encompasses_interval_terminated.delete()
self.newer_city = City.objects.create(name="Eko Atlantic")
self.newer_terminated_city = City.objects.create(name="Eko Pacific")
self.t6 = get_utc_now()

sleep(0.001)
self.newer_terminated_city.delete()

def test_simple_model_interval(self):

city_versions = City.objects.in_interval(self.t1, self.t3, unique=False)

self.assertSetEqual(
{'Indecisiveness', 'Indecisive', 'Umm', 'Rome', 'Palenque', 'Beichuan'},
{c.name for c in city_versions})
self.assertEqual(4, len({c.identity for c in city_versions}))

big_interval_cities = City.objects.in_interval(self.t0, self.t5, unique=False)
self.assertEqual(6, len({c.identity for c in big_interval_cities}))
self.assertEqual(9, len(big_interval_cities))
self.assertSetEqual(
{self.encompasses_interval_terminated.identity,
self.encompasses_interval_non_terminated.identity,
self.overlaps_start.identity,
self.only_exists_in_interval.identity,
self.overlaps_end_non_terminated.identity,
self.overlaps_end_terminated.identity},
{c.identity for c in big_interval_cities}
)

def test_interval_with_unique(self):
big_interval_cities = City.objects.in_interval(self.t0, self.t5, unique=True)
self.assertEqual(6, len(big_interval_cities))
self.assertEqual(6, len({c.identity for c in big_interval_cities}))
self.assertSetEqual(
{self.encompasses_interval_terminated.identity,
self.encompasses_interval_non_terminated.identity,
self.overlaps_start.identity,
self.only_exists_in_interval.identity,
self.overlaps_end_non_terminated.identity,
self.overlaps_end_terminated.identity},
{c.identity for c in big_interval_cities}
)

latest_version = big_interval_cities.filter(identity=self.only_exists_in_interval.identity)
self.assertEqual(1, len(latest_version))
self.assertEquals("Anyone here?", latest_version[0].name)

latest_version = big_interval_cities.filter(identity=self.overlaps_end_terminated.identity)
self.assertEqual(1, len(latest_version))
bodie = latest_version[0]
self.assertEquals("Bodie", bodie.name)
self.assertFalse(bodie.is_current)

0 comments on commit 5b9b816

Please sign in to comment.