diff --git a/emission/core/wrapper/localdate.py b/emission/core/wrapper/localdate.py index a60eb04e4..1f71e737c 100644 --- a/emission/core/wrapper/localdate.py +++ b/emission/core/wrapper/localdate.py @@ -9,9 +9,6 @@ import arrow import emission.core.wrapper.wrapperbase as ecwb -# specify the order of time units, from largest to smallest -DATETIME_UNITS = ['year', 'month', 'day', 'hour', 'minute', 'second'] - class LocalDate(ecwb.WrapperBase): """ Supporting wrapper class that stores the expansions of the components diff --git a/emission/storage/decorations/local_date_queries.py b/emission/storage/decorations/local_date_queries.py index bc6e8bc53..de5ff8d60 100644 --- a/emission/storage/decorations/local_date_queries.py +++ b/emission/storage/decorations/local_date_queries.py @@ -11,52 +11,47 @@ import emission.core.wrapper.localdate as ecwl -def get_range_query(field_prefix, start_ld, end_ld): - units = [u for u in ecwl.DATETIME_UNITS if u in start_ld and u in end_ld] - logging.debug(f'get_range_query: units = {units}') - try: - gt_query = get_comparison_query(field_prefix, start_ld, end_ld, units, 'gt') - lt_query = get_comparison_query(field_prefix, end_ld, start_ld, units, 'lt') - logging.debug(f'get_range_query: gt_query = {gt_query}, lt_query = {lt_query}') - return { "$and": [gt_query, lt_query] } if gt_query and lt_query else {} - except AssertionError as e: - logging.error(f'Invalid range from {str(start_ld)} to {str(end_ld)}: {str(e)}') - return None - -def get_comparison_query(field_prefix, base_ld, limit_ld, units, gt_or_lt): - field_name = lambda i: f'{field_prefix}.{units[i]}' - and_conditions, or_conditions = [], [] - tiebreaker_index = -1 - for i, unit in enumerate(units): - # the range is inclusive, so if on the last unit we should use $lte / $gte instead of $lt / $gt - op = f'${gt_or_lt}e' if i == len(units)-1 else f'${gt_or_lt}' - if tiebreaker_index >= 0: - tiebreaker_conditions = [{ field_name(j): base_ld[units[j]] } for j in range(tiebreaker_index, i)] - tiebreaker_conditions.append({ field_name(i): { op: base_ld[unit] }}) - or_conditions.append({ "$and": tiebreaker_conditions }) - elif base_ld[unit] == limit_ld[unit]: - and_conditions.append({field_name(i): base_ld[unit]}) +def get_filter_query(field_name, start_local_dt, end_local_dt): + if list(start_local_dt.keys()) != list(end_local_dt.keys()): + raise RuntimeError("start_local_dt.keys() = %s does not match end_local_dt.keys() = %s" % + (list(start_local_dt.keys()), list(end_local_dt.keys()))) + query_result = {} + for key in start_local_dt: + curr_field = "%s.%s" % (field_name, key) + gte_lte_query = {} + try: + start_int = int(start_local_dt[key]) + except: + logging.info("start_local_dt[%s] = %s, not an integer, skipping" % + (key, start_local_dt[key])) + continue + + try: + end_int = int(end_local_dt[key]) + except: + logging.info("end_local_dt[%s] = %s, not an integer, skipping" % + (key, end_local_dt[key])) + continue + + is_rollover = start_int > end_int + + if is_rollover: + gte_lte_query = get_rollover_query(start_int, end_int) + else: + gte_lte_query = get_standard_query(start_int, end_int) + + if len(gte_lte_query) > 0: + query_result.update({curr_field: gte_lte_query}) else: - assert (base_ld[unit] < limit_ld[unit]) if gt_or_lt == 'gt' else (base_ld[unit] > limit_ld[unit]) - or_conditions.append({field_name(i): { op: base_ld[unit] }}) - tiebreaker_index = i - if and_conditions and or_conditions: - return { "$and": and_conditions + [{ "$or": or_conditions }] } - elif and_conditions: - return { "$and": and_conditions } - elif or_conditions: - return { "$or": or_conditions } - else: - return {} - -def yyyy_mm_dd_to_local_date(ymd: str) -> ecwl.LocalDate: - return ecwl.LocalDate({ - 'year': int(ymd[0:4]), - 'month': int(ymd[5:7]), - 'day': int(ymd[8:10]) - }) - -def get_yyyy_mm_dd_range_query(field_name, start_ymd: str, end_ymd: str) -> dict: - start_local_date = yyyy_mm_dd_to_local_date(start_ymd) - end_local_date = yyyy_mm_dd_to_local_date(end_ymd) - return get_range_query(field_name, start_local_date, end_local_date) + logging.info("key %s exists, skipping because upper AND lower bounds are missing" % key) + + logging.debug("In get_filter_query, returning query %s" % query_result) + return query_result + +def get_standard_query(start_int, end_int): + assert(start_int <= end_int) + return {'$gte': start_int, '$lte': end_int} + +def get_rollover_query(start_int, end_int): + assert(start_int > end_int) + return {'$not': {'$gt': end_int, '$lt': start_int}} diff --git a/emission/storage/timeseries/tcquery.py b/emission/storage/timeseries/tcquery.py index 830382dd2..88084bc08 100644 --- a/emission/storage/timeseries/tcquery.py +++ b/emission/storage/timeseries/tcquery.py @@ -11,8 +11,11 @@ class TimeComponentQuery(object): """ - Object that encapsulates a query for a particular time at the local time in - the timezone where the data was generated. + Object that encapsulates a query for filtering based on localdate objects. + This works as a set of filters for each localdate field, e.g. year, month, day, etc. + Useful for filtering on one or more localdate fields + e.g. TimeComponentQuery("data.start_local_dt", {"weekday": 0}, {"weekday": 4}) + For range queries, use FmtTimeQuery instead. """ def __init__(self, timeType, startLD, endLD): self.timeType = timeType @@ -20,4 +23,4 @@ def __init__(self, timeType, startLD, endLD): self.endLD = endLD def get_query(self): - return esdl.get_range_query(self.timeType, self.startLD, self.endLD) + return esdl.get_filter_query(self.timeType, self.startLD, self.endLD) diff --git a/emission/tests/storageTests/TestLocalDateQueries.py b/emission/tests/storageTests/TestLocalDateQueries.py index 68bb3fe42..f169be831 100644 --- a/emission/tests/storageTests/TestLocalDateQueries.py +++ b/emission/tests/storageTests/TestLocalDateQueries.py @@ -64,36 +64,46 @@ def testLocalDateReadWrite(self): self.assertEqual(ret_entry.data.local_dt.weekday, 2) self.assertEqual(ret_entry.data.fmt_time, "2016-04-13T15:32:09-07:00") - def testLocalRangeStandardQuery(self): + def testLocalDateFilterStandardQuery(self): """ Search for all entries between 8:18 and 8:20 local time, both inclusive """ start_local_dt = ecwl.LocalDate({'year': 2015, 'month': 8, 'hour': 8, 'minute': 18}) end_local_dt = ecwl.LocalDate({'year': 2015, 'month': 8, 'hour': 8, 'minute': 20}) final_query = {"user_id": self.testUUID} - final_query.update(esdl.get_range_query("data.local_dt", start_local_dt, end_local_dt)) + final_query.update(esdl.get_filter_query("data.local_dt", start_local_dt, end_local_dt)) entriesCnt = edb.get_timeseries_db().count_documents(final_query) self.assertEqual(15, entriesCnt) - def testLocalRangeRolloverQuery(self): + def testLocalDateFilterRolloverQuery(self): """ Search for all entries between 8:18 and 9:08 local time, both inclusive """ start_local_dt = ecwl.LocalDate({'year': 2015, 'month': 8, 'hour': 8, 'minute': 18}) end_local_dt = ecwl.LocalDate({'year': 2015, 'month': 8, 'hour': 9, 'minute': 8}) final_query = {"user_id": self.testUUID} - final_query.update(esdl.get_range_query("data.local_dt", start_local_dt, end_local_dt)) - entriesCnt = edb.get_timeseries_db().count_documents(final_query) - self.assertEqual(232, entriesCnt) + final_query.update(esdl.get_filter_query("data.local_dt", start_local_dt, end_local_dt)) + entries = edb.get_timeseries_db().find(final_query).sort('data.ts', pymongo.ASCENDING) + self.assertEqual(448, edb.get_timeseries_db().count_documents(final_query)) + + entries_list = list(entries) + + # Note that since this is a set of filters, as opposed to a range, this + # returns all entries between 18 and 8 in both hours. + # so 8:18 is valid, but so is 9:57 + self.assertEqual(ecwe.Entry(entries_list[0]).data.local_dt.hour, 8) + self.assertEqual(ecwe.Entry(entries_list[0]).data.local_dt.minute, 18) + self.assertEqual(ecwe.Entry(entries_list[-1]).data.local_dt.hour, 9) + self.assertEqual(ecwe.Entry(entries_list[-1]).data.local_dt.minute, 57) - def testLocalMatchingQuery(self): + def testLocalDateFilterMatchingQuery(self): """ Search for all entries that occur at minute = 8 from any hour """ start_local_dt = ecwl.LocalDate({'minute': 8}) end_local_dt = ecwl.LocalDate({'minute': 8}) final_query = {"user_id": self.testUUID} - final_query.update(esdl.get_range_query("data.local_dt", start_local_dt, end_local_dt)) + final_query.update(esdl.get_filter_query("data.local_dt", start_local_dt, end_local_dt)) entries_docs = edb.get_timeseries_db().find(final_query).sort("metadata.write_ts") self.assertEqual(20, edb.get_timeseries_db().count_documents(final_query)) entries = [ecwe.Entry(doc) for doc in entries_docs]