diff --git a/emission/tests/modellingTests/TestRunGreedyIncrementalModel.py b/emission/tests/modellingTests/TestRunGreedyIncrementalModel.py index 1529f8df5..656011f02 100644 --- a/emission/tests/modellingTests/TestRunGreedyIncrementalModel.py +++ b/emission/tests/modellingTests/TestRunGreedyIncrementalModel.py @@ -32,9 +32,15 @@ def setUp(self): logging.basicConfig( format='%(asctime)s:%(levelname)s:%(message)s', level=logging.DEBUG) + + # read test trips from a test file + input_file = 'emission/tests/data/real_examples/shankari_2016-06-20.expected_confirmed_trips' + with open(input_file, 'r') as f: + test_trips_json = json.load(f, object_hook=esj.wrapped_object_hook) + test_trips = [ecwe.Entry(r) for r in test_trips_json] + logging.debug(f'loaded {len(test_trips)} trips from {input_file}') - # emission/tests/data/real_examples/shankari_2016-06-20.expected_confirmed_trips - self.user_id = uuid.UUID('aa9fdec9-2944-446c-8ee2-50d79b3044d3') + self.user_id = test_trips[0]['user_id'] # all trips within the test file have the same user_id self.ts = esta.TimeSeries.get_time_series(self.user_id) self.new_trips_per_invocation = 3 self.model_type = eamumt.ModelType.GREEDY_SIMILARITY_BINNING @@ -52,15 +58,8 @@ def setUp(self): if len(existing_entries_for_user) != 0: raise Exception(f"test invariant failed, there should be no entries for user {self.user_id}") - # load in trips from a test file source - input_file = 'emission/tests/data/real_examples/shankari_2016-06-20.expected_confirmed_trips' - with open(input_file, 'r') as f: - trips_json = json.load(f, object_hook=esj.wrapped_object_hook) - trips = [ecwe.Entry(r) for r in trips_json] - logging.debug(f'loaded {len(trips)} trips from {input_file}') - self.ts.bulk_insert(trips) - - # confirm write to database succeeded + # write trips to database and confirm that they were written + self.ts.bulk_insert(test_trips) self.initial_data = list(self.ts.find_entries([esdatq.CONFIRMED_TRIP_KEY])) if len(self.initial_data) == 0: logging.debug(f'test setup failed while loading trips from file')