Skip to content

Commit

Permalink
make localdate queries behave like canonical datetime range filters
Browse files Browse the repository at this point in the history
The existing implementation that was here was addressing the localdate fields separately and treating them like a set of filters rather than one date range from start -> end.
The "set of filters" was good enough for some things, like filtering to a specific month, but not for the canonical datetime range filtering where you want everything between an arbitrary start and end datetime.
The complexity in the logic to form the necessary NoSQL queries is more than expected, but after the dust settles it's about the same amount of code as the old implementation was.

In short, this implementation finds what units are used by both start and end localdates, walks down them (from largest unit to smallest unit), and combines $and and $or blocks to determine the upper and lower bound
  • Loading branch information
JGreenlee committed May 16, 2024
1 parent 5bde217 commit 2c341e5
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 44 deletions.
3 changes: 3 additions & 0 deletions emission/core/wrapper/localdate.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
import arrow
import emission.core.wrapper.wrapperbase as ecwb

# specify the order of time units, from largest to smallest
DATETIME_UNITS = ['year', 'month', 'day', 'hour', 'minute', 'second']

class LocalDate(ecwb.WrapperBase):
"""
Supporting wrapper class that stores the expansions of the components
Expand Down
79 changes: 36 additions & 43 deletions emission/storage/decorations/local_date_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,47 +11,40 @@

import emission.core.wrapper.localdate as ecwl

def get_range_query(field_name, start_local_dt, end_local_dt):
if list(start_local_dt.keys()) != list(end_local_dt.keys()):
raise RuntimeError("start_local_dt.keys() = %s does not match end_local_dt.keys() = %s" %
(list(start_local_dt.keys()), list(end_local_dt.keys())))
query_result = {}
for key in start_local_dt:
curr_field = "%s.%s" % (field_name, key)
gte_lte_query = {}
try:
start_int = int(start_local_dt[key])
except:
logging.info("start_local_dt[%s] = %s, not an integer, skipping" %
(key, start_local_dt[key]))
continue

try:
end_int = int(end_local_dt[key])
except:
logging.info("end_local_dt[%s] = %s, not an integer, skipping" %
(key, end_local_dt[key]))
continue

is_rollover = start_int > end_int

if is_rollover:
gte_lte_query = get_rollover_query(start_int, end_int)
else:
gte_lte_query = get_standard_query(start_int, end_int)

if len(gte_lte_query) > 0:
query_result.update({curr_field: gte_lte_query})
def get_range_query(field_prefix, start_ld, end_ld):
units = [u for u in ecwl.DATETIME_UNITS if u in start_ld and u in end_ld]
logging.debug(f'get_range_query: units = {units}')
try:
gt_query = get_comparison_query(field_prefix, start_ld, end_ld, units, 'gt')
lt_query = get_comparison_query(field_prefix, end_ld, start_ld, units, 'lt')
logging.debug(f'get_range_query: gt_query = {gt_query}, lt_query = {lt_query}')
return { "$and": [gt_query, lt_query] } if gt_query and lt_query else {}
except AssertionError as e:
logging.error(f'Invalid range from {str(start_ld)} to {str(end_ld)}: {str(e)}')
return None

def get_comparison_query(field_prefix, base_ld, limit_ld, units, gt_or_lt):
field_name = lambda i: f'{field_prefix}.{units[i]}'
and_conditions, or_conditions = [], []
tiebreaker_index = -1
for i, unit in enumerate(units):
# the range is inclusive, so if on the last unit we should use $lte / $gte instead of $lt / $gt
op = f'${gt_or_lt}e' if i == len(units)-1 else f'${gt_or_lt}'
if tiebreaker_index >= 0:
tiebreaker_conditions = [{ field_name(j): base_ld[units[j]] } for j in range(tiebreaker_index, i)]
tiebreaker_conditions.append({ field_name(i): { op: base_ld[unit] }})
or_conditions.append({ "$and": tiebreaker_conditions })
elif base_ld[unit] == limit_ld[unit]:
and_conditions.append({field_name(i): base_ld[unit]})
else:
logging.info("key %s exists, skipping because upper AND lower range are missing" % key)

logging.debug("In get_range_query, returning query %s" % query_result)
return query_result

def get_standard_query(start_int, end_int):
assert(start_int <= end_int)
return {'$gte': start_int, '$lte': end_int}

def get_rollover_query(start_int, end_int):
assert(start_int > end_int)
return {'$not': {'$gt': end_int, '$lt': start_int}}
assert (base_ld[unit] < limit_ld[unit]) if gt_or_lt == 'gt' else (base_ld[unit] > limit_ld[unit])
or_conditions.append({field_name(i): { op: base_ld[unit] }})
tiebreaker_index = i
if and_conditions and or_conditions:
return { "$and": and_conditions + [{ "$or": or_conditions }] }
elif and_conditions:
return { "$and": and_conditions }
elif or_conditions:
return { "$or": or_conditions }
else:
return {}
2 changes: 1 addition & 1 deletion emission/tests/storageTests/TestLocalDateQueries.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def testLocalRangeStandardQuery(self):

def testLocalRangeRolloverQuery(self):
"""
Search for all entries between 8:18 and 8:20 local time, both inclusive
Search for all entries between 8:18 and 9:08 local time, both inclusive
"""
start_local_dt = ecwl.LocalDate({'year': 2015, 'month': 8, 'hour': 8, 'minute': 18})
end_local_dt = ecwl.LocalDate({'year': 2015, 'month': 8, 'hour': 9, 'minute': 8})
Expand Down

0 comments on commit 2c341e5

Please sign in to comment.