diff --git a/emission/core/wrapper/localdate.py b/emission/core/wrapper/localdate.py index 1f71e737c..a60eb04e4 100644 --- a/emission/core/wrapper/localdate.py +++ b/emission/core/wrapper/localdate.py @@ -9,6 +9,9 @@ import arrow import emission.core.wrapper.wrapperbase as ecwb +# specify the order of time units, from largest to smallest +DATETIME_UNITS = ['year', 'month', 'day', 'hour', 'minute', 'second'] + class LocalDate(ecwb.WrapperBase): """ Supporting wrapper class that stores the expansions of the components diff --git a/emission/storage/decorations/local_date_queries.py b/emission/storage/decorations/local_date_queries.py index 8425005eb..8a0e2d149 100644 --- a/emission/storage/decorations/local_date_queries.py +++ b/emission/storage/decorations/local_date_queries.py @@ -11,47 +11,40 @@ import emission.core.wrapper.localdate as ecwl -def get_range_query(field_name, start_local_dt, end_local_dt): - if list(start_local_dt.keys()) != list(end_local_dt.keys()): - raise RuntimeError("start_local_dt.keys() = %s does not match end_local_dt.keys() = %s" % - (list(start_local_dt.keys()), list(end_local_dt.keys()))) - query_result = {} - for key in start_local_dt: - curr_field = "%s.%s" % (field_name, key) - gte_lte_query = {} - try: - start_int = int(start_local_dt[key]) - except: - logging.info("start_local_dt[%s] = %s, not an integer, skipping" % - (key, start_local_dt[key])) - continue - - try: - end_int = int(end_local_dt[key]) - except: - logging.info("end_local_dt[%s] = %s, not an integer, skipping" % - (key, end_local_dt[key])) - continue - - is_rollover = start_int > end_int - - if is_rollover: - gte_lte_query = get_rollover_query(start_int, end_int) - else: - gte_lte_query = get_standard_query(start_int, end_int) - - if len(gte_lte_query) > 0: - query_result.update({curr_field: gte_lte_query}) +def get_range_query(field_prefix, start_ld, end_ld): + units = [u for u in ecwl.DATETIME_UNITS if u in start_ld and u in end_ld] + logging.debug(f'get_range_query: units = {units}') + try: + gt_query = get_comparison_query(field_prefix, start_ld, end_ld, units, 'gt') + lt_query = get_comparison_query(field_prefix, end_ld, start_ld, units, 'lt') + logging.debug(f'get_range_query: gt_query = {gt_query}, lt_query = {lt_query}') + return { "$and": [gt_query, lt_query] } if gt_query and lt_query else {} + except AssertionError as e: + logging.error(f'Invalid range from {str(start_ld)} to {str(end_ld)}: {str(e)}') + return None + +def get_comparison_query(field_prefix, base_ld, limit_ld, units, gt_or_lt): + field_name = lambda i: f'{field_prefix}.{units[i]}' + and_conditions, or_conditions = [], [] + tiebreaker_index = -1 + for i, unit in enumerate(units): + # the range is inclusive, so if on the last unit we should use $lte / $gte instead of $lt / $gt + op = f'${gt_or_lt}e' if i == len(units)-1 else f'${gt_or_lt}' + if tiebreaker_index >= 0: + tiebreaker_conditions = [{ field_name(j): base_ld[units[j]] } for j in range(tiebreaker_index, i)] + tiebreaker_conditions.append({ field_name(i): { op: base_ld[unit] }}) + or_conditions.append({ "$and": tiebreaker_conditions }) + elif base_ld[unit] == limit_ld[unit]: + and_conditions.append({field_name(i): base_ld[unit]}) else: - logging.info("key %s exists, skipping because upper AND lower range are missing" % key) - - logging.debug("In get_range_query, returning query %s" % query_result) - return query_result - -def get_standard_query(start_int, end_int): - assert(start_int <= end_int) - return {'$gte': start_int, '$lte': end_int} - -def get_rollover_query(start_int, end_int): - assert(start_int > end_int) - return {'$not': {'$gt': end_int, '$lt': start_int}} + assert (base_ld[unit] < limit_ld[unit]) if gt_or_lt == 'gt' else (base_ld[unit] > limit_ld[unit]) + or_conditions.append({field_name(i): { op: base_ld[unit] }}) + tiebreaker_index = i + if and_conditions and or_conditions: + return { "$and": and_conditions + [{ "$or": or_conditions }] } + elif and_conditions: + return { "$and": and_conditions } + elif or_conditions: + return { "$or": or_conditions } + else: + return {} diff --git a/emission/tests/storageTests/TestLocalDateQueries.py b/emission/tests/storageTests/TestLocalDateQueries.py index 3b7597d83..68bb3fe42 100644 --- a/emission/tests/storageTests/TestLocalDateQueries.py +++ b/emission/tests/storageTests/TestLocalDateQueries.py @@ -77,24 +77,14 @@ def testLocalRangeStandardQuery(self): def testLocalRangeRolloverQuery(self): """ - Search for all entries between 8:18 and 8:20 local time, both inclusive + Search for all entries between 8:18 and 9:08 local time, both inclusive """ start_local_dt = ecwl.LocalDate({'year': 2015, 'month': 8, 'hour': 8, 'minute': 18}) end_local_dt = ecwl.LocalDate({'year': 2015, 'month': 8, 'hour': 9, 'minute': 8}) final_query = {"user_id": self.testUUID} final_query.update(esdl.get_range_query("data.local_dt", start_local_dt, end_local_dt)) - entries = edb.get_timeseries_db().find(final_query).sort('data.ts', pymongo.ASCENDING) - self.assertEqual(448, edb.get_timeseries_db().count_documents(final_query)) - - entries_list = list(entries) - - # Note that since this is a set of filters, as opposed to a range, this - # returns all entries between 18 and 8 in both hours. - # so 8:18 is valid, but so is 9:57 - self.assertEqual(ecwe.Entry(entries_list[0]).data.local_dt.hour, 8) - self.assertEqual(ecwe.Entry(entries_list[0]).data.local_dt.minute, 18) - self.assertEqual(ecwe.Entry(entries_list[-1]).data.local_dt.hour, 9) - self.assertEqual(ecwe.Entry(entries_list[-1]).data.local_dt.minute, 57) + entriesCnt = edb.get_timeseries_db().count_documents(final_query) + self.assertEqual(232, entriesCnt) def testLocalMatchingQuery(self): """