Skip to content

Commit

Permalink
Merge pull request #968 from JGreenlee/fix-localdate-ranges
Browse files Browse the repository at this point in the history
make localdate queries behave like canonical datetime range filters
  • Loading branch information
shankari authored May 17, 2024
2 parents 5bde217 + a326b1e commit 36432a8
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 56 deletions.
3 changes: 3 additions & 0 deletions emission/core/wrapper/localdate.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
import arrow
import emission.core.wrapper.wrapperbase as ecwb

# specify the order of time units, from largest to smallest
DATETIME_UNITS = ['year', 'month', 'day', 'hour', 'minute', 'second']

class LocalDate(ecwb.WrapperBase):
"""
Supporting wrapper class that stores the expansions of the components
Expand Down
79 changes: 36 additions & 43 deletions emission/storage/decorations/local_date_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,47 +11,40 @@

import emission.core.wrapper.localdate as ecwl

def get_range_query(field_name, start_local_dt, end_local_dt):
if list(start_local_dt.keys()) != list(end_local_dt.keys()):
raise RuntimeError("start_local_dt.keys() = %s does not match end_local_dt.keys() = %s" %
(list(start_local_dt.keys()), list(end_local_dt.keys())))
query_result = {}
for key in start_local_dt:
curr_field = "%s.%s" % (field_name, key)
gte_lte_query = {}
try:
start_int = int(start_local_dt[key])
except:
logging.info("start_local_dt[%s] = %s, not an integer, skipping" %
(key, start_local_dt[key]))
continue

try:
end_int = int(end_local_dt[key])
except:
logging.info("end_local_dt[%s] = %s, not an integer, skipping" %
(key, end_local_dt[key]))
continue

is_rollover = start_int > end_int

if is_rollover:
gte_lte_query = get_rollover_query(start_int, end_int)
else:
gte_lte_query = get_standard_query(start_int, end_int)

if len(gte_lte_query) > 0:
query_result.update({curr_field: gte_lte_query})
def get_range_query(field_prefix, start_ld, end_ld):
units = [u for u in ecwl.DATETIME_UNITS if u in start_ld and u in end_ld]
logging.debug(f'get_range_query: units = {units}')
try:
gt_query = get_comparison_query(field_prefix, start_ld, end_ld, units, 'gt')
lt_query = get_comparison_query(field_prefix, end_ld, start_ld, units, 'lt')
logging.debug(f'get_range_query: gt_query = {gt_query}, lt_query = {lt_query}')
return { "$and": [gt_query, lt_query] } if gt_query and lt_query else {}
except AssertionError as e:
logging.error(f'Invalid range from {str(start_ld)} to {str(end_ld)}: {str(e)}')
return None

def get_comparison_query(field_prefix, base_ld, limit_ld, units, gt_or_lt):
field_name = lambda i: f'{field_prefix}.{units[i]}'
and_conditions, or_conditions = [], []
tiebreaker_index = -1
for i, unit in enumerate(units):
# the range is inclusive, so if on the last unit we should use $lte / $gte instead of $lt / $gt
op = f'${gt_or_lt}e' if i == len(units)-1 else f'${gt_or_lt}'
if tiebreaker_index >= 0:
tiebreaker_conditions = [{ field_name(j): base_ld[units[j]] } for j in range(tiebreaker_index, i)]
tiebreaker_conditions.append({ field_name(i): { op: base_ld[unit] }})
or_conditions.append({ "$and": tiebreaker_conditions })
elif base_ld[unit] == limit_ld[unit]:
and_conditions.append({field_name(i): base_ld[unit]})
else:
logging.info("key %s exists, skipping because upper AND lower range are missing" % key)

logging.debug("In get_range_query, returning query %s" % query_result)
return query_result

def get_standard_query(start_int, end_int):
assert(start_int <= end_int)
return {'$gte': start_int, '$lte': end_int}

def get_rollover_query(start_int, end_int):
assert(start_int > end_int)
return {'$not': {'$gt': end_int, '$lt': start_int}}
assert (base_ld[unit] < limit_ld[unit]) if gt_or_lt == 'gt' else (base_ld[unit] > limit_ld[unit])
or_conditions.append({field_name(i): { op: base_ld[unit] }})
tiebreaker_index = i
if and_conditions and or_conditions:
return { "$and": and_conditions + [{ "$or": or_conditions }] }
elif and_conditions:
return { "$and": and_conditions }
elif or_conditions:
return { "$or": or_conditions }
else:
return {}
16 changes: 3 additions & 13 deletions emission/tests/storageTests/TestLocalDateQueries.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,24 +77,14 @@ def testLocalRangeStandardQuery(self):

def testLocalRangeRolloverQuery(self):
"""
Search for all entries between 8:18 and 8:20 local time, both inclusive
Search for all entries between 8:18 and 9:08 local time, both inclusive
"""
start_local_dt = ecwl.LocalDate({'year': 2015, 'month': 8, 'hour': 8, 'minute': 18})
end_local_dt = ecwl.LocalDate({'year': 2015, 'month': 8, 'hour': 9, 'minute': 8})
final_query = {"user_id": self.testUUID}
final_query.update(esdl.get_range_query("data.local_dt", start_local_dt, end_local_dt))
entries = edb.get_timeseries_db().find(final_query).sort('data.ts', pymongo.ASCENDING)
self.assertEqual(448, edb.get_timeseries_db().count_documents(final_query))

entries_list = list(entries)

# Note that since this is a set of filters, as opposed to a range, this
# returns all entries between 18 and 8 in both hours.
# so 8:18 is valid, but so is 9:57
self.assertEqual(ecwe.Entry(entries_list[0]).data.local_dt.hour, 8)
self.assertEqual(ecwe.Entry(entries_list[0]).data.local_dt.minute, 18)
self.assertEqual(ecwe.Entry(entries_list[-1]).data.local_dt.hour, 9)
self.assertEqual(ecwe.Entry(entries_list[-1]).data.local_dt.minute, 57)
entriesCnt = edb.get_timeseries_db().count_documents(final_query)
self.assertEqual(232, entriesCnt)

def testLocalMatchingQuery(self):
"""
Expand Down

0 comments on commit 36432a8

Please sign in to comment.