diff --git a/tests/integration/api/test_augmentation.py b/tests/integration/api/test_augmentation.py index cae1b5b831..2f0b11a64c 100644 --- a/tests/integration/api/test_augmentation.py +++ b/tests/integration/api/test_augmentation.py @@ -31,11 +31,12 @@ def _test_augmentation( ).config train_config = config["data"]["train_subset"] train_config["input_size"] = (32, 32) + data_format = config["data"]["data_format"] # Load dataset dm_dataset = DmDataset.import_from( target_dataset_per_task[task_name], - format=config["data"]["data_format"], + format=data_format, ) mem_cache_handler = MemCacheHandlerSingleton.create( mode="sinlgeprocessing", @@ -60,6 +61,7 @@ def _test_augmentation( dm_subset=dm_dataset, cfg_subset=SubsetConfig(sampler=SamplerConfig(**train_config.pop("sampler", {})), **train_config), mem_cache_handler=mem_cache_handler, + data_format=data_format, ) # Check if all aug combinations are size-compatible