diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml
index 60d06fde0..8f5474d0f 100644
--- a/.github/workflows/tests.yml
+++ b/.github/workflows/tests.yml
@@ -15,7 +15,7 @@ jobs:
fail-fast: false
matrix:
os: [ubuntu-latest, windows-latest, macos-latest]
- python-version: ['3.9', '3.10', '3.11', '3.12']
+ python-version: ['3.10', '3.11', '3.12']
exclude:
- os: windows-latest
python-version: '3.12'
diff --git a/.gitignore b/.gitignore
index a4d961a62..3f9bbe3c3 100644
--- a/.gitignore
+++ b/.gitignore
@@ -152,8 +152,7 @@ lightning_logs/
**/results/*
**/cache/*
**/outputs/*
-tests/tmp/cache/*
-tests/tmp/results/*
+tests/tmp/*
user_cfgs/*
inbuilt_cfgs/config.yml
@@ -162,8 +161,14 @@ inbuilt_cfgs/mf_config.yml
.vscode/settings.json
# Exclusions
+!docs/
+docs/*
+!docs/images/
+docs/images/*
!docs/images/*.png
+
!tests/fixtures/data/**/*.png
!tests/fixtures/data/**/*.txt
!tests/fixtures/data/**/*.csv
!tests/fixtures/data/**/*.tif
+!tests/fixtures/data/**/*.json
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 772a54862..4917e48f1 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -18,7 +18,7 @@ repos:
- id: requirements-txt-fixer
- repo: https://github.com/psf/black
- rev: 24.4.2
+ rev: 24.8.0
hooks:
- id: black
diff --git a/README.md b/README.md
index 4beccd94b..3b6c917e4 100644
--- a/README.md
+++ b/README.md
@@ -57,7 +57,7 @@ datasets with upcoming support for [torchvision](https://pytorch.org/vision/stab
Required Python modules for `minerva` are stated in the `setup.cfg`.
-`minerva` currently only supports `python` 3.9 -- 3.12.
+`minerva` currently only supports `python` 3.10 -- 3.12.
(back to top)
diff --git a/minerva/datasets/collators.py b/minerva/datasets/collators.py
index 7d00b88f3..2534e2b15 100644
--- a/minerva/datasets/collators.py
+++ b/minerva/datasets/collators.py
@@ -39,45 +39,36 @@
# =====================================================================================================================
# IMPORTS
# =====================================================================================================================
-from typing import Any, Callable, Dict, Iterable, Optional, Tuple
+from typing import Any, Callable, Iterable
+from hydra.utils import get_method
from torchgeo.datasets.utils import stack_samples
-from minerva.utils import utils
-
# =====================================================================================================================
# METHODS
# =====================================================================================================================
def get_collator(
- collator_params: Optional[Dict[str, str]] = None
+ collator_target: str = "torchgeo.datasets.stack_samples",
) -> Callable[..., Any]:
"""Gets the function defined in parameters to collate samples together to form a batch.
Args:
- collator_params (dict[str, str]): Optional; Dictionary that must contain keys for
- ``'module'`` and ``'name'`` of the collation function. Defaults to ``config['collator']``.
+ collator_target (str): Dot based import path for collator method.
+ Defaults to :meth:`torchgeo.datasets.stack_samples`
Returns:
- ~typing.Callable[..., ~typing.Any]: Collation function found from parameters given.
+ ~typing.Callable[..., ~typing.Any]: Collation function found from target path given.
"""
collator: Callable[..., Any]
- if collator_params is not None:
- module = collator_params.pop("module", "")
- if module == "":
- collator = globals()[collator_params["name"]]
- else:
- collator = utils.func_by_str(module, collator_params["name"])
- else:
- collator = stack_samples
-
+ collator = get_method(collator_target)
assert callable(collator)
return collator
def stack_sample_pairs(
- samples: Iterable[Tuple[Dict[Any, Any], Dict[Any, Any]]]
-) -> Tuple[Dict[Any, Any], Dict[Any, Any]]:
+ samples: Iterable[tuple[dict[Any, Any], dict[Any, Any]]]
+) -> tuple[dict[Any, Any], dict[Any, Any]]:
"""Takes a list of paired sample dicts and stacks them into a tuple of batches of sample dicts.
Args:
diff --git a/minerva/datasets/dfc.py b/minerva/datasets/dfc.py
index bbc5bf277..49f8255f7 100644
--- a/minerva/datasets/dfc.py
+++ b/minerva/datasets/dfc.py
@@ -38,7 +38,7 @@
# =====================================================================================================================
from glob import glob
from pathlib import Path
-from typing import Any, Callable, Dict, List, Optional, Tuple
+from typing import Any, Callable, Optional
import numpy as np
import pandas as pd
@@ -84,7 +84,7 @@ class BaseSenS12MS(NonGeoDataset):
S2_BANDS_MR = [5, 6, 7, 9, 12, 13]
S2_BANDS_LR = [1, 10, 11]
- splits: List[str] = []
+ splits: list[str] = []
igbp = False
@@ -133,13 +133,13 @@ def __init__(
# Make sure parent dir exists.
assert self.root.exists()
- self.samples: List[Dict[str, str]]
+ self.samples: list[dict[str, str]]
def load_sample(
self,
- sample: Dict[str, str],
+ sample: dict[str, str],
index: int,
- ) -> Dict[str, Any]:
+ ) -> dict[str, Any]:
"""Util function for reading data from single sample.
Args:
@@ -192,7 +192,7 @@ def get_ninputs(self) -> int:
n_inputs += 2
return n_inputs
- def get_display_channels(self) -> Tuple[List[int], int]:
+ def get_display_channels(self) -> tuple[list[int], int]:
"""Select channels for preview images.
Returns:
@@ -303,7 +303,7 @@ def load_lc(self, path: str) -> Tensor:
return lc
- def __getitem__(self, index: int) -> Dict[str, Any]:
+ def __getitem__(self, index: int) -> dict[str, Any]:
"""Get a single example from the dataset"""
# Get and load sample from index file.
@@ -393,11 +393,11 @@ def __init__(
def plot(
self,
- sample: Dict[str, Tensor],
+ sample: dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
- classes: Optional[Dict[int, str]] = None,
- colours: Optional[Dict[int, str]] = None,
+ classes: Optional[dict[int, str]] = None,
+ colours: Optional[dict[int, str]] = None,
) -> Figure:
"""Plot a sample from the dataset.
@@ -420,10 +420,15 @@ def plot(
# Reorder image from BGR to RGB.
from kornia.color import BgrToRgb
+ from minerva.transforms import AdjustGamma
+
bgr_to_rgb = BgrToRgb()
+ adjust_gamma = AdjustGamma(gamma=0.8)
+
+ # Reorder channels from BGR to RGB and adjust the gamma ready for plotting.
+ image = adjust_gamma(bgr_to_rgb(sample["image"][:3])).permute(1, 2, 0).numpy()
- image = bgr_to_rgb(sample["image"][:3])
- image = image.permute(1, 2, 0).numpy()
+ block_size_factor = 8
# Use inbuilt class colours and classes mappings.
if classes is None or colours is None:
@@ -454,7 +459,17 @@ def plot(
# Plot the image.
axs[0].imshow(image)
- axs[0].axis("off")
+
+ # Sets tick intervals to block size.
+ axs[0].set_xticks(
+ np.arange(0, image.shape[0] + 1, image.shape[0] // block_size_factor)
+ )
+ axs[0].set_yticks(
+ np.arange(0, image.shape[1] + 1, image.shape[1] // block_size_factor)
+ )
+
+ # Add grid overlay.
+ axs[0].grid(which="both", color="#CCCCCC", linestyle=":")
# Plot the ground truth mask and predicted mask.
mask_plot: Axes
@@ -462,22 +477,54 @@ def plot(
mask_plot = axs[1].imshow(
mask, cmap=cmap, vmin=vmin, vmax=vmax, interpolation="none"
)
- axs[1].axis("off")
+
+ # Sets tick intervals to block size.
+ axs[1].set_xticks(
+ np.arange(0, mask.shape[0] + 1, mask.shape[0] // block_size_factor)
+ )
+ axs[1].set_yticks(
+ np.arange(0, mask.shape[1] + 1, mask.shape[1] // block_size_factor)
+ )
+
+ # Add grid overlay.
+ axs[1].grid(which="both", color="#CCCCCC", linestyle=":")
+
if showing_prediction:
axs[2].imshow(
pred, cmap=cmap, vmin=vmin, vmax=vmax, interpolation="none"
)
- axs[2].axis("off")
+
+ # Sets tick intervals to block size.
+ axs[2].set_xticks(
+ np.arange(0, pred.shape[0] + 1, pred.shape[0] // block_size_factor)
+ )
+ axs[2].set_yticks(
+ np.arange(0, pred.shape[1] + 1, pred.shape[1] // block_size_factor)
+ )
+
+ # Add grid overlay.
+ axs[2].grid(which="both", color="#CCCCCC", linestyle=":")
+
elif showing_prediction:
mask_plot = axs[1].imshow(
pred, cmap=cmap, vmin=vmin, vmax=vmax, interpolation="none"
)
- axs[1].axis("off")
+
+ # Sets tick intervals to block size.
+ axs[1].set_xticks(
+ np.arange(0, pred.shape[0] + 1, pred.shape[0] // block_size_factor)
+ )
+ axs[1].set_yticks(
+ np.arange(0, pred.shape[1] + 1, pred.shape[1] // block_size_factor)
+ )
+
+ # Add grid overlay.
+ axs[1].grid(which="both", color="#CCCCCC", linestyle=":")
if showing_mask or showing_prediction:
# Plots colour bar onto figure.
clb = fig.colorbar(
- mask_plot,
+ mask_plot, # type: ignore[arg-type]
ax=axs,
location="top",
ticks=np.arange(0, len(colours)),
diff --git a/minerva/datasets/factory.py b/minerva/datasets/factory.py
index eaaedf027..c3e4b3adb 100644
--- a/minerva/datasets/factory.py
+++ b/minerva/datasets/factory.py
@@ -49,10 +49,10 @@
import re
from copy import deepcopy
from datetime import timedelta
-from inspect import signature
from pathlib import Path
-from typing import Any, Callable, Dict, Iterable, List, Literal, Optional, Tuple, Union
+from typing import Any, Iterable, Optional
+import hydra
import numpy as np
import pandas as pd
import torch
@@ -60,17 +60,15 @@
from omegaconf import OmegaConf
from pandas import DataFrame
from rasterio.crs import CRS
-from torch.utils.data import DataLoader, Sampler
+from torch.utils.data import DataLoader
from torchgeo.datasets import GeoDataset, NonGeoDataset, RasterDataset
-from torchgeo.samplers import BatchGeoSampler, GeoSampler
from torchgeo.samplers.utils import _to_tuple
-from minerva.samplers import DistributedSamplerWrapper
+from minerva.samplers import DistributedSamplerWrapper, get_sampler
from minerva.transforms import MinervaCompose, init_auto_norm, make_transformations
from minerva.utils import universal_path, utils
from .collators import get_collator, stack_sample_pairs
-from .paired import PairedGeoDataset, PairedNonGeoDataset
from .utils import (
MinervaConcatDataset,
cache_dataset,
@@ -78,7 +76,6 @@
intersect_datasets,
load_all_samples,
load_dataset_from_cache,
- make_bounding_box,
masks_or_labels,
unionise_datasets,
)
@@ -88,92 +85,67 @@
# METHODS
# =====================================================================================================================
def create_subdataset(
- dataset_class: Union[Callable[..., GeoDataset], Callable[..., NonGeoDataset]],
- paths: Union[str, Iterable[str]],
- subdataset_params: Dict[Literal["params"], Dict[str, Any]],
+ paths: str | Iterable[str],
+ subdataset_params: dict[str, Any],
transformations: Optional[Any],
sample_pairs: bool = False,
-) -> Union[GeoDataset, NonGeoDataset]:
+) -> GeoDataset | NonGeoDataset:
"""Creates a sub-dataset based on the parameters supplied.
Args:
- dataset_class (Callable[..., ~typing.datasets.GeoDataset]): Constructor for the sub-dataset.
paths (str | ~typing.Iterable[str]): Paths to where the data for the dataset is located.
subdataset_params (dict[Literal[params], dict[str, ~typing.Any]]): Parameters for the sub-dataset.
transformations (~typing.Any): Transformations to apply to this sub-dataset.
sample_pairs (bool): Will configure the dataset for paired sampling. Defaults to False.
Returns:
- ~torchgeo.datasets.GeoDataset: Subdataset requested.
+ ~torchgeo.datasets.GeoDataset | ~torchgeo.datasets.NonGeoDataset: Subdataset requested.
"""
- copy_params = deepcopy(subdataset_params)
+ params = deepcopy(subdataset_params)
- if "crs" in copy_params["params"]:
- copy_params["params"]["crs"] = CRS.from_epsg(copy_params["params"]["crs"])
+ crs = None
+ if "crs" in params:
+ crs = CRS.from_epsg(params["crs"])
- if sample_pairs:
- if "paths" in signature(dataset_class).parameters:
- return PairedGeoDataset(
- dataset_class, # type: ignore[arg-type]
- paths=paths,
- transforms=transformations,
- **copy_params["params"],
- )
+ subdataset: GeoDataset | NonGeoDataset
+ if "paths" in params:
+ if sample_pairs:
+ params["dataset"] = params["_target_"]
+ params["_target_"] = "minerva.datasets.paired.PairedGeoDataset"
+
+ subdataset = hydra.utils.instantiate(
+ params, paths=paths, transforms=transformations, crs=crs
+ )
+
+ elif "root" in params:
+ if isinstance(paths, list):
+ paths = paths[0]
+ assert isinstance(paths, str)
+
+ if sample_pairs:
+ params["dataset"] = params["_target_"]
+ params["_target_"] = "minerva.datasets.paired.PairedNonGeoDataset"
+
+ subdataset = hydra.utils.instantiate(
+ params, root=paths, transforms=transformations
+ )
- elif "season_transform" in signature(dataset_class).parameters:
- if isinstance(paths, list):
- paths = paths[0]
- assert isinstance(paths, str)
- del copy_params["params"]["season_transform"]
- return PairedNonGeoDataset(
- dataset_class, # type: ignore[arg-type]
- root=paths,
- transforms=transformations,
- season=True,
- season_transform="pair",
- **copy_params["params"],
- )
- elif "root" in signature(dataset_class).parameters:
- if isinstance(paths, list):
- paths = paths[0]
- assert isinstance(paths, str)
- return PairedNonGeoDataset(
- dataset_class, # type: ignore[arg-type]
- root=paths,
- transforms=transformations,
- **copy_params["params"],
- )
- else:
- raise TypeError
else:
- if "paths" in signature(dataset_class).parameters:
- return dataset_class(
- paths=paths,
- transforms=transformations,
- **copy_params["params"],
- )
- elif "root" in signature(dataset_class).parameters:
- if isinstance(paths, list):
- paths = paths[0]
- assert isinstance(paths, str)
- return dataset_class(
- root=paths,
- transforms=transformations,
- **copy_params["params"],
- )
- else:
- raise TypeError
+ raise TypeError
+
+ return subdataset
def get_subdataset(
- data_directory: Union[Iterable[str], str, Path],
- dataset_params: Dict[str, Any],
+ data_directory: Iterable[str] | str | Path,
+ dataset_params: dict[str, Any],
key: str,
transformations: Optional[Any],
sample_pairs: bool = False,
cache: bool = True,
- cache_dir: Union[str, Path] = "",
-) -> Union[GeoDataset, NonGeoDataset]:
+ cache_dir: str | Path = "",
+ auto_norm: bool = False,
+) -> GeoDataset | NonGeoDataset:
"""Get a subdataset based on the parameters specified.
If ``cache==True``, this will attempt to load a cached version of the dataset instance.
@@ -196,17 +168,13 @@ def get_subdataset(
# Get the params for this sub-dataset.
sub_dataset_params = dataset_params[key]
- # Get the constructor for the class of dataset defined in params.
- _sub_dataset: Callable[..., GeoDataset] = utils.func_by_str(
- module_path=sub_dataset_params["module"], func=sub_dataset_params["name"]
- )
-
# Construct the path to the sub-dataset's files.
sub_dataset_paths = utils.compile_dataset_paths(
- universal_path(data_directory), sub_dataset_params["paths"]
+ universal_path(data_directory),
+ sub_dataset_params.get("paths", sub_dataset_params.get("root")),
)
- sub_dataset: Optional[Union[GeoDataset, NonGeoDataset]]
+ sub_dataset: Optional[GeoDataset | NonGeoDataset]
if cache or sub_dataset_params.get("cache_dataset"):
this_hash = utils.make_hash(sub_dataset_params)
@@ -230,7 +198,6 @@ def get_subdataset(
if rank == 0:
print(f"\nCreating dataset on {rank}...")
sub_dataset = create_subdataset(
- _sub_dataset,
sub_dataset_paths,
sub_dataset_params,
transformations,
@@ -255,7 +222,6 @@ def get_subdataset(
else:
print("\nCreating dataset...")
sub_dataset = create_subdataset(
- _sub_dataset,
sub_dataset_paths,
sub_dataset_params,
transformations,
@@ -267,7 +233,6 @@ def get_subdataset(
else:
sub_dataset = create_subdataset(
- _sub_dataset,
sub_dataset_paths,
sub_dataset_params,
transformations,
@@ -275,16 +240,31 @@ def get_subdataset(
)
assert sub_dataset is not None
+
+ if auto_norm:
+ if isinstance(sub_dataset, RasterDataset):
+ init_auto_norm(
+ sub_dataset,
+ length=sub_dataset_params.get("length"),
+ roi=sub_dataset_params.get("roi"),
+ )
+ else:
+ raise TypeError( # pragma: no cover
+ "AutoNorm only supports normalisation of data "
+ + f"from RasterDatasets, not {type(sub_dataset)}!"
+ )
+
return sub_dataset
def make_dataset(
- data_directory: Union[Iterable[str], str, Path],
- dataset_params: Dict[Any, Any],
+ data_directory: Iterable[str] | str | Path,
+ dataset_params: dict[Any, Any],
sample_pairs: bool = False,
+ change_detection: bool = False,
cache: bool = True,
- cache_dir: Union[str, Path] = "",
-) -> Tuple[Any, List[Any]]:
+ cache_dir: str | Path = "",
+) -> tuple[Any, list[Any]]:
"""Constructs a dataset object from ``n`` sub-datasets given by the parameters supplied.
Args:
@@ -293,6 +273,8 @@ def make_dataset(
dataset_params (dict[~typing.Any, ~typing.Any]): Dictionary of parameters defining each sub-datasets to be used.
sample_pairs (bool): Optional; ``True`` if paired sampling. This will ensure paired samples are handled
correctly in the datasets.
+ change_detection (bool): Flag for a change detection dataset which has
+ ``"image1"`` and ``"image2"`` keys rather than ``"image"``.
cache (bool): Cache the dataset or load from cache if pre-existing. Defaults to True.
cache_dir (str | ~pathlib.Path): Path to the directory to save the cached dataset (if ``cache==True``).
Defaults to CWD.
@@ -303,9 +285,7 @@ def make_dataset(
"""
# --+ MAKE SUB-DATASETS +=========================================================================================+
# List to hold all the sub-datasets defined by dataset_params to be intersected together into a single dataset.
- sub_datasets: Union[
- List[GeoDataset], List[Union[NonGeoDataset, MinervaConcatDataset]]
- ] = []
+ sub_datasets: list[GeoDataset] | list[NonGeoDataset | MinervaConcatDataset] = []
if OmegaConf.is_config(dataset_params):
dataset_params = OmegaConf.to_object(dataset_params) # type: ignore[assignment]
@@ -339,74 +319,75 @@ def make_dataset(
add_target_transforms = type_dataset_params["transforms"]
continue
- type_subdatasets: Union[List[GeoDataset], List[NonGeoDataset]] = []
+ type_subdatasets: list[GeoDataset] | list[NonGeoDataset] = []
multi_datasets_exist = False
- auto_norm = None
+ auto_norm = False
master_transforms: Optional[Any] = None
- for area_key in type_dataset_params.keys():
+
+ for sub_type_key in type_dataset_params.keys():
# If any of these keys are present, this must be a parameter set for a singular dataset at this level.
- if area_key in ("module", "name", "params", "paths"):
+ if sub_type_key in ("_target_", "paths", "root"):
multi_datasets_exist = False
continue
# If there are transforms specified, make them. These could cover a single dataset or many.
- elif area_key == "transforms":
- if isinstance(type_dataset_params[area_key], dict):
- transform_params = type_dataset_params[area_key]
- auto_norm = transform_params.get("AutoNorm")
+ elif sub_type_key == "transforms":
+ if isinstance(type_dataset_params[sub_type_key], dict):
+ transform_params = type_dataset_params[sub_type_key]
+ auto_norm = transform_params.get("AutoNorm", False)
# If transforms aren't specified for a particular modality of the sample,
# assume they're for the same type as the dataset.
if (
not ("image", "mask", "label")
- in type_dataset_params[area_key].keys()
+ in type_dataset_params[sub_type_key].keys()
):
- transform_params = {type_key: type_dataset_params[area_key]}
+ transform_params = {type_key: type_dataset_params[sub_type_key]}
else:
transform_params = False
- master_transforms = make_transformations(transform_params)
+ master_transforms = make_transformations(
+ transform_params, change_detection=change_detection
+ )
# Assuming that these keys are names of datasets.
- else:
+ elif sub_type_key == "subdatasets":
multi_datasets_exist = True
- if isinstance(type_dataset_params[area_key].get("transforms"), dict):
- transform_params = type_dataset_params[area_key]["transforms"]
- auto_norm = transform_params.get("AutoNorm")
+ if isinstance(
+ type_dataset_params[sub_type_key].get("transforms"), dict
+ ):
+ transform_params = type_dataset_params[sub_type_key]["transforms"]
+ auto_norm = transform_params.get("AutoNorm", False)
else:
transform_params = False
- transformations = make_transformations({type_key: transform_params})
+ transformations = make_transformations(
+ {type_key: transform_params}, change_detection=change_detection
+ )
# Send the params for this area key back through this function to make the sub-dataset.
- sub_dataset = get_subdataset(
- data_directory,
- type_dataset_params,
- area_key,
- transformations,
- sample_pairs=sample_pairs,
- cache=cache,
- cache_dir=cache_dir,
- )
+ for area_key in type_dataset_params[sub_type_key]:
+ sub_dataset = get_subdataset(
+ data_directory,
+ type_dataset_params[sub_type_key],
+ area_key,
+ transformations,
+ sample_pairs=sample_pairs,
+ cache=cache,
+ cache_dir=cache_dir,
+ auto_norm=auto_norm,
+ )
- # Performs an auto-normalisation initialisation which finds the mean and std of the dataset
- # to make a transform, then adds the transform to the dataset's existing transforms.
- if auto_norm:
- if isinstance(sub_dataset, RasterDataset):
- init_auto_norm(sub_dataset, auto_norm)
- else:
- raise TypeError( # pragma: no cover
- "AutoNorm only supports normalisation of data "
- + f"from RasterDatasets, not {type(sub_dataset)}!"
- )
+ # Reset back to False.
+ auto_norm = False
- # Reset back to None.
- auto_norm = None
+ type_subdatasets.append(sub_dataset) # type: ignore[arg-type]
- type_subdatasets.append(sub_dataset) # type: ignore[arg-type]
+ else:
+ continue
# Unionise all the sub-datsets of this modality together.
if multi_datasets_exist:
@@ -425,20 +406,11 @@ def make_dataset(
sample_pairs=sample_pairs,
cache=cache,
cache_dir=cache_dir,
+ auto_norm=auto_norm,
)
- # Performs an auto-normalisation initialisation which finds the mean and std of the dataset
- # to make a transform, then adds the transform to the dataset's existing transforms.
- if auto_norm:
- if isinstance(sub_dataset, RasterDataset):
- init_auto_norm(sub_dataset, auto_norm)
-
- # Reset back to None.
- auto_norm = None
- else:
- raise TypeError( # pragma: no cover
- f"AutoNorm only supports normalisation of data from RasterDatasets, not {type(sub_dataset)}!"
- )
+ # Reset back to False.
+ auto_norm = False
sub_datasets.append(sub_dataset) # type: ignore[arg-type]
@@ -450,14 +422,16 @@ def make_dataset(
if add_target_transforms is not None:
target_key = masks_or_labels(dataset_params)
- target_transforms = make_transformations({target_key: add_target_transforms})
+ target_transforms = make_transformations(
+ {target_key: add_target_transforms}, change_detection=change_detection
+ )
if hasattr(dataset, "transforms"):
if isinstance(dataset.transforms, MinervaCompose):
assert target_transforms is not None
- dataset.transforms += target_transforms
+ dataset.transforms += target_transforms # type: ignore[union-attr]
else:
- dataset.transforms = target_transforms
+ dataset.transforms = target_transforms # type: ignore[union-attr]
else:
raise TypeError(
f"dataset of type {type(dataset)} has no ``transforms`` atttribute!"
@@ -465,14 +439,15 @@ def make_dataset(
if add_multi_modal_transforms is not None:
multi_modal_transforms = make_transformations(
- {"both": add_multi_modal_transforms}
+ {"both": add_multi_modal_transforms},
+ change_detection=change_detection,
)
if hasattr(dataset, "transforms"):
if isinstance(dataset.transforms, MinervaCompose):
assert multi_modal_transforms is not None
- dataset.transforms += multi_modal_transforms
+ dataset.transforms += multi_modal_transforms # type: ignore[union-attr]
else:
- dataset.transforms = multi_modal_transforms
+ dataset.transforms = multi_modal_transforms # type: ignore[union-attr]
else:
raise TypeError(
f"dataset of type {type(dataset)} has no ``transforms`` atttribute!"
@@ -482,17 +457,18 @@ def make_dataset(
def construct_dataloader(
- data_directory: Union[Iterable[str], str, Path],
- dataset_params: Dict[str, Any],
- sampler_params: Dict[str, Any],
- dataloader_params: Dict[str, Any],
+ data_directory: Iterable[str] | str | Path,
+ dataset_params: dict[str, Any],
+ sampler_params: dict[str, Any],
+ dataloader_params: dict[str, Any],
batch_size: int,
- collator_params: Optional[Dict[str, Any]] = None,
+ collator_target: str = "torchgeo.datasets.stack_samples",
rank: int = 0,
world_size: int = 1,
sample_pairs: bool = False,
+ change_detection: bool = False,
cache: bool = True,
- cache_dir: Union[Path, str] = "",
+ cache_dir: Path | str = "",
) -> DataLoader[Iterable[Any]]:
"""Constructs a :class:`~torch.utils.data.DataLoader` object from the parameters provided for the
datasets, sampler, collator and transforms.
@@ -505,12 +481,14 @@ def construct_dataloader(
to sample from the dataset.
dataloader_params (dict[str, ~typing.Any]): Dictionary of parameters for the DataLoader itself.
batch_size (int): Number of samples per (global) batch.
- collator_params (dict[str, ~typing.Any]): Optional; Dictionary of parameters defining the function to collate
+ collator_target (str): Import target path for collator function to collate
and stack samples from the sampler.
rank (int): Optional; The rank of this process for distributed computing.
world_size (int): Optional; The total number of processes within a distributed run.
sample_pairs (bool): Optional; True if paired sampling. This will wrap the collation function
for paired samples.
+ change_detection (bool): Flag for if using a change detection dataset which has
+ ``"image1"`` and ``"image2"`` keys rather than ``"image"``.
Returns:
~torch.utils.data.DataLoader: Object to handle the returning of batched samples from the dataset.
@@ -519,42 +497,29 @@ def construct_dataloader(
data_directory,
dataset_params,
sample_pairs=sample_pairs,
+ change_detection=change_detection,
cache=cache,
cache_dir=cache_dir,
)
# --+ MAKE SAMPLERS +=============================================================================================+
- _sampler: Callable[..., Union[BatchGeoSampler, GeoSampler]] = utils.func_by_str(
- module_path=sampler_params["module"], func=sampler_params["name"]
- )
+ per_device_batch_size = None
- batch_sampler = True if re.search(r"Batch", sampler_params["name"]) else False
+ batch_sampler = True if re.search(r"Batch", sampler_params["_target_"]) else False
if batch_sampler:
- sampler_params["params"]["batch_size"] = batch_size
-
+ per_device_batch_size = sampler_params.get("batch_size", batch_size)
if dist.is_available() and dist.is_initialized(): # type: ignore[attr-defined]
- assert (
- sampler_params["params"]["batch_size"] % world_size == 0
- ) # pragma: no cover
+ assert per_device_batch_size % world_size == 0 # pragma: no cover
per_device_batch_size = (
- sampler_params["params"]["batch_size"] // world_size
+ per_device_batch_size // world_size
) # pragma: no cover
- sampler_params["params"][
- "batch_size"
- ] = per_device_batch_size # pragma: no cover
-
- sampler: Sampler[Any]
- if "roi" in signature(_sampler).parameters:
- sampler = _sampler(
- subdatasets[0],
- roi=make_bounding_box(sampler_params["roi"]),
- **sampler_params["params"],
- )
- else:
- sampler = _sampler(subdatasets[0], **sampler_params["params"])
+
+ sampler = get_sampler(
+ sampler_params, subdatasets[0], batch_size=per_device_batch_size
+ )
# --+ MAKE DATALOADERS +==========================================================================================+
- collator = get_collator(collator_params)
+ collator = get_collator(collator_target)
# Add batch size from top-level parameters to the dataloader parameters.
dataloader_params["batch_size"] = batch_size
@@ -588,11 +553,11 @@ def construct_dataloader(
def _add_class_transform(
- class_matrix: Dict[int, int], dataset_params: Dict[str, Any], target_key: str
-) -> Dict[str, Any]:
+ class_matrix: dict[int, int], dataset_params: dict[str, Any], target_key: str
+) -> dict[str, Any]:
class_transform = {
"ClassTransform": {
- "module": "minerva.transforms",
+ "_target_": "minerva.transforms.ClassTransform",
"transform": class_matrix,
}
}
@@ -615,12 +580,13 @@ def _make_loader(
dataset_params,
sampler_params,
dataloader_params,
- collator_params,
+ collator_target,
class_matrix,
batch_size,
model_type,
elim,
sample_pairs,
+ change_detection,
cache,
):
target_key = None
@@ -640,10 +606,11 @@ def _make_loader(
sampler_params,
dataloader_params,
batch_size,
- collator_params=collator_params,
+ collator_target=collator_target,
rank=rank,
world_size=world_size,
sample_pairs=sample_pairs,
+ change_detection=change_detection,
cache=cache,
cache_dir=cache_dir,
)
@@ -651,9 +618,9 @@ def _make_loader(
# Calculates number of batches.
assert hasattr(loaders.dataset, "__len__")
n_batches = int(
- sampler_params["params"].get(
+ sampler_params.get(
"length",
- sampler_params["params"].get("num_samples", len(loaders.dataset)),
+ sampler_params.get("num_samples", len(loaders.dataset)),
)
/ batch_size
)
@@ -667,11 +634,11 @@ def make_loaders(
p_dist: bool = False,
task_name: Optional[str] = None,
**params,
-) -> Tuple[
- Union[Dict[str, DataLoader[Iterable[Any]]], DataLoader[Iterable[Any]]],
- Union[Dict[str, int], int],
- List[Tuple[int, int]],
- Dict[Any, Any],
+) -> tuple[
+ dict[str, DataLoader[Iterable[Any]]] | DataLoader[Iterable[Any]],
+ dict[str, int] | int,
+ list[tuple[int, int]],
+ dict[Any, Any],
]:
"""Constructs train, validation and test datasets and places into :class:`~torch.utils.data.DataLoader` objects.
@@ -696,9 +663,7 @@ def make_loaders(
sampler_params (dict[str, ~typing.Any]): Parameters to construct the samplers for each mode of model fitting.
transform_params (dict[str, ~typing.Any]): Parameters to construct the transforms for each dataset.
See documentation for the structure of these.
- collator (dict[str, ~typing.Any]): Defines the collator to use that will collate samples together into batches.
- Contains the ``module`` key to define the import path and the ``name`` key
- for name of the collation function.
+ collator (str): Optional; Defines the collator to use that will collate samples together into batches.
sample_pairs (bool): Activates paired sampling for Siamese models. Only used for ``train`` datasets.
Returns:
@@ -717,14 +682,14 @@ def make_loaders(
cache_dir = params["cache_dir"]
# Gets out the parameters for the DataLoaders from params.
- dataloader_params: Dict[Any, Any] = deepcopy(
+ dataloader_params: dict[Any, Any] = deepcopy(
utils.fallback_params("loader_params", task_params, params)
)
if OmegaConf.is_config(dataloader_params):
dataloader_params = OmegaConf.to_object(dataloader_params) # type: ignore[assignment]
- dataset_params: Dict[str, Any] = utils.fallback_params(
+ dataset_params: dict[str, Any] = utils.fallback_params(
"dataset_params", task_params, params
)
@@ -734,7 +699,7 @@ def make_loaders(
batch_size: int = utils.fallback_params("batch_size", task_params, params)
model_type = utils.fallback_params("model_type", task_params, params)
- class_dist: List[Tuple[int, int]] = [(0, 0)]
+ class_dist: list[tuple[int, int]] = [(0, 0)]
classes = utils.fallback_params("classes", data_config, params, None)
cmap_dict = utils.fallback_params("colours", data_config, params, None)
@@ -745,28 +710,26 @@ def make_loaders(
if n_classes:
classes = {i: f"class {i}" for i in range(n_classes)}
- new_classes: Dict[int, str] = {}
- new_colours: Dict[int, str] = {}
- class_matrix: Dict[int, int] = {}
+ new_classes: dict[int, str] = {}
+ new_colours: dict[int, str] = {}
+ class_matrix: dict[int, int] = {}
+
+ sample_pairs = utils.fallback_params("sample_pairs", task_params, params, False)
- sample_pairs: Union[bool, Any] = utils.fallback_params(
- "sample_pairs", task_params, params, False
+ change_detection = utils.fallback_params(
+ "change_detection", task_params, params, False
)
- if not isinstance(sample_pairs, bool): # pragma: no cover
- sample_pairs = False
elim = utils.fallback_params("elim", task_params, params, False)
cache = utils.fallback_params("cache_dataset", task_params, params, True)
- n_batches: Union[Dict[str, int], int]
- loaders: Union[Dict[str, DataLoader[Iterable[Any]]], DataLoader[Iterable[Any]]]
+ n_batches: dict[str, int] | int
+ loaders: dict[str, DataLoader[Iterable[Any]]] | DataLoader[Iterable[Any]]
- collator_params = deepcopy(utils.fallback_params("collator", task_params, params))
- if OmegaConf.is_config(collator_params):
- collator_params = OmegaConf.to_object(collator_params)
+ collator_target = utils.fallback_params("collator", task_params, params, None)
if "sampler" in dataset_params.keys():
- sampler_params: Dict[str, Any] = dataset_params["sampler"]
+ sampler_params: dict[str, Any] = dataset_params["sampler"]
if not utils.check_substrings_in_string(model_type, "siamese"):
new_classes, class_matrix, new_colours, class_dist = get_data_specs(
@@ -778,7 +741,8 @@ def make_loaders(
dataset_params,
sampler_params,
dataloader_params,
- collator_params,
+ collator_target,
+ change_detection=change_detection,
elim=elim,
)
@@ -791,12 +755,13 @@ def make_loaders(
dataset_params,
sampler_params,
dataloader_params,
- collator_params,
+ collator_target,
class_matrix,
batch_size,
model_type,
elim=elim,
sample_pairs=sample_pairs,
+ change_detection=change_detection,
cache=cache,
)
@@ -806,7 +771,7 @@ def make_loaders(
loaders = {}
for mode in dataset_params.keys():
- mode_sampler_params: Dict[str, Any] = dataset_params[mode]["sampler"]
+ mode_sampler_params: dict[str, Any] = dataset_params[mode]["sampler"]
if (
not utils.check_substrings_in_string(model_type, "siamese")
@@ -821,7 +786,8 @@ def make_loaders(
dataset_params[mode],
mode_sampler_params,
dataloader_params,
- collator_params,
+ collator_target,
+ change_detection=change_detection,
elim=elim,
)
@@ -835,12 +801,13 @@ def make_loaders(
dataset_params[mode],
mode_sampler_params,
dataloader_params,
- collator_params,
+ collator_target,
class_matrix,
batch_size,
model_type,
elim=elim,
sample_pairs=sample_pairs if mode == "train" else False,
+ change_detection=change_detection,
cache=cache,
)
@@ -883,15 +850,16 @@ def make_loaders(
def get_data_specs(
- manifest_name: Union[str, Path],
- classes: Dict[int, str],
- cmap_dict: Dict[int, str],
- cache_dir: Optional[Union[str, Path]] = None,
- data_dir: Optional[Union[str, Path]] = None,
- dataset_params: Optional[Dict[str, Any]] = None,
- sampler_params: Optional[Dict[str, Any]] = None,
- dataloader_params: Optional[Dict[str, Any]] = None,
- collator_params: Optional[Dict[str, Any]] = None,
+ manifest_name: str | Path,
+ classes: dict[int, str],
+ cmap_dict: dict[int, str],
+ cache_dir: Optional[str | Path] = None,
+ data_dir: Optional[str | Path] = None,
+ dataset_params: Optional[dict[str, Any]] = None,
+ sampler_params: Optional[dict[str, Any]] = None,
+ dataloader_params: Optional[dict[str, Any]] = None,
+ collator_target: str = "torchgeo.datasets.stack_samples",
+ change_detection: bool = False,
elim: bool = True,
):
# Load manifest from cache for this dataset.
@@ -901,7 +869,8 @@ def get_data_specs(
dataset_params,
sampler_params,
dataloader_params,
- collator_params=collator_params,
+ collator_target=collator_target,
+ change_detection=change_detection,
)
class_dist = utils.modes_from_manifest(manifest, classes)
@@ -923,12 +892,13 @@ def get_data_specs(
def get_manifest(
- manifest_path: Union[str, Path],
- data_dir: Optional[Union[str, Path]] = None,
- dataset_params: Optional[Dict[str, Any]] = None,
- sampler_params: Optional[Dict[str, Any]] = None,
- loader_params: Optional[Dict[str, Any]] = None,
- collator_params: Optional[Dict[str, Any]] = None,
+ manifest_path: str | Path,
+ data_dir: Optional[str | Path] = None,
+ dataset_params: Optional[dict[str, Any]] = None,
+ sampler_params: Optional[dict[str, Any]] = None,
+ loader_params: Optional[dict[str, Any]] = None,
+ collator_target: str = "torchgeo.datasets.stack_samples",
+ change_detection: bool = False,
) -> DataFrame:
"""Attempts to return the :class:`~pandas.DataFrame` located at ``manifest_path``.
@@ -948,6 +918,8 @@ def get_manifest(
manifest_path (str | ~pathlib.Path): Path (including filename and extension) to the manifest
saved as a ``csv``.
task_name (str): Optional; Name of the task to which the dataset to create a manifest of belongs to.
+ change_detection (bool): Flag for if using a change detection dataset which has
+ ``"image1"`` and ``"image2"`` keys rather than ``"image"``.
Returns:
~pandas.DataFrame: Manifest either loaded from ``manifest_path`` or created from parameters in :data:`CONFIG`.
@@ -969,7 +941,8 @@ def get_manifest(
dataset_params,
sampler_params,
loader_params,
- collator_params=collator_params,
+ collator_target=collator_target,
+ change_detection=change_detection,
)
print(f"MANIFEST TO FILE -----> {manifest_path}")
@@ -983,11 +956,12 @@ def get_manifest(
def make_manifest(
- data_dir: Union[str, Path],
- dataset_params: Dict[str, Any],
- sampler_params: Dict[str, Any],
- loader_params: Dict[str, Any],
- collator_params: Optional[Dict[str, Any]] = None,
+ data_dir: str | Path,
+ dataset_params: dict[str, Any],
+ sampler_params: dict[str, Any],
+ loader_params: dict[str, Any],
+ collator_target: str = "torchgeo.datasets.stack_samples",
+ change_detection: bool = False,
) -> DataFrame:
"""Constructs a manifest of the dataset detailing each sample therein.
@@ -996,12 +970,14 @@ def make_manifest(
Args:
mf_config (dict[~typing.Any, ~typing.Any]): Config to use to construct the manifest with.
task_name (str): Optional; Name of the task to which the dataset to create a manifest of belongs to.
+ change_detection (bool): Flag for if using a change detection dataset which has
+ ``"image1"`` and ``"image2"`` keys rather than ``"image"``.
Returns:
~pandas.DataFrame: The completed manifest as a :class:`~pandas.DataFrame`.
"""
- def delete_transforms(params: Dict[str, Any]) -> None:
+ def delete_transforms(params: dict[str, Any]) -> None:
assert target_key is not None
if params[target_key] is None:
return
@@ -1025,24 +1001,22 @@ def delete_transforms(params: Dict[str, Any]) -> None:
if OmegaConf.is_config(_dataset_params):
_dataset_params = OmegaConf.to_object(_dataset_params) # type: ignore[assignment]
- if _sampler_params["name"] in (
+ if _sampler_params["_target_"].rpartition(".")[2] in (
"RandomGeoSampler",
"RandomPairGeoSampler",
"RandomBatchGeoSampler",
"RandomPairBatchGeoSampler",
):
- _sampler_params["module"] = "torchgeo.samplers"
- _sampler_params["name"] = "GridGeoSampler"
- _sampler_params["params"]["stride"] = [
- 0.9 * x for x in _to_tuple(_sampler_params["params"]["size"])
+ _sampler_params["_target_"] = "torchgeo.samplers.GridGeoSampler"
+ _sampler_params["stride"] = [
+ 0.9 * x for x in _to_tuple(_sampler_params["size"])
]
- if _sampler_params["name"] == "RandomSampler":
- _sampler_params["name"] = "SequentialSampler"
- _sampler_params["params"] = {}
+ if _sampler_params["_target_"].rpartition(".")[2] == "RandomSampler":
+ _sampler_params = {"_target_": "torch.utils.data.sampler.SequentialSampler"}
- if "length" in _sampler_params["params"]:
- del _sampler_params["params"]["length"]
+ if "length" in _sampler_params:
+ del _sampler_params["length"]
# Ensure there are no errant `ClassTransform` transforms in the parameters from previous runs.
# A `ClassTransform` can only be defined with a correct manifest so we cannot use an old one to
@@ -1063,7 +1037,8 @@ def delete_transforms(params: Dict[str, Any]) -> None:
_sampler_params,
loader_params,
batch_size=1, # To prevent issues with stacking different sized patches, set batch size to 1.
- collator_params=collator_params,
+ collator_target=collator_target,
+ change_detection=change_detection,
cache=False,
)
diff --git a/minerva/datasets/multispectral.py b/minerva/datasets/multispectral.py
index 9ece53eb2..de2749845 100644
--- a/minerva/datasets/multispectral.py
+++ b/minerva/datasets/multispectral.py
@@ -56,14 +56,14 @@
class MultiSpectralDataset(VisionDataset, MinervaNonGeoDataset):
"""Generic dataset class for multi-spectral images that works within :mod:`torchgeo`"""
- all_bands: List[str] = []
- rgb_bands: List[str] = []
+ all_bands: tuple[str, ...] = ()
+ rgb_bands: tuple[str, ...] = ()
def __init__(
self,
root: str,
transforms: Optional[Callable[..., Any]] = None,
- bands: Optional[List[str]] = None,
+ bands: Optional[tuple[str, ...]] = None,
as_type=np.float32,
) -> None:
super().__init__(root, transform=transforms, target_transform=None)
diff --git a/minerva/datasets/paired.py b/minerva/datasets/paired.py
index 278069561..5b2b4e53f 100644
--- a/minerva/datasets/paired.py
+++ b/minerva/datasets/paired.py
@@ -44,8 +44,9 @@
# =====================================================================================================================
import random
from inspect import signature
-from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union, overload
+from typing import Any, Callable, Optional, Sequence, Union, overload
+import hydra
import matplotlib.pyplot as plt
import torch
from matplotlib.figure import Figure
@@ -82,7 +83,7 @@ class PairedGeoDataset(RasterDataset):
def __new__( # type: ignore[misc]
cls,
- dataset: Union[Callable[..., GeoDataset], GeoDataset],
+ dataset: Callable[..., GeoDataset] | GeoDataset,
*args,
**kwargs,
) -> Union["PairedGeoDataset", "PairedUnionDataset"]:
@@ -106,9 +107,12 @@ def __init__(
self, dataset: GeoDataset, *args, **kwargs
) -> None: ... # pragma: no cover
+ @overload
+ def __init__(self, dataset: str, *args, **kwargs) -> None: ... # pragma: no cover
+
def __init__(
self,
- dataset: Union[Callable[..., GeoDataset], GeoDataset],
+ dataset: Callable[..., GeoDataset] | GeoDataset | str,
*args,
**kwargs,
) -> None:
@@ -116,6 +120,9 @@ def __init__(
self._args = args
self._kwargs = kwargs
+ if isinstance(dataset, str):
+ dataset = hydra.utils.get_method(dataset)
+
if isinstance(dataset, GeoDataset):
self.dataset = dataset
self._res = dataset.res
@@ -141,8 +148,8 @@ def __init__(
)
def __getitem__( # type: ignore[override]
- self, queries: Tuple[BoundingBox, BoundingBox]
- ) -> Tuple[Dict[str, Any], ...]:
+ self, queries: tuple[BoundingBox, BoundingBox]
+ ) -> tuple[dict[str, Any], ...]:
return self.dataset.__getitem__(queries[0]), self.dataset.__getitem__(
queries[1]
)
@@ -170,11 +177,11 @@ def __and__(self, other: "PairedGeoDataset") -> IntersectionDataset: # type: ig
self, other, collate_fn=utils.pair_collate(concat_samples)
)
- def __or__(self, other: "PairedGeoDataset") -> "PairedUnionDataset": # type: ignore[override]
- """Take the union of two :class:`PairedGeoDataset`.
+ def __or__(self, other: GeoDataset) -> "PairedUnionDataset": # type: ignore[override]
+ """Take the union of two :class:~`torchgeo.datasets.GeoDataset`.
Args:
- other (PairedGeoDataset): Another dataset.
+ other (~torchgeo.datasets.GeoDataset): Another dataset.
Returns:
PairedUnionDataset: A single dataset.
@@ -183,20 +190,42 @@ def __or__(self, other: "PairedGeoDataset") -> "PairedUnionDataset": # type: ig
"""
return PairedUnionDataset(self, other)
- def __getattr__(self, item):
- if item in self.dataset.__dict__:
- return getattr(self.dataset, item) # pragma: no cover
- elif item in self.__dict__:
- return getattr(self, item)
- else:
- raise AttributeError
+ def __getattr__(self, name: str) -> Any:
+ """
+ Called only if the attribute 'name' is not found by usual means.
+ Checks if 'name' exists in the dataset attribute.
+ """
+ # Instead of calling __getattr__ directly, access the dataset attribute directly
+ dataset = super().__getattribute__("dataset")
+ if hasattr(dataset, name):
+ return getattr(dataset, name)
+ # If not found in dataset, raise an AttributeError
+ raise AttributeError(f"{name} cannot be found in self or dataset")
+
+ def __getattribute__(self, name: str) -> Any:
+ """
+ Overrides default attribute access method to prevent recursion.
+ Checks 'self' first and uses __getattr__ for fallback.
+ """
+ try:
+ # First, try to get the attribute from the current instance
+ return super().__getattribute__(name)
+ except AttributeError:
+ # If not found in self, __getattr__ will check in the dataset
+ dataset = super().__getattribute__("dataset")
+ if hasattr(dataset, name):
+ return getattr(dataset, name)
+ # Raise AttributeError if not found in dataset either
+ raise AttributeError(
+ f"'{self.__class__.__name__}' object has no attribute '{name}'"
+ )
def __repr__(self) -> Any:
return self.dataset.__repr__()
@staticmethod
def plot(
- sample: Dict[str, Any],
+ sample: dict[str, Any],
show_titles: bool = True,
suptitle: Optional[str] = None,
) -> Figure:
@@ -235,7 +264,7 @@ def plot(
def plot_random_sample(
self,
- size: Union[Tuple[int, int], int],
+ size: tuple[int, int] | int,
res: float,
show_titles: bool = True,
suptitle: Optional[str] = None,
@@ -277,6 +306,13 @@ def __init__(
] = merge_samples,
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
) -> None:
+
+ # Extract the actual dataset out of the paired dataset (otherwise we'll be pairing the datasets twice!)
+ if isinstance(dataset1, PairedGeoDataset):
+ dataset1 = dataset1.dataset
+ if isinstance(dataset2, PairedGeoDataset):
+ dataset2 = dataset2.dataset
+
super().__init__(dataset1, dataset2, collate_fn, transforms)
new_datasets = []
@@ -291,8 +327,8 @@ def __init__(
self.datasets = new_datasets
def __getitem__( # type: ignore[override]
- self, query: Tuple[BoundingBox, BoundingBox]
- ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
+ self, query: tuple[BoundingBox, BoundingBox]
+ ) -> tuple[dict[str, Any], dict[str, Any]]:
"""Retrieve image and metadata indexed by query.
Uses :meth:`torchgeo.datasets.UnionDataset.__getitem__` to send each query of the pair off to get a
@@ -320,6 +356,16 @@ def __or__(self, other: "PairedGeoDataset") -> "PairedUnionDataset": # type: ig
"""
return PairedUnionDataset(self, other)
+ def __getattr__(self, name: str) -> Any:
+ if name in self.__dict__:
+ return getattr(self, name)
+ elif name in self.datasets[0].__dict__:
+ return getattr(self.datasets[0], name) # pragma: no cover
+ elif name in self.datasets[1].__dict__:
+ return getattr(self.datasets[1], name) # pragma: no cover
+ else:
+ raise AttributeError
+
class PairedNonGeoDataset(NonGeoDataset):
"""Custom dataset to act as a wrapper to other datasets to handle paired sampling.
@@ -355,7 +401,7 @@ def __getnewargs__(self):
def __init__(
self,
dataset: Callable[..., NonGeoDataset],
- size: Union[Tuple[int, int], int],
+ size: tuple[int, int] | int,
max_r: int,
season: bool = False,
*args,
@@ -366,7 +412,18 @@ def __init__(
def __init__(
self,
dataset: NonGeoDataset,
- size: Union[Tuple[int, int], int],
+ size: tuple[int, int] | int,
+ max_r: int,
+ season: bool = False,
+ *args,
+ **kwargs,
+ ) -> None: ... # pragma: no cover
+
+ @overload
+ def __init__(
+ self,
+ dataset: str,
+ size: tuple[int, int] | int,
max_r: int,
season: bool = False,
*args,
@@ -375,8 +432,8 @@ def __init__(
def __init__(
self,
- dataset: Union[Callable[..., NonGeoDataset], NonGeoDataset],
- size: Union[Tuple[int, int], int],
+ dataset: Callable[..., NonGeoDataset] | NonGeoDataset | str,
+ size: tuple[int, int] | int,
max_r: int,
season: bool = False,
*args,
@@ -396,6 +453,9 @@ def __init__(
if isinstance(dataset, PairedNonGeoDataset):
raise ValueError("Cannot pair an already paired dataset!")
+ if isinstance(dataset, str):
+ dataset = hydra.utils.get_method(dataset)
+
if isinstance(dataset, NonGeoDataset):
self.dataset = dataset
@@ -420,7 +480,7 @@ def __init__(
self.make_geo_pair = SamplePair(self.size, self.max_r, season=season)
- def __getitem__(self, index: int) -> Tuple[Dict[str, Any], ...]: # type: ignore[override]
+ def __getitem__(self, index: int) -> tuple[dict[str, Any], ...]: # type: ignore[override]
patch = self.dataset[index]
image_a, image_b = self.make_geo_pair(patch["image"])
@@ -450,9 +510,39 @@ def __len__(self) -> int:
def __repr__(self) -> Any:
return self.dataset.__repr__()
+ def __getattr__(self, name: str) -> Any:
+ """
+ Called only if the attribute 'name' is not found by usual means.
+ Checks if 'name' exists in the dataset attribute.
+ """
+ # Instead of calling __getattr__ directly, access the dataset attribute directly
+ dataset = super().__getattribute__("dataset")
+ if hasattr(dataset, name):
+ return getattr(dataset, name)
+ # If not found in dataset, raise an AttributeError
+ raise AttributeError(f"{name} cannot be found in self or dataset")
+
+ def __getattribute__(self, name: str) -> Any:
+ """
+ Overrides default attribute access method to prevent recursion.
+ Checks 'self' first and uses __getattr__ for fallback.
+ """
+ try:
+ # First, try to get the attribute from the current instance
+ return super().__getattribute__(name)
+ except AttributeError:
+ # If not found in self, __getattr__ will check in the dataset
+ dataset = super().__getattribute__("dataset")
+ if hasattr(dataset, name):
+ return getattr(dataset, name)
+ # Raise AttributeError if not found in dataset either
+ raise AttributeError(
+ f"'{self.__class__.__name__}' object has no attribute '{name}'"
+ )
+
@staticmethod
def plot(
- sample: Dict[str, Any],
+ sample: dict[str, Any],
show_titles: bool = True,
suptitle: Optional[str] = None,
) -> Figure:
@@ -546,7 +636,7 @@ def __init__(
super().__init__(datasets)
- def __getitem__(self, index: int) -> Tuple[Dict[str, Any], Dict[str, Any]]:
+ def __getitem__(self, index: int) -> tuple[dict[str, Any], dict[str, Any]]:
"""Retrieve image and metadata indexed by query.
Uses :meth:`torch.utils.data.ConcatDataset.__getitem__` to get the pair of samples from
@@ -604,7 +694,7 @@ def __init__(self, size: int = 64, max_r: int = 20, season: bool = False) -> Non
# Transform to cut samples out at the desired output size.
self.random_crop = RandomCrop(self.size)
- def __call__(self, x: Tensor) -> Tuple[Tensor, Tensor]:
+ def __call__(self, x: Tensor) -> tuple[Tensor, Tensor]:
if self.season:
max_width, h, w = self._find_max_width(x[0])
@@ -631,7 +721,7 @@ def __call__(self, x: Tensor) -> Tuple[Tensor, Tensor]:
# Now cut out 2 random samples from within that sampling area and return.
return self.random_crop(sampling_area), self.random_crop(sampling_area)
- def _find_max_width(self, x: Tensor) -> Tuple[int, int, int]:
+ def _find_max_width(self, x: Tensor) -> tuple[int, int, int]:
max_width = self.max_width
w = x.shape[-1]
@@ -647,7 +737,7 @@ def _find_max_width(self, x: Tensor) -> Tuple[int, int, int]:
return max_width, h, w
@staticmethod
- def _get_random_crop_params(img: Tensor, max_width: int) -> Tuple[int, int]:
+ def _get_random_crop_params(img: Tensor, max_width: int) -> tuple[int, int]:
i = torch.randint(0, img.shape[-1] - max_width + 1, size=(1,)).item()
j = torch.randint(0, img.shape[-2] - max_width + 1, size=(1,)).item()
diff --git a/minerva/datasets/ssl4eos12.py b/minerva/datasets/ssl4eos12.py
index 74849618c..3e6a447be 100644
--- a/minerva/datasets/ssl4eos12.py
+++ b/minerva/datasets/ssl4eos12.py
@@ -32,7 +32,7 @@
# =====================================================================================================================
import os
import pickle
-from typing import Dict, List, Optional, Tuple, Union
+from typing import Optional
import cv2
import lmdb
@@ -69,7 +69,7 @@ class GeoSSL4EOS12Sentinel2(Sentinel2):
filename_glob = "{}.*"
filename_regex = r"""(?PB[^[0-1]?[0-9]|B[^[1]?[0-9][\dA])\..*$"""
date_format = ""
- all_bands = [
+ all_bands: tuple[str, ...] = (
"B1",
"B2",
"B3",
@@ -82,12 +82,12 @@ class GeoSSL4EOS12Sentinel2(Sentinel2):
"B9",
"B11",
"B12",
- ]
- rgb_bands = ["B4", "B3", "B2"]
+ )
+ rgb_bands: tuple[str, str, str] = ("B4", "B3", "B2")
class NonGeoSSL4EOS12Sentinel2(MultiSpectralDataset):
- all_bands = [
+ all_bands: tuple[str, ...] = (
"B1",
"B2",
"B3",
@@ -100,8 +100,8 @@ class NonGeoSSL4EOS12Sentinel2(MultiSpectralDataset):
"B9",
"B11",
"B12",
- ]
- rgb_bands = ["B4", "B3", "B2"]
+ )
+ rgb_bands: tuple[str, str, str] = ("B4", "B3", "B2")
class MinervaSSL4EO(VisionDataset, MinervaNonGeoDataset):
@@ -110,7 +110,7 @@ class MinervaSSL4EO(VisionDataset, MinervaNonGeoDataset):
Source: https://github.com/zhu-xlab/SSL4EO-S12/tree/main/src/benchmark/pretrain_ssl/datasets/SSL4EO
"""
- ALL_BANDS_S2_L2A = [
+ ALL_BANDS_S2_L2A: tuple[str, ...] = (
"B1",
"B2",
"B3",
@@ -123,8 +123,8 @@ class MinervaSSL4EO(VisionDataset, MinervaNonGeoDataset):
"B9",
"B11",
"B12",
- ]
- ALL_BANDS_S2_L1C = [
+ )
+ ALL_BANDS_S2_L1C: tuple[str, ...] = (
"B1",
"B2",
"B3",
@@ -138,9 +138,9 @@ class MinervaSSL4EO(VisionDataset, MinervaNonGeoDataset):
"B10",
"B11",
"B12",
- ]
- RGB_BANDS = ["B4", "B3", "B2"]
- ALL_BANDS_S1_GRD = ["VV", "VH"]
+ )
+ RGB_BANDS: tuple[str, str, str] = ("B4", "B3", "B2")
+ ALL_BANDS_S1_GRD: tuple[str, str] = ("VV", "VH")
# Band statistics: mean & std
# Calculated from 50k data
@@ -215,7 +215,7 @@ def __init__(
lmdb_file: Optional[str] = None,
normalize: bool = False,
mode: str = "s2a",
- bands: Optional[List[str]] = None,
+ bands: Optional[tuple[str, ...]] = None,
dtype: str = "uint8",
is_slurm_job=False,
transforms=None,
@@ -263,7 +263,7 @@ def _init_db(self):
with self.env.begin(write=False) as txn: # type: ignore[unreachable]
self.length = txn.stat()["entries"]
- def __getitem__(self, index: int) -> Dict[str, Union[Tuple[Tensor, ...], Tensor]]:
+ def __getitem__(self, index: int) -> dict[str, tuple[Tensor, ...] | Tensor]:
if self.lmdb_file:
if self.is_slurm_job:
# Delay loading LMDB data until after initialization
@@ -363,7 +363,9 @@ def __getitem__(self, index: int) -> Dict[str, Union[Tuple[Tensor, ...], Tensor]
return {"image": img_4s}
- def get_array(self, patch_id: str, mode: str, bands: Optional[List[str]] = None):
+ def get_array(
+ self, patch_id: str, mode: str, bands: Optional[tuple[str, ...]] = None
+ ):
data_root_patch = os.path.join(self.root, mode, patch_id)
patch_seasons = os.listdir(data_root_patch)
seasons = []
diff --git a/minerva/datasets/utils.py b/minerva/datasets/utils.py
index a2f4c86e0..c64f5e843 100644
--- a/minerva/datasets/utils.py
+++ b/minerva/datasets/utils.py
@@ -46,18 +46,7 @@
# =====================================================================================================================
import pickle
from pathlib import Path
-from typing import (
- Any,
- Callable,
- Dict,
- Iterable,
- List,
- Literal,
- Optional,
- Sequence,
- Tuple,
- Union,
-)
+from typing import Any, Callable, Iterable, Literal, Optional, Sequence, Union
import numpy as np
from nptyping import NDArray
@@ -113,6 +102,16 @@ def __or__(
"""
return MinervaConcatDataset([self, other])
+ def __getattr__(self, name: str) -> Any:
+ if name in self.__dict__:
+ return getattr(self, name)
+ elif name in self.datasets[0].__dict__:
+ return getattr(self.datasets[0], name) # pragma: no cover
+ elif name in self.datasets[1].__dict__:
+ return getattr(self.datasets[1], name) # pragma: no cover
+ else:
+ raise AttributeError
+
# =====================================================================================================================
# METHODS
@@ -129,7 +128,7 @@ def intersect_datasets(datasets: Sequence[GeoDataset]) -> IntersectionDataset:
~torchgeo.datasets.IntersectionDataset: Final dataset object representing an intersection
of all the parsed datasets.
"""
- master_dataset: Union[GeoDataset, IntersectionDataset] = datasets[0]
+ master_dataset: GeoDataset | IntersectionDataset = datasets[0]
for i in range(len(datasets) - 1):
master_dataset = master_dataset & datasets[i + 1]
@@ -140,7 +139,7 @@ def intersect_datasets(datasets: Sequence[GeoDataset]) -> IntersectionDataset:
def unionise_datasets(
datasets: Sequence[GeoDataset],
- transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,
+ transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
) -> UnionDataset:
"""Unionises a list of :class:`~torchgeo.datasets.GeoDataset` together to return a single dataset object.
@@ -156,7 +155,7 @@ def unionise_datasets(
Returns:
~torchgeo.datasets.UnionDataset: Final dataset object representing an union of all the parsed datasets.
"""
- master_dataset: Union[GeoDataset, UnionDataset] = datasets[0]
+ master_dataset: GeoDataset | UnionDataset = datasets[0]
for i in range(len(datasets) - 1):
master_dataset = master_dataset | datasets[i + 1]
@@ -168,7 +167,7 @@ def unionise_datasets(
def concatenate_datasets(
datasets: Sequence[NonGeoDataset],
- transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,
+ transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
) -> MinervaConcatDataset:
"""Unionises a list of :class:`~torchgeo.datasets.GeoDataset` together to return a single dataset object.
@@ -185,20 +184,18 @@ def concatenate_datasets(
~minerva.datasets.MinervaConcatDataset: Final dataset object representing an union of all the parsed datasets.
"""
assert isinstance(datasets[0], MinervaNonGeoDataset)
- master_dataset: Union[MinervaNonGeoDataset, MinervaConcatDataset] = datasets[0]
+ master_dataset: MinervaNonGeoDataset | MinervaConcatDataset = datasets[0]
for i in range(len(datasets) - 1):
master_dataset = master_dataset | datasets[i + 1] # type: ignore[operator]
if hasattr(master_dataset, "transforms"):
- master_dataset.transforms = transforms
+ master_dataset.transforms = transforms # type: ignore[union-attr]
assert isinstance(master_dataset, MinervaConcatDataset)
return master_dataset
-def make_bounding_box(
- roi: Union[Sequence[float], bool] = False
-) -> Optional[BoundingBox]:
+def make_bounding_box(roi: Sequence[float] | bool = False) -> Optional[BoundingBox]:
"""Construct a :class:`~torchgeo.datasets.utils.BoundingBox` object from the corners of the box.
``False`` for no :class:`~torchgeo.datasets.utils.BoundingBox`.
@@ -236,7 +233,7 @@ def load_all_samples(
~numpy.ndarray: 2D array of the class modes within every sample defined by the parsed
:class:`~torch.utils.data.DataLoader`.
"""
- sample_modes: List[List[Tuple[int, int]]] = []
+ sample_modes: list[list[tuple[int, int]]] = []
for sample in tqdm(dataloader):
modes = utils.find_modes(sample[target_key])
sample_modes.append(modes)
@@ -245,8 +242,8 @@ def load_all_samples(
def get_random_sample(
- dataset: GeoDataset, size: Union[Tuple[int, int], int], res: float
-) -> Dict[str, Any]:
+ dataset: GeoDataset, size: tuple[int, int] | int, res: float
+) -> dict[str, Any]:
"""Gets a random sample from the provided dataset of size ``size`` and at ``res`` resolution.
Args:
@@ -262,7 +259,7 @@ def get_random_sample(
def load_dataset_from_cache(
cached_dataset_path: Path,
-) -> Union[NonGeoDataset, GeoDataset]:
+) -> NonGeoDataset | GeoDataset:
"""Load a pickled dataset object in from a cache.
Args:
@@ -280,7 +277,7 @@ def load_dataset_from_cache(
def cache_dataset(
- dataset: Union[GeoDataset, NonGeoDataset], cached_dataset_path: Path
+ dataset: GeoDataset | NonGeoDataset, cached_dataset_path: Path
) -> None:
"""Pickle and cache a dataset object.
@@ -295,7 +292,7 @@ def cache_dataset(
pickle.dump(dataset, fp)
-def masks_or_labels(dataset_params: Dict[str, Any]) -> str:
+def masks_or_labels(dataset_params: dict[str, Any]) -> str:
for key in dataset_params.keys():
if key not in (
"sampler",
diff --git a/minerva/inbuilt_cfgs/example_3rd_party.yaml b/minerva/inbuilt_cfgs/example_3rd_party.yaml
index 4db007b28..1dcd35933 100644
--- a/minerva/inbuilt_cfgs/example_3rd_party.yaml
+++ b/minerva/inbuilt_cfgs/example_3rd_party.yaml
@@ -41,12 +41,11 @@ optim_func: SGD # Name of the optimiser function.
# ---+ Model Parameters +------------------------------------------------------
model_params:
- module: lightly.models
- params:
- input_size: ${input_size}
- n_classes: ${n_classes}
- num_classes: ${n_classes}
- name: resnet-9
+ _target_: lightly.models.resnet.ResNetGenerator
+ input_size: ${input_size}
+ n_classes: ${n_classes}
+ num_classes: ${n_classes}
+ name: resnet-9
# ---+ Optimiser Parameters +--------------------------------------------------
optimiser:
@@ -68,24 +67,16 @@ wandb_log: true # Activates wandb logging.
project: pytest # Define the project name for wandb.
wandb_dir: /test/tmp/wandb # Directory to store wandb logs locally.
-# ---+ Minerva Inbuilt Logging Functions +-------------------------------------
-# task_logger: SupervisedTaskLogger
-# step_logger: SupervisedGeoStepLogger
-# model_io: sup_tg
-
record_int: true # Store integer results in memory.
record_float: true # Store floating point results too. Beware memory overload!
# ---+ Collator +--------------------------------------------------------------
-collator:
- module: torchgeo.datasets
- name: stack_samples
+collator: torchgeo.datasets.stack_samples
# === TASKS ===================================================================
tasks:
fit-train:
- name: StandardEpoch
- module: minerva.tasks
+ _target_: minerva.tasks.StandardEpoch
train: true
record_float: true
@@ -95,45 +86,36 @@ tasks:
# ---+ Dataset Parameters +--------------------------------------------
dataset_params:
sampler:
- module: torchgeo.samplers
- name: RandomGeoSampler
+ _target_: torchgeo.samplers.RandomGeoSampler
roi: false
- params:
- size: ${patch_size}
- length: 120
+ size: ${patch_size}
+ length: 120
image:
transforms:
ToRGB:
- module: minerva.transforms
-
- images_1:
- module: minerva.datasets.__testing
- name: TstImgDataset
- paths: NAIP
- params:
+ _target_: minerva.transforms.ToRGB
+ subdatasets:
+ images_1:
+ _target_: minerva.datasets.__testing.TstImgDataset
+ paths: NAIP
res: 1.0
- image2:
- module: minerva.datasets.__testing
- name: TstImgDataset
- paths: NAIP
- params:
+ image2:
+ _target_: minerva.datasets.__testing.TstImgDataset
+ paths: NAIP
res: 1.0
mask:
transforms:
SingleLabel:
- module: minerva.transforms
- module: minerva.datasets.__testing
- name: TstMaskDataset
+ _target_: minerva.transforms.SingleLabel
+ _target_: minerva.datasets.__testing.TstMaskDataset
paths: Chesapeake7
- params:
- res: 1.0
+ res: 1.0
fit-val:
- name: StandardEpoch
- module: minerva.tasks
+ _target_: minerva.tasks.StandardEpoch
train: false
record_float: true
@@ -143,36 +125,29 @@ tasks:
# ---+ Dataset Parameters +--------------------------------------------
dataset_params:
sampler:
- module: torchgeo.samplers
- name: RandomGeoSampler
+ _target_: torchgeo.samplers.RandomGeoSampler
roi: false
- params:
- size: ${patch_size}
- length: 32
+ size: ${patch_size}
+ length: 32
image:
transforms:
ToRGB:
- module: minerva.transforms
- module: minerva.datasets.__testing
- name: TstImgDataset
+ _target_: minerva.transforms.ToRGB
+ _target_: minerva.datasets.__testing.TstImgDataset
paths: NAIP
- params:
- res: 1.0
+ res: 1.0
mask:
transforms:
SingleLabel:
- module: minerva.transforms
- module: minerva.datasets.__testing
- name: TstMaskDataset
+ _target_: minerva.transforms.SingleLabel
+ _target_: minerva.datasets.__testing.TstMaskDataset
paths: Chesapeake7
- params:
- res: 1.0
+ res: 1.0
test-test:
- name: StandardEpoch
- module: minerva.tasks
+ _target_: minerva.tasks.StandardEpoch
train: false
record_float: true
@@ -182,32 +157,26 @@ tasks:
# ---+ Dataset Parameters +--------------------------------------------
dataset_params:
sampler:
- module: torchgeo.samplers
- name: RandomGeoSampler
+ _target_: torchgeo.samplers.RandomGeoSampler
roi: false
- params:
- size: ${patch_size}
- length: 32
+ size: ${patch_size}
+ length: 32
image:
transforms:
ToRGB:
- module: minerva.transforms
- module: minerva.datasets.__testing
- name: TstImgDataset
+ _target_: minerva.transforms.ToRGB
+ _target_: minerva.datasets.__testing.TstImgDataset
paths: NAIP
- params:
- res: 1.0
+ res: 1.0
mask:
transforms:
SingleLabel:
- module: minerva.transforms
- module: minerva.datasets.__testing
- name: TstMaskDataset
+ _target_: minerva.transforms.SingleLabel
+ _target_: minerva.datasets.__testing.TstMaskDataset
paths: Chesapeake7
- params:
- res: 1.0
+ res: 1.0
# === PLOTTING OPTIONS ========================================================
plots:
diff --git a/minerva/inbuilt_cfgs/example_CNN_config.yaml b/minerva/inbuilt_cfgs/example_CNN_config.yaml
index 3936e48e4..2a5daf2dc 100644
--- a/minerva/inbuilt_cfgs/example_CNN_config.yaml
+++ b/minerva/inbuilt_cfgs/example_CNN_config.yaml
@@ -24,7 +24,7 @@ model_type: scene-classifier
batch_size: 8 # Number of samples in each batch.
input_size: [4, 32, 32] # patch_size plus leading channel dim.
patch_size: '${to_patch_size: ${input_size}}' # 2D tuple or float.
-n_classes: &n_classes 8 # Number of classes in dataset.
+n_classes: 8 # Number of classes in dataset.
# ---+ Experiment Execution +--------------------------------------------------
max_epochs: 3 # Maximum number of training epochs.
@@ -41,10 +41,9 @@ optim_func: SGD # Name of the optimiser function.
# ---+ Model Parameters +------------------------------------------------------
model_params:
- module:
- params:
- input_size: ${input_size}
- n_classes: *n_classes
+ _target_: minerva.models.CNN
+ input_size: ${input_size}
+ n_classes: ${n_classes}
# ---+ Optimiser Parameters +--------------------------------------------------
optimiser:
@@ -67,22 +66,19 @@ wandb_dir: /test/tmp/wandb # Directory to store wandb logs locally.
# === MODEL IO & LOGGING ======================================================
# ---+ Minerva Inbuilt Logging Functions +-------------------------------------
-task_logger: SupervisedTaskLogger
-model_io: sup_tg
+task_logger: minerva.logger.tasklog.SupervisedTaskLogger
+model_io: minerva.modelio.supervised_torchgeo_io
record_int: true # Store integer results in memory.
record_float: false # Store floating point results too. Beware memory overload!
# ---+ Collator +--------------------------------------------------------------
-collator:
- module: torchgeo.datasets
- name: stack_samples
+collator: torchgeo.datasets.stack_samples
# === TASKS ===================================================================
tasks:
fit-train:
- name: StandardEpoch
- module: minerva.tasks
+ _target_: minerva.tasks.StandardEpoch
train: true
record_float: true
@@ -91,37 +87,30 @@ tasks:
dataset_params:
sampler:
- module: torchgeo.samplers
- name: RandomGeoSampler
+ _target_: torchgeo.samplers.RandomGeoSampler
roi: false
- params:
- size: ${patch_size}
- length: 320
+ size: ${patch_size}
+ length: 320
image:
transforms:
Normalise:
- module: minerva.transforms
+ _target_: minerva.transforms.Normalise
norm_value: 255
- module: minerva.datasets.__testing
- name: TstImgDataset
+ _target_: minerva.datasets.__testing.TstImgDataset
paths: NAIP
- params:
- res: 1.0
+ res: 1.0
mask:
transforms:
SingleLabel:
- module: minerva.transforms
+ _target_: minerva.transforms.SingleLabel
mode: modal
- module: minerva.datasets.__testing
- name: TstMaskDataset
+ _target_: minerva.datasets.__testing.TstMaskDataset
paths: Chesapeake7
- params:
- res: 1.0
+ res: 1.0
fit-val:
- name: StandardEpoch
- module: minerva.tasks
+ _target_: minerva.tasks.StandardEpoch
train: false
record_float: true
@@ -130,38 +119,31 @@ tasks:
dataset_params:
sampler:
- module: torchgeo.samplers
- name: RandomGeoSampler
+ _target_: torchgeo.samplers.RandomGeoSampler
roi: false
- params:
- size: ${patch_size}
- length: 80
+ size: ${patch_size}
+ length: 80
image:
transforms:
Normalise:
- module: minerva.transforms
+ _target_: minerva.transforms.Normalise
norm_value: 255
- module: minerva.datasets.__testing
- name: TstImgDataset
+ _target_: minerva.datasets.__testing.TstImgDataset
paths: NAIP
- params:
- res: 1.0
+ res: 1.0
mask:
transforms:
SingleLabel:
- module: minerva.transforms
+ _target_: minerva.transforms.SingleLabel
mode: modal
- module: minerva.datasets.__testing
- name: TstMaskDataset
+ _target_: minerva.datasets.__testing.TstMaskDataset
paths: Chesapeake7
- params:
- res: 1.0
+ res: 1.0
test-test:
- name: StandardEpoch
- module: minerva.tasks
+ _target_: minerva.tasks.StandardEpoch
train: false
record_float: true
@@ -170,34 +152,28 @@ tasks:
dataset_params:
sampler:
- module: torchgeo.samplers
- name: RandomGeoSampler
+ _target_: torchgeo.samplers.RandomGeoSampler
roi: false
- params:
- size: ${patch_size}
- length: 160
+ size: ${patch_size}
+ length: 160
image:
transforms:
Normalise:
- module: minerva.transforms
+ _target_: minerva.transforms.Normalise
norm_value: 255
- module: minerva.datasets.__testing
- name: TstImgDataset
+ _target_: minerva.datasets.__testing.TstImgDataset
paths: NAIP
- params:
- res: 1.0
+ res: 1.0
mask:
transforms:
SingleLabel:
- module: minerva.transforms
+ _target_: minerva.transforms.SingleLabel
mode: modal
- module: minerva.datasets.__testing
- name: TstMaskDataset
+ _target_: minerva.datasets.__testing.TstMaskDataset
paths: Chesapeake7
- params:
- res: 1.0
+ res: 1.0
# === PLOTTING OPTIONS ========================================================
plots:
diff --git a/minerva/inbuilt_cfgs/example_ChangeDetector.yaml b/minerva/inbuilt_cfgs/example_ChangeDetector.yaml
index 244b3a7e5..e060b3c97 100644
--- a/minerva/inbuilt_cfgs/example_ChangeDetector.yaml
+++ b/minerva/inbuilt_cfgs/example_ChangeDetector.yaml
@@ -34,6 +34,7 @@ pre_train: false # Activate pre-training mode.
fine_tune: false # Activate fine-tuning mode.
torch_compile: false # Wrap model in `torch.compile`.
sample_pairs: false
+change_detection: true
# ---+ Loss and Optimisers +---------------------------------------------------
loss_func: CrossEntropyLoss # Name of the loss function to use.
@@ -42,25 +43,24 @@ optim_func: Adam # Name of the optimiser function.
# ---+ Model Parameters +------------------------------------------------------
model_params:
- module:
- params:
+ _target_: minerva.models.ChangeDetector
+ input_size: ${input_size}
+ n_classes: ${n_classes}
+ encoder_on: true
+ filter_dim: -1
+ fc_dim: 512
+ freeze_backbone: false
+ backbone_args:
+ module: minerva.models
+ name: MinervaPSP
input_size: ${input_size}
- n_classes: ${n_classes}
- encoder_on: true
- filter_dim: -1
- fc_dim: 512
- freeze_backbone: false
- backbone_args:
- module: minerva.models
- name: MinervaPSP
- input_size: ${input_size}
- n_classes: 1
- encoder_name: resnet18
- encoder_weights:
- psp_out_channels: 512
- segmentation_on: false
- classification_on: false
- encoder: false
+ n_classes: 1
+ encoder_name: resnet18
+ encoder_weights:
+ psp_out_channels: 512
+ segmentation_on: false
+ classification_on: false
+ encoder: false
# ---+ Optimiser Parameters +--------------------------------------------------
optimiser:
@@ -83,20 +83,17 @@ wandb_dir: /test/tmp/wandb # Directory to store wandb logs locally.
# === MODEL IO & LOGGING ======================================================
# ---+ Minerva Inbuilt Logging Functions +-------------------------------------
-
+model_io: minerva.modelio.change_detection_io
record_int: true # Store integer results in memory.
record_float: false # Store floating point results too. Beware memory overload!
# ---+ Collator +--------------------------------------------------------------
-collator:
- module: torchgeo.datasets
- name: stack_samples
+collator: torchgeo.datasets.stack_samples
# === TASKS ===================================================================
tasks:
fit-train:
- name: StandardEpoch
- module: minerva.tasks
+ _target_: minerva.tasks.StandardEpoch
train: true
record_float: true
sample_pairs: false
@@ -106,51 +103,35 @@ tasks:
# ---+ Dataset Parameters +--------------------------------------------
dataset_params:
sampler:
- module: torch.utils.data
- name: RandomSampler
- params:
- num_samples: 32
+ _target_: torch.utils.data.RandomSampler
+ num_samples: 32
transforms:
RandomCrop:
- module: kornia.augmentation
+ _target_: kornia.augmentation.RandomCrop
size: ${patch_size}
keepdim: true
image:
transforms:
Normalise:
- module: minerva.transforms
+ _target_: minerva.transforms.Normalise
norm_value: 4095
- SelectChannels:
- module: minerva.transforms
- channels:
- - 1
- - 2
- - 3
- - 7
- - 14
- - 15
- - 16
- - 20
-
- module: torchgeo.datasets
- name: OSCD
- paths: OSCD
- params:
- split: train
- bands: all
- download: true
+
+ _target_: torchgeo.datasets.OSCD
+ root: OSCD
+ split: train
+ bands: [B02, B03, B04, B08]
+ download: true
mask:
transforms:
SingleLabel:
- module: minerva.transforms
+ _target_: minerva.transforms.SingleLabel
mode: centre
test-test:
- name: StandardEpoch
- module: minerva.tasks
+ _target_: minerva.tasks.StandardEpoch
train: false
record_float: true
sample_pairs: false
@@ -160,46 +141,31 @@ tasks:
# ---+ Dataset Parameters +--------------------------------------------
dataset_params:
sampler:
- module: torch.utils.data
- name: RandomSampler
- params:
- num_samples: 16
+ _target_: torch.utils.data.RandomSampler
+ num_samples: 16
transforms:
RandomCrop:
- module: kornia.augmentation
+ _target_: kornia.augmentation.RandomCrop
size: ${patch_size}
keepdim: true
image:
transforms:
Normalise:
- module: minerva.transforms
+ _target_: minerva.transforms.Normalise
norm_value: 4095
- SelectChannels:
- module: minerva.transforms
- channels:
- - 1
- - 2
- - 3
- - 7
- - 14
- - 15
- - 16
- - 20
-
- module: torchgeo.datasets
- name: OSCD
- paths: OSCD
- params:
- split: test
- bands: all
- download: true
+
+ _target_: torchgeo.datasets.OSCD
+ root: OSCD
+ split: test
+ bands: [B02, B03, B04, B08]
+ download: true
mask:
transforms:
SingleLabel:
- module: minerva.transforms
+ _target_: minerva.transforms.SingleLabel
mode: centre
# === PLOTTING OPTIONS ========================================================
diff --git a/minerva/inbuilt_cfgs/example_GSConvNet-II.yaml b/minerva/inbuilt_cfgs/example_GSConvNet-II.yaml
index d04bd52b1..9d2467358 100644
--- a/minerva/inbuilt_cfgs/example_GSConvNet-II.yaml
+++ b/minerva/inbuilt_cfgs/example_GSConvNet-II.yaml
@@ -49,10 +49,9 @@ val_freq: 1 # Validation epoch every ``val_freq`` training epochs.
# ---+ Model Parameters +------------------------------------------------------
model_params:
- module:
- params:
- input_size: ${input_size}
- encoder_weights: imagenet
+ _target_: minerva.models.SimConv
+ input_size: ${input_size}
+ encoder_weights: imagenet
# ---+ Optimiser Parameters +--------------------------------------------------
optimiser:
@@ -92,160 +91,140 @@ record_int: true # Store integer results in memory.
record_float: false # Store floating point results too. Beware memory overload!
# ---+ Collator +--------------------------------------------------------------
-collator:
- module: torchgeo.datasets
- name: stack_samples
+collator: torchgeo.datasets.stack_samples
# === TASKS ===================================================================
tasks:
fit-train:
- name: StandardEpoch
- module: minerva.tasks
+ _target_: minerva.tasks.StandardEpoch
train: true
record_float: true
sample_pairs: true
# ---+ Minerva Inbuilt Logging Functions +-------------------------
- task_logger: SSLTaskLogger
- model_io: ssl_pair_tg
+ task_logger: minerva.logger.tasklog.SSLTaskLogger
+ model_io: minerva.modelio.ssl_pair_torchgeo_io
# ---+ Dataset Parameters +----------------------------------------
dataset_params:
sampler:
- module: torch.utils.data
- name: RandomSampler
- params:
- num_samples: 32
+ _target_: torch.utils.data.RandomSampler
+ num_samples: 32
image:
transforms:
- Normalise:
- module: minerva.transforms
- norm_value: 10000
+ normalise:
+ _target_: minerva.transforms.Normalise
+ norm_value: 255
RandomApply:
p: 0.25
- DetachedColorJitter:
- module: minerva.transforms
+ jitter:
+ _target_: minerva.transforms.DetachedColorJitter
brightness: 0.2
contrast: 0.1
saturation: 0.1
hue: 0.15
- RandomResizedCrop:
- module: kornia.augmentation
+ resize_crop:
+ _target_: kornia.augmentation.RandomResizedCrop
p: 0.2
size: ${patch_size}
cropping_mode: resample
keepdim: true
- RandomHorizontalFlip:
- module: kornia.augmentation
+ horizontal_flip:
+ _target_: kornia.augmentation.RandomHorizontalFlip
p: 0.2
keepdim: true
- RandomGaussianBlur:
- module: kornia.augmentation
+ gaussian_blur:
+ _target_: kornia.augmentation.RandomGaussianBlur
kernel_size: 9
p: 0.2
sigma: [0.01, 0.2]
keepdim: true
- RandomGaussianNoise:
- module: kornia.augmentation
+ gaussian_noise:
+ _target_: kornia.augmentation.RandomGaussianNoise
p: 0.2
std: 0.05
keepdim: true
- RandomErasing:
- module: kornia.augmentation
+ random_erasing:
+ _target_: kornia.augmentation.RandomErasing
p: 0.2
keepdim: true
- module: minerva.datasets
- name: NonGeoSSL4EOS12Sentinel2
- paths: SSL4EO-S12
- params:
- bands: [B2, B3, B4, B8]
- size: ${patch_size}
- max_r: *max_r
+ _target_: minerva.datasets.MinervaSSL4EO
+ root: SSL4EO-S12
+ mode: s2a
+ bands: [B2, B3, B4, B8]
+ size: ${patch_size}
+ max_r: *max_r
+ season: true
+ season_transform: pair
fit-val:
- name: WeightedKNN
- module: minerva.tasks
+ _target_: minerva.tasks.WeightedKNN
train: false
record_float: true
sample_pairs: false
n_classes: 11
# ---+ Minerva Inbuilt Logging Functions +-------------------------
- task_logger: SSLTaskLogger
+ task_logger: minerva.logger.tasklog.SSLTaskLogger
step_logger:
- name: KNNStepLogger
- model_io: sup_tg
+ _target_: minerva.logger.steplog.KNNStepLogger
+ model_io: minerva.modelio.supervised_torchgeo_io
# ---+ Dataset Parameters +----------------------------------------
dataset_params:
features:
sampler:
- module: torch.utils.data
- name: RandomSampler
- params:
- num_samples: 16
+ _target_: torch.utils.data.RandomSampler
+ num_samples: 16
image:
- module: minerva.datasets
- name: DFC2020
- paths: DFC/DFC2020
- params:
- split: val
- use_s2hr: true
- labels: true
+ _target_: minerva.datasets.DFC2020
+ root: DFC/DFC2020
+ split: val
+ use_s2hr: true
+ labels: true
test:
sampler:
- module: torch.utils.data
- name: RandomSampler
- params:
- num_samples: 16
+ _target_: torch.utils.data.RandomSampler
+ num_samples: 16
image:
- module: minerva.datasets
- name: DFC2020
- paths: DFC/DFC2020
- params:
- split: test
- use_s2hr: true
- labels: true
+ _target_: minerva.datasets.DFC2020
+ root: DFC/DFC2020
+ split: test
+ use_s2hr: true
+ labels: true
test-test:
- name: StandardEpoch
- module: minerva.tasks
+ _target_: minerva.tasks.StandardEpoch
train: false
record_float: true
sample_pairs: false
# ---+ Minerva Inbuilt Logging Functions +-------------------------
- task_logger: SupervisedTaskLogger
- model_io: sup_tg
+ task_logger: minerva.logger.tasklog.SupervisedTaskLogger
+ model_io: minerva.modelio.supervised_torchgeo_io
# ---+ Dataset Parameters +----------------------------------------
dataset_params:
sampler:
- module: torchgeo.samplers
- name: RandomGeoSampler
+ _target_: torchgeo.samplers.RandomGeoSampler
roi: false
- params:
- size: ${patch_size}
- length: 16
+ size: ${patch_size}
+ length: 16
image:
- module: minerva.datasets.__testing
- name: TstImgDataset
+ _target_: minerva.datasets.__testing.TstImgDataset
paths: NAIP
- params:
- res: 1.0
+ res: 1.0
mask:
- module: minerva.datasets.__testing
- name: TstMaskDataset
+ _target_: minerva.datasets.__testing.TstMaskDataset
paths: Chesapeake7
- params:
- res: 1.0
+ res: 1.0
# === PLOTTING OPTIONS ========================================================
plots:
diff --git a/minerva/inbuilt_cfgs/example_GeoCLR_config.yaml b/minerva/inbuilt_cfgs/example_GeoCLR_config.yaml
index f474b26fe..3cabddad4 100644
--- a/minerva/inbuilt_cfgs/example_GeoCLR_config.yaml
+++ b/minerva/inbuilt_cfgs/example_GeoCLR_config.yaml
@@ -47,10 +47,9 @@ val_freq: 2 # Validation epoch every ``val_freq`` training epochs.
# ---+ Model Parameters +------------------------------------------------------
model_params:
- module:
- params:
- input_size: ${input_size}
- # any other params...
+ _target_: minerva.models.SimCLR18
+ input_size: ${input_size}
+ # any other params...
# ---+ Optimiser Parameters +--------------------------------------------------
optimiser:
@@ -77,15 +76,12 @@ record_int: true # Store integer results in memory.
record_float: true # Store floating point results too. Beware memory overload!
# ---+ Collator +--------------------------------------------------------------
-collator:
- module: torchgeo.datasets
- name: stack_samples
+collator: torchgeo.datasets.stack_samples
# === TASKS ===================================================================
tasks:
fit-train:
- name: StandardEpoch
- module: minerva.tasks
+ _target_: minerva.tasks.StandardEpoch
train: true
record_float: true
sample_pairs: true
@@ -93,77 +89,70 @@ tasks:
imagery_config: '${oc.create:${cfg_load: minerva/inbuilt_cfgs/dataset/NAIP.yaml}}' # yamllint disable-line rule:line-length
# ---+ Minerva Inbuilt Logging Functions +-------------------------
- task_logger: SSLTaskLogger
+ task_logger: minerva.logger.tasklog.SSLTaskLogger
step_logger:
- name: SSLStepLogger
- model_io: ssl_pair_tg
+ _target_: minerva.logger.steplog.SSLStepLogger
+ model_io: minerva.modelio.ssl_pair_torchgeo_io
# ---+ Dataset Parameters +----------------------------------------
dataset_params:
sampler:
- module: minerva.samplers
- name: RandomPairGeoSampler
+ _target_: minerva.samplers.RandomPairGeoSampler
roi: false
- params:
- size: ${patch_size}
- length: 120
- max_r: *max_r
+ size: ${patch_size}
+ length: 120
+ max_r: *max_r
image:
transforms:
- Normalise:
- module: minerva.transforms
+ normalise:
+ _target_: minerva.transforms.Normalise
norm_value: 255
RandomApply:
p: 0.25
- DetachedColorJitter:
- module: minerva.transforms
+ jitter:
+ _target_: minerva.transforms.DetachedColorJitter
brightness: 0.2
contrast: 0.1
saturation: 0.1
hue: 0.15
- RandomResizedCrop:
- module: kornia.augmentation
+ resize_crop:
+ _target_: kornia.augmentation.RandomResizedCrop
p: 0.2
size: ${patch_size}
cropping_mode: resample
keepdim: true
- RandomHorizontalFlip:
- module: kornia.augmentation
+ horizontal_flip:
+ _target_: kornia.augmentation.RandomHorizontalFlip
p: 0.2
keepdim: true
- RandomGaussianBlur:
- module: kornia.augmentation
+ gaussian_blur:
+ _target_: kornia.augmentation.RandomGaussianBlur
kernel_size: 9
p: 0.2
sigma: [0.01, 0.2]
keepdim: true
- RandomGaussianNoise:
- module: kornia.augmentation
+ gaussian_noise:
+ _target_: kornia.augmentation.RandomGaussianNoise
p: 0.2
std: 0.05
keepdim: true
- RandomErasing:
- module: kornia.augmentation
+ random_erasing:
+ _target_: kornia.augmentation.RandomErasing
p: 0.2
keepdim: true
-
- image1:
- module: minerva.datasets.__testing
- name: TstImgDataset
- paths: NAIP
- params:
+ subdatasets:
+ image1:
+ _target_: minerva.datasets.__testing.TstImgDataset
+ paths: NAIP
res: 1.0
- image2:
- module: minerva.datasets.__testing
- name: TstImgDataset
- paths: NAIP
- params:
+ image2:
+ _target_: minerva.datasets.__testing.TstImgDataset
+ paths: NAIP
res: 1.0
fit-val:
- name: WeightedKNN
- module: minerva.tasks
+ _target_: minerva.tasks.WeightedKNN
train: false
sample_pairs: false
n_classes: 8
@@ -172,62 +161,49 @@ tasks:
data_config: '${oc.create:${cfg_load: minerva/inbuilt_cfgs/dataset/Chesapeake7.yaml}}' # yamllint disable-line rule:line-length
# ---+ Minerva Inbuilt Logging Functions +-------------------------
- task_logger: SSLTaskLogger
+ task_logger: minerva.logger.tasklog.SSLTaskLogger
step_logger:
- name: KNNStepLogger
- model_io: ssl_pair_tg
+ _target_: minerva.logger.steplog.KNNStepLogger
+ model_io: minerva.modelio.ssl_pair_torchgeo_io
# ---+ Dataset Parameters +----------------------------------------
dataset_params:
features:
sampler:
- module: torchgeo.samplers
- name: RandomGeoSampler
+ _target_: torchgeo.samplers.RandomGeoSampler
roi: false
- params:
- size: ${patch_size}
- length: 32
+ size: ${patch_size}
+ length: 32
image:
- module: minerva.datasets.__testing
- name: TstImgDataset
+ _target_: minerva.datasets.__testing.TstImgDataset
paths: NAIP
- params:
- res: 1.0
+ res: 1.0
mask:
- module: minerva.datasets.__testing
- name: TstMaskDataset
+ _target_: minerva.datasets.__testing.TstMaskDataset
paths: Chesapeake7
- params:
- res: 1.0
+ res: 1.0
test:
sampler:
- module: torchgeo.samplers
- name: RandomGeoSampler
+ _target_: torchgeo.samplers.RandomGeoSampler
roi: false
- params:
- size: ${patch_size}
- length: 32
+ size: ${patch_size}
+ length: 32
image:
- module: minerva.datasets.__testing
- name: TstImgDataset
+ _target_: minerva.datasets.__testing.TstImgDataset
paths: NAIP
- params:
- res: 1.0
+ res: 1.0
mask:
- module: minerva.datasets.__testing
- name: TstMaskDataset
+ _target_: minerva.datasets.__testing.TstMaskDataset
paths: Chesapeake7
- params:
- res: 1.0
+ res: 1.0
test-test:
- name: StandardEpoch
- module: minerva.tasks
+ _target_: minerva.tasks.StandardEpoch
train: false
record_float: true
sample_pairs: false
@@ -237,32 +213,26 @@ tasks:
data_config: '${oc.create:${cfg_load: minerva/inbuilt_cfgs/dataset/Chesapeake7.yaml}}' # yamllint disable-line rule:line-length
# ---+ Minerva Inbuilt Logging Functions +-------------------------
- task_logger: SupervisedTaskLogger
- model_io: sup_tg
+ task_logger: minerva.logger.tasklog.SupervisedTaskLogger
+ model_io: minerva.modelio.supervised_torchgeo_io
# ---+ Dataset Parameters +----------------------------------------
dataset_params:
sampler:
- module: torchgeo.samplers
- name: RandomGeoSampler
+ _target_: torchgeo.samplers.RandomGeoSampler
roi: false
- params:
- size: ${patch_size}
- length: 32
+ size: ${patch_size}
+ length: 32
image:
- module: minerva.datasets.__testing
- name: TstImgDataset
+ _target_: minerva.datasets.__testing.TstImgDataset
paths: NAIP
- params:
- res: 1.0
+ res: 1.0
mask:
- module: minerva.datasets.__testing
- name: TstMaskDataset
+ _target_: minerva.datasets.__testing.TstMaskDataset
paths: Chesapeake7
- params:
- res: 1.0
+ res: 1.0
# === PLOTTING OPTIONS ========================================================
plots:
diff --git a/minerva/inbuilt_cfgs/example_GeoSimConvNet.yaml b/minerva/inbuilt_cfgs/example_GeoSimConvNet.yaml
index b9536e839..8305e6bab 100644
--- a/minerva/inbuilt_cfgs/example_GeoSimConvNet.yaml
+++ b/minerva/inbuilt_cfgs/example_GeoSimConvNet.yaml
@@ -49,9 +49,8 @@ val_freq: 1 # Validation epoch every ``val_freq`` training epochs.
# ---+ Model Parameters +------------------------------------------------------
model_params:
- module:
- params:
- input_size: ${input_size}
+ _target_: minerva.models.SimConv
+ input_size: ${input_size}
# ---+ Optimiser Parameters +--------------------------------------------------
optimiser:
@@ -59,12 +58,11 @@ optimiser:
lr: ${lr}
# ---+ Scheduler Parameters +--------------------------------------------------
-scheduler_params:
- name: LinearLR
- params:
- start_factor: *lr
- end_factor: 5.0E-4
- total_iters: 2
+scheduler:
+ _target_: torch.optim.lr_scheduler.LinearLR
+ start_factor: ${lr}
+ end_factor: 5.0E-4
+ total_iters: 2
# ---+ Loss Function Parameters +----------------------------------------------
loss_params:
@@ -84,81 +82,73 @@ record_int: true # Store integer results in memory.
record_float: false # Store floating point results too. Beware memory overload!
# ---+ Collator +--------------------------------------------------------------
-collator:
- module: torchgeo.datasets
- name: stack_samples
+collator: torchgeo.datasets.stack_samples
# === TASKS ===================================================================
tasks:
fit-train:
- name: StandardEpoch
- module: minerva.tasks
+ _target_: minerva.tasks.StandardEpoch
train: true
record_float: true
sample_pairs: true
# ---+ Minerva Inbuilt Logging Functions +-------------------------
- task_logger: SSLTaskLogger
- model_io: ssl_pair_tg
+ task_logger: minerva.logger.tasklog.SSLTaskLogger
+ model_io: minerva.modelio.ssl_pair_torchgeo_io
# ---+ Dataset Parameters +----------------------------------------
dataset_params:
sampler:
- module: torch.utils.data
- name: RandomSampler
- params:
- num_samples: 32
+ _target_: torch.utils.data.RandomSampler
+ num_samples: 32
image:
transforms:
- Normalise:
- module: minerva.transforms
- norm_value: 10000
+ normalise:
+ _target_: minerva.transforms.Normalise
+ norm_value: 255
RandomApply:
p: 0.25
- DetachedColorJitter:
- module: minerva.transforms
+ jitter:
+ _target_: minerva.transforms.DetachedColorJitter
brightness: 0.2
contrast: 0.1
saturation: 0.1
hue: 0.15
- RandomResizedCrop:
- module: kornia.augmentation
+ resize_crop:
+ _target_: kornia.augmentation.RandomResizedCrop
p: 0.2
size: ${patch_size}
cropping_mode: resample
keepdim: true
- RandomHorizontalFlip:
- module: kornia.augmentation
+ horizontal_flip:
+ _target_: kornia.augmentation.RandomHorizontalFlip
p: 0.2
keepdim: true
- RandomGaussianBlur:
- module: kornia.augmentation
+ gaussian_blur:
+ _target_: kornia.augmentation.RandomGaussianBlur
kernel_size: 9
p: 0.2
sigma: [0.01, 0.2]
keepdim: true
- RandomGaussianNoise:
- module: kornia.augmentation
+ gaussian_noise:
+ _target_: kornia.augmentation.RandomGaussianNoise
p: 0.2
std: 0.05
keepdim: true
- RandomErasing:
- module: kornia.augmentation
+ random_erasing:
+ _target_: kornia.augmentation.RandomErasing
p: 0.2
keepdim: true
- module: minerva.datasets
- name: NonGeoSSL4EOS12Sentinel2
- paths: SSL4EO-S12
- params:
- bands: [B2, B3, B4, B8]
- size: ${patch_size}
- max_r: *max_r
+ _target_: minerva.datasets.NonGeoSSL4EOS12Sentinel2
+ root: SSL4EO-S12
+ bands: [B2, B3, B4, B8]
+ size: ${patch_size}
+ max_r: *max_r
fit-val:
- name: WeightedKNN
- module: minerva.tasks
+ _target_: minerva.tasks.WeightedKNN
train: false
record_float: true
sample_pairs: false
@@ -167,62 +157,49 @@ tasks:
data_config: '${oc.create:${cfg_load: minerva/inbuilt_cfgs/dataset/Chesapeake7.yaml}}' # yamllint disable-line rule:line-length
# ---+ Minerva Inbuilt Logging Functions +-------------------------
- task_logger: SSLTaskLogger
+ task_logger: minerva.logger.tasklog.SSLTaskLogger
step_logger:
- name: KNNStepLogger
- model_io: ssl_pair_tg
+ _target_: minerva.logger.steplog.KNNStepLogger
+ model_io: minerva.modelio.ssl_pair_torchgeo_io
# ---+ Dataset Parameters +----------------------------------------
dataset_params:
features:
sampler:
- module: torchgeo.samplers
- name: RandomGeoSampler
+ _target_: torchgeo.samplers.RandomGeoSampler
roi: false
- params:
- size: ${patch_size}
- length: 16
+ size: ${patch_size}
+ length: 16
image:
- module: minerva.datasets.__testing
- name: TstImgDataset
+ _target_: minerva.datasets.__testing.TstImgDataset
paths: NAIP
- params:
- res: 1.0
+ res: 1.0
mask:
- module: minerva.datasets.__testing
- name: TstMaskDataset
+ _target_: minerva.datasets.__testing.TstMaskDataset
paths: Chesapeake7
- params:
- res: 1.0
+ res: 1.0
test:
sampler:
- module: torchgeo.samplers
- name: RandomGeoSampler
+ _target_: torchgeo.samplers.RandomGeoSampler
roi: false
- params:
- size: ${patch_size}
- length: 16
+ size: ${patch_size}
+ length: 16
image:
- module: minerva.datasets.__testing
- name: TstImgDataset
+ _target_: minerva.datasets.__testing.TstImgDataset
paths: NAIP
- params:
- res: 1.0
+ res: 1.0
mask:
- module: minerva.datasets.__testing
- name: TstMaskDataset
+ _target_: minerva.datasets.__testing.TstMaskDataset
paths: Chesapeake7
- params:
- res: 1.0
+ res: 1.0
test-test:
- name: StandardEpoch
- module: minerva.tasks
+ _target_: minerva.tasks.StandardEpoch
train: false
record_float: true
sample_pairs: false
@@ -231,32 +208,26 @@ tasks:
data_config: '${oc.create:${cfg_load: minerva/inbuilt_cfgs/dataset/Chesapeake7.yaml}}' # yamllint disable-line rule:line-length
# ---+ Minerva Inbuilt Logging Functions +-------------------------
- task_logger: SSLTaskLogger
- model_io: ssl_pair_tg
+ task_logger: minerva.logger.tasklog.SSLTaskLogger
+ model_io: minerva.modelio.ssl_pair_torchgeo_io
# ---+ Dataset Parameters +----------------------------------------
dataset_params:
sampler:
- module: torchgeo.samplers
- name: RandomGeoSampler
+ _target_: torchgeo.samplers.RandomGeoSampler
roi: false
- params:
- size: ${patch_size}
- length: 16
+ size: ${patch_size}
+ length: 16
image:
- module: minerva.datasets.__testing
- name: TstImgDataset
+ _target_: minerva.datasets.__testing.TstImgDataset
paths: NAIP
- params:
- res: 1.0
+ res: 1.0
mask:
- module: minerva.datasets.__testing
- name: TstMaskDataset
+ _target_: minerva.datasets.__testing.TstMaskDataset
paths: Chesapeake7
- params:
- res: 1.0
+ res: 1.0
# === PLOTTING OPTIONS ========================================================
plots:
diff --git a/minerva/inbuilt_cfgs/example_MultiLabel.yaml b/minerva/inbuilt_cfgs/example_MultiLabel.yaml
index ed33d12fb..b0e120c20 100644
--- a/minerva/inbuilt_cfgs/example_MultiLabel.yaml
+++ b/minerva/inbuilt_cfgs/example_MultiLabel.yaml
@@ -24,7 +24,7 @@ model_type: multilabel-scene-classifier
batch_size: 2 # Number of samples in each batch.
input_size: [4, 120, 120] # patch_size plus leading channel dim.
patch_size: '${to_patch_size: ${input_size}}' # 2D tuple or float.
-n_classes: &n_classes 19 # Number of classes in dataset.
+n_classes: 19 # Number of classes in dataset.
# ---+ Experiment Execution +--------------------------------------------------
max_epochs: 2 # Maximum number of training epochs.
@@ -41,26 +41,25 @@ optim_func: SGD # Name of the optimiser function.
# ---+ Model Parameters +------------------------------------------------------
model_params:
- module:
- params:
+ _target_: minerva.models.FlexiSceneClassifier
+ input_size: ${input_size}
+ n_classes: ${n_classes}
+ encoder_on: true
+ filter_dim: -1
+ fc_dim: 512
+ freeze_backbone: false
+ clamp_outputs: true
+ backbone_args:
+ module: minerva.models
+ name: MinervaPSP
input_size: ${input_size}
- n_classes: *n_classes
- encoder_on: true
- filter_dim: -1
- fc_dim: 512
- freeze_backbone: false
- clamp_outputs: true
- backbone_args:
- module: minerva.models
- name: MinervaPSP
- input_size: ${input_size}
- n_classes: 1
- encoder_name: resnet18
- encoder_weights:
- psp_out_channels: 512
- segmentation_on: false
- classification_on: false
- encoder: false
+ n_classes: 1
+ encoder_name: resnet18
+ encoder_weights:
+ psp_out_channels: 512
+ segmentation_on: false
+ classification_on: false
+ encoder: false
# ---+ Optimiser Parameters +--------------------------------------------------
optimiser:
@@ -88,15 +87,12 @@ record_int: true # Store integer results in memory.
record_float: false # Store floating point results too. Beware memory overload!
# ---+ Collator +--------------------------------------------------------------
-collator:
- module: torchgeo.datasets
- name: stack_samples
+collator: torchgeo.datasets.stack_samples
# === TASKS ===================================================================
tasks:
fit-train:
- name: StandardEpoch
- module: minerva.tasks
+ _target_: minerva.tasks.StandardEpoch
train: true
record_float: true
cache_dataset: false
@@ -106,71 +102,66 @@ tasks:
# ---+ Dataset Parameters +--------------------------------------------
dataset_params:
sampler:
- module: torch.utils.data
- name: RandomSampler
- params:
- num_samples: 32
+ _target_: torch.utils.data.RandomSampler
+ num_samples: 32
image:
transforms:
SelectChannels:
- module: minerva.transforms
+ _target_: minerva.transforms.SelectChannels
channels:
- 1
- 2
- 3
- 7
Normalise:
- module: minerva.transforms
+ _target_: minerva.transforms.Normalise
norm_value: 4095
RandomApply:
p: 0.25
DetachedColorJitter:
- module: minerva.transforms
+ _target_: minerva.transforms.DetachedColorJitter
brightness: 0.2
contrast: 0.1
saturation: 0.1
hue: 0.15
- RandomResizedCrop:
- module: kornia.augmentation
+ resize_crop:
+ _target_: kornia.augmentation.RandomResizedCrop
p: 0.2
size: ${patch_size}
cropping_mode: resample
keepdim: true
- RandomHorizontalFlip:
- module: kornia.augmentation
+ horizontal_flip:
+ _target_: kornia.augmentation.RandomHorizontalFlip
p: 0.2
keepdim: true
- RandomGaussianBlur:
- module: kornia.augmentation
+ gaussian_blur:
+ _target_: kornia.augmentation.RandomGaussianBlur
kernel_size: 9
p: 0.2
sigma: [0.01, 0.2]
keepdim: true
- RandomGaussianNoise:
- module: kornia.augmentation
+ gaussian_noise:
+ _target_: kornia.augmentation.RandomGaussianNoise
p: 0.2
std: 0.05
keepdim: true
- RandomErasing:
- module: kornia.augmentation
+ random_erasing:
+ _target_: kornia.augmentation.RandomErasing
p: 0.2
keepdim: true
- module: torchgeo.datasets
- name: BigEarthNet
- paths: BigEarthNet-mini
- params:
- split: train
- bands: s2
- download: true
- num_classes: *n_classes
+ _target_: torchgeo.datasets.BigEarthNet
+ root: BigEarthNet-mini
+ split: train
+ bands: s2
+ download: true
+ num_classes: ${n_classes}
label:
test-test:
- name: StandardEpoch
- module: minerva.tasks
+ _target_: minerva.tasks.StandardEpoch
train: false
record_float: true
cache_dataset: false
@@ -180,32 +171,28 @@ tasks:
# ---+ Dataset Parameters +--------------------------------------------
dataset_params:
sampler:
- module: torch.utils.data
- name: RandomSampler
- params:
- num_samples: 16
+ _target_: torch.utils.data.RandomSampler
+ num_samples: 16
image:
transforms:
SelectChannels:
- module: minerva.transforms
+ _target_: minerva.transforms.SelectChannels
channels:
- 1
- 2
- 3
- 7
Normalise:
- module: minerva.transforms
+ _target_: minerva.transforms.Normalise
norm_value: 4095
- module: torchgeo.datasets
- name: BigEarthNet
- paths: BigEarthNet-mini
- params:
- split: test
- bands: s2
- download: true
- num_classes: *n_classes
+ _target_: torchgeo.datasets.BigEarthNet
+ root: BigEarthNet-mini
+ split: test
+ bands: s2
+ download: true
+ num_classes: ${n_classes}
label:
diff --git a/minerva/inbuilt_cfgs/example_PSP.yaml b/minerva/inbuilt_cfgs/example_PSP.yaml
index a8483852a..1359f066d 100644
--- a/minerva/inbuilt_cfgs/example_PSP.yaml
+++ b/minerva/inbuilt_cfgs/example_PSP.yaml
@@ -41,16 +41,15 @@ optim_func: SGD # Name of the optimiser function.
# ---+ Model Parameters +------------------------------------------------------
model_params:
- module:
- params:
- input_size: ${input_size}
- n_classes: ${n_classes}
- encoder: true
- segmentation_on: true
- classification_on: true
- upsampling: 8
- aux_params:
- classes: ${n_classes}
+ _target_: minerva.models.MinervaPSP
+ input_size: ${input_size}
+ n_classes: ${n_classes}
+ encoder: true
+ segmentation_on: true
+ classification_on: true
+ upsampling: 8
+ aux_params:
+ classes: ${n_classes}
# ---+ Optimiser Parameters +--------------------------------------------------
optimiser:
@@ -78,15 +77,12 @@ record_int: true # Store integer results in memory.
record_float: false # Store floating point results too. Beware memory overload!
# ---+ Collator +--------------------------------------------------------------
-collator:
- module: torchgeo.datasets
- name: stack_samples
+collator: torchgeo.datasets.stack_samples
# === TASKS ===================================================================
tasks:
fit-train:
- name: StandardEpoch
- module: minerva.tasks
+ _target_: minerva.tasks.StandardEpoch
train: true
record_float: true
@@ -95,68 +91,64 @@ tasks:
# ---+ Dataset Parameters +--------------------------------------------
dataset_params:
sampler:
- module: torch.utils.data
- name: RandomSampler
- params:
- num_samples: 32
+ _target_: torch.utils.data.RandomSampler
+ num_samples: 32
transforms:
RandomResizedCrop:
- module: kornia.augmentation
+ _target_: kornia.augmentation.RandomResizedCrop
p: 1.0
size: ${patch_size}
cropping_mode: resample
keepdim: true
RandomHorizontalFlip:
- module: kornia.augmentation
+ _target_: kornia.augmentation.RandomHorizontalFlip
p: 0.2
keepdim: true
image:
transforms:
- Normalise:
- module: minerva.transforms
+ normalise:
+ _target_: minerva.transforms.Normalise
norm_value: 4095
RandomApply:
p: 0.25
- DetachedColorJitter:
- module: minerva.transforms
+ jitter:
+ _target_: minerva.transforms.DetachedColorJitter
brightness: 0.2
contrast: 0.1
saturation: 0.1
hue: 0.15
- RandomGaussianBlur:
- module: kornia.augmentation
+ gaussian_blur:
+ _target_: kornia.augmentation.RandomGaussianBlur
kernel_size: 9
p: 0.2
sigma: [0.01, 0.2]
keepdim: true
- RandomGaussianNoise:
- module: kornia.augmentation
+ gaussian_noise:
+ _target_: kornia.augmentation.RandomGaussianNoise
p: 0.2
std: 0.05
keepdim: true
- RandomErasing:
- module: kornia.augmentation
+ random_erasing:
+ _target_: kornia.augmentation.RandomErasing
p: 0.2
keepdim: true
- module: minerva.datasets
- name: DFC2020
- paths: DFC/DFC2020
- params:
- split: test
- use_s2hr: true
- labels: true
+ _target_: minerva.datasets.DFC2020
+ root: DFC/DFC2020
+ split: test
+ use_s2hr: true
+ labels: true
mask:
transforms:
MaskResize:
- module: minerva.transforms
+ _target_: minerva.transforms.MaskResize
size: 64
+ interpolation: NEAREST
test-test:
- name: StandardEpoch
- module: minerva.tasks
+ _target_: minerva.tasks.StandardEpoch
train: false
record_float: true
@@ -165,30 +157,27 @@ tasks:
# ---+ Dataset Parameters +--------------------------------------------
dataset_params:
sampler:
- module: torch.utils.data
- name: RandomSampler
- params:
- num_samples: 16
+ _target_: torch.utils.data.RandomSampler
+ num_samples: 16
image:
transforms:
Normalise:
- module: minerva.transforms
+ _target_: minerva.transforms.Normalise
norm_value: 4095
- module: minerva.datasets
- name: DFC2020
- paths: DFC/DFC2020
- params:
- split: test
- use_s2hr: true
- labels: true
+ _target_: minerva.datasets.DFC2020
+ root: DFC/DFC2020
+ split: test
+ use_s2hr: true
+ labels: true
mask:
transforms:
MaskResize:
- module: minerva.transforms
+ _target_: minerva.transforms.MaskResize
size: 64
+ interpolation: NEAREST
# === PLOTTING OPTIONS ========================================================
plots:
diff --git a/minerva/inbuilt_cfgs/example_SceneClassifier.yaml b/minerva/inbuilt_cfgs/example_SceneClassifier.yaml
index 5f35abcdb..716355437 100644
--- a/minerva/inbuilt_cfgs/example_SceneClassifier.yaml
+++ b/minerva/inbuilt_cfgs/example_SceneClassifier.yaml
@@ -24,7 +24,7 @@ model_type: scene-classifier
batch_size: 2 # Number of samples in each batch.
input_size: [4, 64, 64] # patch_size plus leading channel dim.
patch_size: '${to_patch_size: ${input_size}}' # 2D tuple or float.
-n_classes: &n_classes 10 # Number of classes in dataset.
+n_classes: 10 # Number of classes in dataset.
# ---+ Experiment Execution +--------------------------------------------------
max_epochs: 2 # Maximum number of training epochs.
@@ -41,25 +41,24 @@ optim_func: SGD # Name of the optimiser function.
# ---+ Model Parameters +------------------------------------------------------
model_params:
- module:
- params:
+ _target_: minerva.models.FlexiSceneClassifier
+ input_size: ${input_size}
+ n_classes: ${n_classes}
+ encoder_on: true
+ filter_dim: -1
+ fc_dim: 512
+ freeze_backbone: false
+ backbone_args:
+ module: minerva.models
+ name: MinervaPSP
input_size: ${input_size}
- n_classes: *n_classes
- encoder_on: true
- filter_dim: -1
- fc_dim: 512
- freeze_backbone: false
- backbone_args:
- module: minerva.models
- name: MinervaPSP
- input_size: ${input_size}
- n_classes: 1
- encoder_name: resnet18
- encoder_weights:
- psp_out_channels: 512
- segmentation_on: false
- classification_on: false
- encoder: false
+ n_classes: 1
+ encoder_name: resnet18
+ encoder_weights:
+ psp_out_channels: 512
+ segmentation_on: false
+ classification_on: false
+ encoder: false
# ---+ Optimiser Parameters +--------------------------------------------------
optimiser:
@@ -87,15 +86,12 @@ record_int: true # Store integer results in memory.
record_float: false # Store floating point results too. Beware memory overload!
# ---+ Collator +--------------------------------------------------------------
-collator:
- module: torchgeo.datasets
- name: stack_samples
+collator: torchgeo.datasets.stack_samples
# === TASKS ===================================================================
tasks:
fit-train:
- name: StandardEpoch
- module: minerva.tasks
+ _target_: minerva.tasks.StandardEpoch
train: true
record_float: true
@@ -104,67 +100,62 @@ tasks:
# ---+ Dataset Parameters +--------------------------------------------
dataset_params:
sampler:
- module: torch.utils.data
- name: RandomSampler
- params:
- num_samples: 32
+ _target_: torch.utils.data.RandomSampler
+ num_samples: 32
image:
transforms:
- Normalise:
- module: minerva.transforms
- norm_value: 4095
+ normalise:
+ _target_: minerva.transforms.Normalise
+ norm_value: 255
RandomApply:
p: 0.25
- DetachedColorJitter:
- module: minerva.transforms
+ jitter:
+ _target_: minerva.transforms.DetachedColorJitter
brightness: 0.2
contrast: 0.1
saturation: 0.1
hue: 0.15
- RandomResizedCrop:
- module: kornia.augmentation
+ resize_crop:
+ _target_: kornia.augmentation.RandomResizedCrop
p: 0.2
size: ${patch_size}
cropping_mode: resample
keepdim: true
- RandomHorizontalFlip:
- module: kornia.augmentation
+ horizontal_flip:
+ _target_: kornia.augmentation.RandomHorizontalFlip
p: 0.2
keepdim: true
- RandomGaussianBlur:
- module: kornia.augmentation
+ gaussian_blur:
+ _target_: kornia.augmentation.RandomGaussianBlur
kernel_size: 9
p: 0.2
sigma: [0.01, 0.2]
keepdim: true
- RandomGaussianNoise:
- module: kornia.augmentation
+ gaussian_noise:
+ _target_: kornia.augmentation.RandomGaussianNoise
p: 0.2
std: 0.05
keepdim: true
- RandomErasing:
- module: kornia.augmentation
+ random_erasing:
+ _target_: kornia.augmentation.RandomErasing
p: 0.2
keepdim: true
- module: torchgeo.datasets
- name: EuroSAT100
- paths: EuroSAT100
- params:
- split: train
- bands:
- - B02
- - B03
- - B04
- - B08
- download: true
+ _target_: torchgeo.datasets.EuroSAT100
+ root: EuroSAT100
+ split: train
+ bands:
+ - B02
+ - B03
+ - B04
+ - B08
+ download: true
label:
test-test:
- name: StandardEpoch
- module: minerva.tasks
+ _target_: minerva.tasks.StandardEpoch
train: false
record_float: true
@@ -173,28 +164,24 @@ tasks:
# ---+ Dataset Parameters +--------------------------------------------
dataset_params:
sampler:
- module: torch.utils.data
- name: RandomSampler
- params:
- num_samples: 16
+ _target_: torch.utils.data.RandomSampler
+ num_samples: 16
image:
transforms:
Normalise:
- module: minerva.transforms
+ _target_: minerva.transforms.Normalise
norm_value: 4095
- module: torchgeo.datasets
- name: EuroSAT100
- paths: EuroSAT100
- params:
- split: test
- bands:
- - B02
- - B03
- - B04
- - B08
- download: true
+ _target_: torchgeo.datasets.EuroSAT100
+ root: EuroSAT100
+ split: test
+ bands:
+ - B02
+ - B03
+ - B04
+ - B08
+ download: true
label:
diff --git a/minerva/inbuilt_cfgs/example_UNetR_config.yaml b/minerva/inbuilt_cfgs/example_UNetR_config.yaml
index b720f76ba..338488294 100644
--- a/minerva/inbuilt_cfgs/example_UNetR_config.yaml
+++ b/minerva/inbuilt_cfgs/example_UNetR_config.yaml
@@ -24,10 +24,10 @@ model_type: segmentation
batch_size: 8 # Number of samples in each batch.
input_size: [4, 224, 224] # patch_size plus leading channel dim.
patch_size: '${to_patch_size: ${input_size}}' # 2D tuple or float.
-n_classes: &n_classes 8 # Number of classes in dataset.
+n_classes: 8 # Number of classes in dataset.
# ---+ Experiment Execution +--------------------------------------------------
-max_epochs: 5 # Maximum number of training epochs.
+max_epochs: 2 # Maximum number of training epochs.
elim: true # Eliminates empty classes from schema.
balance: true # Balances dataset classes.
pre_train: false # Activate pre-training mode.
@@ -41,12 +41,11 @@ optim_func: SGD # Name of the optimiser function.
# ---+ Model Parameters +------------------------------------------------------
model_params:
- module:
- params:
- input_size: ${input_size}
- n_classes: *n_classes
- backbone_kwargs:
- torch_weights: true
+ _target_: minerva.models.UNetR18
+ input_size: ${input_size}
+ n_classes: ${n_classes}
+ backbone_kwargs:
+ torch_weights: true
# ---+ Optimiser Parameters +--------------------------------------------------
optimiser:
@@ -67,25 +66,16 @@ wandb_log: true # Activates wandb logging.
project: pytest # Define the project name for wandb.
wandb_dir: /test/tmp/wandb # Directory to store wandb logs locally.
-# === MODEL IO & LOGGING ======================================================
-# ---+ Minerva Inbuilt Logging Functions +-------------------------------------
-# task_logger: SupervisedTaskLogger
-# step_logger: SupervisedGeoStepLogger
-# model_io: sup_tg
-
record_int: true # Store integer results in memory.
record_float: false # Store floating point results too. Beware memory overload!
# ---+ Collator +--------------------------------------------------------------
-collator:
- module: torchgeo.datasets
- name: stack_samples
+collator: torchgeo.datasets.stack_samples
# === TASKS ===================================================================
tasks:
fit-train:
- name: StandardEpoch
- module: minerva.tasks
+ _target_: minerva.tasks.StandardEpoch
train: true
record_float: true
@@ -95,33 +85,26 @@ tasks:
# ---+ Dataset Parameters +----------------------------------------
dataset_params:
sampler:
- module: torchgeo.samplers
- name: RandomGeoSampler
+ _target_: torchgeo.samplers.RandomGeoSampler
roi: false
- params:
- size: ${patch_size}
- length: 320
+ size: ${patch_size}
+ length: 320
image:
transforms:
Normalise:
- module: minerva.transforms
+ _target_: minerva.transforms.Normalise
norm_value: 255
- module: minerva.datasets.__testing
- name: TstImgDataset
+ _target_: minerva.datasets.__testing.TstImgDataset
paths: NAIP
- params:
- res: 1.0
+ res: 1.0
mask:
- module: minerva.datasets.__testing
- name: TstMaskDataset
+ _target_: minerva.datasets.__testing.TstMaskDataset
paths: Chesapeake7
- params:
- res: 1.0
+ res: 1.0
fit-val:
- name: StandardEpoch
- module: minerva.tasks
+ _target_: minerva.tasks.StandardEpoch
train: true
record_float: true
@@ -131,33 +114,26 @@ tasks:
# ---+ Dataset Parameters +----------------------------------------
dataset_params:
sampler:
- module: torchgeo.samplers
- name: RandomGeoSampler
+ _target_: torchgeo.samplers.RandomGeoSampler
roi: false
- params:
- size: ${patch_size}
- length: 80
+ size: ${patch_size}
+ length: 80
image:
transforms:
Normalise:
- module: minerva.transforms
+ _target_: minerva.transforms.Normalise
norm_value: 255
- module: minerva.datasets.__testing
- name: TstImgDataset
+ _target_: minerva.datasets.__testing.TstImgDataset
paths: NAIP
- params:
- res: 1.0
+ res: 1.0
mask:
- module: minerva.datasets.__testing
- name: TstMaskDataset
+ _target_: minerva.datasets.__testing.TstMaskDataset
paths: Chesapeake7
- params:
- res: 1.0
+ res: 1.0
test-test:
- name: StandardEpoch
- module: minerva.tasks
+ _target_: minerva.tasks.StandardEpoch
train: true
record_float: true
@@ -167,28 +143,22 @@ tasks:
# ---+ Dataset Parameters +----------------------------------------
dataset_params:
sampler:
- module: torchgeo.samplers
- name: RandomGeoSampler
+ _target_: torchgeo.samplers.RandomGeoSampler
roi: false
- params:
- size: ${patch_size}
- length: 160
+ size: ${patch_size}
+ length: 160
image:
transforms:
Normalise:
- module: minerva.transforms
+ _target_: minerva.transforms.Normalise
norm_value: 255
- module: minerva.datasets.__testing
- name: TstImgDataset
+ _target_: minerva.datasets.__testing.TstImgDataset
paths: NAIP
- params:
- res: 1.0
+ res: 1.0
mask:
- module: minerva.datasets.__testing
- name: TstMaskDataset
+ _target_: minerva.datasets.__testing.TstMaskDataset
paths: Chesapeake7
- params:
- res: 1.0
+ res: 1.0
# === PLOTTING OPTIONS ========================================================
plots:
diff --git a/minerva/inbuilt_cfgs/example_autoencoder_config.yaml b/minerva/inbuilt_cfgs/example_autoencoder_config.yaml
index cb162c699..78e45157d 100644
--- a/minerva/inbuilt_cfgs/example_autoencoder_config.yaml
+++ b/minerva/inbuilt_cfgs/example_autoencoder_config.yaml
@@ -41,12 +41,11 @@ optim_func: SGD # Name of the optimiser function.
# ---+ Model Parameters +------------------------------------------------------
model_params:
- module:
- params:
- input_size: ${input_size}
- n_classes: ${n_classes}
- backbone_kwargs:
- torch_weights: false
+ _target_: minerva.models.UNetR18
+ input_size: ${input_size}
+ n_classes: ${n_classes}
+ backbone_kwargs:
+ torch_weights: false
# ---+ Optimiser Parameters +--------------------------------------------------
optimiser:
@@ -69,24 +68,21 @@ wandb_dir: /test/tmp/wandb # Directory to store wandb logs locally.
# === MODEL IO & LOGGING ======================================================
# ---+ Minerva Inbuilt Logging Functions +-------------------------------------
-task_logger: SupervisedTaskLogger
+task_logger: minerva.logger.tasklog.SupervisedTaskLogger
step_logger:
- name: SupervisedStepLogger
-model_io: autoencoder_io
+ _target_: minerva.logger.steplog.SupervisedStepLogger
+model_io: minerva.modelio.autoencoder_io
record_int: true # Store integer results in memory.
record_float: false # Store floating point results too. Beware memory overload!
# ---+ Collator +--------------------------------------------------------------
-collator:
- module: torchgeo.datasets
- name: stack_samples
+collator: torchgeo.datasets.stack_samples
# === TASKS ===================================================================
tasks:
fit-train:
- name: StandardEpoch
- module: minerva.tasks
+ _target_: minerva.tasks.StandardEpoch
train: true
record_float: true
autoencoder_data_key: mask
@@ -97,34 +93,27 @@ tasks:
# ---+ Dataset Parameters +--------------------------------------------
dataset_params:
sampler:
- module: torchgeo.samplers
- name: RandomGeoSampler
+ _target_: torchgeo.samplers.RandomGeoSampler
roi: false
- params:
- size: ${patch_size}
- length: 36
+ size: ${patch_size}
+ length: 36
image:
transforms:
Normalise:
- module: minerva.transforms
+ _target_: minerva.transforms.Normalise
norm_value: 255
- module: minerva.datasets.__testing
- name: TstImgDataset
+ _target_: minerva.datasets.__testing.TstImgDataset
paths: NAIP
- params:
- res: 1.0
+ res: 1.0
mask:
- module: minerva.datasets.__testing
- name: TstMaskDataset
+ _target_: minerva.datasets.__testing.TstMaskDataset
paths: Chesapeake7
- params:
- res: 1.0
+ res: 1.0
fit-val:
- name: StandardEpoch
- module: minerva.tasks
+ _target_: minerva.tasks.StandardEpoch
train: false
record_float: true
autoencoder_data_key: mask
@@ -135,34 +124,27 @@ tasks:
# ---+ Dataset Parameters +--------------------------------------------
dataset_params:
sampler:
- module: torchgeo.samplers
- name: RandomGeoSampler
+ _target_: torchgeo.samplers.RandomGeoSampler
roi: false
- params:
- size: ${patch_size}
- length: 18
+ size: ${patch_size}
+ length: 18
image:
transforms:
Normalise:
- module: minerva.transforms
+ _target_: minerva.transforms.Normalise
norm_value: 255
- module: minerva.datasets.__testing
- name: TstImgDataset
+ _target_: minerva.datasets.__testing.TstImgDataset
paths: NAIP
- params:
- res: 1.0
+ res: 1.0
mask:
- module: minerva.datasets.__testing
- name: TstMaskDataset
+ _target_: minerva.datasets.__testing.TstMaskDataset
paths: Chesapeake7
- params:
- res: 1.0
+ res: 1.0
test-test:
- name: StandardEpoch
- module: minerva.tasks
+ _target_: minerva.tasks.StandardEpoch
train: false
record_float: true
autoencoder_data_key: mask
@@ -173,30 +155,24 @@ tasks:
# ---+ Dataset Parameters +--------------------------------------------
dataset_params:
sampler:
- module: torchgeo.samplers
- name: RandomGeoSampler
+ _target_: torchgeo.samplers.RandomGeoSampler
roi: false
- params:
- size: ${patch_size}
- length: 18
+ size: ${patch_size}
+ length: 18
image:
transforms:
Normalise:
- module: minerva.transforms
+ _target_: minerva.transforms.Normalise
norm_value: 255
- module: minerva.datasets.__testing
- name: TstImgDataset
+ _target_: minerva.datasets.__testing.TstImgDataset
paths: NAIP
- params:
- res: 1.0
+ res: 1.0
mask:
- module: minerva.datasets.__testing
- name: TstMaskDataset
+ _target_: minerva.datasets.__testing.TstMaskDataset
paths: Chesapeake7
- params:
- res: 1.0
+ res: 1.0
# === PLOTTING OPTIONS ========================================================
plots:
diff --git a/minerva/inbuilt_cfgs/example_config.yaml b/minerva/inbuilt_cfgs/example_config.yaml
index 65991055e..b1370f014 100644
--- a/minerva/inbuilt_cfgs/example_config.yaml
+++ b/minerva/inbuilt_cfgs/example_config.yaml
@@ -14,7 +14,7 @@ cache_dir: tests/tmp/cache
# === HYPERPARAMETERS =========================================================
# ---+ Model Specification +---------------------------------------------------
-# Name of model. Substring before hyphen is model class.
+# Name of model. This no longer used for model class (see model_params).
model_name: FCN32ResNet18-test
# Type of model. Can be mlp, scene classifier, segmentation, ssl or siamese.
@@ -40,11 +40,10 @@ optim_func: SGD # Name of the optimiser function.
# ---+ Model Parameters +------------------------------------------------------
model_params:
- module:
- params:
- input_size: ${input_size}
- n_classes: *n_classes
- # any other params...
+ _target_: minerva.models.FCN32ResNet18
+ input_size: ${input_size}
+ n_classes: ${n_classes}
+ # any other params...
# ---+ Optimiser Parameters +--------------------------------------------------
optimiser:
@@ -52,12 +51,11 @@ optimiser:
lr: ${lr}
# ---+ Scheduler Parameters +--------------------------------------------------
-scheduler_params:
- name: LinearLR
- params:
- start_factor: 1.0
- end_factor: 0.5
- total_iters: 5
+scheduler:
+ _target_: torch.optim.lr_scheduler.LinearLR
+ start_factor: 1.0
+ end_factor: 0.5
+ total_iters: 5
# ---+ Loss Function Parameters +----------------------------------------------
loss_params:
@@ -75,15 +73,12 @@ project: pytest # Define the project name for wandb.
wandb_dir: /test/tmp/wandb # Directory to store wandb logs locally.
# ---+ Collator +--------------------------------------------------------------
-collator:
- module: torchgeo.datasets
- name: stack_samples
+collator: torchgeo.datasets.stack_samples
# === TASKS ===================================================================
tasks:
fit-train:
- name: StandardEpoch
- module: minerva.tasks
+ _target_: minerva.tasks.StandardEpoch
train: true
record_float: true
@@ -93,40 +88,32 @@ tasks:
# ---+ Dataset Parameters +----------------------------------------
dataset_params:
sampler:
- module: torchgeo.samplers
- name: RandomGeoSampler
+ _target_: torchgeo.samplers.RandomGeoSampler
roi: false
- params:
- size: ${patch_size}
- length: 120
+ size: ${patch_size}
+ length: 120
image:
transforms: false
- images_1:
- module: minerva.datasets.__testing
- name: TstImgDataset
- paths: NAIP
- params:
+ subdatasets:
+ images_1:
+ _target_: minerva.datasets.__testing.TstImgDataset
+ paths: NAIP
res: 1.0
- image2:
- module: minerva.datasets.__testing
- name: TstImgDataset
- paths: NAIP
- params:
+ image2:
+ _target_: minerva.datasets.__testing.TstImgDataset
+ paths: NAIP
res: 1.0
mask:
transforms: false
- module: minerva.datasets.__testing
- name: TstMaskDataset
+ _target_: minerva.datasets.__testing.TstMaskDataset
paths: Chesapeake7
- params:
- res: 1.0
+ res: 1.0
fit-val:
- name: StandardEpoch
- module: minerva.tasks
+ _target_: minerva.tasks.StandardEpoch
train: false
record_float: true
@@ -134,74 +121,59 @@ tasks:
data_config: '${oc.create:${cfg_load: minerva/inbuilt_cfgs/dataset/Chesapeake7.yaml}}' # yamllint disable-line rule:line-length
# ---+ Minerva Inbuilt Logging Functions +-------------------------
- # logger: STGLogger
- # metrics: SPMetrics
- # model_io: sup_tg
+ task_logger: minerva.logger.tasklog.SupervisedTaskLogger
+ model_io: minerva.modelio.supervised_torchgeo_io
# ---+ Dataset Parameters +----------------------------------------
dataset_params:
sampler:
- module: torchgeo.samplers
- name: RandomGeoSampler
+ _target_: torchgeo.samplers.RandomGeoSampler
roi: false
- params:
- size: ${patch_size}
- length: 32
+ size: ${patch_size}
+ length: 32
image:
transforms: false
- module: minerva.datasets.__testing
- name: TstImgDataset
+ _target_: minerva.datasets.__testing.TstImgDataset
paths: NAIP
- params:
- res: 1.0
+ res: 1.0
mask:
transforms: false
- module: minerva.datasets.__testing
- name: TstMaskDataset
+ _target_: minerva.datasets.__testing.TstMaskDataset
paths: Chesapeake7
- params:
- res: 1.0
+ res: 1.0
test-test:
- name: StandardEpoch
- module: minerva.tasks
+ _target_: minerva.tasks.StandardEpoch
record_float: true
imagery_config: '${oc.create:${cfg_load: minerva/inbuilt_cfgs/dataset/NAIP.yaml}}' # yamllint disable-line rule:line-length
data_config: '${oc.create:${cfg_load: minerva/inbuilt_cfgs/dataset/Chesapeake7.yaml}}' # yamllint disable-line rule:line-length
# ---+ Minerva Inbuilt Logging Functions +-------------------------
- # logger: STGLogger
- # metrics: SPMetrics
- # model_io: sup_tg
+ task_logger: minerva.logger.tasklog.SupervisedTaskLogger
+ model_io: minerva.modelio.supervised_torchgeo_io
# ---+ Dataset Parameters +----------------------------------------
dataset_params:
sampler:
- module: torchgeo.samplers
- name: RandomGeoSampler
+ _target_: torchgeo.samplers.RandomGeoSampler
roi: false
- params:
- size: ${patch_size}
- length: 32
+ size: ${patch_size}
+ length: 32
image:
transforms: false
- module: minerva.datasets.__testing
- name: TstImgDataset
+ _target_: minerva.datasets.__testing.TstImgDataset
paths: NAIP
- params:
- res: 1.0
+ res: 1.0
mask:
transforms: false
- module: minerva.datasets.__testing
- name: TstMaskDataset
+ _target_: minerva.datasets.__testing.TstMaskDataset
paths: Chesapeake7
- params:
- res: 1.0
+ res: 1.0
# === PLOTTING OPTIONS ========================================================
plots:
diff --git a/minerva/logger/steplog.py b/minerva/logger/steplog.py
index 20cb0b13e..0b3bc9bde 100644
--- a/minerva/logger/steplog.py
+++ b/minerva/logger/steplog.py
@@ -38,7 +38,6 @@
"SupervisedStepLogger",
"SSLStepLogger",
"KNNStepLogger",
- "get_logger",
]
# =====================================================================================================================
@@ -47,16 +46,7 @@
import abc
import math
from abc import ABC
-from typing import (
- TYPE_CHECKING,
- Any,
- Callable,
- Dict,
- Optional,
- SupportsFloat,
- Tuple,
- Union,
-)
+from typing import TYPE_CHECKING, Any, Callable, Optional, SupportsFloat
import mlflow
import numpy as np
@@ -76,7 +66,7 @@
from wandb.sdk.wandb_run import Run
from minerva.utils import utils
-from minerva.utils.utils import check_substrings_in_string, func_by_str
+from minerva.utils.utils import check_substrings_in_string
# =====================================================================================================================
# GLOBALS
@@ -133,10 +123,10 @@ def __init__(
task_name: str,
n_batches: int,
batch_size: int,
- output_size: Tuple[int, ...],
+ output_size: tuple[int, ...],
record_int: bool = True,
record_float: bool = False,
- writer: Optional[Union[SummaryWriter, Run]] = None,
+ writer: Optional[SummaryWriter | Run] = None,
model_type: str = "",
**kwargs,
) -> None:
@@ -157,8 +147,8 @@ def __init__(
self.model_type = model_type
- self.logs: Dict[str, Any] = {}
- self.results: Dict[str, Any] = {}
+ self.logs: dict[str, Any] = {}
+ self.results: dict[str, Any] = {}
def __call__(
self,
@@ -186,7 +176,7 @@ def log(
loss: Tensor,
z: Optional[Tensor] = None,
y: Optional[Tensor] = None,
- index: Optional[Union[int, BoundingBox]] = None,
+ index: Optional[int | BoundingBox] = None,
*args,
**kwargs,
) -> None:
@@ -243,7 +233,7 @@ def write_metric(
mlflow.log_metric(key, value) # pragma: no cover
@property
- def get_logs(self) -> Dict[str, Any]:
+ def get_logs(self) -> dict[str, Any]:
"""Gets the logs dictionary.
Returns:
@@ -252,7 +242,7 @@ def get_logs(self) -> Dict[str, Any]:
return self.logs
@property
- def get_results(self) -> Dict[str, Any]:
+ def get_results(self) -> dict[str, Any]:
"""Gets the results dictionary.
Returns:
@@ -309,10 +299,10 @@ def __init__(
task_name: str,
n_batches: int,
batch_size: int,
- output_size: Tuple[int, int],
+ output_size: tuple[int, int],
record_int: bool = True,
record_float: bool = False,
- writer: Optional[Union[SummaryWriter, Run]] = None,
+ writer: Optional[SummaryWriter | Run] = None,
model_type: str = "",
n_classes: Optional[int] = None,
**kwargs,
@@ -330,13 +320,13 @@ def __init__(
if n_classes is None:
raise ValueError("`n_classes` must be specified for this type of logger!")
- self.logs: Dict[str, Any] = {
+ self.logs: dict[str, Any] = {
"batch_num": 0,
"total_loss": 0.0,
"total_correct": 0.0,
}
- self.results: Dict[str, Any] = {
+ self.results: dict[str, Any] = {
"y": None,
"z": None,
"probs": None,
@@ -354,7 +344,7 @@ def __init__(
# Allocate memory for the integer values to be recorded.
if self.record_int:
- int_log_shape: Tuple[int, ...]
+ int_log_shape: tuple[int, ...]
if check_substrings_in_string(self.model_type, "scene-classifier"):
if check_substrings_in_string(self.model_type, "multilabel"):
int_log_shape = (self.n_batches, self.batch_size, n_classes)
@@ -375,7 +365,7 @@ def __init__(
# Allocate memory for the floating point values to be recorded.
if self.record_float:
- float_log_shape: Tuple[int, ...]
+ float_log_shape: tuple[int, ...]
if check_substrings_in_string(self.model_type, "scene-classifier"):
float_log_shape = (self.n_batches, self.batch_size, n_classes)
else:
@@ -416,7 +406,7 @@ def log(
loss: Tensor,
z: Optional[Tensor] = None,
y: Optional[Tensor] = None,
- index: Optional[Union[int, BoundingBox]] = None,
+ index: Optional[int | BoundingBox] = None,
*args,
**kwargs,
) -> None:
@@ -543,7 +533,7 @@ def __init__(
batch_size: int,
record_int: bool = True,
record_float: bool = False,
- writer: Optional[Union[SummaryWriter, Run]] = None,
+ writer: Optional[SummaryWriter | Run] = None,
model_type: str = "",
**kwargs,
) -> None:
@@ -558,14 +548,14 @@ def __init__(
**kwargs,
)
- self.logs: Dict[str, Any] = {
+ self.logs: dict[str, Any] = {
"batch_num": 0,
"total_loss": 0.0,
"total_correct": 0.0,
"total_top5": 0.0,
}
- self.results: Dict[str, Any] = {
+ self.results: dict[str, Any] = {
"y": None,
"z": None,
"probs": None,
@@ -580,7 +570,7 @@ def log(
loss: Tensor,
z: Optional[Tensor] = None,
y: Optional[Tensor] = None,
- index: Optional[Union[int, BoundingBox]] = None,
+ index: Optional[int | BoundingBox] = None,
*args,
**kwargs,
) -> None:
@@ -663,10 +653,10 @@ def __init__(
task_name: str,
n_batches: int,
batch_size: int,
- output_size: Tuple[int, int],
+ output_size: tuple[int, int],
record_int: bool = True,
record_float: bool = False,
- writer: Optional[Union[SummaryWriter, Run]] = None,
+ writer: Optional[SummaryWriter | Run] = None,
model_type: str = "",
**kwargs,
) -> None:
@@ -682,7 +672,7 @@ def __init__(
**kwargs,
)
- self.logs: Dict[str, Any] = {
+ self.logs: dict[str, Any] = {
"batch_num": 0,
"total_loss": 0.0,
"avg_loss": 0.0,
@@ -707,7 +697,7 @@ def log(
loss: Tensor,
z: Optional[Tensor] = None,
y: Optional[Tensor] = None,
- index: Optional[Union[int, BoundingBox]] = None,
+ index: Optional[int | BoundingBox] = None,
*args,
**kwargs,
) -> None:
@@ -798,19 +788,3 @@ def log(
# Writes the loss to the writer.
self.write_metric("loss", ls, step_num=global_step_num)
-
-
-# =====================================================================================================================
-# METHODS
-# =====================================================================================================================
-def get_logger(name) -> Callable[..., Any]:
- """Gets the constructor for a step logger to log the results from each step of model fitting during an epoch.
-
- Returns:
- ~typing.Callable[..., ~typing.Any]: The constructor of :class:`~logging.step.log.MinervaStepLogger`
- to be intialised within the epoch.
-
- .. versionadded:: 0.27
- """
- logger: Callable[..., Any] = func_by_str("minerva.logger.steplog", name)
- return logger
diff --git a/minerva/logger/tasklog.py b/minerva/logger/tasklog.py
index 43f4c42dc..f8ad92735 100644
--- a/minerva/logger/tasklog.py
+++ b/minerva/logger/tasklog.py
@@ -42,13 +42,14 @@
# =====================================================================================================================
import abc
from abc import ABC
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
+from typing import TYPE_CHECKING, Any, Optional
if TYPE_CHECKING: # pragma: no cover
from torch.utils.tensorboard.writer import SummaryWriter
else: # pragma: no cover
SummaryWriter = None
+import hydra
import numpy as np
from torch import Tensor
from torchgeo.datasets.utils import BoundingBox
@@ -56,7 +57,7 @@
from minerva.utils.utils import check_substrings_in_string
-from .steplog import MinervaStepLogger, get_logger
+from .steplog import MinervaStepLogger
# =====================================================================================================================
@@ -85,8 +86,8 @@ class MinervaTaskLogger(ABC):
__metaclass__ = abc.ABCMeta
- metric_types: List[str] = []
- special_metric_types: List[str] = []
+ metric_types: list[str] = []
+ special_metric_types: list[str] = []
logger_cls: str
def __init__(
@@ -94,11 +95,11 @@ def __init__(
task_name: str,
n_batches: int,
batch_size: int,
- output_size: Tuple[int, ...],
- step_logger_params: Optional[Dict[str, Any]] = None,
+ output_size: tuple[int, ...],
+ step_logger_params: Optional[dict[str, Any]] = None,
record_int: bool = True,
record_float: bool = False,
- writer: Optional[Union[SummaryWriter, Run]] = None,
+ writer: Optional[SummaryWriter | Run] = None,
**params,
) -> None:
super(MinervaTaskLogger, self).__init__()
@@ -119,15 +120,14 @@ def __init__(
self.writer = writer
- if isinstance(step_logger_params, dict):
- self._logger = get_logger(step_logger_params.get("name", self.logger_cls))
- if "params" not in step_logger_params:
- step_logger_params["params"] = {}
-
+ if not isinstance(step_logger_params, dict):
+ step_logger_params = {"_target_": self.logger_cls}
+ elif "_target_" not in step_logger_params:
+ step_logger_params["_target_"] = self.logger_cls
else:
- step_logger_params = {"params": {}}
+ pass
- step_logger_params["params"]["n_classes"] = self.n_classes
+ step_logger_params["n_classes"] = self.n_classes
self.step_logger_params = step_logger_params
@@ -137,7 +137,7 @@ def __init__(
self.metric_types += self.special_metric_types
# Creates a dict to hold the loss and accuracy results from training, validation and testing.
- self.metrics: Dict[str, Any] = {}
+ self.metrics: dict[str, Any] = {}
for metric in self.metric_types:
self.metrics[f"{self.task_name}_{metric}"] = {"x": [], "y": []}
@@ -147,7 +147,8 @@ def _make_logger(self) -> None:
.. note::
Will overwrite ``self.logger`` with new logger.
"""
- self.step_logger: MinervaStepLogger = self._logger(
+ self.step_logger: MinervaStepLogger = hydra.utils.instantiate(
+ self.step_logger_params,
task_name=self.task_name,
n_batches=self.n_batches,
batch_size=self.batch_size,
@@ -156,7 +157,6 @@ def _make_logger(self) -> None:
record_float=self.record_float,
writer=self.writer,
model_type=self.model_type,
- **self.step_logger_params.get("params", {}),
)
def refresh_step_logger(self) -> None:
@@ -207,7 +207,7 @@ def calc_metrics(self, epoch_no: int) -> None:
self.log_epoch_number(epoch_no)
@abc.abstractmethod
- def _calc_metrics(self, logs: Dict[str, Any]) -> None:
+ def _calc_metrics(self, logs: dict[str, Any]) -> None:
"""Updates metrics with epoch results.
Must be defined before use.
@@ -226,7 +226,7 @@ def log_epoch_number(self, epoch_no: int) -> None:
self.metrics[metric]["x"].append(epoch_no)
@property
- def get_metrics(self) -> Dict[str, Any]:
+ def get_metrics(self) -> dict[str, Any]:
"""Get the ``metrics`` dictionary.
Returns:
@@ -235,7 +235,7 @@ def get_metrics(self) -> Dict[str, Any]:
return self.metrics
@property
- def get_logs(self) -> Dict[str, Any]:
+ def get_logs(self) -> dict[str, Any]:
"""Get the logs of each step from the latest epoch of the task.
Returns:
@@ -246,7 +246,7 @@ def get_logs(self) -> Dict[str, Any]:
return self.step_logger.get_logs
@property
- def get_results(self) -> Dict[str, Any]:
+ def get_results(self) -> dict[str, Any]:
"""Get the results of each step from the latest epoch of the task.
Returns:
@@ -266,8 +266,8 @@ def log_null(self) -> None:
self.metrics[metric]["y"].append(np.NAN)
def get_sub_metrics(
- self, pattern: Tuple[str, ...] = ("train", "val")
- ) -> Dict[str, Any]:
+ self, pattern: tuple[str, ...] = ("train", "val")
+ ) -> dict[str, Any]:
"""Gets a subset of the metrics dictionary with keys containing strings in the pattern.
Useful for getting the train and validation metrics for plotting for example.
@@ -315,19 +315,19 @@ class SupervisedTaskLogger(MinervaTaskLogger):
.. versionadded:: 0.27
"""
- metric_types: List[str] = ["loss", "acc", "miou"]
- logger_cls = "SupervisedStepLogger"
+ metric_types: list[str] = ["loss", "acc", "miou"]
+ logger_cls = "minerva.logger.steplog.SupervisedStepLogger"
def __init__(
self,
task_name: str,
n_batches: int,
batch_size: int,
- output_size: Tuple[int, ...],
- step_logger_params: Optional[Dict[str, Any]] = None,
+ output_size: tuple[int, ...],
+ step_logger_params: Optional[dict[str, Any]] = None,
record_int: bool = True,
record_float: bool = False,
- writer: Optional[Union[SummaryWriter, Run]] = None,
+ writer: Optional[SummaryWriter | Run] = None,
model_type: str = "segmentation",
**params,
) -> None:
@@ -344,7 +344,7 @@ def __init__(
**params,
)
- def _calc_metrics(self, logs: Dict[str, Any]) -> None:
+ def _calc_metrics(self, logs: dict[str, Any]) -> None:
"""Updates metrics with epoch results.
Args:
@@ -415,30 +415,34 @@ class SSLTaskLogger(MinervaTaskLogger):
metric_types = ["loss", "acc", "top5_acc"]
special_metric_types = ["collapse_level", "euc_dist"]
- logger_cls = "SSLStepLogger"
+ logger_cls = "minerva.logger.steplog.SSLStepLogger"
def __init__(
self,
task_name: str,
n_batches: int,
batch_size: int,
- output_size: Tuple[int, ...],
- step_logger_params: Optional[Dict[str, Any]] = None,
+ output_size: tuple[int, ...],
+ step_logger_params: Optional[dict[str, Any]] = None,
record_int: bool = True,
record_float: bool = False,
- writer: Optional[Union[SummaryWriter, Run]] = None,
+ writer: Optional[SummaryWriter | Run] = None,
model_type: str = "segmentation",
sample_pairs: bool = False,
**params,
) -> None:
if not step_logger_params:
- step_logger_params = {}
- if "params" not in step_logger_params:
- step_logger_params["params"] = {}
+ step_logger_params = {"_target_": self.logger_cls}
- step_logger_params["params"]["sample_pairs"] = sample_pairs
- step_logger_params["params"]["collapse_level"] = sample_pairs
- step_logger_params["params"]["euclidean"] = sample_pairs
+ step_logger_params["sample_pairs"] = step_logger_params.get(
+ "sample_pairs", sample_pairs
+ )
+ step_logger_params["collapse_level"] = step_logger_params.get(
+ "collapse_level", sample_pairs
+ )
+ step_logger_params["euclidean"] = step_logger_params.get(
+ "euclidean", sample_pairs
+ )
super(SSLTaskLogger, self).__init__(
task_name,
@@ -465,7 +469,7 @@ def __init__(
if not getattr(self.step_logger, "euclidean", False):
del self.metrics[f"{self.task_name}_euc_dist"]
- def _calc_metrics(self, logs: Dict[str, Any]) -> None:
+ def _calc_metrics(self, logs: dict[str, Any]) -> None:
"""Updates metrics with epoch results.
Args:
diff --git a/minerva/loss.py b/minerva/loss.py
index 111d28b55..1c7d13162 100644
--- a/minerva/loss.py
+++ b/minerva/loss.py
@@ -33,11 +33,12 @@
__copyright__ = "Copyright (C) 2024 Harry Baker"
__all__ = ["SegBarlowTwinsLoss", "AuxCELoss"]
+import importlib
+
# =====================================================================================================================
# IMPORTS
# =====================================================================================================================
from typing import Optional
-import importlib
import torch
from torch import Tensor
@@ -91,7 +92,12 @@ class AuxCELoss(Module):
Source: https://github.com/xitongpu/PSPNet/blob/main/src/model/cell.py
"""
- def __init__(self, weight: Optional[Tensor] = None, ignore_index: int = 255, alpha: float = 0.4) -> None:
+ def __init__(
+ self,
+ weight: Optional[Tensor] = None,
+ ignore_index: int = 255,
+ alpha: float = 0.4,
+ ) -> None:
super().__init__()
self.loss = CrossEntropyLoss(weight=weight, ignore_index=ignore_index)
self.alpha = alpha
diff --git a/minerva/modelio.py b/minerva/modelio.py
index 2fe6ac934..e539ec3e5 100644
--- a/minerva/modelio.py
+++ b/minerva/modelio.py
@@ -32,14 +32,16 @@
__license__ = "MIT License"
__copyright__ = "Copyright (C) 2024 Harry Baker"
__all__ = [
- "sup_tg",
- "ssl_pair_tg",
+ "supervised_torchgeo_io",
+ "change_detection_io",
+ "autoencoder_io",
+ "ssl_pair_torchgeo_io",
]
# =====================================================================================================================
# IMPORTS
# =====================================================================================================================
-from typing import Any, Dict, Optional, Sequence, Tuple, Union
+from typing import Any, Optional, Sequence
import numpy as np
import torch
@@ -48,29 +50,33 @@
from torchgeo.datasets.utils import BoundingBox
from minerva.models import MinervaModel
-from minerva.utils.utils import check_substrings_in_string, mask_to_ohe
+from minerva.utils.utils import (
+ check_substrings_in_string,
+ get_sample_index,
+ mask_to_ohe,
+)
# =====================================================================================================================
# METHODS
# =====================================================================================================================
-def sup_tg(
- batch: Dict[Any, Any],
+def supervised_torchgeo_io(
+ batch: dict[Any, Any],
model: MinervaModel,
device: torch.device, # type: ignore[name-defined]
train: bool,
**kwargs,
-) -> Tuple[
+) -> tuple[
Tensor,
- Union[Tensor, Tuple[Tensor, ...]],
+ Tensor | tuple[Tensor, ...],
Tensor,
- Optional[Union[Sequence[str], Sequence[BoundingBox]]],
+ Optional[Sequence[str] | Sequence[int] | Sequence[BoundingBox]],
]:
"""Provides IO functionality for a supervised model using :mod:`torchgeo` datasets.
Args:
batch (dict[~typing.Any, ~typing.Any]): Batch of data in a :class:`dict`.
- Must have ``"image"``, ``"mask"``/ ``"label"`` and ``"bbox"`` / ``"id"`` keys.
+ Must have ``"image"``, ``"mask"``/ ``"label"`` and ``"bbox"`` / ``"bounds"`` / ``"id"`` keys.
model (MinervaModel): Model being fitted.
device (~torch.device): `torch` device object to send data to (e.g. CUDA device).
train (bool): True to run a step of the model in training mode. False for eval mode.
@@ -94,65 +100,140 @@ def sup_tg(
multilabel = True if check_substrings_in_string(model_type, "multilabel") else False
# Extracts the x and y batches from the dict.
- images: Tensor = batch["image"]
- targets: Tensor = batch[target_key]
+ x: Tensor = batch["image"]
+ y: Tensor = batch[target_key]
# Check that none of the data is NaN or infinity.
if kwargs.get("validate_variables", False):
- assert not images.isnan().any()
- assert not images.isinf().any()
- assert not targets.isnan().any()
- assert not targets.isinf().any()
+ assert not x.isnan().any()
+ assert not x.isinf().any()
+ assert not y.isnan().any()
+ assert not y.isinf().any()
# Re-arranges the x and y batches.
- x_batch: Tensor = images.to(float_dtype) # type: ignore[attr-defined]
- y_batch: Tensor
+ x = x.to(float_dtype) # type: ignore[attr-defined]
# Squeeze out axis 1 if only 1 element wide.
if target_key == "mask":
- targets = targets.squeeze()
+ y = y.squeeze()
- if isinstance(targets, Tensor):
- targets = targets.detach().cpu()
- y_batch = torch.tensor(targets, dtype=torch.float if multilabel else torch.long) # type: ignore[attr-defined]
+ if isinstance(y, Tensor):
+ y = y.detach().cpu()
+ y = y.to(dtype=torch.float if multilabel else torch.long) # type: ignore[attr-defined]
# Transfer to GPU.
- x: Tensor = x_batch.to(device)
- y: Tensor = y_batch.to(device)
+ x = x.to(device)
+ y = y.to(device)
# Runs a step of the epoch.
loss, z = model.step(x, y, train=train)
- # Get the indices of the batch. Either bounding boxes or filenames.
- index: Optional[Union[Sequence[str], Sequence[BoundingBox]]]
- if "bbox" in batch:
- index = batch["bbox"]
- elif "id" in batch:
- index = batch["id"]
- else:
- index = None
+ # Get the indices of the batch. Either bounding boxes, filenames or index number.
+ index: Optional[Sequence[str] | Sequence[int] | Sequence[BoundingBox]] = (
+ get_sample_index(batch)
+ )
+
+ return loss, z, y, index
+
+
+def change_detection_io(
+ batch: dict[Any, Any],
+ model: MinervaModel,
+ device: torch.device, # type: ignore[name-defined]
+ train: bool,
+ **kwargs,
+) -> tuple[
+ Tensor,
+ Tensor | tuple[Tensor, ...],
+ Tensor,
+ Optional[Sequence[str] | Sequence[int] | Sequence[BoundingBox]],
+]:
+ """Provides IO functionality for a change_detection model.
+
+ Args:
+ batch (dict[~typing.Any, ~typing.Any]): Batch of data in a :class:`dict`.
+ Must have ``"image1"``, ``"image2"``, ``"mask"``/ ``"label"`` and ``"bbox"`` / ``"bounds"`` / ``"id"`` keys.
+ model (MinervaModel): Model being fitted.
+ device (~torch.device): `torch` device object to send data to (e.g. CUDA device).
+ train (bool): True to run a step of the model in training mode. False for eval mode.
+
+ Kwargs:
+ mix_precision (bool): Use mixed-precision. Will set the floating tensors to 16-bit
+ rather than the default 32-bit.
+ target_key (str): Should be either ``"mask"`` or ``"label"``.
+
+ Returns:
+ tuple[
+ ~torch.Tensor, ~torch.Tensor, ~torch.Tensor, ~typing.Sequence[~torchgeo.datasets.utils.BoundingBox] | None
+ ]:
+ The ``loss``, the model output ``z``, the ground truth ``y`` supplied and the bounding boxes
+ of the input images supplied.
+ """
+ float_dtype = _determine_float_dtype(device, kwargs.get("mix_precision", False))
+ target_key = kwargs.get("target_key", "mask")
+
+ model_type = kwargs.get("model_type", "")
+ multilabel = True if check_substrings_in_string(model_type, "multilabel") else False
+
+ # Extracts the x and y batches from the dict.
+ x1: Tensor = batch["image1"]
+ x2: Tensor = batch["image2"]
+ y: Tensor = batch[target_key]
+
+ # Check that none of the data is NaN or infinity.
+ if kwargs.get("validate_variables", False):
+ assert not x1.isnan().any()
+ assert not x1.isinf().any()
+ assert not x2.isnan().any()
+ assert not x2.isinf().any()
+ assert not y.isnan().any()
+ assert not y.isinf().any()
+
+ # Re-arranges the x and y batches.
+ x1 = x1.to(float_dtype) # type: ignore[attr-defined]
+ x2 = x2.to(float_dtype) # type: ignore[attr-defined]
+
+ # Squeeze out axis 1 if only 1 element wide.
+ if target_key == "mask":
+ y = y.squeeze()
+
+ if isinstance(y, Tensor):
+ y = y.detach().cpu()
+ y = y.to(dtype=torch.float if multilabel else torch.long) # type: ignore[attr-defined]
+
+ # Transfer to GPU.
+ x = torch.stack([x1, x2]).to(device)
+ y = y.to(device)
+
+ # Runs a step of the epoch.
+ loss, z = model.step(x, y, train=train)
+
+ # Get the indices of the batch. Either bounding boxes, filenames or index number.
+ index: Optional[Sequence[str] | Sequence[int] | Sequence[BoundingBox]] = (
+ get_sample_index(batch)
+ )
return loss, z, y, index
def autoencoder_io(
- batch: Dict[Any, Any],
+ batch: dict[Any, Any],
model: MinervaModel,
device: torch.device, # type: ignore[name-defined]
train: bool,
**kwargs,
-) -> Tuple[
+) -> tuple[
Tensor,
- Union[Tensor, Tuple[Tensor, ...]],
+ Tensor | tuple[Tensor, ...],
Tensor,
- Optional[Union[Sequence[str], Sequence[BoundingBox]]],
+ Optional[Sequence[str] | Sequence[int] | Sequence[BoundingBox]],
]:
"""Provides IO functionality for an autoencoder using :mod:`torchgeo` datasets by only using the same data
for input and ground truth.
Args:
batch (dict[~typing.Any, ~typing.Any]): Batch of data in a :class:`dict`.
- Must have ``"image"``, ``"mask"``/ ``"label"`` and ``"bbox"`` / ``"id"`` keys.
+ Must have ``"image"``, ``"mask"``/ ``"label"`` and ``"bbox"`` / ``"bounds"`` / ``"id"`` keys.
model (MinervaModel): Model being fitted.
device (~torch.device): `torch` device object to send data to (e.g. CUDA device).
train (bool): True to run a step of the model in training mode. False for eval mode.
@@ -205,15 +286,13 @@ def autoencoder_io(
)
output_masks: LongTensor = masks
- if isinstance(input_masks, Tensor):
- input_masks = input_masks.detach().cpu().numpy()
-
- if isinstance(output_masks, Tensor):
- output_masks = output_masks.detach().cpu().numpy()
-
# Transfer to GPU and cast to correct dtypes.
- x = torch.tensor(input_masks, dtype=float_dtype, device=device)
- y = torch.tensor(output_masks, dtype=torch.long, device=device)
+ x = torch.tensor(
+ input_masks.detach().cpu().numpy(), dtype=float_dtype, device=device
+ )
+ y = torch.tensor(
+ output_masks.detach().cpu().numpy(), dtype=torch.long, device=device
+ )
elif key == "image":
# Extract the images from the batch, set to float, transfer to GPU and make x and y.
@@ -228,35 +307,31 @@ def autoencoder_io(
# Runs a step of the epoch.
loss, z = model.step(x, y, train=train)
- # Get the indices of the batch. Either bounding boxes or filenames.
- index: Optional[Union[Sequence[str], Sequence[BoundingBox]]]
- if "bbox" in batch:
- index = batch["bbox"]
- elif "id" in batch:
- index = batch["id"]
- else:
- index = None
+ # Get the indices of the batch. Either bounding boxes, filenames or index number.
+ index: Optional[Sequence[str] | Sequence[int] | Sequence[BoundingBox]] = (
+ get_sample_index(batch)
+ )
return loss, z, y, index
-def ssl_pair_tg(
- batch: Tuple[Dict[str, Any], Dict[str, Any]],
+def ssl_pair_torchgeo_io(
+ batch: tuple[dict[str, Any], dict[str, Any]],
model: MinervaModel,
device: torch.device, # type: ignore[name-defined]
train: bool,
**kwargs,
-) -> Tuple[
+) -> tuple[
Tensor,
- Union[Tensor, Tuple[Tensor, ...]],
+ Tensor | tuple[Tensor, ...],
None,
- Optional[Union[Sequence[BoundingBox], Sequence[int]]],
+ Optional[Sequence[BoundingBox] | Sequence[int]],
]:
"""Provides IO functionality for a self-supervised Siamese model using :mod:`torchgeo` datasets.
Args:
batch (tuple[dict[str, ~typing.Any], dict[str, ~typing.Any]]): Pair of batches of data in :class:`dict` (s).
- Must have ``"image"`` and ``"bbox"`` keys.
+ Must have ``"image"`` and ``"bbox"`` / ``"bounds"`` / ``"id"`` keys.
model (MinervaModel): Model being fitted.
device (~torch.device): :mod:`torch` device object to send data to (e.g. ``CUDA`` device).
train (bool): True to run a step of the model in training mode. False for eval mode.
@@ -273,13 +348,9 @@ def ssl_pair_tg(
"""
float_dtype = _determine_float_dtype(device, kwargs.get("mix_precision", False))
- # Extracts the x_i batch from the dict.
- x_i_batch: Tensor = batch[0]["image"]
- x_j_batch: Tensor = batch[1]["image"]
-
- # Ensures images are floats.
- x_i_batch = x_i_batch.to(float_dtype) # type: ignore[attr-defined]
- x_j_batch = x_j_batch.to(float_dtype) # type: ignore[attr-defined]
+ # Extracts both batches from the dict and ensures images are floats.
+ x_i_batch: Tensor = batch[0]["image"].to(float_dtype) # type: ignore[attr-defined]
+ x_j_batch: Tensor = batch[1]["image"].to(float_dtype) # type: ignore[attr-defined]
if kwargs.get("validate_variables", False):
try:
@@ -287,11 +358,8 @@ def ssl_pair_tg(
except AssertionError:
print("WARNING: Batches are the same!")
- # Stacks each side of the pair batches together.
- x_batch = torch.stack([x_i_batch, x_j_batch])
-
- # Transfer to GPU.
- x = x_batch.to(device, non_blocking=True)
+ # Stacks each side of the pair batches together and transfer to GPU.
+ x = torch.stack([x_i_batch, x_j_batch]).to(device, non_blocking=True)
# Check that none of the data is NaN or infinity.
if kwargs.get("validate_variables", False):
@@ -301,12 +369,12 @@ def ssl_pair_tg(
# Runs a step of the epoch.
loss, z = model.step(x, train=train)
- if "bbox" in batch[0].keys():
- return loss, z, None, batch[0]["bbox"] + batch[1]["bbox"]
- elif "id" in batch[0].keys():
- return loss, z, None, batch[0]["id"] + batch[1]["id"]
- else:
+ index_0, index_1 = get_sample_index(batch[0]), get_sample_index(batch[1])
+
+ if index_0 is None or index_1 is None:
return loss, z, None, None
+ else:
+ return loss, z, None, index_0 + index_1
def _determine_float_dtype(device: torch.device, mix_precision: bool) -> torch.dtype:
diff --git a/minerva/models/__depreciated.py b/minerva/models/__depreciated.py
index 2df888790..544a817d5 100644
--- a/minerva/models/__depreciated.py
+++ b/minerva/models/__depreciated.py
@@ -37,7 +37,7 @@
# IMPORTS
# =====================================================================================================================
from collections import OrderedDict
-from typing import Any, List, Optional, Sequence, Tuple, Union
+from typing import Any, Optional, Sequence
import numpy as np
import torch.nn.modules as nn
@@ -80,7 +80,7 @@ def __init__(
criterion: Optional[Any] = None,
input_size: int = 288,
n_classes: int = 8,
- hidden_sizes: Union[Tuple[int, ...], List[int], int] = (256, 144),
+ hidden_sizes: tuple[int, ...] | list[int] | int = (256, 144),
) -> None:
super(MLP, self).__init__(
criterion=criterion, input_size=(input_size,), n_classes=n_classes
@@ -161,14 +161,14 @@ class CNN(MinervaModel):
def __init__(
self,
criterion,
- input_size: Tuple[int, int, int] = (4, 256, 256),
+ input_size: tuple[int, int, int] = (4, 256, 256),
n_classes: int = 8,
- features: Union[Tuple[int, ...], List[int]] = (2, 1, 1),
- fc_sizes: Union[Tuple[int, ...], List[int]] = (128, 64),
- conv_kernel_size: Union[int, Tuple[int, ...]] = 3,
- conv_stride: Union[int, Tuple[int, ...]] = 1,
- max_kernel_size: Union[int, Tuple[int, ...]] = 2,
- max_stride: Union[int, Tuple[int, ...]] = 2,
+ features: tuple[int, ...] | list[int] = (2, 1, 1),
+ fc_sizes: tuple[int, ...] | list[int] = (128, 64),
+ conv_kernel_size: int | tuple[int, ...] = 3,
+ conv_stride: int | tuple[int, ...] = 1,
+ max_kernel_size: int | tuple[int, ...] = 2,
+ max_stride: int | tuple[int, ...] = 2,
conv_do: bool = True,
fc_do: bool = True,
p_conv_do: float = 0.1,
diff --git a/minerva/models/change_detector.py b/minerva/models/change_detector.py
index 5ece466a5..72d816786 100644
--- a/minerva/models/change_detector.py
+++ b/minerva/models/change_detector.py
@@ -38,7 +38,7 @@
# IMPORTS
# =====================================================================================================================
from pathlib import Path
-from typing import Any, Dict, Optional, Tuple, Union
+from typing import Any, Optional
import torch
from torch import Tensor
@@ -57,7 +57,7 @@ class ChangeDetector(MinervaModel):
def __init__(
self,
criterion: Optional[Module] = None,
- input_size: Optional[Tuple[int]] = None,
+ input_size: Optional[tuple[int]] = None,
n_classes: int = 1,
scaler: Optional[GradScaler] = None,
fc_dim: int = 512,
@@ -65,8 +65,8 @@ def __init__(
encoder_on: bool = False,
filter_dim: int = 0,
freeze_backbone: bool = False,
- backbone_weight_path: Optional[Union[str, Path]] = None,
- backbone_args: Dict[str, Any] = {},
+ backbone_weight_path: Optional[str | Path] = None,
+ backbone_args: dict[str, Any] = {},
clamp_outputs: bool = False,
) -> None:
super().__init__(criterion, input_size, n_classes, scaler)
@@ -128,8 +128,7 @@ def forward(self, x: Tensor) -> Tensor:
~torch.Tensor: Likelihoods the network places on the
input ``x`` being of each class.
"""
- x = torch.squeeze(x)
- x_0, x_1 = torch.chunk(x, 2, dim=1)
+ x_0, x_1 = x[0], x[1]
f_0 = self.backbone(x_0)
f_1 = self.backbone(x_1)
@@ -140,7 +139,6 @@ def forward(self, x: Tensor) -> Tensor:
f = torch.cat((f_0, f_1), 1)
- # f = f.view(f.size(0), -1)
z: Tensor = self.classification_head(f)
assert isinstance(z, Tensor)
diff --git a/minerva/models/classifiers.py b/minerva/models/classifiers.py
index bb3472798..c95eb42fa 100644
--- a/minerva/models/classifiers.py
+++ b/minerva/models/classifiers.py
@@ -42,7 +42,7 @@
# IMPORTS
# =====================================================================================================================
from pathlib import Path
-from typing import Any, Dict, Optional, Tuple, Union
+from typing import Any, Optional
import torch
from torch import Tensor
@@ -61,7 +61,7 @@ class FlexiSceneClassifier(MinervaBackbone):
def __init__(
self,
criterion: Optional[Module] = None,
- input_size: Optional[Tuple[int]] = None,
+ input_size: Optional[tuple[int]] = None,
n_classes: int = 1,
scaler: Optional[GradScaler] = None,
fc_dim: int = 512,
@@ -69,8 +69,8 @@ def __init__(
encoder_on: bool = False,
filter_dim: int = 0,
freeze_backbone: bool = False,
- backbone_weight_path: Optional[Union[str, Path]] = None,
- backbone_args: Dict[str, Any] = {},
+ backbone_weight_path: Optional[str | Path] = None,
+ backbone_args: dict[str, Any] = {},
clamp_outputs: bool = False,
) -> None:
super().__init__(criterion, input_size, n_classes, scaler)
diff --git a/minerva/models/core.py b/minerva/models/core.py
index f06513979..493ea94c4 100644
--- a/minerva/models/core.py
+++ b/minerva/models/core.py
@@ -58,18 +58,7 @@
import warnings
from abc import ABC
from pathlib import Path
-from typing import (
- Any,
- Callable,
- Iterable,
- List,
- Optional,
- Sequence,
- Tuple,
- Type,
- Union,
- overload,
-)
+from typing import Any, Callable, Iterable, Optional, Sequence, Type, overload
import numpy as np
import torch
@@ -81,8 +70,8 @@
from torch.nn.modules import Module
from torch.nn.parallel import DataParallel
from torch.nn.parallel import DistributedDataParallel as DDP
-from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
+from torch.optim.optimizer import Optimizer
from torchvision.models._api import WeightsEnum
from minerva.utils.utils import func_by_str
@@ -101,7 +90,7 @@ class FilterOutputs(Module):
indexes (int | list[int]): Index(es) of the inputs to pass forward.
"""
- def __init__(self, indexes: Union[int, List[int]]) -> None:
+ def __init__(self, indexes: int | list[int]) -> None:
self.indexes = indexes
def forward(self, inputs: Tensor) -> Tensor:
@@ -134,7 +123,7 @@ class MinervaModel(Module, ABC):
def __init__(
self,
criterion: Optional[Module] = None,
- input_size: Optional[Tuple[int, ...]] = None,
+ input_size: Optional[tuple[int, ...]] = None,
n_classes: Optional[int] = None,
scaler: Optional[GradScaler] = None,
) -> None:
@@ -148,7 +137,7 @@ def __init__(
self.scaler = scaler
# Output shape initialised as None. Should be set by calling determine_output_dim.
- self.output_shape: Optional[Tuple[int, ...]] = None
+ self.output_shape: Optional[tuple[int, ...]] = None
# Optimiser initialised as None as the model parameters created by its init is required to init a
# torch optimiser. The optimiser MUST be set by calling set_optimiser before the model can be trained.
@@ -206,19 +195,19 @@ def update_n_classes(self, n_classes: int) -> None:
@overload
def step(
self, x: Tensor, y: Tensor, train: bool = False
- ) -> Tuple[Tensor, Union[Tensor, Tuple[Tensor, ...]]]: ... # pragma: no cover
+ ) -> tuple[Tensor, Tensor | tuple[Tensor, ...]]: ... # pragma: no cover
@overload
def step(
self, x: Tensor, *, train: bool = False
- ) -> Tuple[Tensor, Union[Tensor, Tuple[Tensor, ...]]]: ... # pragma: no cover
+ ) -> tuple[Tensor, Tensor | tuple[Tensor, ...]]: ... # pragma: no cover
def step(
self,
x: Tensor,
y: Optional[Tensor] = None,
train: bool = False,
- ) -> Tuple[Tensor, Union[Tensor, Tuple[Tensor, ...]]]:
+ ) -> tuple[Tensor, Tensor | tuple[Tensor, ...]]:
"""Generic step of model fitting using a batch of data.
Raises:
@@ -248,7 +237,7 @@ def step(
if train:
self.optimiser.zero_grad()
- z: Union[Tensor, Tuple[Tensor, ...]]
+ z: Tensor | tuple[Tensor, ...]
loss: Tensor
mix_precision: bool = True if self.scaler else False
@@ -300,9 +289,9 @@ class MinervaWrapper(MinervaModel):
def __init__(
self,
- model: Union[Module, Callable[..., Module]],
+ model: Module | Callable[..., Module],
criterion: Optional[Module] = None,
- input_size: Optional[Tuple[int, ...]] = None,
+ input_size: Optional[tuple[int, ...]] = None,
n_classes: Optional[int] = None,
scaler: Optional[GradScaler] = None,
*args,
@@ -386,7 +375,7 @@ class MinervaDataParallel(Module): # pragma: no cover
def __init__(
self,
model: Module,
- paralleliser: Union[Type[DataParallel], Type[DDP]], # type: ignore[type-arg]
+ paralleliser: Type[DataParallel] | Type[DDP], # type: ignore[type-arg]
*args,
**kwargs,
) -> None:
@@ -396,7 +385,7 @@ def __init__(
self.output_shape = model.output_shape
self.n_classes = model.n_classes
- def forward(self, *inputs: Tuple[Tensor, ...]) -> Tuple[Tensor, ...]:
+ def forward(self, *inputs: tuple[Tensor, ...]) -> tuple[Tensor, ...]:
"""Ensures a forward call to the model goes to the actual wrapped model.
Args:
@@ -410,7 +399,7 @@ def forward(self, *inputs: Tuple[Tensor, ...]) -> Tuple[Tensor, ...]:
assert isinstance(z, tuple) and list(map(type, z)) == [Tensor] * len(z)
return z
- def __call__(self, *inputs) -> Tuple[Tensor, ...]:
+ def __call__(self, *inputs) -> tuple[Tensor, ...]:
return self.forward(*inputs)
def __getattr__(self, name):
@@ -515,10 +504,10 @@ def get_torch_weights(weights_name: str) -> Optional[WeightsEnum]:
def get_output_shape(
model: Module,
- image_dim: Union[Sequence[int], int],
+ image_dim: Sequence[int] | int,
sample_pairs: bool = False,
change_detection: bool = False,
-) -> Tuple[int, ...]:
+) -> tuple[int, ...]:
"""Gets the output shape of a model.
Args:
@@ -530,7 +519,7 @@ def get_output_shape(
Returns:
tuple[int, ...]: The shape of the output data from the model.
"""
- _image_dim: Union[Sequence[int], int] = image_dim
+ _image_dim: Sequence[int] | int = image_dim
try:
assert not isinstance(image_dim, int)
if len(image_dim) == 1:
@@ -542,12 +531,9 @@ def get_output_shape(
if not hasattr(_image_dim, "__len__"):
assert isinstance(_image_dim, int)
random_input = torch.rand([4, _image_dim])
- elif sample_pairs:
+ elif sample_pairs or change_detection:
assert isinstance(_image_dim, Iterable)
random_input = torch.rand([2, 4, *_image_dim])
- elif change_detection:
- assert isinstance(_image_dim, Iterable)
- random_input = torch.rand([4, 2 * _image_dim[0], *_image_dim[1:]])
else:
assert isinstance(_image_dim, Iterable)
random_input = torch.rand([4, *_image_dim])
@@ -637,7 +623,7 @@ def is_minerva_subtype(model: Module, subtype: type) -> bool:
def extract_wrapped_model(
- model: Union[MinervaModel, MinervaDataParallel, OptimizedModule]
+ model: MinervaModel | MinervaDataParallel | OptimizedModule,
) -> MinervaModel:
"""
Extracts the actual model object from within :class:`MinervaDataParallel` or
diff --git a/minerva/models/fcn.py b/minerva/models/fcn.py
index d80bb3962..9c32910b3 100644
--- a/minerva/models/fcn.py
+++ b/minerva/models/fcn.py
@@ -52,7 +52,7 @@
# =====================================================================================================================
# IMPORTS
# =====================================================================================================================
-from typing import Any, Dict, Literal, Optional, Sequence, Tuple
+from typing import Any, Literal, Optional, Sequence
import torch
import torch.nn.modules as nn
@@ -100,12 +100,12 @@ class FCN(MinervaBackbone):
def __init__(
self,
criterion: Any,
- input_size: Tuple[int, ...] = (4, 256, 256),
+ input_size: tuple[int, ...] = (4, 256, 256),
n_classes: int = 8,
scaler: Optional[GradScaler] = None,
backbone_weight_path: Optional[str] = None,
freeze_backbone: bool = False,
- backbone_kwargs: Dict[str, Any] = {},
+ backbone_kwargs: dict[str, Any] = {},
) -> None:
super(FCN, self).__init__(
criterion=criterion,
@@ -294,7 +294,7 @@ def __init__(
f"Variant {self.variant} does not match known types"
)
- def forward(self, x: Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]) -> Tensor:
+ def forward(self, x: tuple[Tensor, Tensor, Tensor, Tensor, Tensor]) -> Tensor:
"""Performs a forward pass of the decoder. Depending on DCN variant, will take multiple inputs
throughout pass from the encoder.
diff --git a/minerva/models/psp.py b/minerva/models/psp.py
index e0c1a2276..d76984f0e 100644
--- a/minerva/models/psp.py
+++ b/minerva/models/psp.py
@@ -38,7 +38,7 @@
# =====================================================================================================================
# IMPORTS
# =====================================================================================================================
-from typing import Any, Callable, Dict, Optional, Tuple, Union
+from typing import Any, Callable, Optional
import segmentation_models_pytorch as smp
import torch
@@ -106,9 +106,9 @@ def __init__(
psp_dropout: float = 0.2,
in_channels: Optional[int] = None,
n_classes: int = 1,
- activation: Optional[Union[str, Callable[..., Any]]] = None,
+ activation: Optional[str | Callable[..., Any]] = None,
upsampling: int = 8,
- aux_params: Optional[Dict[str, Any]] = None,
+ aux_params: Optional[dict[str, Any]] = None,
backbone_weight_path: Optional[str] = None,
freeze_backbone: bool = False,
encoder: bool = True,
@@ -150,7 +150,7 @@ def __init__(
def make_segmentation_head(
self,
n_classes: int,
- activation: Optional[Union[str, Callable[..., Any]]] = None,
+ activation: Optional[str | Callable[..., Any]] = None,
upsampling: int = 8,
) -> None:
self.segmentation_head = SegmentationHead(
@@ -164,7 +164,7 @@ def make_segmentation_head(
self.encoder_mode = True
self.segmentation_on = True
- def make_classification_head(self, aux_params: Dict[str, Any]) -> None:
+ def make_classification_head(self, aux_params: dict[str, Any]) -> None:
# Makes the classification head.
self.classification_head = ClassificationHead(
in_channels=self.encoder.out_channels[-1], **aux_params
@@ -198,7 +198,7 @@ def freeze_backbone(self, freeze: bool = True) -> None:
"""
self.encoder.requires_grad_(False if freeze else True)
- def forward(self, x: Tensor) -> Union[Tensor, Tuple[Tensor, ...]]:
+ def forward(self, x: Tensor) -> Tensor | tuple[Tensor, ...]:
f = self.encoder(x)
if not self.encoder_mode:
@@ -223,7 +223,7 @@ class MinervaPSP(MinervaWrapper):
def __init__(
self,
criterion: Optional[Module] = None,
- input_size: Optional[Tuple[int, ...]] = None,
+ input_size: Optional[tuple[int, ...]] = None,
n_classes: int = 1,
scaler: Optional[GradScaler] = None,
encoder_name: str = "resnet34",
@@ -232,9 +232,9 @@ def __init__(
psp_out_channels: int = 512,
psp_use_batchnorm: bool = True,
psp_dropout: float = 0.2,
- activation: Optional[Union[str, Callable[..., Any]]] = None,
+ activation: Optional[str | Callable[..., Any]] = None,
upsampling: int = 8,
- aux_params: Optional[Dict[str, Any]] = None,
+ aux_params: Optional[dict[str, Any]] = None,
backbone_weight_path: Optional[str] = None,
freeze_backbone: bool = False,
encoder: bool = False,
@@ -315,7 +315,7 @@ def forward(self, x):
class PSPUNetDecoder(Module):
def __init__(
self,
- encoder_channels: Tuple[int, int, int, int, int, int],
+ encoder_channels: tuple[int, int, int, int, int, int],
n_classes: int,
use_batchnorm: bool = True,
dropout: float = 0.2,
@@ -467,7 +467,7 @@ class MinervaPSPUNet(MinervaWrapper):
def __init__(
self,
criterion: Optional[Module] = None,
- input_size: Optional[Tuple[int, ...]] = None,
+ input_size: Optional[tuple[int, ...]] = None,
n_classes: int = 1,
scaler: Optional[GradScaler] = None,
encoder_name: str = "resnet34",
@@ -476,9 +476,9 @@ def __init__(
psp_out_channels: int = 512,
psp_use_batchnorm: bool = True,
psp_dropout: float = 0.2,
- activation: Optional[Union[str, Callable[..., Any]]] = None,
+ activation: Optional[str | Callable[..., Any]] = None,
upsampling: int = 8,
- aux_params: Optional[Dict[str, Any]] = None,
+ aux_params: Optional[dict[str, Any]] = None,
backbone_weight_path: Optional[str] = None,
freeze_backbone: bool = False,
encoder: bool = False,
diff --git a/minerva/models/resnet.py b/minerva/models/resnet.py
index d9e437c0a..36c67cc49 100644
--- a/minerva/models/resnet.py
+++ b/minerva/models/resnet.py
@@ -46,7 +46,7 @@
# IMPORTS
# =====================================================================================================================
import abc
-from typing import Any, Callable, List, Optional, Tuple, Type, Union
+from typing import Any, Callable, Optional, Type
import torch
import torch.nn.modules as nn
@@ -122,14 +122,14 @@ class ResNet(MinervaModel):
def __init__(
self,
- block: Type[Union[BasicBlock, Bottleneck]],
- layers: Union[List[int], Tuple[int, int, int, int]],
+ block: Type[BasicBlock | Bottleneck],
+ layers: list[int] | tuple[int, int, int, int],
in_channels: int = 3,
n_classes: int = 8,
zero_init_residual: bool = False,
groups: int = 1,
width_per_group: int = 64,
- replace_stride_with_dilation: Optional[Tuple[bool, bool, bool]] = None,
+ replace_stride_with_dilation: Optional[tuple[bool, bool, bool]] = None,
norm_layer: Optional[Callable[..., Module]] = None,
encoder: bool = False,
) -> None:
@@ -224,7 +224,7 @@ def __init__(
def _make_layer(
self,
- block: Type[Union[BasicBlock, Bottleneck]],
+ block: Type[BasicBlock | Bottleneck],
planes: int,
blocks: int,
stride: int = 1,
@@ -296,7 +296,7 @@ def _make_layer(
def _forward_impl(
self, x: Tensor
- ) -> Union[Tensor, Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]]:
+ ) -> Tensor | tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
@@ -320,7 +320,7 @@ def _forward_impl(
def forward(
self, x: Tensor
- ) -> Union[Tensor, Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]]:
+ ) -> Tensor | tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
"""Performs a forward pass of the :class:`ResNet`.
Can be called directly as a method (e.g. ``model.forward``) or when data is parsed
@@ -390,19 +390,19 @@ class ResNetX(MinervaModel):
"""
__metaclass__ = abc.ABCMeta
- block_type: Union[Type[BasicBlock], Type[Bottleneck]] = BasicBlock
- layer_struct: List[int] = [2, 2, 2, 2]
+ block_type: Type[BasicBlock] | Type[Bottleneck] = BasicBlock
+ layer_struct: list[int] = [2, 2, 2, 2]
weights_name = "ResNet18_Weights.IMAGENET1K_V1"
def __init__(
self,
criterion: Optional[Any] = None,
- input_size: Tuple[int, int, int] = (4, 256, 256),
+ input_size: tuple[int, int, int] = (4, 256, 256),
n_classes: int = 8,
zero_init_residual: bool = False,
groups: int = 1,
width_per_group: int = 64,
- replace_stride_with_dilation: Optional[Tuple[bool, bool, bool]] = None,
+ replace_stride_with_dilation: Optional[tuple[bool, bool, bool]] = None,
norm_layer: Optional[Callable[..., Module]] = None,
encoder: bool = False,
torch_weights: bool = False,
@@ -431,7 +431,7 @@ def __init__(
def forward(
self, x: Tensor
- ) -> Union[Tensor, Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]]:
+ ) -> Tensor | tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
"""Performs a forward pass of the :class:`ResNet`.
Can be called directly as a method (e.g. :func:`model.forward`) or when data is parsed
@@ -445,9 +445,7 @@ def forward(
initialised as an encoder, returns a tuple of outputs from each ``layer`` 1-4. Else, returns
:class:`~torch.Tensor` of the likelihoods the network places on the input ``x`` being of each class.
"""
- z: Union[Tensor, Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]] = self.network(
- x
- )
+ z: Tensor | tuple[Tensor, Tensor, Tensor, Tensor, Tensor] = self.network(x)
if isinstance(z, Tensor):
return z
elif isinstance(z, tuple):
@@ -460,7 +458,7 @@ class ResNet18(ResNetX):
by stripping classification layers away.
"""
- layer_struct: List[int] = [2, 2, 2, 2]
+ layer_struct: list[int] = [2, 2, 2, 2]
weights_name = "ResNet18_Weights.IMAGENET1K_V1"
@@ -469,7 +467,7 @@ class ResNet34(ResNetX):
by stripping classification layers away.
"""
- layer_struct: List[int] = [3, 4, 6, 3]
+ layer_struct: list[int] = [3, 4, 6, 3]
weights_name = "ResNet34_Weights.IMAGENET1K_V1"
@@ -508,8 +506,8 @@ class ResNet152(ResNetX):
# =====================================================================================================================
def _preload_weights(
resnet: ResNet,
- weights: Optional[Union[WeightsEnum, Any]],
- input_shape: Tuple[int, int, int],
+ weights: Optional[WeightsEnum | Any],
+ input_shape: tuple[int, int, int],
encoder_on: bool,
) -> ResNet: # pragma: no cover
if not weights:
diff --git a/minerva/models/siamese.py b/minerva/models/siamese.py
index 62d6564cf..7e8b5fd9b 100644
--- a/minerva/models/siamese.py
+++ b/minerva/models/siamese.py
@@ -53,7 +53,7 @@
# IMPORTS
# =====================================================================================================================
import abc
-from typing import Any, Dict, Optional, Sequence, Tuple
+from typing import Any, Optional, Sequence
import numpy as np
import torch
@@ -86,7 +86,7 @@ def __init__(self, *args, **kwargs) -> None:
self.backbone: MinervaModel
self.proj_head: Module
- def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
+ def forward(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
"""Performs a forward pass of the network by using the forward methods of the backbone and
feeding its output into the projection heads.
@@ -106,7 +106,7 @@ def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
"""
return self.forward_pair(x)
- def forward_pair(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
+ def forward_pair(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
"""Performs a forward pass of the network by using the forward methods of the backbone and
feeding its output into the projection heads.
@@ -131,7 +131,7 @@ def forward_pair(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tenso
return g, g_a, g_b, f_a, f_b
@abc.abstractmethod
- def forward_single(self, x: Tensor) -> Tuple[Tensor, Tensor]:
+ def forward_single(self, x: Tensor) -> tuple[Tensor, Tensor]:
"""Performs a forward pass of a single head of the network by using the forward methods of the backbone
and feeding its output into the projection heads.
@@ -171,10 +171,10 @@ class SimCLR(MinervaSiamese):
def __init__(
self,
criterion: Any,
- input_size: Tuple[int, int, int] = (4, 256, 256),
+ input_size: tuple[int, int, int] = (4, 256, 256),
feature_dim: int = 128,
scaler: Optional[GradScaler] = None,
- backbone_kwargs: Dict[str, Any] = {},
+ backbone_kwargs: dict[str, Any] = {},
) -> None:
super(SimCLR, self).__init__(
criterion=criterion, input_size=input_size, scaler=scaler
@@ -196,7 +196,7 @@ def __init__(
nn.Linear(512, feature_dim, bias=False),
)
- def forward_single(self, x: Tensor) -> Tuple[Tensor, Tensor]:
+ def forward_single(self, x: Tensor) -> tuple[Tensor, Tensor]:
"""Performs a forward pass of a single head of the network by using the forward methods of the
:attr:`~SimCLR.backbone` and feeding its output into the :attr:`~SimCLR.proj_head`.
@@ -214,7 +214,7 @@ def forward_single(self, x: Tensor) -> Tuple[Tensor, Tensor]:
return g, f
- def step(self, x: Tensor, *args, train: bool = False) -> Tuple[Tensor, Tensor]:
+ def step(self, x: Tensor, *args, train: bool = False) -> tuple[Tensor, Tensor]:
"""Overwrites :class:`~models.core.MinervaModel` to account for paired logits.
Raises:
@@ -318,11 +318,11 @@ class SimSiam(MinervaSiamese):
def __init__(
self,
criterion: Any,
- input_size: Tuple[int, int, int] = (4, 256, 256),
+ input_size: tuple[int, int, int] = (4, 256, 256),
feature_dim: int = 128,
pred_dim: int = 512,
scaler: Optional[GradScaler] = None,
- backbone_kwargs: Dict[str, Any] = {},
+ backbone_kwargs: dict[str, Any] = {},
) -> None:
super(SimSiam, self).__init__(
criterion=criterion, input_size=input_size, scaler=scaler
@@ -359,7 +359,7 @@ def __init__(
nn.Linear(pred_dim, feature_dim),
) # output layer
- def forward_single(self, x: Tensor) -> Tuple[Tensor, Tensor]:
+ def forward_single(self, x: Tensor) -> tuple[Tensor, Tensor]:
"""Performs a forward pass of a single head of :class:`SimSiam` by using the forward methods of the backbone
and feeding its output into the :attr:`~SimSiam.proj_head`.
@@ -376,7 +376,7 @@ def forward_single(self, x: Tensor) -> Tuple[Tensor, Tensor]:
return p, z.detach()
- def step(self, x: Tensor, *args, train: bool = False) -> Tuple[Tensor, Tensor]:
+ def step(self, x: Tensor, *args, train: bool = False) -> tuple[Tensor, Tensor]:
"""Overwrites :class:`~models.core.MinervaModel` to account for paired logits.
Raises:
@@ -480,12 +480,12 @@ class SimConv(MinervaSiamese):
def __init__(
self,
criterion: Any,
- input_size: Tuple[int, int, int] = (4, 256, 256),
+ input_size: tuple[int, int, int] = (4, 256, 256),
feature_dim: int = 2048,
projection_dim: int = 512,
scaler: Optional[GradScaler] = None,
encoder_weights: Optional[str] = None,
- backbone_kwargs: Dict[str, Any] = {},
+ backbone_kwargs: dict[str, Any] = {},
) -> None:
super(SimConv, self).__init__(
criterion=criterion, input_size=input_size, scaler=scaler
@@ -518,7 +518,7 @@ def __init__(
nn.UpsamplingBilinear2d(scale_factor=4),
)
- def forward_single(self, x: Tensor) -> Tuple[Tensor, Tensor]:
+ def forward_single(self, x: Tensor) -> tuple[Tensor, Tensor]:
"""Performs a forward pass of a single head of the network by using the forward methods of the
:attr:`~SimCLR.backbone` and feeding its output into the :attr:`~SimCLR.proj_head`.
@@ -536,7 +536,7 @@ def forward_single(self, x: Tensor) -> Tuple[Tensor, Tensor]:
return g, f
- def step(self, x: Tensor, *args, train: bool = False) -> Tuple[Tensor, Tensor]:
+ def step(self, x: Tensor, *args, train: bool = False) -> tuple[Tensor, Tensor]:
"""Overwrites :class:`~models.core.MinervaModel` to account for paired logits.
Raises:
diff --git a/minerva/models/unet.py b/minerva/models/unet.py
index 200c514ed..b1fb0f641 100644
--- a/minerva/models/unet.py
+++ b/minerva/models/unet.py
@@ -51,7 +51,7 @@
# IMPORTS
# =====================================================================================================================
import abc
-from typing import Any, Dict, Optional, Sequence, Tuple
+from typing import Any, Optional, Sequence
import torch
import torch.nn.functional as F
@@ -267,7 +267,7 @@ class UNet(MinervaModel):
def __init__(
self,
criterion: Any,
- input_size: Tuple[int, ...] = (4, 256, 256),
+ input_size: tuple[int, ...] = (4, 256, 256),
n_classes: int = 8,
bilinear: bool = False,
scaler: Optional[GradScaler] = None,
@@ -356,13 +356,13 @@ class UNetR(MinervaModel):
def __init__(
self,
criterion: Any,
- input_size: Tuple[int, ...] = (4, 256, 256),
+ input_size: tuple[int, ...] = (4, 256, 256),
n_classes: int = 8,
bilinear: bool = False,
scaler: Optional[GradScaler] = None,
backbone_weight_path: Optional[str] = None,
freeze_backbone: bool = False,
- backbone_kwargs: Dict[str, Any] = {},
+ backbone_kwargs: dict[str, Any] = {},
) -> None:
super(UNetR, self).__init__(
criterion=criterion,
diff --git a/minerva/optimisers.py b/minerva/optimisers.py
index 90bf4cc7b..5a5a10f9f 100644
--- a/minerva/optimisers.py
+++ b/minerva/optimisers.py
@@ -31,7 +31,7 @@
# =====================================================================================================================
# IMPORTS
# =====================================================================================================================
-from typing import Any, Callable, Dict, Iterable, Optional, Union
+from typing import Any, Callable, Iterable, Optional
import torch
from torch.optim.optimizer import Optimizer
@@ -67,7 +67,7 @@ class LARS(Optimizer):
def __init__(
self,
- params: Union[Iterable[Any], Dict[Any, Any]],
+ params: Iterable[Any] | dict[Any, Any],
lr: float,
momentum: float = 0.9,
weight_decay: float = 0.0005,
diff --git a/minerva/pytorchtools.py b/minerva/pytorchtools.py
index 538f8b669..1cb61c848 100644
--- a/minerva/pytorchtools.py
+++ b/minerva/pytorchtools.py
@@ -35,7 +35,7 @@
# IMPORTS
# =====================================================================================================================
from pathlib import Path
-from typing import Callable, Optional, Union
+from typing import Callable, Optional
import numpy as np
import torch
@@ -81,7 +81,7 @@ def __init__(
patience: int = 7,
verbose: bool = False,
delta: float = 0.0,
- path: Union[str, Path] = "checkpoint.pt",
+ path: str | Path = "checkpoint.pt",
trace_func: Callable[..., None] = print,
external_save: bool = False,
):
@@ -92,7 +92,7 @@ def __init__(
self.early_stop: bool = False
self.val_loss_min: float = np.Inf
self.delta: float = delta
- self.path: Union[str, Path] = path
+ self.path: str | Path = path
self.trace_func: Callable[..., None] = trace_func
self.external_save: bool = external_save
self.save_model: bool = False
diff --git a/minerva/samplers.py b/minerva/samplers.py
index 33bce7b3e..9731961e3 100644
--- a/minerva/samplers.py
+++ b/minerva/samplers.py
@@ -41,22 +41,26 @@
"DistributedSamplerWrapper",
"get_greater_bbox",
"get_pair_bboxes",
+ "get_sampler",
]
# =====================================================================================================================
# IMPORTS
# =====================================================================================================================
import random
+import re
from operator import itemgetter
-from typing import Any, Iterator, List, Optional, Sequence, Tuple, Union
+from typing import Any, Iterator, Optional, Sequence
+import hydra
import torch
from torch.utils.data import Dataset, DistributedSampler, Sampler
-from torchgeo.datasets import GeoDataset
+from torchgeo.datasets import GeoDataset, NonGeoDataset
from torchgeo.datasets.utils import BoundingBox
from torchgeo.samplers import BatchGeoSampler, RandomGeoSampler, Units
from torchgeo.samplers.utils import _to_tuple, get_random_bounding_box
+from minerva.datasets.utils import make_bounding_box
from minerva.utils import utils
@@ -90,7 +94,7 @@ class RandomPairGeoSampler(RandomGeoSampler):
def __init__(
self,
dataset: GeoDataset,
- size: Union[Tuple[float, float], float],
+ size: tuple[float, float] | float,
length: int,
roi: Optional[BoundingBox] = None,
units: Units = Units.PIXELS,
@@ -99,7 +103,7 @@ def __init__(
super().__init__(dataset, size, length, roi, units)
self.max_r = max_r
- def __iter__(self) -> Iterator[Tuple[BoundingBox, BoundingBox]]: # type: ignore[override]
+ def __iter__(self) -> Iterator[tuple[BoundingBox, BoundingBox]]: # type: ignore[override]
"""Return a pair of :class:`~torchgeo.datasets.utils.BoundingBox` indices of a dataset
that are geospatially close.
@@ -155,7 +159,7 @@ class RandomPairBatchGeoSampler(BatchGeoSampler):
def __init__(
self,
dataset: GeoDataset,
- size: Union[Tuple[float, float], float],
+ size: tuple[float, float] | float,
batch_size: int,
length: int,
roi: Optional[BoundingBox] = None,
@@ -176,7 +180,7 @@ def __init__(
else:
raise ValueError(f"{tiles_per_batch=} is not a multiple of {batch_size=}")
- def __iter__(self) -> Iterator[List[Tuple[BoundingBox, BoundingBox]]]: # type: ignore[override]
+ def __iter__(self) -> Iterator[list[tuple[BoundingBox, BoundingBox]]]: # type: ignore[override]
"""Return the indices of a dataset.
Returns:
@@ -208,7 +212,9 @@ def __len__(self) -> int:
def get_greater_bbox(
- bbox: BoundingBox, r: float, size: Union[float, int, Sequence[float]]
+ bbox: BoundingBox,
+ r: float,
+ size: float | int | Sequence[float],
) -> BoundingBox:
"""Return a bounding box at ``r`` distance around the first box.
@@ -246,10 +252,10 @@ def get_greater_bbox(
def get_pair_bboxes(
bounds: BoundingBox,
- size: Union[Tuple[float, float], float],
+ size: tuple[float, float] | float,
res: float,
max_r: float,
-) -> Tuple[BoundingBox, BoundingBox]:
+) -> tuple[BoundingBox, BoundingBox]:
"""Samples a pair of bounding boxes geo-spatially close to each other.
Args:
@@ -345,7 +351,7 @@ class DatasetFromSampler(Dataset): # type: ignore[type-arg]
def __init__(self, sampler: Sampler[Any]):
"""Initialisation for :class:`DatasetFromSampler`."""
self.sampler = sampler
- self.sampler_list: Optional[List[Sampler[Any]]] = None
+ self.sampler_list: Optional[list[Sampler[Any]]] = None
def __getitem__(self, index: int) -> Any:
"""Gets element of the dataset.
@@ -366,3 +372,43 @@ def __len__(self) -> int:
int: Length of the dataset
"""
return len(self.sampler) # type: ignore[arg-type]
+
+
+# =====================================================================================================================
+# METHODS
+# =====================================================================================================================
+def get_sampler(
+ params: dict[str, Any],
+ dataset: GeoDataset | NonGeoDataset,
+ batch_size: Optional[int] = None,
+) -> Sampler[Any]:
+ """Use :meth:`hydra.utils.instantiate` to get the sampler using config parameters.
+
+ Args:
+ params (dict[str, ~typing.Any]): Sampler parameters. Must include the ``_target_`` key pointing to
+ the sampler class.
+ dataset (~torchgeo.datasets.GeoDataset, ~torchgeo.datasets.NonGeoDataset]): Dataset to sample.
+ batch_size (int): Optional; Batch size to sample if using a batch sampler.
+ Use if you need to overwrite the config parameters due to distributed computing as the batch size
+ needs to modified to split the batch across devices. Defaults to ``None``.
+
+ Returns:
+ ~torch.utils.data.Sampler: Sampler requested by config parameters.
+ """
+
+ batch_sampler = True if re.search(r"Batch", params["_target_"]) else False
+ if batch_sampler and batch_size is not None:
+ params["batch_size"] = batch_size
+
+ if "roi" in params:
+ sampler = hydra.utils.instantiate(
+ params, dataset=dataset, roi=make_bounding_box(params["roi"])
+ )
+ else:
+ if "torchgeo" in params["_target_"]:
+ sampler = hydra.utils.instantiate(params, dataset=dataset)
+ else:
+ sampler = hydra.utils.instantiate(params, data_source=dataset)
+
+ assert isinstance(sampler, Sampler)
+ return sampler
diff --git a/minerva/scheduler.py b/minerva/scheduler.py
index ff8b33684..c49f0fe9a 100644
--- a/minerva/scheduler.py
+++ b/minerva/scheduler.py
@@ -37,7 +37,6 @@
# IMPORTS
# =====================================================================================================================
import warnings
-from typing import List
import numpy as np
from torch.optim.lr_scheduler import LRScheduler
@@ -81,7 +80,7 @@ def __init__(
self.n_periods = n_periods
super().__init__(optimizer, last_epoch, verbose)
- def get_lr(self) -> List[float]: # type: ignore[override]
+ def get_lr(self) -> list[float]: # type: ignore[override]
if not hasattr(self, "_get_lr_called_within_step"):
warnings.warn(
"To get the last learning rate computed by the scheduler, "
diff --git a/minerva/tasks/core.py b/minerva/tasks/core.py
index 67faac873..fcd6d84d2 100644
--- a/minerva/tasks/core.py
+++ b/minerva/tasks/core.py
@@ -41,7 +41,7 @@
import abc
from abc import ABC
from pathlib import Path
-from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Sequence, Union
+from typing import TYPE_CHECKING, Any, Callable, Optional, Sequence
if TYPE_CHECKING: # pragma: no cover
from torch.utils.tensorboard.writer import SummaryWriter
@@ -66,7 +66,7 @@
wrap_model,
)
from minerva.utils import utils, visutils
-from minerva.utils.utils import fallback_params, func_by_str
+from minerva.utils.utils import fallback_params
# =====================================================================================================================
@@ -158,19 +158,19 @@ class MinervaTask(ABC):
"""
logger_cls: str = "SupervisedTaskLogger"
- model_io_name: str = "sup_tg"
+ model_io_name: str = "minerva.modelio.supervised_torchgeo_io"
def __init__(
self,
name: str,
- model: Union[MinervaModel, MinervaDataParallel, OptimizedModule],
+ model: MinervaModel | MinervaDataParallel | OptimizedModule,
device: torch.device,
exp_fn: Path,
gpu: int = 0,
rank: int = 0,
world_size: int = 1,
- writer: Optional[Union[SummaryWriter, Run]] = None,
- backbone_weight_path: Optional[Union[str, Path]] = None,
+ writer: Optional[SummaryWriter | Run] = None,
+ backbone_weight_path: Optional[str | Path] = None,
record_int: bool = True,
record_float: bool = False,
train: bool = False,
@@ -181,9 +181,12 @@ def __init__(
self.model = model
# Gets the datasets, number of batches, class distribution and the modfied parameters for the experiment.
- loaders, n_batches, class_dist, task_params, = make_loaders(
- rank, world_size, task_name=name, **global_params
- )
+ (
+ loaders,
+ n_batches,
+ class_dist,
+ task_params,
+ ) = make_loaders(rank, world_size, task_name=name, **global_params)
# If there are multiple modes and therefore number of batches, just take the value of the first one.
if isinstance(n_batches, dict):
@@ -226,6 +229,9 @@ def __init__(
self.sample_pairs = fallback_params(
"sample_pairs", self.params, self.global_params
)
+ self.change_detection = fallback_params(
+ "change_detection", self.params, self.global_params
+ )
self.n_classes = fallback_params("n_classes", self.params, self.global_params)
@@ -259,7 +265,7 @@ def __init__(
self.model.to(self.device)
# To eliminate classes, we're going to have to do a fair bit of rebuilding of the model...
- if self.elim:
+ if self.elim and self.train:
# Update the stored number of classes within the model and
# then rebuild the classification layers that are dependent on the number of classes.
self.model.update_n_classes(self.n_classes)
@@ -319,18 +325,24 @@ def make_criterion(self) -> Any:
_weights = Tensor(weights)
# Use hydra to instantiate the loss function with the weights and return.
- return hydra.utils.instantiate(fallback_params("loss_params", self.params, self.global_params), weight=_weights)
+ return hydra.utils.instantiate(
+ fallback_params("loss_params", self.params, self.global_params),
+ weight=_weights,
+ )
else:
# Use hydra to instantiate the loss function based of the config, without weights.
- return hydra.utils.instantiate(fallback_params("loss_params", self.params, self.global_params))
+ return hydra.utils.instantiate(
+ fallback_params("loss_params", self.params, self.global_params)
+ )
def make_optimiser(self) -> None:
"""Creates a :mod:`torch` optimiser based on config parameters and sets optimiser."""
# Constructs and sets the optimiser for the model based on supplied config parameters.
optimiser = hydra.utils.instantiate(
- fallback_params("optimiser", self.params, self.global_params), params=self.model.parameters()
+ fallback_params("optimiser", self.params, self.global_params),
+ params=self.model.parameters(),
)
self.model.set_optimiser(optimiser)
@@ -350,8 +362,7 @@ def make_logger(self) -> MinervaTaskLogger:
"""
# Gets constructor of the metric logger from name in the config.
- _logger_cls = func_by_str(
- "minerva.logger.tasklog",
+ _logger_cls = hydra.utils.get_class(
utils.fallback_params(
"task_logger", self.params, self.global_params, self.logger_cls
),
@@ -382,8 +393,7 @@ def get_io_func(self) -> Callable[..., Any]:
Returns:
~typing.Callable[..., ~typing.Any]: Model IO function requested from parameters.
"""
- io_func: Callable[..., Any] = func_by_str(
- "minerva.modelio",
+ io_func: Callable[..., Any] = hydra.utils.get_method(
utils.fallback_params(
"model_io", self.params, self.global_params, self.model_io_name
),
@@ -394,7 +404,7 @@ def get_io_func(self) -> Callable[..., Any]:
def step(self) -> None: # pragma: no cover
raise NotImplementedError
- def _generic_step(self, epoch_no: int) -> Optional[Dict[str, Any]]:
+ def _generic_step(self, epoch_no: int) -> Optional[dict[str, Any]]:
self.local_step_num = 0
self.step()
@@ -421,11 +431,11 @@ def __call__(self, epoch_no: int) -> Any:
return self._generic_step(epoch_no)
@property
- def get_logs(self) -> Dict[str, Any]:
+ def get_logs(self) -> dict[str, Any]:
return self.logger.get_logs
@property
- def get_metrics(self) -> Dict[str, Any]:
+ def get_metrics(self) -> dict[str, Any]:
return self.logger.get_metrics
def log_null(self, epoch_no: int) -> None:
@@ -445,8 +455,8 @@ def print_epoch_results(self, epoch_no: int) -> None:
def plot(
self,
- results: Dict[str, Any],
- metrics: Optional[Dict[str, Any]] = None,
+ results: dict[str, Any],
+ metrics: Optional[dict[str, Any]] = None,
save: bool = True,
show: bool = False,
) -> None:
@@ -491,7 +501,9 @@ def plot(
colours=utils.fallback_params("colours", self.params, self.global_params),
save=save,
show=show,
- model_name=utils.fallback_params("model_name", self.params, self.global_params),
+ model_name=utils.fallback_params(
+ "model_name", self.params, self.global_params
+ ),
timestamp=self.global_params["timestamp"],
results_dir=self.task_dir,
task_cfg=self.params,
@@ -545,7 +557,7 @@ def __repr__(self) -> str:
# =====================================================================================================================
# METHODS
# =====================================================================================================================
-def get_task(task_name: str, task_module: str = "minerva.tasks", *args, **params) -> MinervaTask:
+def get_task(task_name: str, *args, **params) -> MinervaTask:
"""Get the requested :class:`MinervaTask` by name.
Args:
@@ -555,7 +567,7 @@ def get_task(task_name: str, task_module: str = "minerva.tasks", *args, **params
Returns:
MinervaTask: Constructed :class:`MinervaTask` object.
"""
- _task = func_by_str(task_module, task_name)
+ _task = hydra.utils.get_class(task_name)
task = _task(*args, **params)
assert isinstance(task, MinervaTask)
diff --git a/minerva/tasks/epoch.py b/minerva/tasks/epoch.py
index 19363a8b3..98e64fb33 100644
--- a/minerva/tasks/epoch.py
+++ b/minerva/tasks/epoch.py
@@ -60,7 +60,7 @@ class StandardEpoch(MinervaTask):
.. versionadded:: 0.27
"""
- logger_cls = "SupervisedTaskLogger"
+ logger_cls = "minerva.logger.tasklog.SupervisedTaskLogger"
def step(self) -> None:
# Initialises a progress bar for the epoch.
@@ -72,7 +72,7 @@ def step(self) -> None:
self.model.eval()
# Ensure gradients will not be calculated if this is not a training task.
- with torch.no_grad() if not self.train else nullcontext():
+ with torch.no_grad() if not self.train else nullcontext(): # type: ignore[attr-defined]
# Core of the epoch.
for batch in self.loaders:
@@ -85,7 +85,7 @@ def step(self) -> None:
)
if self.local_step_num % self.log_rate == 0:
- if dist.is_available() and dist.is_initialized(): # type: ignore[attr-defined] # pragma: no cover
+ if dist.is_available() and dist.is_initialized(): # type: ignore[attr-defined] # pragma: no cover # noqa: E501
loss = results[0].data.clone()
dist.all_reduce(loss.div_(dist.get_world_size())) # type: ignore[attr-defined]
results = (loss, *results[1:])
diff --git a/minerva/tasks/knn.py b/minerva/tasks/knn.py
index 984faab7f..9c650a139 100644
--- a/minerva/tasks/knn.py
+++ b/minerva/tasks/knn.py
@@ -41,7 +41,7 @@
# =====================================================================================================================
# IMPORTS
# =====================================================================================================================
-from typing import TYPE_CHECKING, Optional, Tuple, Union
+from typing import TYPE_CHECKING, Optional
import torch
import torch.distributed as dist
@@ -137,19 +137,19 @@ class WeightedKNN(MinervaTask):
.. versionadded:: 0.27
"""
- logger_cls = "SSLTaskLogger"
+ logger_cls = "minerva.logger.tasklog.SSLTaskLogger"
def __init__(
self,
name: str,
- model: Union[MinervaModel, MinervaDataParallel],
+ model: MinervaModel | MinervaDataParallel,
device: torch.device,
exp_fn: Path,
gpu: int = 0,
rank: int = 0,
world_size: int = 1,
- writer: Optional[Union[SummaryWriter, Run]] = None,
- backbone_weight_path: Optional[Union[str, Path]] = None,
+ writer: Optional[SummaryWriter | Run] = None,
+ backbone_weight_path: Optional[str | Path] = None,
record_int: bool = True,
record_float: bool = False,
k: int = 5,
@@ -174,10 +174,12 @@ def __init__(
self.temp = temp
self.k = k
- def generate_feature_bank(self) -> Tuple[Tensor, Tensor]:
+ def generate_feature_bank(self) -> tuple[Tensor, Tensor]:
feature_list = []
target_list = []
+ assert isinstance(self.loaders, dict)
+
for batch in tqdm(self.loaders["features"]):
val_data: Tensor = batch["image"].to(self.device, non_blocking=True)
val_target: Tensor = batch["mask"].to(self.device, non_blocking=True)
@@ -231,6 +233,8 @@ def step(self) -> None:
# Generate feature bank and target bank.
feature_bank, feature_labels = self.generate_feature_bank()
+ assert isinstance(self.loaders, dict)
+
# Loop test data to predict the label by weighted KNN search.
for batch in tqdm(self.loaders["test"]):
test_data: Tensor = batch["image"].to(self.device, non_blocking=True)
diff --git a/minerva/tasks/tsne.py b/minerva/tasks/tsne.py
index 38bd37f26..4db2bd922 100644
--- a/minerva/tasks/tsne.py
+++ b/minerva/tasks/tsne.py
@@ -38,7 +38,7 @@
# IMPORTS
# =====================================================================================================================
from pathlib import Path
-from typing import Optional, Union
+from typing import Any, Optional
import torch
from torch import Tensor
@@ -46,6 +46,7 @@
from wandb.sdk.wandb_run import Run
from minerva.models import MinervaDataParallel, MinervaModel
+from minerva.utils.utils import get_sample_index
from minerva.utils.visutils import plot_embedding
from .core import MinervaTask
@@ -64,14 +65,14 @@ class TSNEVis(MinervaTask):
def __init__(
self,
name: str,
- model: Union[MinervaModel, MinervaDataParallel],
+ model: MinervaModel | MinervaDataParallel,
device: torch.device,
exp_fn: Path,
gpu: int = 0,
rank: int = 0,
world_size: int = 1,
- writer: Union[SummaryWriter, Run, None] = None,
- backbone_weight_path: Optional[Union[str, Path]] = None,
+ writer: Optional[SummaryWriter | Run] = None,
+ backbone_weight_path: Optional[str | Path] = None,
record_int: bool = True,
record_float: bool = False,
**params,
@@ -79,7 +80,7 @@ def __init__(
backbone = model.get_backbone() # type: ignore[assignment, operator]
# Set dummy optimiser. It won't be used as this is a test.
- backbone.set_optimiser(torch.optim.SGD(backbone.parameters(), lr=1.0e-3))
+ backbone.set_optimiser(torch.optim.SGD(backbone.parameters(), lr=1.0e-3)) # type: ignore[attr-defined]
super().__init__(
name,
@@ -103,7 +104,8 @@ def step(self) -> None:
Passes these embeddings to :mod:`visutils` to train a TSNE algorithm and then visual the cluster.
"""
# Get a batch of data.
- data = next(iter(self.loaders))
+ assert isinstance(self.loaders, torch.utils.data.DataLoader)
+ data: dict[str, Any] = next(iter(self.loaders))
# Make sure the model is in evaluation mode.
self.model.eval()
@@ -118,7 +120,7 @@ def step(self) -> None:
plot_embedding(
embeddings.detach().cpu(),
- data["bbox"],
+ get_sample_index(data), # type: ignore[arg-type]
self.global_params["data_root"],
self.params["dataset_params"],
show=True,
diff --git a/minerva/trainer.py b/minerva/trainer.py
index 0f35b99fa..88a9cf470 100644
--- a/minerva/trainer.py
+++ b/minerva/trainer.py
@@ -39,11 +39,12 @@
# IMPORTS
# =====================================================================================================================
import os
+import re
import warnings
from copy import deepcopy
from pathlib import Path
from platform import python_version
-from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Union
+from typing import TYPE_CHECKING, Any, Callable, Optional
import hydra
import packaging
@@ -66,7 +67,6 @@
MinervaDataParallel,
MinervaModel,
MinervaOnnxModel,
- MinervaWrapper,
extract_wrapped_model,
wrap_model,
)
@@ -237,7 +237,7 @@ def __init__(
rank: int = 0,
world_size: int = 1,
verbose: bool = True,
- wandb_run: Optional[Union[Run, RunDisabled]] = None,
+ wandb_run: Optional[Run | RunDisabled] = None,
**params,
) -> None:
assert not isinstance(wandb_run, RunDisabled)
@@ -264,7 +264,7 @@ def __init__(
utils.print_config(params) # type: ignore[arg-type]
# Now that we have pretty printed the config, it is easier to handle as a dict.
- self.params: Dict[str, Any] = OmegaConf.to_object(params) # type: ignore[assignment]
+ self.params: dict[str, Any] = OmegaConf.to_object(params) # type: ignore[assignment]
assert isinstance(self.params, dict)
# Set variables for checkpointing the experiment or loading from a previous checkpoint.
@@ -324,7 +324,7 @@ def __init__(
# Makes a directory for this experiment.
utils.mkexpdir(self.params["exp_name"])
- self.writer: Optional[Union[SummaryWriter, Run]] = None
+ self.writer: Optional[SummaryWriter | Run] = None
if self.params.get("wandb_log", False):
# Sets the `wandb` run object (or None).
self.writer = wandb_run
@@ -337,9 +337,10 @@ def __init__(
else: # pragma: no cover
self.writer = None
- self.model: Union[
- MinervaModel, MinervaDataParallel, MinervaBackbone, OptimizedModule
- ]
+ self.model: (
+ MinervaModel | MinervaDataParallel | MinervaBackbone | OptimizedModule
+ )
+
if Path(self.params.get("pre_train_name", "none")).suffix == ".onnx":
# Loads model from `onnx` format.
self.model = self.load_onnx_model()
@@ -351,7 +352,7 @@ def __init__(
self.model = self.make_model()
# Determines the output shape of the model.
- sample_pairs: Union[bool, Any] = self.sample_pairs
+ sample_pairs: bool | Any = self.sample_pairs
if not isinstance(sample_pairs, bool):
sample_pairs = False
self.params["sample_pairs"] = False
@@ -409,7 +410,9 @@ def __init__(
else:
pass
- self.checkpoint_path = self.exp_fn / (self.params["exp_name"] + "-checkpoint.pt")
+ self.checkpoint_path = self.exp_fn / (
+ self.params["exp_name"] + "-checkpoint.pt"
+ )
self.backbone_path = self.exp_fn / (self.params["exp_name"] + "-backbone.pt")
self.print("Checkpoint will be saved to " + str(self.checkpoint_path))
@@ -455,22 +458,19 @@ def _setup_writer(self) -> None:
if isinstance(self.writer, Run):
self.writer.watch(self.model)
- def get_input_size(self) -> Tuple[int, ...]:
+ def get_input_size(self) -> tuple[int, ...]:
"""Determines the input size of the model.
Returns:
tuple[int, ...]: :class:`tuple` describing the input shape of the model.
"""
- input_shape: Optional[Tuple[int, ...]] = self.model.input_size # type: ignore
+ input_shape: Optional[tuple[int, ...]] = self.model.input_size # type: ignore
assert input_shape is not None
- input_size: Tuple[int, ...] = (self.batch_size, *input_shape)
+ input_size: tuple[int, ...] = (self.batch_size, *input_shape)
- if self.sample_pairs:
+ if self.sample_pairs or self.change_detection:
input_size = (2, *input_size)
- if self.change_detection:
- input_size = (input_size[0], 2 * input_size[1], *input_size[2:])
-
return input_size
def get_model_cache_path(self) -> Path:
@@ -497,43 +497,39 @@ def make_model(self) -> MinervaModel:
Returns:
MinervaModel: Initialised model.
"""
- model_params: Dict[str, Any] = deepcopy(self.params["model_params"])
+ model_params: dict[str, Any] = deepcopy(self.params["model_params"])
if OmegaConf.is_config(model_params):
model_params = OmegaConf.to_object(model_params) # type: ignore[assignment]
- module = model_params.pop("module", "minerva.models")
- if not module:
- module = "minerva.models"
- is_minerva = True if module == "minerva.models" else False
-
- # Gets the model requested by config parameters.
- _model = utils.func_by_str(module, self.params["model_name"].split("-")[0])
+ is_minerva = True if re.search(r"minerva", model_params["_target_"]) else False
if self.fine_tune:
# Add the path to the pre-trained weights to the model params.
model_params["backbone_weight_path"] = f"{self.get_weights_path()}.pt"
- params = model_params.get("params", {})
- if "n_classes" in params.keys():
+ if "n_classes" in model_params.keys():
# Updates the number of classes in case it has been altered by class balancing.
- params["n_classes"] = self.params["n_classes"]
+ model_params["n_classes"] = self.params["n_classes"]
- if "num_classes" in params.keys():
+ if "num_classes" in model_params.keys():
# Updates the number of classes in case it has been altered by class balancing.
- params["num_classes"] = self.params["n_classes"]
+ model_params["num_classes"] = self.params["n_classes"]
if self.params.get("mix_precision", False):
- params["scaler"] = torch.cuda.amp.grad_scaler.GradScaler()
+ model_params["scaler"] = torch.cuda.amp.grad_scaler.GradScaler()
# Initialise model.
model: MinervaModel
if is_minerva:
- model = _model(self.make_criterion(), **params)
+ model = hydra.utils.instantiate(
+ model_params, criterion=self.make_criterion()
+ )
else:
- model = MinervaWrapper(
- _model,
- self.make_criterion(),
- **params,
+ model_params["model"] = hydra.utils.get_method(model_params["_target_"])
+ model_params["_target_"] = "minerva.models.MinervaWrapper"
+ model = hydra.utils.instantiate(
+ model_params,
+ criterion=self.make_criterion(),
)
if self.params.get("reload", False):
@@ -562,7 +558,10 @@ def load_onnx_model(self) -> MinervaModel:
package="onnx2torch",
)
- model_params = self.params["model_params"].get("params", {})
+ model_params = deepcopy(self.params["model_params"])
+
+ if "_target_" in model_params:
+ del model_params["_target_"]
onnx_model = convert(onnx_load(f"{self.get_weights_path()}.onnx"))
model = MinervaOnnxModel(onnx_model, self.make_criterion(), **model_params)
@@ -620,11 +619,10 @@ def fit(self) -> None:
}
)
- tasks: Dict[str, MinervaTask] = {}
+ tasks: dict[str, MinervaTask] = {}
for mode in fit_params.keys():
tasks[mode] = get_task(
- fit_params[mode]["name"],
- fit_params[mode].get("module", "minerva.tasks"),
+ fit_params[mode]["_target_"],
mode,
self.model,
self.device,
@@ -637,7 +635,9 @@ def fit(self) -> None:
**self.params,
)
- if tasks[mode].params.get("elim", False):
+ # Update the number of classes from the task
+ # if class elimination and training is active in the task.
+ if tasks[mode].elim and tasks[mode].train:
self.params["n_classes"] = tasks[mode].n_classes
while self.epoch_no < self.max_epochs:
@@ -649,28 +649,34 @@ def fit(self) -> None:
# Conduct training or validation epoch.
for mode in tasks.keys():
# Only run a validation epoch at set frequency of epochs. Goes to next epoch if not.
- if (
- utils.check_substrings_in_string(mode, "val")
- and (self.epoch_no) % self.val_freq != 0
- ):
- tasks[mode].log_null(self.epoch_no - 1)
- break
+ if utils.check_substrings_in_string(mode, "val"):
+ tasks[mode].model = self.model
+
+ if self.epoch_no % self.val_freq != 0:
+ tasks[mode].log_null(self.epoch_no - 1)
+ break
if tasks[mode].train:
self.model.train()
else:
self.model.eval()
- results: Optional[Dict[str, Any]]
-
- results = tasks[mode](self.epoch_no - 1)
+ results: Optional[dict[str, Any]] = tasks[mode](self.epoch_no - 1)
# Print epoch results.
if self.gpu == 0:
tasks[mode].print_epoch_results(self.epoch_no - 1)
- if not self.stopper and self.checkpoint_experiment:
+ if (
+ not self.stopper
+ and self.checkpoint_experiment
+ and utils.check_substrings_in_string(mode, "train")
+ ):
self.save_checkpoint()
+ # Update Trainer's copy of the model from the training task.
+ if utils.check_substrings_in_string(mode, "train"):
+ self.model = tasks[mode].model
+
# Sends validation loss to the stopper and updates early stop bool.
if (
utils.check_substrings_in_string(mode, "val")
@@ -692,7 +698,7 @@ def fit(self) -> None:
self.print("\nEarly stopping triggered")
# Create a subset of metrics for plotting model history.
- fit_metrics: Dict[str, Any] = {}
+ fit_metrics: dict[str, Any] = {}
for _mode in tasks.keys():
fit_metrics = {**fit_metrics, **tasks[_mode].get_metrics}
@@ -742,8 +748,7 @@ def test(self, save: bool = True, show: bool = False) -> None:
self.params["plot_last_epoch"] = True
for task_name in test_params.keys():
task = get_task(
- test_params[task_name]["name"],
- test_params[task_name].get("module", "minerva.tasks"),
+ test_params[task_name]["_target_"],
task_name,
self.model,
self.device,
@@ -872,6 +877,11 @@ def load_checkpoint(self) -> None:
if self.model.scheduler is not None and "scheduler_state_dict" in checkpoint:
self.model.scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
+ # Calculate the output dimensions of the model.
+ self.model.determine_output_dim(
+ sample_pairs=self.sample_pairs, change_detection=self.change_detection
+ )
+
# Transfer to GPU.
self.model.to(self.device)
@@ -960,7 +970,7 @@ def close(self) -> None:
# Saves model state dict to PyTorch file.
self.save_model_weights()
- def save_model_weights(self, fn: Optional[Union[str, Path]] = None) -> None:
+ def save_model_weights(self, fn: Optional[str | Path] = None) -> None:
"""Saves model state dict to :mod:`torch` file.
Args:
@@ -973,9 +983,7 @@ def save_model_weights(self, fn: Optional[Union[str, Path]] = None) -> None:
torch.save(model.state_dict(), f"{fn}.pt")
- def save_model(
- self, fn: Optional[Union[Path, str]] = None, fmt: str = "pt"
- ) -> None:
+ def save_model(self, fn: Optional[Path | str] = None, fmt: str = "pt") -> None:
"""Saves the model object itself to :mod:`torch` file.
Args:
diff --git a/minerva/transforms.py b/minerva/transforms.py
index 7768d091e..bbb857581 100644
--- a/minerva/transforms.py
+++ b/minerva/transforms.py
@@ -55,20 +55,9 @@
import re
from copy import deepcopy
from pathlib import Path
-from typing import (
- Any,
- Callable,
- Dict,
- List,
- Literal,
- Optional,
- Sequence,
- Tuple,
- Union,
- cast,
- overload,
-)
+from typing import Any, Callable, Literal, Optional, Sequence, cast, overload
+import hydra
import numpy as np
import rasterio
import torch
@@ -79,18 +68,14 @@
from torchvision.transforms import (
ColorJitter,
ConvertImageDtype,
+ InterpolationMode,
Normalize,
RandomApply,
Resize,
)
from torchvision.transforms.v2 import functional as ft
-from minerva.utils.utils import (
- find_tensor_mode,
- func_by_str,
- get_centre_pixel_value,
- mask_transform,
-)
+from minerva.utils.utils import find_tensor_mode, get_centre_pixel_value, mask_transform
# =====================================================================================================================
@@ -106,7 +91,7 @@ class ClassTransform:
transform (dict[int, int]): Mapping from one labelling schema to another.
"""
- def __init__(self, transform: Dict[int, int]) -> None:
+ def __init__(self, transform: dict[int, int]) -> None:
self.transform = transform
def __call__(self, mask: LongTensor) -> LongTensor:
@@ -134,14 +119,14 @@ class PairCreate:
def __init__(self) -> None:
pass
- def __call__(self, sample: Any) -> Tuple[Any, Any]:
+ def __call__(self, sample: Any) -> tuple[Any, Any]:
return self.forward(sample)
def __repr__(self) -> str:
return f"{self.__class__.__name__}()"
@staticmethod
- def forward(sample: Any) -> Tuple[Any, Any]:
+ def forward(sample: Any) -> tuple[Any, Any]:
"""Takes a sample and returns it and a copy as a :class:`tuple` pair.
Args:
@@ -220,7 +205,7 @@ def __init__(
super().__init__(mean, std, inplace)
- def _calc_mean_std(self) -> Tuple[List[float], List[float]]:
+ def _calc_mean_std(self) -> tuple[list[float], list[float]]:
per_img_means = []
per_img_stds = []
for query in self.sampler:
@@ -236,7 +221,7 @@ def _calc_mean_std(self) -> Tuple[List[float], List[float]]:
return per_band_mean, per_band_std
- def _get_tile_mean_std(self, query: BoundingBox) -> Tuple[List[float], List[float]]:
+ def _get_tile_mean_std(self, query: BoundingBox) -> tuple[list[float], list[float]]:
hits = self.dataset.index.intersection(tuple(query), objects=True)
filepaths = cast(list[str], [hit.object for hit in hits])
@@ -245,8 +230,8 @@ def _get_tile_mean_std(self, query: BoundingBox) -> Tuple[List[float], List[floa
f"query: {query} not found in index with bounds: {self.dataset.bounds}"
)
- means: List[float]
- stds: List[float]
+ means: list[float]
+ stds: list[float]
if self.dataset.separate_files:
filename_regex = re.compile(self.dataset.filename_regex, re.VERBOSE)
@@ -277,9 +262,9 @@ def _get_tile_mean_std(self, query: BoundingBox) -> Tuple[List[float], List[floa
def _get_image_mean_std(
self,
- filepaths: List[str],
+ filepaths: list[str],
band_indexes: Optional[Sequence[int]] = None,
- ) -> Tuple[List[float], List[float]]:
+ ) -> tuple[list[float], list[float]]:
stats = [self._get_meta_mean_std(fp, band_indexes) for fp in filepaths]
means = list(np.mean([stat[0] for stat in stats], axis=0))
@@ -289,7 +274,7 @@ def _get_image_mean_std(
def _get_meta_mean_std(
self, filepath, band_indexes: Optional[Sequence[int]] = None
- ) -> Tuple[List[float], List[float]]:
+ ) -> tuple[list[float], list[float]]:
# Open the Tiff file and get the statistics from the meta (min, max, mean, std).
means = []
stds = []
@@ -374,7 +359,7 @@ class ToRGB:
"""
- def __init__(self, channels: Optional[Tuple[int, int, int]] = None) -> None:
+ def __init__(self, channels: Optional[tuple[int, int, int]] = None) -> None:
self.channels = channels
def __call__(self, img: Tensor) -> Tensor:
@@ -426,7 +411,7 @@ class SelectChannels:
channels (list[int]): Channel indices to keep.
"""
- def __init__(self, channels: List[int]) -> None:
+ def __init__(self, channels: list[int]) -> None:
self.channels = channels
def __call__(self, img: Tensor) -> Tensor:
@@ -501,8 +486,8 @@ class MinervaCompose:
Args:
transforms (~typing.Sequence[~typing.Callable[..., ~typing.Any]] | ~typing.Callable[..., ~typing.Any]):
List of transforms to compose.
- key (str): Optional; For use with :mod:`torchgeo` samples and must be assigned a value if using.
- The key of the data type in the sample dict to transform.
+ change_detection (bool): Flag for if transforming a change detection dataset which has
+ ``"image1"`` and ``"image2"`` keys rather than ``"image"``.
Example:
>>> transforms.MinervaCompose([
@@ -514,21 +499,26 @@ class MinervaCompose:
def __init__(
self,
- transforms: Union[
- List[Callable[..., Any]],
- Callable[..., Any],
- Dict[str, Union[List[Callable[..., Any]], Callable[..., Any]]],
- ],
+ transforms: (
+ list[Callable[..., Any]]
+ | Callable[..., Any]
+ | dict[str, list[Callable[..., Any]] | Callable[..., Any]]
+ ),
+ change_detection: bool = False,
) -> None:
- self.transforms: Union[
- List[Callable[..., Any]], Dict[str, List[Callable[..., Any]]]
- ]
+ self.transforms: list[Callable[..., Any]] | dict[str, list[Callable[..., Any]]]
+
+ self.change_detection = change_detection
if isinstance(transforms, Sequence):
self.transforms = list(transforms)
elif callable(transforms):
self.transforms = [transforms]
elif isinstance(transforms, dict):
+ if self.change_detection and "image" in transforms:
+ transforms["image1"] = transforms["image"]
+ transforms["image2"] = transforms["image"]
+ del transforms["image"]
self.transforms = transforms # type: ignore[assignment]
assert isinstance(self.transforms, dict)
for key in transforms.keys():
@@ -546,12 +536,10 @@ def __call__(self, sample: Tensor) -> Tensor: ... # pragma: no cover
@overload
def __call__(
- self, sample: Dict[str, Any]
- ) -> Dict[str, Any]: ... # pragma: no cover
+ self, sample: dict[str, Any]
+ ) -> dict[str, Any]: ... # pragma: no cover
- def __call__(
- self, sample: Union[Tensor, Dict[str, Any]]
- ) -> Union[Tensor, Dict[str, Any]]:
+ def __call__(self, sample: Tensor | dict[str, Any]) -> Tensor | dict[str, Any]:
if isinstance(sample, Tensor):
assert not isinstance(self.transforms, dict)
return self._transform_input(sample, self.transforms)
@@ -563,11 +551,24 @@ def __call__(
# Assumes the keys must be "image" and "mask"
# We need to apply these first before applying any seperate transforms for modalities.
if "both" in self.transforms:
+ if self.change_detection:
+ # Transform images1 with new random states (if applicable).
+ sample["image1"] = self._transform_input(
+ sample["image1"], self.transforms["both"]
+ )
- # Transform images with new random states (if applicable).
- sample["image"] = self._transform_input(
- sample["image"], self.transforms["both"]
- )
+ # Transform images1 with new random states (if applicable).
+ sample["image2"] = self._transform_input(
+ sample["image2"],
+ self.transforms["both"],
+ reapply=True,
+ )
+
+ else:
+ # Transform images with new random states (if applicable).
+ sample["image"] = self._transform_input(
+ sample["image"], self.transforms["both"]
+ )
# We'll have to convert the masks to float for these transforms to work
# so need to store the current dtype to cast back to after.
@@ -596,7 +597,7 @@ def __call__(
@staticmethod
def _transform_input(
- img: Tensor, transforms: List[Callable[..., Any]], reapply: bool = False
+ img: Tensor, transforms: list[Callable[..., Any]], reapply: bool = False
) -> Tensor:
if isinstance(transforms, Sequence):
for t in transforms:
@@ -615,17 +616,17 @@ def _transform_input(
def _add(
self,
- new_transform: Union[
- "MinervaCompose",
- Sequence[Callable[..., Any]],
- Callable[..., Any],
- Dict[str, Union[Sequence[Callable[..., Any]], Callable[..., Any]]],
- ],
- ) -> Union[Dict[str, List[Callable[..., Any]]], List[Callable[..., Any]]]:
+ new_transform: (
+ "MinervaCompose"
+ | Sequence[Callable[..., Any]]
+ | Callable[..., Any]
+ | dict[str, Sequence[Callable[..., Any]] | Callable[..., Any]]
+ ),
+ ) -> dict[str, list[Callable[..., Any]]] | list[Callable[..., Any]]:
def add_transforms(
- _new_transform: Union[Sequence[Callable[..., Any]], Callable[..., Any]],
- old_transform: List[Callable[..., Any]],
- ) -> List[Callable[..., Any]]:
+ _new_transform: Sequence[Callable[..., Any]] | Callable[..., Any],
+ old_transform: list[Callable[..., Any]],
+ ) -> list[Callable[..., Any]]:
if isinstance(_new_transform, Sequence):
old_transform.extend(_new_transform)
return old_transform
@@ -665,12 +666,12 @@ def add_transforms(
def __add__(
self,
- new_transform: Union[
- "MinervaCompose",
- Sequence[Callable[..., Any]],
- Callable[..., Any],
- Dict[str, Union[Sequence[Callable[..., Any]], Callable[..., Any]]],
- ],
+ new_transform: (
+ "MinervaCompose"
+ | Sequence[Callable[..., Any]]
+ | Callable[..., Any]
+ | dict[str, Sequence[Callable[..., Any]] | Callable[..., Any]]
+ ),
) -> "MinervaCompose":
new_compose = deepcopy(self)
new_compose.transforms = self._add(new_transform)
@@ -678,12 +679,12 @@ def __add__(
def __iadd__(
self,
- new_transform: Union[
- "MinervaCompose",
- Sequence[Callable[..., Any]],
- Callable[..., Any],
- Dict[str, Union[Sequence[Callable[..., Any]], Callable[..., Any]]],
- ],
+ new_transform: (
+ "MinervaCompose"
+ | Sequence[Callable[..., Any]]
+ | Callable[..., Any]
+ | dict[str, Sequence[Callable[..., Any]] | Callable[..., Any]]
+ ),
) -> "MinervaCompose":
self.transforms = self._add(new_transform)
return self
@@ -738,13 +739,13 @@ def __init__(self, from_key: str, to_key: str) -> None:
self.from_key = from_key
self.to_key = to_key
- def __call__(self, sample: Dict[str, Any]) -> Dict[str, Any]:
+ def __call__(self, sample: dict[str, Any]) -> dict[str, Any]:
return self.forward(sample)
def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.from_key} -> {self.to_key})"
- def forward(self, sample: Dict[str, Any]) -> Dict[str, Any]:
+ def forward(self, sample: dict[str, Any]) -> dict[str, Any]:
"""Sets the ``to_key`` of ``sample`` to the ``from_key`` and returns.
Args:
@@ -766,12 +767,14 @@ class SeasonTransform:
season (str): How to handle what seasons to return:
* ``pair``: Randomly pick 2 seasons to return that will form a pair.
* ``random``: Randomly pick a single season to return.
+
+ .. versionadded:: 0.28
"""
def __init__(self, season: str = "random") -> None:
self.season = season
- def __call__(self, x: Tensor) -> Union[Tuple[Tensor, Tensor], Tensor]:
+ def __call__(self, x: Tensor) -> tuple[Tensor, Tensor] | Tensor:
if self.season == "pair":
season1 = np.random.choice([0, 1, 2, 3])
@@ -799,6 +802,8 @@ class ConvertDtypeFromStr(ConvertImageDtype):
Args:
dtype (str): A tensor type as :class:`str`.
+
+ .. versionadded:: 0.28
"""
def __init__(self, dtype: str) -> None:
@@ -806,7 +811,20 @@ def __init__(self, dtype: str) -> None:
class MaskResize(Resize):
- """Wrapper of :class:`torchvision.transforms.Resize` for use with masks that have no channel dimension."""
+ """Wrapper of :class:`torchvision.transforms.Resize` for use with masks that have no channel dimension.
+
+ .. versionadded:: 0.28
+ """
+
+ def __init__(
+ self,
+ size,
+ interpolation: str = "NEAREST",
+ max_size=None,
+ antialias: bool = True,
+ ) -> None:
+ interpolation_mode = getattr(InterpolationMode, interpolation)
+ super().__init__(size, interpolation_mode, max_size, antialias)
def forward(self, img: Tensor) -> Tensor:
"""
@@ -818,7 +836,7 @@ def forward(self, img: Tensor) -> Tensor:
"""
org_shape = img.shape
- tmp_shape: Tuple[int, int, int, int]
+ tmp_shape: tuple[int, int, int, int]
if len(org_shape) == 4:
# Mask already has shape [N,C,H,W] so no need to modify shape for Resize.
@@ -836,35 +854,61 @@ def forward(self, img: Tensor) -> Tensor:
return torch.squeeze(super().forward(torch.reshape(img, tmp_shape)))
+class AdjustGamma:
+ """Callable version of :meth:`torchvision.transforms.functional.adjust_gamma`
+
+ Args:
+ gamma (float): Optional; Gamma factor.
+ gain (float): Optional; Gain scalar.
+
+ .. versionadded:: 0.28
+ """
+
+ def __init__(self, gamma: float = 1.0, gain: float = 1.0) -> None:
+ self.gamma = gamma
+ self.gain = gain
+
+ def forward(self, img: Tensor) -> Tensor:
+ img = ft.adjust_gamma(img, gamma=self.gamma, gain=self.gain)
+ assert isinstance(img, Tensor)
+ return img
+
+ def __call__(self, img: Tensor) -> Tensor:
+ return self.forward(img)
+
+
# =====================================================================================================================
# METHODS
# =====================================================================================================================
-def _construct_random_transforms(random_params: Dict[str, Any]) -> Any:
+def _construct_random_transforms(random_params: dict[str, Any]) -> Any:
p = random_params.pop("p", 0.5)
random_transforms = []
for ran_name in random_params:
- random_transforms.append(get_transform(ran_name, random_params[ran_name]))
+ random_transforms.append(get_transform(random_params[ran_name]))
return RandomApply(random_transforms, p=p)
def init_auto_norm(
- dataset: RasterDataset, params: Dict[str, Any] = {}
+ dataset: RasterDataset,
+ length: int = 128,
+ roi: Optional[BoundingBox] = None,
+ inplace=False,
) -> RasterDataset:
"""Uses :class:~`minerva.transforms.AutoNorm` to automatically find the mean and standard deviation of `dataset`
to create a normalisation transform that is then added to the existing transforms of `dataset`.
Args:
dataset (RasterDataset): Dataset to find and apply the normalisation conditions to.
- params (Dict[str, Any]): Parameters for :class:~`minerva.transforms.AutoNorm`.
+ params (dict[str, Any]): Parameters for :class:~`minerva.transforms.AutoNorm`.
Returns:
RasterDataset: `dataset` with an additional :class:~`minerva.transforms.AutoNorm` transform
added to it's :attr:~`torchgeo.datasets.RasterDataset.transforms` attribute.
"""
# Creates the AutoNorm transform by sampling `dataset` for its mean and standard deviation stats.
- auto_norm = AutoNorm(dataset, **params)
+ auto_norm = AutoNorm(dataset, length=length, roi=roi, inplace=inplace)
if dataset.transforms is None:
dataset.transforms = MinervaCompose(auto_norm)
@@ -887,35 +931,24 @@ def init_auto_norm(
return dataset
-def get_transform(name: str, transform_params: Dict[str, Any]) -> Callable[..., Any]:
+def get_transform(transform_params: dict[str, Any]) -> Callable[..., Any]:
"""Creates a transform object based on config parameters.
Args:
- name (str): Name of transform object to import e.g :class:`~torchvision.transforms.RandomResizedCrop`.
transform_params (dict[str, ~typing.Any]): Arguements to construct transform with.
- Should also include ``"module"`` key defining the import path to the transform object.
+ Should also include ``"_target_"`` key defining the import path to the transform object.
Returns:
Initialised transform object specified by config parameters.
- .. note::
- If ``transform_params`` contains no ``"module"`` key, it defaults to ``torchvision.transforms``.
-
Example:
- >>> name = "RandomResizedCrop"
- >>> params = {"module": "torchvision.transforms", "size": 128}
- >>> transform = get_transform(name, params)
+ >>> params = {"_target": "torchvision.transforms.RandomResizedCrop", "size": 128}
+ >>> transform = get_transform(params)
Raises:
TypeError: If created transform object is itself not :class:`~typing.Callable`.
"""
- params = transform_params.copy()
- module = params.pop("module", "torchvision.transforms")
-
- # Gets the transform requested by config parameters.
- _transform: Callable[..., Any] = func_by_str(module, name)
-
- transform: Callable[..., Any] = _transform(**params)
+ transform: Callable[..., Any] = hydra.utils.instantiate(transform_params)
if callable(transform):
return transform
else:
@@ -923,7 +956,8 @@ def get_transform(name: str, transform_params: Dict[str, Any]) -> Callable[...,
def make_transformations(
- transform_params: Union[Dict[str, Any], Literal[False]]
+ transform_params: dict[str, Any] | Literal[False],
+ change_detection: bool = False,
) -> Optional[MinervaCompose]:
"""Constructs a transform or series of transforms based on parameters provided.
@@ -931,11 +965,13 @@ def make_transformations(
transform_params (dict[str, ~typing.Any] | ~typing.Literal[False]): Parameters defining transforms desired.
The name of each transform should be the key, while the kwargs for the transform should
be the value of that key as a dict.
+ change_detection (bool): Flag for if transforming a change detection dataset which has
+ ``"image1"`` and ``"image2"`` keys rather than ``"image"``.
Example:
>>> transform_params = {
- >>> "CenterCrop": {"module": "torchvision.transforms", "size": 128},
- >>> "RandomHorizontalFlip": {"module": "torchvision.transforms", "p": 0.7}
+ >>> "crop": {"_target_": "torchvision.transforms.CenterCrop", "size": 128},
+ >>> "flip": {"_target_": "torchvision.transforms.RandomHorizontalFlip", "p": 0.7}
>>> }
>>> transforms = make_transformations(transform_params)
@@ -945,7 +981,7 @@ def make_transformations(
If multiple transforms are defined, a Compose object of Transform objects is returned.
"""
- def construct(type_transform_params: Dict[str, Any]) -> List[Callable[..., Any]]:
+ def construct(type_transform_params: dict[str, Any]) -> list[Callable[..., Any]]:
type_transformations = []
# Get each transform.
@@ -959,9 +995,7 @@ def construct(type_transform_params: Dict[str, Any]) -> List[Callable[..., Any]]
continue
else:
- type_transformations.append(
- get_transform(_name, type_transform_params[_name])
- )
+ type_transformations.append(get_transform(type_transform_params[_name]))
return type_transformations
@@ -971,7 +1005,7 @@ def construct(type_transform_params: Dict[str, Any]) -> List[Callable[..., Any]]
if all(transform_params.values()) is None:
return None
- transformations: Dict[str, Any] = {}
+ transformations: dict[str, Any] = {}
for name in transform_params:
if name in ("image", "mask", "label", "both"):
@@ -980,6 +1014,8 @@ def construct(type_transform_params: Dict[str, Any]) -> List[Callable[..., Any]]
else:
transformations[name] = construct(transform_params[name])
else:
- return MinervaCompose(construct(transform_params))
+ return MinervaCompose(
+ construct(transform_params), change_detection=change_detection
+ )
- return MinervaCompose(transformations)
+ return MinervaCompose(transformations, change_detection=change_detection)
diff --git a/minerva/utils/runner.py b/minerva/utils/runner.py
index edba8f1ba..7f74b18e6 100644
--- a/minerva/utils/runner.py
+++ b/minerva/utils/runner.py
@@ -54,7 +54,7 @@
import signal
import subprocess
from pathlib import Path
-from typing import Any, Callable, Optional, Tuple, Union
+from typing import Any, Callable, Optional
import requests
import torch
@@ -62,7 +62,7 @@
import torch.multiprocessing as mp
import wandb
import yaml
-from omegaconf import DictConfig, OmegaConf, ListConfig
+from omegaconf import DictConfig, ListConfig, OmegaConf
from wandb.sdk.lib import RunDisabled
from wandb.sdk.wandb_run import Run
@@ -120,14 +120,14 @@ def _config_load_resolver(path: str):
return cfg
-def _construct_patch_size(input_size: Tuple[int, int, int]) -> ListConfig:
+def _construct_patch_size(input_size: tuple[int, int, int]) -> ListConfig:
return ListConfig(input_size[-2:])
def setup_wandb_run(
gpu: int,
cfg: DictConfig,
-) -> Tuple[Optional[Union[Run, RunDisabled]], DictConfig]:
+) -> tuple[Optional[Run | RunDisabled], DictConfig]:
"""Sets up a :mod:`wandb` logger for either every process, the master process or not if not logging.
Note:
@@ -148,7 +148,7 @@ def setup_wandb_run(
~wandb.sdk.wandb_run.Run | ~wandb.sdk.lib.RunDisabled | None: The :mod:`wandb` run object
for this process or ``None`` if ``log_all=False`` and ``rank!=0``.
"""
- run: Optional[Union[Run, RunDisabled]] = None
+ run: Optional[Run | RunDisabled] = None
if cfg.get("wandb_log", False) or cfg.get("project", None):
try:
if cfg.get("log_all", False) and cfg.world_size > 1:
@@ -290,7 +290,7 @@ def config_args(cfg: DictConfig) -> DictConfig:
def _run_preamble(
gpu: int,
- run: Callable[[int, Optional[Union[Run, RunDisabled]], DictConfig], Any],
+ run: Callable[[int, Optional[Run | RunDisabled], DictConfig], Any],
cfg: DictConfig,
) -> None: # pragma: no cover
# Calculates the global rank of this process.
@@ -301,7 +301,7 @@ def _run_preamble(
if cfg.world_size > 1:
dist.init_process_group( # type: ignore[attr-defined]
- backend="nccl",
+ backend=cfg.get("dist_backend", "gloo"),
init_method=cfg.dist_url,
world_size=cfg.world_size,
rank=cfg.rank,
@@ -316,7 +316,7 @@ def _run_preamble(
def distributed_run(
- run: Callable[[int, Optional[Union[Run, RunDisabled]], DictConfig], Any]
+ run: Callable[[int, Optional[Run | RunDisabled], DictConfig], Any]
) -> Callable[..., Any]:
"""Runs the supplied function and arguments with distributed computing according to arguments.
@@ -334,7 +334,9 @@ def distributed_run(
OmegaConf.register_new_resolver("cfg_load", _config_load_resolver, replace=True)
OmegaConf.register_new_resolver("eval", eval, replace=True)
- OmegaConf.register_new_resolver("to_patch_size", _construct_patch_size, replace=True)
+ OmegaConf.register_new_resolver(
+ "to_patch_size", _construct_patch_size, replace=True
+ )
@functools.wraps(run)
def inner_decorator(cfg: DictConfig):
@@ -360,7 +362,7 @@ def inner_decorator(cfg: DictConfig):
def run_trainer(
- gpu: int, wandb_run: Optional[Union[Run, RunDisabled]], cfg: DictConfig
+ gpu: int, wandb_run: Optional[Run | RunDisabled], cfg: DictConfig
) -> None:
trainer = Trainer(
diff --git a/minerva/utils/utils.py b/minerva/utils/utils.py
index b8c563867..8e701b1ad 100644
--- a/minerva/utils/utils.py
+++ b/minerva/utils/utils.py
@@ -107,7 +107,6 @@
# ---+ Inbuilt +-------------------------------------------------------------------------------------------------------
import cmath
import functools
-import glob
import hashlib
import importlib
import inspect
@@ -126,7 +125,7 @@
from types import ModuleType
from typing import Any, Callable
from typing import Counter as CounterType
-from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Union, overload
+from typing import Iterable, Optional, Sequence, overload
# ---+ 3rd Party +-----------------------------------------------------------------------------------------------------
import numpy as np
@@ -172,8 +171,8 @@
# DECORATORS
# =====================================================================================================================
def return_updated_kwargs(
- func: Callable[..., Tuple[Any, ...]]
-) -> Callable[..., Tuple[Any, ...]]:
+ func: Callable[..., tuple[Any, ...]]
+) -> Callable[..., tuple[Any, ...]]:
"""Decorator that allows the `kwargs` supplied to the wrapped function to be returned with updated values.
Assumes that the wrapped function returns a :class:`dict` in the last position of the
@@ -212,7 +211,7 @@ def pair_collate(func: Callable[[Any], Any]) -> Callable[[Any], Any]:
"""
@functools.wraps(func)
- def wrapper(samples: Iterable[Tuple[Any, Any]]) -> Tuple[Any, Any]:
+ def wrapper(samples: Iterable[tuple[Any, Any]]) -> tuple[Any, Any]:
a, b = tuple(zip(*samples))
return func(a), func(b)
@@ -227,7 +226,7 @@ class Wrapper:
def __init__(self, *args, **kwargs) -> None:
self.wrap = cls(*args, **kwargs)
- def __call__(self, pair: Tuple[Any, Any]) -> Tuple[Any, Any]:
+ def __call__(self, pair: tuple[Any, Any]) -> tuple[Any, Any]:
a, b = pair
return self.wrap.__call__(a), self.wrap.__call__(b)
@@ -258,26 +257,26 @@ class Wrapper:
def __init__(self, *args, **kwargs) -> None:
self.wrap: Callable[
[
- Union[Dict[str, Any], Tensor],
+ dict[str, Any] | Tensor,
],
- Dict[str, Any],
+ dict[str, Any],
] = cls(*args, **kwargs)
self.keys = keys
@overload
def __call__(
- self, batch: Dict[str, Any]
- ) -> Dict[str, Any]: ... # pragma: no cover
+ self, batch: dict[str, Any]
+ ) -> dict[str, Any]: ... # pragma: no cover
@overload
- def __call__(self, batch: Tensor) -> Dict[str, Any]: ... # pragma: no cover
+ def __call__(self, batch: Tensor) -> dict[str, Any]: ... # pragma: no cover
- def __call__(self, batch: Union[Dict[str, Any], Tensor]) -> Dict[str, Any]:
+ def __call__(self, batch: dict[str, Any] | Tensor) -> dict[str, Any]:
if isinstance(batch, Tensor):
return self.wrap(batch)
elif isinstance(batch, dict) and isinstance(self.keys, Sequence):
- aug_batch: Dict[str, Any] = {}
+ aug_batch: dict[str, Any] = {}
for key in self.keys:
aug_batch[key] = self.wrap(batch.pop(key))
@@ -311,7 +310,7 @@ class Wrapper:
def __init__(self, *args, **kwargs) -> None:
self.wrap = cls(*args, **kwargs)
- def __getitem__(self, queries: Any = None) -> Tuple[Any, Any]:
+ def __getitem__(self, queries: Any = None) -> tuple[Any, Any]:
return self.wrap[queries[0]], self.wrap[queries[1]]
def __getattr__(self, item):
@@ -399,7 +398,7 @@ def _optional_import(
def _optional_import(
module: str, *, name: Optional[str] = None, package: Optional[str] = None
-) -> Union[ModuleType, Callable[..., Any]]:
+) -> ModuleType | Callable[..., Any]:
try:
_module: ModuleType = importlib.import_module(module)
return _module if name is None else getattr(_module, name)
@@ -462,7 +461,7 @@ def is_notebook() -> bool:
return True
-def get_cuda_device(device_sig: Union[int, str] = "cuda:0") -> _device:
+def get_cuda_device(device_sig: int | str = "cuda:0") -> _device:
"""Finds and returns the ``CUDA`` device, if one is available. Else, returns CPU as device.
Assumes there is at most only one ``CUDA`` device.
@@ -478,7 +477,7 @@ def get_cuda_device(device_sig: Union[int, str] = "cuda:0") -> _device:
return device
-def exist_delete_check(fn: Union[str, Path]) -> None:
+def exist_delete_check(fn: str | Path) -> None:
"""Checks if given file exists then deletes if true.
Args:
@@ -491,7 +490,7 @@ def exist_delete_check(fn: Union[str, Path]) -> None:
Path(fn).unlink(missing_ok=True)
-def mkexpdir(name: str, results_dir: Union[Path, str] = "results") -> None:
+def mkexpdir(name: str, results_dir: Path | str = "results") -> None:
"""Makes a new directory below the results directory with name provided. If directory already exists,
no action is taken.
@@ -521,7 +520,7 @@ def set_seeds(seed: int) -> None:
torch.cuda.random.manual_seed_all(seed)
-def check_dict_key(dictionary: Dict[Any, Any], key: Any) -> bool:
+def check_dict_key(dictionary: dict[Any, Any], key: Any) -> bool:
"""Checks if a key exists in a dictionary and if it is ``None`` or ``False``.
Args:
@@ -583,7 +582,7 @@ def transform_coordinates(
y: Sequence[float],
src_crs: CRS,
new_crs: CRS = WGS84,
-) -> Tuple[Sequence[float], Sequence[float]]: ... # pragma: no cover
+) -> tuple[Sequence[float], Sequence[float]]: ... # pragma: no cover
@overload
@@ -592,7 +591,7 @@ def transform_coordinates(
y: float,
src_crs: CRS,
new_crs: CRS = WGS84,
-) -> Tuple[Sequence[float], Sequence[float]]: ... # pragma: no cover
+) -> tuple[Sequence[float], Sequence[float]]: ... # pragma: no cover
@overload
@@ -601,21 +600,21 @@ def transform_coordinates(
y: Sequence[float],
src_crs: CRS,
new_crs: CRS = WGS84,
-) -> Tuple[Sequence[float], Sequence[float]]: ... # pragma: no cover
+) -> tuple[Sequence[float], Sequence[float]]: ... # pragma: no cover
@overload
def transform_coordinates(
x: float, y: float, src_crs: CRS, new_crs: CRS = WGS84
-) -> Tuple[float, float]: ... # pragma: no cover
+) -> tuple[float, float]: ... # pragma: no cover
def transform_coordinates(
- x: Union[Sequence[float], float],
- y: Union[Sequence[float], float],
+ x: Sequence[float] | float,
+ y: Sequence[float] | float,
src_crs: CRS,
new_crs: CRS = WGS84,
-) -> Union[Tuple[Sequence[float], Sequence[float]], Tuple[float, float]]:
+) -> tuple[Sequence[float], Sequence[float]] | tuple[float, float]:
"""Transforms co-ordinates from one :class:`~rasterio.crs.CRS` to another.
Args:
@@ -641,7 +640,7 @@ def transform_coordinates(
y = check_len(y, x)
# Transform co-ordinates from source to new CRS and returns a tuple of (x, y)
- co_ordinates: Tuple[Sequence[float], Sequence[float]] = rt.warp.transform( # type: ignore
+ co_ordinates: tuple[Sequence[float], Sequence[float]] = rt.warp.transform( # type: ignore
src_crs=src_crs, dst_crs=new_crs, xs=x, ys=y
)
@@ -715,9 +714,9 @@ def deg_to_dms(deg: float, axis: str = "lat") -> str:
def dec2deg(
- dec_co: Union[Sequence[float], NDArray[Shape["*"], Float]], # noqa: F722
+ dec_co: Sequence[float] | NDArray[Shape["*"], Float], # noqa: F722
axis: str = "lat",
-) -> List[str]:
+) -> list[str]:
"""Wrapper for :func:`deg_to_dms`.
Args:
@@ -727,14 +726,14 @@ def dec2deg(
Returns:
list[str]: List of formatted strings in degrees, minutes and seconds.
"""
- deg_co: List[str] = []
+ deg_co: list[str] = []
for co in dec_co:
deg_co.append(deg_to_dms(co, axis=axis))
return deg_co
-def get_centre_loc(bounds: BoundingBox) -> Tuple[float, float]:
+def get_centre_loc(bounds: BoundingBox) -> tuple[float, float]:
"""Gets the centre co-ordinates of the parsed bounding box.
Args:
@@ -775,7 +774,7 @@ def get_centre_pixel_value(x: Tensor) -> Any:
raise ValueError()
-def lat_lon_to_loc(lat: Union[str, float], lon: Union[str, float]) -> str:
+def lat_lon_to_loc(lat: str | float, lon: str | float) -> str:
"""Takes a latitude - longitude co-ordinate and returns a string of the semantic location.
Args:
@@ -804,7 +803,7 @@ def lat_lon_to_loc(lat: Union[str, float], lon: Union[str, float]) -> str:
location = query.raw["properties"] # type: ignore
# Attempts to add possible fields to address of the location. Not all will be present for every query.
- locs: List[str] = []
+ locs: list[str] = []
try:
locs.append(location["city"])
except KeyError:
@@ -889,8 +888,8 @@ def mask_to_ohe(mask: LongTensor, n_classes: int) -> LongTensor:
def class_weighting(
- class_dist: List[Tuple[int, int]], normalise: bool = False
-) -> Dict[int, float]:
+ class_dist: list[tuple[int, int]], normalise: bool = False
+) -> dict[int, float]:
"""Constructs weights for each class defined by the distribution provided.
Note:
@@ -912,7 +911,7 @@ def class_weighting(
n_samples += mode[1]
# Constructs class weights. Each weight is 1 / number of samples for that class.
- class_weights: Dict[int, float] = {}
+ class_weights: dict[int, float] = {}
if normalise:
for mode in class_dist:
class_weights[mode[0]] = n_samples / mode[1]
@@ -924,8 +923,8 @@ def class_weighting(
def find_empty_classes(
- class_dist: List[Tuple[int, int]], class_names: Dict[int, str]
-) -> List[int]:
+ class_dist: list[tuple[int, int]], class_names: dict[int, str]
+) -> list[int]:
"""Finds which classes defined by config files are not present in the dataset.
Args:
@@ -936,7 +935,7 @@ def find_empty_classes(
Returns:
list[int]: List of classes not found in ``class_dist`` and are thus empty/ not present in dataset.
"""
- empty: List[int] = []
+ empty: list[int] = []
# Checks which classes are not present in class_dist
for label in class_names.keys():
@@ -948,10 +947,10 @@ def find_empty_classes(
def eliminate_classes(
- empty_classes: Union[List[int], Tuple[int, ...], NDArray[Any, Int]],
- old_classes: Dict[int, str],
- old_cmap: Optional[Dict[int, str]] = None,
-) -> Tuple[Dict[int, str], Dict[int, int], Optional[Dict[int, str]]]:
+ empty_classes: list[int] | tuple[int, ...] | NDArray[Any, Int],
+ old_classes: dict[int, str],
+ old_cmap: Optional[dict[int, str]] = None,
+) -> tuple[dict[int, str], dict[int, int], Optional[dict[int, str]]]:
"""Eliminates empty classes from the class text label and class colour dictionaries and re-normalise.
This should ensure that the remaining list of classes is still a linearly spaced list of numbers.
@@ -1020,7 +1019,7 @@ def eliminate_classes(
return reordered_classes, conversion, reordered_colours
-def class_transform(label: int, matrix: Dict[int, int]) -> int:
+def class_transform(label: int, matrix: dict[int, int]) -> int:
"""Transforms labels from one schema to another mapped by a supplied dictionary.
Args:
@@ -1035,20 +1034,20 @@ def class_transform(label: int, matrix: Dict[int, int]) -> int:
@overload
def mask_transform( # type: ignore[overload-overlap]
- array: NDArray[Any, Int], matrix: Dict[int, int]
+ array: NDArray[Any, Int], matrix: dict[int, int]
) -> NDArray[Any, Int]: ... # pragma: no cover
@overload
def mask_transform(
- array: LongTensor, matrix: Dict[int, int]
+ array: LongTensor, matrix: dict[int, int]
) -> LongTensor: ... # pragma: no cover
def mask_transform(
- array: Union[NDArray[Any, Int], LongTensor],
- matrix: Dict[int, int],
-) -> Union[NDArray[Any, Int], LongTensor]:
+ array: NDArray[Any, Int] | LongTensor,
+ matrix: dict[int, int],
+) -> NDArray[Any, Int] | LongTensor:
"""Transforms all labels of an N-dimensional array from one schema to another mapped by a supplied dictionary.
Args:
@@ -1065,11 +1064,11 @@ def mask_transform(
def check_test_empty(
- pred: Union[Sequence[int], NDArray[Any, Int]],
- labels: Union[Sequence[int], NDArray[Any, Int]],
- class_labels: Optional[Dict[int, str]] = None,
+ pred: Sequence[int] | NDArray[Any, Int],
+ labels: Sequence[int] | NDArray[Any, Int],
+ class_labels: Optional[dict[int, str]] = None,
p_dist: bool = True,
-) -> Tuple[NDArray[Any, Int], NDArray[Any, Int], Dict[int, str]]:
+) -> tuple[NDArray[Any, Int], NDArray[Any, Int], dict[int, str]]:
"""Checks if any of the classes in the dataset were not present in both the predictions and ground truth labels.
Returns corrected and re-ordered predictions, labels and class labels.
@@ -1120,8 +1119,8 @@ def check_test_empty(
def class_dist_transform(
- class_dist: List[Tuple[int, int]], matrix: Dict[int, int]
-) -> List[Tuple[int, int]]:
+ class_dist: list[tuple[int, int]], matrix: dict[int, int]
+) -> list[tuple[int, int]]:
"""Transforms the class distribution from an old schema to a new one.
Args:
@@ -1132,14 +1131,14 @@ def class_dist_transform(
Returns:
list[tuple[int, int]]: Class distribution updated to new labels.
"""
- new_class_dist: List[Tuple[int, int]] = []
+ new_class_dist: list[tuple[int, int]] = []
for mode in class_dist:
new_class_dist.append((class_transform(mode[0], matrix), mode[1]))
return new_class_dist
-def class_frac(patch: pd.Series) -> Dict[Any, Any]:
+def class_frac(patch: pd.Series) -> dict[Any, Any]:
"""Computes the fractional sizes of the classes of the given :term:`patch` and returns a
:class:`dict` of the results.
@@ -1150,7 +1149,7 @@ def class_frac(patch: pd.Series) -> Dict[Any, Any]:
Mapping: Dictionary-like object with keys as class numbers and associated values
of fractional size of class plus a key-value pair for the :term:`patch` ID.
"""
- new_columns: Dict[Any, Any] = dict(patch.to_dict())
+ new_columns: dict[Any, Any] = dict(patch.to_dict())
counts = 0
for mode in patch["MODES"]:
counts += mode[1]
@@ -1173,7 +1172,7 @@ def cloud_cover(scene: NDArray[Any, Any]) -> Any:
return np.sum(scene) / scene.size
-def threshold_scene_select(df: DataFrame, thres: float = 0.3) -> List[str]:
+def threshold_scene_select(df: DataFrame, thres: float = 0.3) -> list[str]:
"""Selects all scenes in a :term:`patch` with a cloud cover less than the threshold provided.
Args:
@@ -1191,9 +1190,9 @@ def threshold_scene_select(df: DataFrame, thres: float = 0.3) -> List[str]:
def find_best_of(
patch_id: str,
manifest: DataFrame,
- selector: Callable[[DataFrame], List[str]] = threshold_scene_select,
+ selector: Callable[[DataFrame], list[str]] = threshold_scene_select,
**kwargs,
-) -> List[str]:
+) -> list[str]:
"""Finds the scenes sorted by cloud cover using selector function supplied.
Args:
@@ -1234,9 +1233,9 @@ def timestamp_now(fmt: str = "%d-%m-%Y_%H%M") -> str:
def find_modes(
labels: Iterable[int],
plot: bool = False,
- classes: Optional[Dict[int, str]] = None,
- cmap_dict: Optional[Dict[int, str]] = None,
-) -> List[Tuple[int, int]]:
+ classes: Optional[dict[int, str]] = None,
+ cmap_dict: Optional[dict[int, str]] = None,
+) -> list[tuple[int, int]]:
"""Finds the modal distribution of the classes within the labels provided.
Can plot the results as a pie chart if ``plot=True``.
@@ -1249,7 +1248,7 @@ def find_modes(
list[tuple[int, int]]: Modal distribution of classes in input in order of most common class.
"""
# Finds the distribution of the classes within the data
- class_dist: List[Tuple[int, int]] = Counter(
+ class_dist: list[tuple[int, int]] = Counter(
np.array(labels).flatten()
).most_common()
@@ -1264,10 +1263,10 @@ def find_modes(
def modes_from_manifest(
manifest: DataFrame,
- classes: Dict[int, str],
+ classes: dict[int, str],
plot: bool = False,
- cmap_dict: Optional[Dict[int, str]] = None,
-) -> List[Tuple[int, int]]:
+ cmap_dict: Optional[dict[int, str]] = None,
+) -> list[tuple[int, int]]:
"""Uses the dataset manifest to calculate the fractional size of the classes.
Args:
@@ -1295,7 +1294,7 @@ def count_samples(cls):
class_counter[classification] = count
except KeyError:
continue
- class_dist: List[Tuple[int, int]] = class_counter.most_common()
+ class_dist: list[tuple[int, int]] = class_counter.most_common()
if plot:
# Plots a pie chart of the distribution of the classes within the given list of patches
@@ -1325,7 +1324,7 @@ def func_by_str(module_path: str, func: str) -> Callable[..., Any]:
return func
-def check_len(param: Any, comparator: Any) -> Union[Any, Sequence[Any]]:
+def check_len(param: Any, comparator: Any) -> Any | Sequence[Any]:
"""Checks the length of one object against a comparator object.
Args:
@@ -1383,8 +1382,8 @@ def calc_grad(model: Module) -> Optional[float]:
def print_class_dist(
- class_dist: List[Tuple[int, int]],
- class_labels: Optional[Dict[int, str]] = None,
+ class_dist: list[tuple[int, int]],
+ class_labels: Optional[dict[int, str]] = None,
) -> None:
"""Prints the supplied ``class_dist`` in a pretty table format using :mod:`tabulate`.
@@ -1435,7 +1434,7 @@ def calc_frac(count: float, total: float) -> str:
def batch_flatten(
- x: Union[NDArray[Any, Any], ArrayLike]
+ x: NDArray[Any, Any] | ArrayLike
) -> NDArray[Shape["*"], Any]: # noqa: F722
"""Flattens the supplied array with :func:`numpy`.
@@ -1455,9 +1454,9 @@ def batch_flatten(
def make_classification_report(
- pred: Union[Sequence[int], NDArray[Any, Int]],
- labels: Union[Sequence[int], NDArray[Any, Int]],
- class_labels: Optional[Dict[int, str]] = None,
+ pred: Sequence[int] | NDArray[Any, Int],
+ labels: Sequence[int] | NDArray[Any, Int],
+ class_labels: Optional[dict[int, str]] = None,
print_cr: bool = True,
p_dist: bool = False,
) -> DataFrame:
@@ -1564,9 +1563,9 @@ def calc_contrastive_acc(z: Tensor) -> Tensor:
def run_tensorboard(
exp_name: str,
- path: Union[str, List[str], Tuple[str, ...], Path] = "",
+ path: str | list[str] | tuple[str, ...] | Path = "",
env_name: str = "env",
- host_num: Union[str, int] = 6006,
+ host_num: str | int = 6006,
_testing: bool = False,
) -> Optional[int]:
"""Runs the :mod:`TensorBoard` logs and hosts on a local webpage.
@@ -1627,11 +1626,11 @@ def run_tensorboard(
def compute_roc_curves(
probs: NDArray[Any, Float],
- labels: Union[Sequence[int], NDArray[Any, Int]],
- class_labels: List[int],
+ labels: Sequence[int] | NDArray[Any, Int],
+ class_labels: list[int],
micro: bool = True,
macro: bool = True,
-) -> Tuple[Dict[Any, float], Dict[Any, float], Dict[Any, float]]:
+) -> tuple[dict[Any, float], dict[Any, float], dict[Any, float]]:
"""Computes the false-positive rate, true-positive rate and AUCs for each class using a one-vs-all approach.
The micro and macro averages are for each of these variables is also computed.
@@ -1660,13 +1659,13 @@ def compute_roc_curves(
# Dicts to hold the false-positive rate, true-positive rate and Area Under Curves
# of each class and micro, macro averages.
- fpr: Dict[Any, Any] = {}
- tpr: Dict[Any, Any] = {}
- roc_auc: Dict[Any, Any] = {}
+ fpr: dict[Any, Any] = {}
+ tpr: dict[Any, Any] = {}
+ roc_auc: dict[Any, Any] = {}
# Holds a list of the classes that were in the targets supplied to the model.
# Avoids warnings about empty targets from sklearn!
- populated_classes: List[int] = []
+ populated_classes: list[int] = []
print("Computing class ROC curves")
@@ -1844,8 +1843,8 @@ def calc_norm_euc_dist(a: Tensor, b: Tensor) -> Tensor:
def fallback_params(
key: str,
- params_a: Dict[str, Any],
- params_b: Dict[str, Any],
+ params_a: dict[str, Any],
+ params_b: dict[str, Any],
fallback: Optional[Any] = None,
) -> Any:
"""Search for a value associated with ``key`` from
@@ -1868,9 +1867,9 @@ def fallback_params(
def compile_dataset_paths(
- data_dir: Union[Path, str],
- in_paths: Union[List[Union[Path, str]], Union[Path, str]],
-) -> List[str]:
+ data_dir: Path | str,
+ in_paths: list[Path | str] | Path | str,
+) -> list[str]:
"""Ensures that a list of paths is returned with the data directory prepended, even if a single string is supplied
Args:
@@ -1885,15 +1884,15 @@ def compile_dataset_paths(
else:
out_paths = [universal_path(data_dir) / in_paths]
- compiled_paths = []
+ # Check if each path exists. If not, make the path.
for path in out_paths:
- compiled_paths.extend(glob.glob(str(path), recursive=True))
+ path.mkdir(parents=True, exist_ok=True)
- # For each path, get the absolute path, convert to string and return.
- return [str(Path(path).absolute()) for path in compiled_paths]
+ # For each path, get the absolute path, make the path if it does not exist then convert to string and return.
+ return [str(Path(path).absolute()) for path in out_paths]
-def make_hash(obj: Dict[Any, Any]) -> str:
+def make_hash(obj: dict[Any, Any]) -> str:
"""Make a deterministic MD5 hash of a serialisable object using JSON.
Source: https://death.andgravity.com/stable-hashing
@@ -1958,3 +1957,61 @@ def closest_factors(n):
best_pair = (best_pair[1], best_pair[0])
return best_pair
+
+
+def get_sample_index(sample: dict[str, Any]) -> Optional[Any]:
+ """Get the index for a sample with unkown index key.
+
+ Will try:
+ * ``bbox`` (:mod:`torchgeo` < 0.6.0) for :class:`~torchgeo.datasets.GeoDataset`
+ * ``bounds`` (:mod:`torchgeo` >= 0.6.0) for :class:`~torchgeo.datasets.GeoDataset`
+ * ``id`` for :class:`~torchgeo.datasets.NonGeoDataset`
+
+ Args:
+ sample (dict[str, ~typing.Any]): Sample dictionary to find index in.
+
+ Returns:
+ None | ~typing.Any: Sample index or ``None`` if not found.
+
+ .. versionadded:: 0.28
+ """
+ if "bbox" in sample:
+ index = sample["bbox"]
+ elif "bounds" in sample:
+ index = sample["bounds"]
+ elif "id" in sample:
+ index = sample["id"]
+ else:
+ index = None
+
+ return index
+
+
+def compare_models(model_1: Module, model_2: Module) -> None:
+ """Compare two models weight-by-weight.
+
+ Source: https://discuss.pytorch.org/t/check-if-models-have-same-weights/4351/5
+
+ Args:
+ model_1 (torch.nn.Module): First model.
+ model_2 (torch.nn.Module): Second model.
+
+ Raises:
+ AssertionError: If models do not match exactly.
+
+ .. versionadded:: 0.28
+ """
+ models_differ = 0
+ for key_item_1, key_item_2 in zip(
+ model_1.state_dict().items(), model_2.state_dict().items()
+ ):
+ if torch.equal(key_item_1[1], key_item_2[1]):
+ pass
+ else:
+ models_differ += 1
+ if key_item_1[0] == key_item_2[0]:
+ print("Mismtach found at", key_item_1[0])
+ else:
+ raise AssertionError
+ if models_differ == 0:
+ print("Models match perfectly! :)")
diff --git a/minerva/utils/visutils.py b/minerva/utils/visutils.py
index 8201c7fcd..dd36c6153 100644
--- a/minerva/utils/visutils.py
+++ b/minerva/utils/visutils.py
@@ -64,7 +64,7 @@
import os
import random
from pathlib import Path
-from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
+from typing import Any, Optional, Sequence
import imageio
import matplotlib as mlp
@@ -126,7 +126,7 @@ def de_interlace(x: Sequence[Any], f: int) -> NDArray[Any, Any]:
Returns:
~numpy.ndarray[~typing.Any]: De-interlaced array. Each source array is now sequentially connected.
"""
- new_x: List[NDArray[Any, Any]] = []
+ new_x: list[NDArray[Any, Any]] = []
for i in range(f):
x_i = []
for j in np.arange(start=i, stop=len(x), step=f):
@@ -137,12 +137,12 @@ def de_interlace(x: Sequence[Any], f: int) -> NDArray[Any, Any]:
def dec_extent_to_deg(
- shape: Tuple[int, int],
+ shape: tuple[int, int],
bounds: BoundingBox,
src_crs: CRS,
new_crs: CRS = WGS84,
spacing: int = 32,
-) -> Tuple[Tuple[int, int, int, int], NDArray[Any, Float], NDArray[Any, Float]]:
+) -> tuple[tuple[int, int, int, int], NDArray[Any, Float], NDArray[Any, Float]]:
"""Gets the extent of the image with ``shape`` and with ``bounds`` in latitude, longitude of system ``new_crs``.
Args:
@@ -194,7 +194,7 @@ def dec_extent_to_deg(
def get_mlp_cmap(
- cmap_style: Optional[Union[Colormap, str]] = None, n_classes: Optional[int] = None
+ cmap_style: Optional[Colormap | str] = None, n_classes: Optional[int] = None
) -> Optional[Colormap]:
"""Creates a cmap from query
@@ -225,8 +225,8 @@ def get_mlp_cmap(
def discrete_heatmap(
data: NDArray[Shape["*, *"], Int], # noqa: F722
- classes: Union[List[str], Tuple[str, ...]],
- cmap_style: Optional[Union[str, ListedColormap]] = None,
+ classes: list[str] | tuple[str, ...],
+ cmap_style: Optional[str | ListedColormap] = None,
block_size: int = 32,
) -> None:
"""Plots a heatmap with a discrete colour bar. Designed for Radiant Earth MLHub 256x256 SENTINEL images.
@@ -329,16 +329,16 @@ def labelled_rgb_image(
mask: NDArray[Shape["*, *"], Int], # noqa: F722
bounds: BoundingBox,
src_crs: CRS,
- path: Union[str, Path],
+ path: str | Path,
name: str,
- classes: Union[List[str], Tuple[str, ...]],
- cmap_style: Optional[Union[str, ListedColormap]] = None,
+ classes: list[str] | tuple[str, ...],
+ cmap_style: Optional[str | ListedColormap] = None,
new_crs: Optional[CRS] = WGS84,
block_size: int = 32,
alpha: float = 0.5,
show: bool = True,
save: bool = True,
- figdim: Tuple[Union[int, float], Union[int, float]] = (8.02, 10.32),
+ figdim: tuple[int | float, int | float] = (8.02, 10.32),
) -> Path:
"""Produces a layered image of an RGB image, and it's associated label mask heat map alpha blended on top.
@@ -363,7 +363,7 @@ def labelled_rgb_image(
str: Path to figure save location.
"""
# Checks that the mask and image shapes will align.
- mask_shape: Tuple[int, int] = mask.shape # type: ignore[assignment]
+ mask_shape: tuple[int, int] = mask.shape # type: ignore[assignment]
assert mask_shape == image.shape[:2]
assert new_crs is not None
@@ -475,14 +475,14 @@ def make_gif(
masks: NDArray[Shape["*, *, *"], Any], # noqa: F722
bounds: BoundingBox,
src_crs: CRS,
- classes: Union[List[str], Tuple[str, ...]],
+ classes: list[str] | tuple[str, ...],
gif_name: str,
- path: Union[str, Path],
- cmap_style: Optional[Union[str, ListedColormap]] = None,
+ path: str | Path,
+ cmap_style: Optional[str | ListedColormap] = None,
fps: float = 1.0,
new_crs: Optional[CRS] = WGS84,
alpha: float = 0.5,
- figdim: Tuple[Union[int, float], Union[int, float]] = (8.02, 10.32),
+ figdim: tuple[int | float, int | float] = (8.02, 10.32),
) -> None:
"""Wrapper to :func:`labelled_rgb_image` to make a GIF for a patch out of scenes.
@@ -548,19 +548,19 @@ def make_gif(
def prediction_plot(
- sample: Dict[str, Any],
+ sample: dict[str, Any],
sample_id: str,
- classes: Dict[int, str],
+ classes: dict[int, str],
src_crs: CRS,
new_crs: CRS = WGS84,
path: str = "",
- cmap_style: Optional[Union[str, ListedColormap]] = None,
+ cmap_style: Optional[str | ListedColormap] = None,
exp_id: Optional[str] = None,
- fig_dim: Optional[Tuple[Union[int, float], Union[int, float]]] = None,
+ fig_dim: Optional[tuple[int | float, int | float]] = None,
block_size: int = 32,
show: bool = True,
save: bool = True,
- fn_prefix: Optional[Union[str, Path]] = None,
+ fn_prefix: Optional[str | Path] = None,
) -> None:
"""
Produces a figure containing subplots of the predicted label mask, the ground truth label mask
@@ -704,21 +704,21 @@ def prediction_plot(
def seg_plot(
- z: Union[List[int], NDArray[Any, Any]],
- y: Union[List[int], NDArray[Any, Any]],
- ids: List[str],
- index: Union[Sequence[Any], NDArray[Any, Any]],
- data_dir: Union[Path, str],
- dataset_params: Dict[str, Any],
- classes: Dict[int, str],
- colours: Dict[int, str],
- fn_prefix: Optional[Union[str, Path]],
+ z: list[int] | NDArray[Any, Any],
+ y: list[int] | NDArray[Any, Any],
+ ids: list[str],
+ index: Sequence[Any] | NDArray[Any, Any],
+ data_dir: Path | str,
+ dataset_params: dict[str, Any],
+ classes: dict[int, str],
+ colours: dict[int, str],
+ fn_prefix: Optional[str | Path],
frac: float = 0.05,
- fig_dim: Optional[Tuple[Union[int, float], Union[int, float]]] = (9.3, 10.5),
+ fig_dim: Optional[tuple[int | float, int | float]] = (9.3, 10.5),
model_name: str = "",
path: str = "",
max_pixel_value: int = 255,
- cache_dir: Optional[Union[str, Path]] = None,
+ cache_dir: Optional[str | Path] = None,
) -> None:
"""Custom function for pre-processing the outputs from image segmentation testing for data visualisation.
@@ -840,10 +840,10 @@ def seg_plot(
def plot_subpopulations(
- class_dist: List[Tuple[int, int]],
- class_names: Optional[Dict[int, str]] = None,
- cmap_dict: Optional[Dict[int, str]] = None,
- filename: Optional[Union[str, Path]] = None,
+ class_dist: list[tuple[int, int]],
+ class_names: Optional[dict[int, str]] = None,
+ cmap_dict: Optional[dict[int, str]] = None,
+ filename: Optional[str | Path] = None,
save: bool = True,
show: bool = False,
) -> None:
@@ -867,7 +867,7 @@ def plot_subpopulations(
counts = []
# List to hold colours of classes in the correct order.
- colours: Optional[List[str]] = []
+ colours: Optional[list[str]] = []
if class_names is None:
class_numbers = [x[0] for x in class_dist]
@@ -920,8 +920,8 @@ def plot_subpopulations(
def plot_history(
- metrics: Dict[str, Any],
- filename: Optional[Union[str, Path]] = None,
+ metrics: dict[str, Any],
+ filename: Optional[str | Path] = None,
save: bool = True,
show: bool = False,
) -> None:
@@ -972,12 +972,12 @@ def plot_history(
def make_confusion_matrix(
- pred: Union[List[int], NDArray[Any, Int]],
- labels: Union[List[int], NDArray[Any, Int]],
- classes: Dict[int, str],
- filename: Optional[Union[str, Path]] = None,
+ pred: list[int] | NDArray[Any, Int],
+ labels: list[int] | NDArray[Any, Int],
+ classes: dict[int, str],
+ filename: Optional[str | Path] = None,
cmap_style: str = "Blues",
- figsize: Tuple[int, int] = (2, 2),
+ figsize: tuple[int, int] = (2, 2),
show: bool = True,
save: bool = False,
) -> None:
@@ -1029,12 +1029,12 @@ def make_confusion_matrix(
def make_multilabel_confusion_matrix(
- preds: Union[List[int], NDArray[Any, Int]],
- labels: Union[List[int], NDArray[Any, Int]],
- classes: Dict[int, str],
- filename: Optional[Union[str, Path]] = None,
+ preds: list[int] | NDArray[Any, Int],
+ labels: list[int] | NDArray[Any, Int],
+ classes: dict[int, str],
+ filename: Optional[str | Path] = None,
cmap_style: str = "Blues",
- figsize: Tuple[int, int] = (2, 2),
+ figsize: tuple[int, int] = (2, 2),
show: bool = True,
save: bool = False,
) -> None:
@@ -1105,12 +1105,12 @@ def make_multilabel_confusion_matrix(
def make_roc_curves(
probs: ArrayLike,
- labels: Union[Sequence[int], NDArray[Any, Int]],
- class_names: Dict[int, str],
- colours: Dict[int, str],
+ labels: Sequence[int] | NDArray[Any, Int],
+ class_names: dict[int, str],
+ colours: dict[int, str],
micro: bool = True,
macro: bool = True,
- filename: Optional[Union[str, Path]] = None,
+ filename: Optional[str | Path] = None,
show: bool = False,
save: bool = True,
) -> None:
@@ -1206,15 +1206,15 @@ def make_roc_curves(
def plot_embedding(
embeddings: Any,
- index: Union[Sequence[BoundingBox], Sequence[int]],
- data_dir: Union[Path, str],
- dataset_params: Dict[str, Any],
+ index: Sequence[BoundingBox] | Sequence[int],
+ data_dir: Path | str,
+ dataset_params: dict[str, Any],
title: Optional[str] = None,
show: bool = False,
save: bool = True,
- filename: Optional[Union[Path, str]] = None,
+ filename: Optional[Path | str] = None,
max_pixel_value: int = 255,
- cache_dir: Optional[Union[Path, str]] = None,
+ cache_dir: Optional[Path | str] = None,
) -> None:
"""Using TSNE Clustering, visualises the embeddings from a model.
@@ -1316,8 +1316,8 @@ def plot_embedding(
def format_plot_names(
- model_name: str, timestamp: str, path: Union[Sequence[str], str, Path]
-) -> Dict[str, Path]:
+ model_name: str, timestamp: str, path: Sequence[str] | str | Path
+) -> dict[str, Path]:
"""Creates unique filenames of plots in a standardised format.
Args:
@@ -1357,23 +1357,23 @@ def standard_format(plot_type: str, *sub_dir) -> str:
def plot_results(
- plots: Dict[str, bool],
- z: Optional[Union[List[int], NDArray[Any, Int]]] = None,
- y: Optional[Union[List[int], NDArray[Any, Int]]] = None,
- metrics: Optional[Dict[str, Any]] = None,
- ids: Optional[List[str]] = None,
+ plots: dict[str, bool],
+ z: Optional[list[int] | NDArray[Any, Int]] = None,
+ y: Optional[list[int] | NDArray[Any, Int]] = None,
+ metrics: Optional[dict[str, Any]] = None,
+ ids: Optional[list[str]] = None,
index: Optional[NDArray[Any, Any]] = None,
- probs: Optional[Union[List[float], NDArray[Any, Float]]] = None,
+ probs: Optional[list[float] | NDArray[Any, Float]] = None,
embeddings: Optional[NDArray[Any, Any]] = None,
- class_names: Optional[Dict[int, str]] = None,
- colours: Optional[Dict[int, str]] = None,
+ class_names: Optional[dict[int, str]] = None,
+ colours: Optional[dict[int, str]] = None,
save: bool = True,
show: bool = False,
model_name: Optional[str] = None,
timestamp: Optional[str] = None,
- results_dir: Optional[Union[Sequence[str], str, Path]] = None,
- task_cfg: Optional[Dict[str, Any]] = None,
- global_cfg: Optional[Dict[str, Any]] = None,
+ results_dir: Optional[Sequence[str] | str | Path] = None,
+ task_cfg: Optional[dict[str, Any]] = None,
+ global_cfg: Optional[dict[str, Any]] = None,
) -> None:
"""Orchestrates the creation of various plots from the results of a model fitting.
diff --git a/pyproject.toml b/pyproject.toml
index 7c84e03cb..b3fb51d41 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -16,7 +16,6 @@ classifiers = [
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.10",
- "Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3 :: Only",
"License :: OSI Approved :: MIT License",
"Development Status :: 4 - Beta",
@@ -38,6 +37,7 @@ dependencies = [
"numba>=0.57.0; python_version>='3.11'",
"numpy",
"overload",
+ "opencv-python",
"pandas",
"psutil",
"pyyaml",
diff --git a/requirements/requirements.txt b/requirements/requirements.txt
index d4a7d2bbb..e1de41ede 100644
--- a/requirements/requirements.txt
+++ b/requirements/requirements.txt
@@ -1,50 +1,50 @@
-argcomplete==3.4.0
+argcomplete==3.5.0
fastapi>=0.101.0
fiona>=1.9.1
geopy==2.4.1
GitPython>=3.1.35
hydra-core==1.3.2
-imagecodecs==2024.1.1
-imageio==2.34.2
+imagecodecs==2024.6.1
+imageio==2.35.1
inputimeout==1.0.4
-ipykernel==6.29.4
-kornia==0.7.2
-lightly==1.5.7
-lmdb==1.4.1
-matplotlib==3.9.0
-mlflow-skinny==2.14.0
+ipykernel==6.29.5
+kornia==0.7.3
+lightly==1.5.12
+lmdb==1.5.1
+matplotlib==3.9.2
+mlflow-skinny==2.16.0
nptyping==2.5.0
numba>=0.57.0 # not directly required but pinned to ensure Python 3.11 compatibility.
numpy==1.26.4
-onnx==1.16.1
-onnx2torch==1.5.14
-opencv-python==4.9.0.80
+onnx==1.16.2
+onnx2torch==1.5.15
+opencv-python-headless==4.10.0.84
overload==1.1
pandas==2.2.2
-patch-ng==1.17.4
+patch-ng==1.18.0
protobuf>=3.19.5 # not directly required, pinned by Snyk to avoid a vulnerability.
psutil==6.0.0
-pyyaml==6.0.1
+pyyaml==6.0.2
rasterio>=1.3.6
requests==2.32.3
-scikit-learn==1.5.0
-segmentation-models-pytorch==0.3.3
-setuptools==70.1.0
+scikit-learn==1.5.1
+segmentation-models-pytorch==0.3.4
+setuptools==74.0.0
starlette==0.37.2 # not directly required, pinned by Dependabot to avoid a vulnerability.
tabulate==0.9.0
-tensorflow==2.16.2
+tensorflow==2.17.0
tifffile==2024.1.30
-timm==0.9.2
-torch==2.3.1
+timm==0.9.7
+torch==2.4.0
torcheval==0.0.7
-torchgeo==0.5.2
+torchgeo==0.6.0
torchinfo==1.8.0
-torchvision==0.18.1
+torchvision==0.19.0
tornado>=6.3.3
-tqdm==4.66.4
-types-PyYAML==6.0.12.20240311
-types-requests==2.32.0.20240602
+tqdm==4.66.5
+types-PyYAML==6.0.12.20240808
+types-requests==2.32.0.20240712
types-tabulate==0.9.0.20240106
-wandb==0.17.4
+wandb==0.17.8
Werkzeug>=2.2.3 # Patches a potential security vulnerability.
wheel>=0.38.0 # not directly required, pinned by Snyk to avoid a vulnerability.
diff --git a/requirements/requirements_dev.txt b/requirements/requirements_dev.txt
index 65a266767..e9e00a78f 100644
--- a/requirements/requirements_dev.txt
+++ b/requirements/requirements_dev.txt
@@ -1,63 +1,63 @@
-argcomplete==3.4.0
+argcomplete==3.5.0
certifi>=2022.12.7 # not directly required, pinned by Snyk to avoid a vulnerability.
fastapi>=0.101.0
fiona>=1.9.1
-flake8==7.1.0
+flake8==7.1.1
geopy==2.4.1
GitPython>=3.1.35
hydra-core==1.3.2
-imagecodecs==2024.1.1
-imageio==2.34.2
+imagecodecs==2024.6.1
+imageio==2.35.1
inputimeout==1.0.4
internet-sabotage3==0.1.6
-ipykernel==6.29.4
-kornia==0.7.2
-lightly==1.5.7
-lmdb==1.4.1
-matplotlib==3.9.0
-mlflow-skinny==2.14.0
-mypy==1.10.1
-myst_parser==3.0.1
+ipykernel==6.29.5
+kornia==0.7.3
+lightly==1.5.12
+lmdb==1.5.1
+matplotlib==3.9.2
+mlflow-skinny==2.16.0
+mypy==1.11.2
+myst_parser==4.0.0
nptyping==2.5.0
numba>=0.57.0 # not directly required but pinned to ensure Python 3.11 compatibility.
numpy==1.26.4
-onnx==1.16.1
-onnx2torch==1.5.14
-opencv-python==4.9.0.80
+onnx==1.16.2
+onnx2torch==1.5.15
+opencv-python-headless==4.10.0.84
overload==1.1
pandas==2.2.2
-patch-ng==1.17.4
-pre-commit==3.7.1
+patch-ng==1.18.0
+pre-commit==3.8.0
protobuf>=3.19.5 # not directly required, pinned by Snyk to avoid a vulnerability.
psutil==6.0.0
pygments>=2.7.4 # not directly required, pinned by Snyk to avoid a vulnerability.
pytest==7.4.4
pytest-cov==5.0.0
pytest-lazy-fixture==0.6.3
-pyyaml==6.0.1
+pyyaml==6.0.2
rasterio>=1.3.6
requests==2.32.3
-scikit-learn==1.5.0
-segmentation-models-pytorch==0.3.3
-setuptools==70.1.0
-sphinx==7.3.7
+scikit-learn==1.5.1
+segmentation-models-pytorch==0.3.4
+setuptools==72.1.0
+sphinx==7.4.4
sphinx-rtd-theme==2.0.0
starlette>=0.25.0 # not directly required, pinned by Dependabot to avoid a vulnerability.
tabulate==0.9.0
-tensorflow==2.16.2
+tensorflow==2.17.0
tifffile==2024.1.30
-timm==0.9.2
-torch==2.3.1
+timm==0.9.7
+torch==2.4.0
torcheval==0.0.7
-torchgeo==0.5.2
+torchgeo==0.6.0
torchinfo==1.8.0
-torchvision==0.18.1
+torchvision==0.19.0
tornado>=6.3.3
-tox==4.15.1
-tqdm==4.66.4
-types-PyYAML==6.0.12.20240311
-types-requests==2.32.0.20240602
+tox==4.18.0
+tqdm==4.66.5
+types-PyYAML==6.0.12.20240808
+types-requests==2.32.0.20240712
types-tabulate==0.9.0.20240106
-wandb==0.17.4
+wandb==0.17.8
Werkzeug>=2.2.3 # Patches a potential security vulnerability.
wheel>=0.38.0 # not directly required, pinned by Snyk to avoid a vulnerability.
diff --git a/scripts/MinervaExp.py b/scripts/MinervaExp.py
index fc40d93f5..f8ff132a4 100644
--- a/scripts/MinervaExp.py
+++ b/scripts/MinervaExp.py
@@ -39,7 +39,7 @@
# =====================================================================================================================
# IMPORTS
# =====================================================================================================================
-from typing import Optional, Union
+from typing import Optional
import hydra
from omegaconf import DictConfig
@@ -58,9 +58,7 @@
config_name=DEFAULT_CONFIG_NAME,
)
@runner.distributed_run
-def main(
- gpu: int, wandb_run: Optional[Union[Run, RunDisabled]], cfg: DictConfig
-) -> None:
+def main(gpu: int, wandb_run: Optional[Run | RunDisabled], cfg: DictConfig) -> None:
# Due to the nature of multiprocessing and its interaction with hydra, wandb and SLURM,
# the actual code excuted in the job is contained in `run_trainer` in `runner`.
diff --git a/scripts/MinervaPipe.py b/scripts/MinervaPipe.py
index 5a4ced905..6a95f6b91 100644
--- a/scripts/MinervaPipe.py
+++ b/scripts/MinervaPipe.py
@@ -36,7 +36,7 @@
import shlex
import subprocess
import sys
-from typing import Any, Dict
+from typing import Any
import yaml
@@ -46,7 +46,7 @@
# =====================================================================================================================
def main(config_path: str):
with open(config_path) as f:
- config: Dict[str, Any] = yaml.safe_load(f)
+ config: dict[str, Any] = yaml.safe_load(f)
for key in config.keys():
print(
diff --git a/scripts/RunTensorBoard.py b/scripts/RunTensorBoard.py
index 589b1a7d9..3559429ad 100644
--- a/scripts/RunTensorBoard.py
+++ b/scripts/RunTensorBoard.py
@@ -33,7 +33,7 @@
# IMPORTS
# =====================================================================================================================
import argparse
-from typing import List, Optional, Union
+from typing import Optional
from minerva.utils import utils
@@ -42,7 +42,7 @@
# MAIN
# =====================================================================================================================
def main(
- path: Optional[Union[str, List[str]]] = None,
+ path: Optional[str | list[str]] = None,
env_name: str = "env2",
exp_name: Optional[str] = None,
host_num: int = 6006,
diff --git a/scripts/TorchWeightDownloader.py b/scripts/TorchWeightDownloader.py
index 2d076b7b8..b0f6f2f5c 100644
--- a/scripts/TorchWeightDownloader.py
+++ b/scripts/TorchWeightDownloader.py
@@ -21,7 +21,7 @@
"""Loads :mod:`torch` weights from Torch Hub into cache.
Attributes:
- resnets (List[str]): List of tags for ``pytorch`` resnet weights to download.
+ resnets (list[str]): List of tags for ``pytorch`` resnet weights to download.
"""
# =====================================================================================================================
diff --git a/setup.cfg b/setup.cfg
index a48cea02c..ea926089e 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -13,7 +13,6 @@ classifiers =
Programming Language :: Python :: 3.12
Programming Language :: Python :: 3.11
Programming Language :: Python :: 3.10
- Programming Language :: Python :: 3.9
Programming Language :: Python :: 3 :: Only
License :: OSI Approved :: MIT License
Development Status :: 4 - Beta
@@ -34,6 +33,7 @@ install_requires =
psutil
geopy
overload
+ opencv-python
nptyping
lightly
argcomplete
diff --git a/tests/conftest.py b/tests/conftest.py
index 393afd48a..cbfa31a45 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -41,7 +41,7 @@
import os
import shutil
from pathlib import Path
-from typing import Any, Dict, Generator, Tuple
+from typing import Any, Generator
import hydra
import numpy as np
@@ -199,7 +199,7 @@ def std_batch_size() -> int:
@pytest.fixture
-def std_n_classes(exp_classes: Dict[int, str]) -> int:
+def std_n_classes(exp_classes: dict[int, str]) -> int:
return len(exp_classes)
@@ -214,12 +214,12 @@ def x_entropy_loss() -> nn.CrossEntropyLoss:
@pytest.fixture
-def small_patch_size() -> Tuple[int, int]:
+def small_patch_size() -> tuple[int, int]:
return (32, 32)
@pytest.fixture
-def rgbi_input_size() -> Tuple[int, int, int]:
+def rgbi_input_size() -> tuple[int, int, int]:
return (4, 32, 32)
@@ -230,7 +230,7 @@ def exp_mlp(x_entropy_loss: nn.CrossEntropyLoss) -> MLP:
@pytest.fixture
def exp_cnn(
- x_entropy_loss: nn.CrossEntropyLoss, rgbi_input_size: Tuple[int, int, int]
+ x_entropy_loss: nn.CrossEntropyLoss, rgbi_input_size: tuple[int, int, int]
) -> CNN:
return CNN(x_entropy_loss, rgbi_input_size)
@@ -238,7 +238,7 @@ def exp_cnn(
@pytest.fixture
def exp_fcn(
x_entropy_loss: nn.CrossEntropyLoss,
- rgbi_input_size: Tuple[int, int, int],
+ rgbi_input_size: tuple[int, int, int],
std_n_classes: int,
) -> FCN32ResNet18:
return FCN32ResNet18(x_entropy_loss, rgbi_input_size, std_n_classes)
@@ -246,14 +246,14 @@ def exp_fcn(
@pytest.fixture
def exp_simconv(
- rgbi_input_size: Tuple[int, int, int],
+ rgbi_input_size: tuple[int, int, int],
) -> SimConv:
return SimConv(SegBarlowTwinsLoss(), rgbi_input_size, feature_dim=128)
@pytest.fixture
def random_mask(
- small_patch_size: Tuple[int, int], std_n_classes: int
+ small_patch_size: tuple[int, int], std_n_classes: int
) -> NDArray[Shape["32, 32"], Int]:
mask = np.random.randint(0, std_n_classes - 1, size=small_patch_size)
assert isinstance(mask, np.ndarray)
@@ -262,20 +262,20 @@ def random_mask(
@pytest.fixture
def random_image(
- small_patch_size: Tuple[int, int]
+ small_patch_size: tuple[int, int]
) -> NDArray[Shape["32, 32, 3"], Float]:
return np.random.rand(*small_patch_size, 3)
@pytest.fixture
def random_rgbi_image(
- small_patch_size: Tuple[int, int]
+ small_patch_size: tuple[int, int]
) -> NDArray[Shape["32, 32, 4"], Float]:
return np.random.rand(*small_patch_size, 4)
@pytest.fixture
-def random_rgbi_tensor(rgbi_input_size: Tuple[int, int, int]) -> Tensor:
+def random_rgbi_tensor(rgbi_input_size: tuple[int, int, int]) -> Tensor:
return torch.rand(rgbi_input_size)
@@ -305,27 +305,27 @@ def flipped_rgb_img() -> Tensor:
@pytest.fixture
-def simple_sample(simple_rgb_img: Tensor, simple_mask: LongTensor) -> Dict[str, Tensor]:
+def simple_sample(simple_rgb_img: Tensor, simple_mask: LongTensor) -> dict[str, Tensor]:
return {"image": simple_rgb_img, "mask": simple_mask}
@pytest.fixture
def flipped_simple_sample(
flipped_rgb_img: Tensor, flipped_simple_mask: LongTensor
-) -> Dict[str, Tensor]:
+) -> dict[str, Tensor]:
return {"image": flipped_rgb_img, "mask": flipped_simple_mask}
@pytest.fixture
def random_rgbi_batch(
- rgbi_input_size: Tuple[int, int, int], std_batch_size: int
+ rgbi_input_size: tuple[int, int, int], std_batch_size: int
) -> Tensor:
return torch.rand((std_batch_size, *rgbi_input_size))
@pytest.fixture
def random_tensor_mask(
- std_n_classes: int, small_patch_size: Tuple[int, int]
+ std_n_classes: int, small_patch_size: tuple[int, int]
) -> LongTensor:
mask = torch.randint(0, std_n_classes - 1, size=small_patch_size, dtype=torch.long)
assert isinstance(mask, LongTensor)
@@ -334,7 +334,7 @@ def random_tensor_mask(
@pytest.fixture
def random_mask_batch(
- std_batch_size: int, std_n_classes: int, rgbi_input_size: Tuple[int, int, int]
+ std_batch_size: int, std_n_classes: int, rgbi_input_size: tuple[int, int, int]
) -> LongTensor:
mask = torch.randint(
0,
@@ -368,7 +368,7 @@ def bounds_for_test_img() -> BoundingBox:
@pytest.fixture
-def exp_classes() -> Dict[int, str]:
+def exp_classes() -> dict[int, str]:
return {
0: "No Data",
1: "Water",
@@ -382,7 +382,7 @@ def exp_classes() -> Dict[int, str]:
@pytest.fixture
-def exp_cmap_dict() -> Dict[int, str]:
+def exp_cmap_dict() -> dict[int, str]:
return {
0: "#000000", # Transparent
1: "#00c5ff", # Light Blue
@@ -411,7 +411,7 @@ def flipped_simple_mask() -> LongTensor:
@pytest.fixture
-def example_matrix() -> Dict[int, int]:
+def example_matrix() -> dict[int, int]:
return {1: 1, 3: 3, 4: 2, 5: 0}
@@ -421,35 +421,31 @@ def simple_bbox() -> BoundingBox:
@pytest.fixture
-def exp_dataset_params() -> Dict[str, Any]:
+def exp_dataset_params() -> dict[str, Any]:
return {
"image": {
"transforms": {"AutoNorm": {"length": 12}},
- "module": "minerva.datasets.__testing",
- "name": "TstImgDataset",
+ "_target_": "minerva.datasets.__testing.TstImgDataset",
"paths": "NAIP",
- "params": {"res": 1.0, "crs": 26918},
+ "res": 1.0,
+ "crs": 26918,
},
"mask": {
"transforms": False,
- "module": "minerva.datasets.__testing",
- "name": "TstMaskDataset",
+ "_target_": "minerva.datasets.__testing.TstMaskDataset",
"paths": "Chesapeake7",
- "params": {"res": 1.0},
+ "res": 1.0,
},
}
@pytest.fixture
-def exp_sampler_params(small_patch_size: Tuple[int, int]):
+def exp_sampler_params(small_patch_size: tuple[int, int]):
return {
- "module": "torchgeo.samplers",
- "name": "RandomGeoSampler",
+ "_target_": "torchgeo.samplers.RandomGeoSampler",
"roi": False,
- "params": {
- "size": small_patch_size,
- "length": 120,
- },
+ "size": small_patch_size,
+ "length": 120,
}
@@ -478,7 +474,7 @@ def default_dataset(
@pytest.fixture
def default_image_dataset(
- default_config: DictConfig, exp_dataset_params: Dict[str, Any]
+ default_config: DictConfig, exp_dataset_params: dict[str, Any]
) -> RasterDataset:
del exp_dataset_params["mask"]
dataset, _ = make_dataset(
diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200424T094029_20200424T094123_T33TYN/B1.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200424T094029_20200424T094123_T33TYN/B1.tif
new file mode 100644
index 000000000..8fdc0fda9
Binary files /dev/null and b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200424T094029_20200424T094123_T33TYN/B1.tif differ
diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200424T094029_20200424T094123_T33TYN/B11.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200424T094029_20200424T094123_T33TYN/B11.tif
new file mode 100644
index 000000000..8257c4f6f
Binary files /dev/null and b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200424T094029_20200424T094123_T33TYN/B11.tif differ
diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200424T094029_20200424T094123_T33TYN/B12.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200424T094029_20200424T094123_T33TYN/B12.tif
new file mode 100644
index 000000000..d00a4de7e
Binary files /dev/null and b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200424T094029_20200424T094123_T33TYN/B12.tif differ
diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200424T094029_20200424T094123_T33TYN/B2.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200424T094029_20200424T094123_T33TYN/B2.tif
new file mode 100644
index 000000000..950ca2e7b
Binary files /dev/null and b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200424T094029_20200424T094123_T33TYN/B2.tif differ
diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200424T094029_20200424T094123_T33TYN/B3.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200424T094029_20200424T094123_T33TYN/B3.tif
new file mode 100644
index 000000000..cd6acd4d0
Binary files /dev/null and b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200424T094029_20200424T094123_T33TYN/B3.tif differ
diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200424T094029_20200424T094123_T33TYN/B4.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200424T094029_20200424T094123_T33TYN/B4.tif
new file mode 100644
index 000000000..a0a7530be
Binary files /dev/null and b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200424T094029_20200424T094123_T33TYN/B4.tif differ
diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200424T094029_20200424T094123_T33TYN/B5.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200424T094029_20200424T094123_T33TYN/B5.tif
new file mode 100644
index 000000000..00a1207ff
Binary files /dev/null and b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200424T094029_20200424T094123_T33TYN/B5.tif differ
diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200424T094029_20200424T094123_T33TYN/B6.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200424T094029_20200424T094123_T33TYN/B6.tif
new file mode 100644
index 000000000..42941b2d5
Binary files /dev/null and b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200424T094029_20200424T094123_T33TYN/B6.tif differ
diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200424T094029_20200424T094123_T33TYN/B7.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200424T094029_20200424T094123_T33TYN/B7.tif
new file mode 100644
index 000000000..fcd2a5317
Binary files /dev/null and b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200424T094029_20200424T094123_T33TYN/B7.tif differ
diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200424T094029_20200424T094123_T33TYN/B8.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200424T094029_20200424T094123_T33TYN/B8.tif
new file mode 100644
index 000000000..25a6621e3
Binary files /dev/null and b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200424T094029_20200424T094123_T33TYN/B8.tif differ
diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200424T094029_20200424T094123_T33TYN/B8A.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200424T094029_20200424T094123_T33TYN/B8A.tif
new file mode 100644
index 000000000..bfdc215fb
Binary files /dev/null and b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200424T094029_20200424T094123_T33TYN/B8A.tif differ
diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200424T094029_20200424T094123_T33TYN/B9.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200424T094029_20200424T094123_T33TYN/B9.tif
new file mode 100644
index 000000000..5335622a3
Binary files /dev/null and b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200424T094029_20200424T094123_T33TYN/B9.tif differ
diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200424T094029_20200424T094123_T33TYN/metadata.json b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200424T094029_20200424T094123_T33TYN/metadata.json
new file mode 100644
index 000000000..081ca2731
--- /dev/null
+++ b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200424T094029_20200424T094123_T33TYN/metadata.json
@@ -0,0 +1 @@
+{"AOT_RETRIEVAL_ACCURACY": 0, "CLOUDY_PIXEL_PERCENTAGE": 2.812705, "CLOUD_COVERAGE_ASSESSMENT": 2.812705, "CLOUD_SHADOW_PERCENTAGE": 0.015249, "DARK_FEATURES_PERCENTAGE": 0.468881, "DATASTRIP_ID": "S2B_OPER_MSI_L2A_DS_SGS__20200424T123330_S20200424T094123_N02.14", "DATATAKE_IDENTIFIER": "GS2B_20200424T094029_016364_N02.14", "DATATAKE_TYPE": "INS-NOBS", "DEGRADED_MSI_DATA_PERCENTAGE": 0, "FORMAT_CORRECTNESS": "PASSED", "GENERAL_QUALITY": "PASSED", "GENERATION_TIME": 1587731610000, "GEOMETRIC_QUALITY": "PASSED", "GRANULE_ID": "L2A_T33TYN_A016364_20200424T094123", "HIGH_PROBA_CLOUDS_PERCENTAGE": 0.015214, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B1": 106.543897547236, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B10": 105.92977627649, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B11": 106.2294370992, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B12": 106.526973972026, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B2": 105.427079405054, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B3": 105.700015654406, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B4": 105.929384543531, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B5": 106.065332002664, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B6": 106.186455739266, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B7": 106.317289469222, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B8": 105.564528810271, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B8A": 106.398164642186, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B9": 106.660211986834, "MEAN_INCIDENCE_ZENITH_ANGLE_B1": 8.67687308176335, "MEAN_INCIDENCE_ZENITH_ANGLE_B10": 8.54845080104398, "MEAN_INCIDENCE_ZENITH_ANGLE_B11": 8.60029719555954, "MEAN_INCIDENCE_ZENITH_ANGLE_B12": 8.66269723035736, "MEAN_INCIDENCE_ZENITH_ANGLE_B2": 8.50990628623815, "MEAN_INCIDENCE_ZENITH_ANGLE_B3": 8.53383114161826, "MEAN_INCIDENCE_ZENITH_ANGLE_B4": 8.56085011020257, "MEAN_INCIDENCE_ZENITH_ANGLE_B5": 8.5810054217592, "MEAN_INCIDENCE_ZENITH_ANGLE_B6": 8.60351046174368, "MEAN_INCIDENCE_ZENITH_ANGLE_B7": 8.62846298114179, "MEAN_INCIDENCE_ZENITH_ANGLE_B8": 8.52063436830031, "MEAN_INCIDENCE_ZENITH_ANGLE_B8A": 8.64956945753696, "MEAN_INCIDENCE_ZENITH_ANGLE_B9": 8.70970015807593, "MEAN_SOLAR_AZIMUTH_ANGLE": 156.022265244126, "MEAN_SOLAR_ZENITH_ANGLE": 36.2587132733646, "MEDIUM_PROBA_CLOUDS_PERCENTAGE": 0.024197, "MGRS_TILE": "33TYN", "NODATA_PIXEL_PERCENTAGE": 24.501736, "NOT_VEGETATED_PERCENTAGE": 33.184448, "PROCESSING_BASELINE": "02.14", "PRODUCT_ID": "S2B_MSIL2A_20200424T094029_N0214_R036_T33TYN_20200424T123330", "RADIATIVE_TRANSFER_ACCURACY": 0, "RADIOMETRIC_QUALITY": "PASSED", "REFLECTANCE_CONVERSION_CORRECTION": 0.990710191784445, "SATURATED_DEFECTIVE_PIXEL_PERCENTAGE": 0, "SENSING_ORBIT_DIRECTION": "DESCENDING", "SENSING_ORBIT_NUMBER": 36, "SENSOR_QUALITY": "PASSED", "SNOW_ICE_PERCENTAGE": 4e-06, "SOLAR_IRRADIANCE_B1": 1874.3, "SOLAR_IRRADIANCE_B10": 365.41, "SOLAR_IRRADIANCE_B11": 247.08, "SOLAR_IRRADIANCE_B12": 87.75, "SOLAR_IRRADIANCE_B2": 1959.75, "SOLAR_IRRADIANCE_B3": 1824.93, "SOLAR_IRRADIANCE_B4": 1512.79, "SOLAR_IRRADIANCE_B5": 1425.78, "SOLAR_IRRADIANCE_B6": 1291.13, "SOLAR_IRRADIANCE_B7": 1175.57, "SOLAR_IRRADIANCE_B8": 1041.28, "SOLAR_IRRADIANCE_B8A": 953.93, "SOLAR_IRRADIANCE_B9": 817.58, "SPACECRAFT_NAME": "Sentinel-2B", "THIN_CIRRUS_PERCENTAGE": 2.773294, "UNCLASSIFIED_PERCENTAGE": 0.328074, "VEGETATION_PERCENTAGE": 58.864659, "WATER_PERCENTAGE": 4.325977, "WATER_VAPOUR_RETRIEVAL_ACCURACY": 0, "system:asset_size": 1340927025, "system:band_names": ["B1", "B2", "B3", "B4", "B5", "B6", "B7", "B8", "B8A", "B9", "B11", "B12", "AOT", "WVP", "SCL", "TCI_R", "TCI_G", "TCI_B", "MSK_CLDPRB", "MSK_SNWPRB", "QA10", "QA20", "QA60"], "system:bands": {"B11": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [5490, 5490], "crs": "EPSG:32633", "crs_transform": [20, 0, 699960, 0, -20, 5300040]}, "TCI_B": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 255}, "dimensions": [10980, 10980], "crs": "EPSG:32633", "crs_transform": [10, 0, 699960, 0, -10, 5300040]}, "B12": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [5490, 5490], "crs": "EPSG:32633", "crs_transform": [20, 0, 699960, 0, -20, 5300040]}, "AOT": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [10980, 10980], "crs": "EPSG:32633", "crs_transform": [10, 0, 699960, 0, -10, 5300040]}, "QA10": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [10980, 10980], "crs": "EPSG:32633", "crs_transform": [10, 0, 699960, 0, -10, 5300040]}, "B8A": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [5490, 5490], "crs": "EPSG:32633", "crs_transform": [20, 0, 699960, 0, -20, 5300040]}, "QA20": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 4294967295}, "dimensions": [5490, 5490], "crs": "EPSG:32633", "crs_transform": [20, 0, 699960, 0, -20, 5300040]}, "MSK_CLDPRB": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 255}, "dimensions": [5490, 5490], "crs": "EPSG:32633", "crs_transform": [20, 0, 699960, 0, -20, 5300040]}, "B1": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [1830, 1830], "crs": "EPSG:32633", "crs_transform": [60, 0, 699960, 0, -60, 5300040]}, "QA60": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [1830, 1830], "crs": "EPSG:32633", "crs_transform": [60, 0, 699960, 0, -60, 5300040]}, "B2": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [10980, 10980], "crs": "EPSG:32633", "crs_transform": [10, 0, 699960, 0, -10, 5300040]}, "WVP": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [10980, 10980], "crs": "EPSG:32633", "crs_transform": [10, 0, 699960, 0, -10, 5300040]}, "B3": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [10980, 10980], "crs": "EPSG:32633", "crs_transform": [10, 0, 699960, 0, -10, 5300040]}, "B4": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [10980, 10980], "crs": "EPSG:32633", "crs_transform": [10, 0, 699960, 0, -10, 5300040]}, "TCI_R": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 255}, "dimensions": [10980, 10980], "crs": "EPSG:32633", "crs_transform": [10, 0, 699960, 0, -10, 5300040]}, "B5": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [5490, 5490], "crs": "EPSG:32633", "crs_transform": [20, 0, 699960, 0, -20, 5300040]}, "B6": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [5490, 5490], "crs": "EPSG:32633", "crs_transform": [20, 0, 699960, 0, -20, 5300040]}, "B7": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [5490, 5490], "crs": "EPSG:32633", "crs_transform": [20, 0, 699960, 0, -20, 5300040]}, "B8": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [10980, 10980], "crs": "EPSG:32633", "crs_transform": [10, 0, 699960, 0, -10, 5300040]}, "B9": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [1830, 1830], "crs": "EPSG:32633", "crs_transform": [60, 0, 699960, 0, -60, 5300040]}, "MSK_SNWPRB": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 255}, "dimensions": [5490, 5490], "crs": "EPSG:32633", "crs_transform": [20, 0, 699960, 0, -20, 5300040]}, "SCL": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 255}, "dimensions": [5490, 5490], "crs": "EPSG:32633", "crs_transform": [20, 0, 699960, 0, -20, 5300040]}, "TCI_G": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 255}, "dimensions": [10980, 10980], "crs": "EPSG:32633", "crs_transform": [10, 0, 699960, 0, -10, 5300040]}}, "system:footprint": {"type": "LinearRing", "coordinates": [[19.135125920080306, 47.77916186962477], [19.135029727720475, 47.779255442991314], [18.17974674339065, 47.80973453124768], [18.175369534026473, 47.80439727501413], [18.158478123290493, 47.76381793259735], [18.135659921426363, 47.70397065763017], [18.02588276983751, 47.409953370306475], [17.84226898482384, 46.89305421884395], [17.83483726398006, 46.87163662286434], [17.82687335431474, 46.839969201428644], [17.825690373350913, 46.83243260435472], [17.82561662238212, 46.831025740554466], [17.82597229677631, 46.830892495549335], [19.058775213444516, 46.79376963860806], [19.058911741293866, 46.793834509947615], [19.077589698914174, 47.04020634864976], [19.096514577054162, 47.28654137136117], [19.115691869935954, 47.53285992686103], [19.135125920080306, 47.77916186962477]]}, "system:id": "COPERNICUS/S2_SR/20200424T094029_20200424T094123_T33TYN", "system:index": "20200424T094029_20200424T094123_T33TYN", "system:time_end": 1587721647687, "system:time_start": 1587721647687, "system:version": 1587930521476594}
diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200812T094039_20200812T094034_T33TYN/B1.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200812T094039_20200812T094034_T33TYN/B1.tif
new file mode 100644
index 000000000..9dff310b4
Binary files /dev/null and b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200812T094039_20200812T094034_T33TYN/B1.tif differ
diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200812T094039_20200812T094034_T33TYN/B11.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200812T094039_20200812T094034_T33TYN/B11.tif
new file mode 100644
index 000000000..9efd6e287
Binary files /dev/null and b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200812T094039_20200812T094034_T33TYN/B11.tif differ
diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200812T094039_20200812T094034_T33TYN/B12.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200812T094039_20200812T094034_T33TYN/B12.tif
new file mode 100644
index 000000000..47d64de43
Binary files /dev/null and b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200812T094039_20200812T094034_T33TYN/B12.tif differ
diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200812T094039_20200812T094034_T33TYN/B2.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200812T094039_20200812T094034_T33TYN/B2.tif
new file mode 100644
index 000000000..733b98ac7
Binary files /dev/null and b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200812T094039_20200812T094034_T33TYN/B2.tif differ
diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200812T094039_20200812T094034_T33TYN/B3.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200812T094039_20200812T094034_T33TYN/B3.tif
new file mode 100644
index 000000000..972fc9af0
Binary files /dev/null and b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200812T094039_20200812T094034_T33TYN/B3.tif differ
diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200812T094039_20200812T094034_T33TYN/B4.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200812T094039_20200812T094034_T33TYN/B4.tif
new file mode 100644
index 000000000..37dc66eb0
Binary files /dev/null and b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200812T094039_20200812T094034_T33TYN/B4.tif differ
diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200812T094039_20200812T094034_T33TYN/B5.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200812T094039_20200812T094034_T33TYN/B5.tif
new file mode 100644
index 000000000..f3363bd95
Binary files /dev/null and b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200812T094039_20200812T094034_T33TYN/B5.tif differ
diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200812T094039_20200812T094034_T33TYN/B6.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200812T094039_20200812T094034_T33TYN/B6.tif
new file mode 100644
index 000000000..b19a03a30
Binary files /dev/null and b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200812T094039_20200812T094034_T33TYN/B6.tif differ
diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200812T094039_20200812T094034_T33TYN/B7.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200812T094039_20200812T094034_T33TYN/B7.tif
new file mode 100644
index 000000000..e988bcfba
Binary files /dev/null and b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200812T094039_20200812T094034_T33TYN/B7.tif differ
diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200812T094039_20200812T094034_T33TYN/B8.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200812T094039_20200812T094034_T33TYN/B8.tif
new file mode 100644
index 000000000..ccc50c8ba
Binary files /dev/null and b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200812T094039_20200812T094034_T33TYN/B8.tif differ
diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200812T094039_20200812T094034_T33TYN/B8A.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200812T094039_20200812T094034_T33TYN/B8A.tif
new file mode 100644
index 000000000..27ee7369f
Binary files /dev/null and b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200812T094039_20200812T094034_T33TYN/B8A.tif differ
diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200812T094039_20200812T094034_T33TYN/B9.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200812T094039_20200812T094034_T33TYN/B9.tif
new file mode 100644
index 000000000..b07964457
Binary files /dev/null and b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200812T094039_20200812T094034_T33TYN/B9.tif differ
diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200812T094039_20200812T094034_T33TYN/metadata.json b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200812T094039_20200812T094034_T33TYN/metadata.json
new file mode 100644
index 000000000..d3c3117e5
--- /dev/null
+++ b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200812T094039_20200812T094034_T33TYN/metadata.json
@@ -0,0 +1 @@
+{"AOT_RETRIEVAL_ACCURACY": 0, "CLOUDY_PIXEL_PERCENTAGE": 0.208477, "CLOUD_COVERAGE_ASSESSMENT": 0.208477, "CLOUD_SHADOW_PERCENTAGE": 0.078845, "DARK_FEATURES_PERCENTAGE": 0.266203, "DATASTRIP_ID": "S2B_OPER_MSI_L2A_DS_EPAE_20200812T120400_S20200812T094034_N02.14", "DATATAKE_IDENTIFIER": "GS2B_20200812T094039_017937_N02.14", "DATATAKE_TYPE": "INS-NOBS", "DEGRADED_MSI_DATA_PERCENTAGE": 0, "FORMAT_CORRECTNESS": "PASSED", "GENERAL_QUALITY": "PASSED", "GENERATION_TIME": 1597233840000, "GEOMETRIC_QUALITY": "PASSED", "GRANULE_ID": "L2A_T33TYN_A017937_20200812T094034", "HIGH_PROBA_CLOUDS_PERCENTAGE": 0.137471, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B1": 106.429324062005, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B10": 105.852008867325, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B11": 106.077757429308, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B12": 106.335669133722, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B2": 105.380300664506, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B3": 105.619830676885, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B4": 105.833171389503, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B5": 105.955584495271, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B6": 106.063210971005, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B7": 106.180783262801, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B8": 105.500006984303, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B8A": 106.287212521103, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B9": 106.474000451992, "MEAN_INCIDENCE_ZENITH_ANGLE_B1": 8.65065979233341, "MEAN_INCIDENCE_ZENITH_ANGLE_B10": 8.5230203611065, "MEAN_INCIDENCE_ZENITH_ANGLE_B11": 8.5659857030015, "MEAN_INCIDENCE_ZENITH_ANGLE_B12": 8.6286564275332, "MEAN_INCIDENCE_ZENITH_ANGLE_B2": 8.48649538224668, "MEAN_INCIDENCE_ZENITH_ANGLE_B3": 8.50238336028187, "MEAN_INCIDENCE_ZENITH_ANGLE_B4": 8.53311449180501, "MEAN_INCIDENCE_ZENITH_ANGLE_B5": 8.55333421195006, "MEAN_INCIDENCE_ZENITH_ANGLE_B6": 8.57592089704101, "MEAN_INCIDENCE_ZENITH_ANGLE_B7": 8.60096440676022, "MEAN_INCIDENCE_ZENITH_ANGLE_B8": 8.48915559090344, "MEAN_INCIDENCE_ZENITH_ANGLE_B8A": 8.62827228110442, "MEAN_INCIDENCE_ZENITH_ANGLE_B9": 8.68825044784009, "MEAN_SOLAR_AZIMUTH_ANGLE": 152.478041660485, "MEAN_SOLAR_ZENITH_ANGLE": 35.147934745323, "MEDIUM_PROBA_CLOUDS_PERCENTAGE": 0.070437, "MGRS_TILE": "33TYN", "NODATA_PIXEL_PERCENTAGE": 23.649484, "NOT_VEGETATED_PERCENTAGE": 20.454341, "PROCESSING_BASELINE": "02.14", "PRODUCT_ID": "S2B_MSIL2A_20200812T094039_N0214_R036_T33TYN_20200812T120400", "RADIATIVE_TRANSFER_ACCURACY": 0, "RADIOMETRIC_QUALITY": "PASSED", "REFLECTANCE_CONVERSION_CORRECTION": 0.97302905813138, "SATURATED_DEFECTIVE_PIXEL_PERCENTAGE": 0, "SENSING_ORBIT_DIRECTION": "DESCENDING", "SENSING_ORBIT_NUMBER": 36, "SENSOR_QUALITY": "PASSED", "SNOW_ICE_PERCENTAGE": 0.000183, "SOLAR_IRRADIANCE_B1": 1874.3, "SOLAR_IRRADIANCE_B10": 365.41, "SOLAR_IRRADIANCE_B11": 247.08, "SOLAR_IRRADIANCE_B12": 87.75, "SOLAR_IRRADIANCE_B2": 1959.75, "SOLAR_IRRADIANCE_B3": 1824.93, "SOLAR_IRRADIANCE_B4": 1512.79, "SOLAR_IRRADIANCE_B5": 1425.78, "SOLAR_IRRADIANCE_B6": 1291.13, "SOLAR_IRRADIANCE_B7": 1175.57, "SOLAR_IRRADIANCE_B8": 1041.28, "SOLAR_IRRADIANCE_B8A": 953.93, "SOLAR_IRRADIANCE_B9": 817.58, "SPACECRAFT_NAME": "Sentinel-2B", "THIN_CIRRUS_PERCENTAGE": 0.000569, "UNCLASSIFIED_PERCENTAGE": 0.25633, "VEGETATION_PERCENTAGE": 74.069816, "WATER_PERCENTAGE": 4.665802, "WATER_VAPOUR_RETRIEVAL_ACCURACY": 0, "system:asset_size": 1331946979, "system:band_names": ["B1", "B2", "B3", "B4", "B5", "B6", "B7", "B8", "B8A", "B9", "B11", "B12", "AOT", "WVP", "SCL", "TCI_R", "TCI_G", "TCI_B", "MSK_CLDPRB", "MSK_SNWPRB", "QA10", "QA20", "QA60"], "system:bands": {"B11": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [5490, 5490], "crs": "EPSG:32633", "crs_transform": [20, 0, 699960, 0, -20, 5300040]}, "TCI_B": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 255}, "dimensions": [10980, 10980], "crs": "EPSG:32633", "crs_transform": [10, 0, 699960, 0, -10, 5300040]}, "B12": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [5490, 5490], "crs": "EPSG:32633", "crs_transform": [20, 0, 699960, 0, -20, 5300040]}, "AOT": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [10980, 10980], "crs": "EPSG:32633", "crs_transform": [10, 0, 699960, 0, -10, 5300040]}, "QA10": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [10980, 10980], "crs": "EPSG:32633", "crs_transform": [10, 0, 699960, 0, -10, 5300040]}, "B8A": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [5490, 5490], "crs": "EPSG:32633", "crs_transform": [20, 0, 699960, 0, -20, 5300040]}, "QA20": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 4294967295}, "dimensions": [5490, 5490], "crs": "EPSG:32633", "crs_transform": [20, 0, 699960, 0, -20, 5300040]}, "MSK_CLDPRB": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 255}, "dimensions": [5490, 5490], "crs": "EPSG:32633", "crs_transform": [20, 0, 699960, 0, -20, 5300040]}, "B1": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [1830, 1830], "crs": "EPSG:32633", "crs_transform": [60, 0, 699960, 0, -60, 5300040]}, "QA60": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [1830, 1830], "crs": "EPSG:32633", "crs_transform": [60, 0, 699960, 0, -60, 5300040]}, "B2": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [10980, 10980], "crs": "EPSG:32633", "crs_transform": [10, 0, 699960, 0, -10, 5300040]}, "WVP": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [10980, 10980], "crs": "EPSG:32633", "crs_transform": [10, 0, 699960, 0, -10, 5300040]}, "B3": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [10980, 10980], "crs": "EPSG:32633", "crs_transform": [10, 0, 699960, 0, -10, 5300040]}, "B4": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [10980, 10980], "crs": "EPSG:32633", "crs_transform": [10, 0, 699960, 0, -10, 5300040]}, "TCI_R": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 255}, "dimensions": [10980, 10980], "crs": "EPSG:32633", "crs_transform": [10, 0, 699960, 0, -10, 5300040]}, "B5": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [5490, 5490], "crs": "EPSG:32633", "crs_transform": [20, 0, 699960, 0, -20, 5300040]}, "B6": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [5490, 5490], "crs": "EPSG:32633", "crs_transform": [20, 0, 699960, 0, -20, 5300040]}, "B7": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [5490, 5490], "crs": "EPSG:32633", "crs_transform": [20, 0, 699960, 0, -20, 5300040]}, "B8": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [10980, 10980], "crs": "EPSG:32633", "crs_transform": [10, 0, 699960, 0, -10, 5300040]}, "B9": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [1830, 1830], "crs": "EPSG:32633", "crs_transform": [60, 0, 699960, 0, -60, 5300040]}, "MSK_SNWPRB": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 255}, "dimensions": [5490, 5490], "crs": "EPSG:32633", "crs_transform": [20, 0, 699960, 0, -20, 5300040]}, "SCL": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 255}, "dimensions": [5490, 5490], "crs": "EPSG:32633", "crs_transform": [20, 0, 699960, 0, -20, 5300040]}, "TCI_G": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 255}, "dimensions": [10980, 10980], "crs": "EPSG:32633", "crs_transform": [10, 0, 699960, 0, -10, 5300040]}}, "system:footprint": {"type": "LinearRing", "coordinates": [[19.135125920080306, 47.77916186962477], [19.135029727720475, 47.779255442991314], [18.16693852790548, 47.8100889727091], [18.164267652937216, 47.80632443625982], [18.158840869103887, 47.79620697829836], [18.088526955702534, 47.61022107161706], [17.903701479756446, 47.09670164788835], [17.84229285719976, 46.92329249637601], [17.827427767869775, 46.88046144909253], [17.81606866971548, 46.844018609587174], [17.813931590571993, 46.833263753347495], [17.81382991163962, 46.831316282941685], [17.814185583479816, 46.83118308142532], [19.058775213444516, 46.79376963860806], [19.058911741293866, 46.793834509947615], [19.077589698914174, 47.04020634864976], [19.096514577054162, 47.28654137136117], [19.115691869935954, 47.53285992686103], [19.135125920080306, 47.77916186962477]]}, "system:id": "COPERNICUS/S2_SR/20200812T094039_20200812T094034_T33TYN", "system:index": "20200812T094039_20200812T094034_T33TYN", "system:time_end": 1597225656684, "system:time_start": 1597225656684, "system:version": 1597431921198064}
diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20201118T095311_20201118T095400_T33TYN/B1.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20201118T095311_20201118T095400_T33TYN/B1.tif
new file mode 100644
index 000000000..09c23f408
Binary files /dev/null and b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20201118T095311_20201118T095400_T33TYN/B1.tif differ
diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20201118T095311_20201118T095400_T33TYN/B11.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20201118T095311_20201118T095400_T33TYN/B11.tif
new file mode 100644
index 000000000..262d05a60
Binary files /dev/null and b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20201118T095311_20201118T095400_T33TYN/B11.tif differ
diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20201118T095311_20201118T095400_T33TYN/B12.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20201118T095311_20201118T095400_T33TYN/B12.tif
new file mode 100644
index 000000000..42914e101
Binary files /dev/null and b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20201118T095311_20201118T095400_T33TYN/B12.tif differ
diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20201118T095311_20201118T095400_T33TYN/B2.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20201118T095311_20201118T095400_T33TYN/B2.tif
new file mode 100644
index 000000000..8bbca800c
Binary files /dev/null and b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20201118T095311_20201118T095400_T33TYN/B2.tif differ
diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20201118T095311_20201118T095400_T33TYN/B3.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20201118T095311_20201118T095400_T33TYN/B3.tif
new file mode 100644
index 000000000..54d75d141
Binary files /dev/null and b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20201118T095311_20201118T095400_T33TYN/B3.tif differ
diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20201118T095311_20201118T095400_T33TYN/B4.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20201118T095311_20201118T095400_T33TYN/B4.tif
new file mode 100644
index 000000000..07d895a7d
Binary files /dev/null and b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20201118T095311_20201118T095400_T33TYN/B4.tif differ
diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20201118T095311_20201118T095400_T33TYN/B5.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20201118T095311_20201118T095400_T33TYN/B5.tif
new file mode 100644
index 000000000..f88d453d9
Binary files /dev/null and b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20201118T095311_20201118T095400_T33TYN/B5.tif differ
diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20201118T095311_20201118T095400_T33TYN/B6.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20201118T095311_20201118T095400_T33TYN/B6.tif
new file mode 100644
index 000000000..b6d0815e3
Binary files /dev/null and b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20201118T095311_20201118T095400_T33TYN/B6.tif differ
diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20201118T095311_20201118T095400_T33TYN/B7.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20201118T095311_20201118T095400_T33TYN/B7.tif
new file mode 100644
index 000000000..8ae56682c
Binary files /dev/null and b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20201118T095311_20201118T095400_T33TYN/B7.tif differ
diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20201118T095311_20201118T095400_T33TYN/B8.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20201118T095311_20201118T095400_T33TYN/B8.tif
new file mode 100644
index 000000000..579d2a9a8
Binary files /dev/null and b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20201118T095311_20201118T095400_T33TYN/B8.tif differ
diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20201118T095311_20201118T095400_T33TYN/B8A.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20201118T095311_20201118T095400_T33TYN/B8A.tif
new file mode 100644
index 000000000..bed885ace
Binary files /dev/null and b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20201118T095311_20201118T095400_T33TYN/B8A.tif differ
diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20201118T095311_20201118T095400_T33TYN/B9.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20201118T095311_20201118T095400_T33TYN/B9.tif
new file mode 100644
index 000000000..8973a2922
Binary files /dev/null and b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20201118T095311_20201118T095400_T33TYN/B9.tif differ
diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20201118T095311_20201118T095400_T33TYN/metadata.json b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20201118T095311_20201118T095400_T33TYN/metadata.json
new file mode 100644
index 000000000..83ac26bd5
--- /dev/null
+++ b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20201118T095311_20201118T095400_T33TYN/metadata.json
@@ -0,0 +1 @@
+{"AOT_RETRIEVAL_ACCURACY": 0, "CLOUDY_PIXEL_PERCENTAGE": 0.453934, "CLOUD_COVERAGE_ASSESSMENT": 0.45393399999999995, "CLOUD_SHADOW_PERCENTAGE": 1.138567, "DARK_FEATURES_PERCENTAGE": 15.463047, "DATASTRIP_ID": "S2A_OPER_MSI_L2A_DS_EPAE_20201118T115530_S20201118T095400_N02.14", "DATATAKE_IDENTIFIER": "GS2A_20201118T095311_028247_N02.14", "DATATAKE_TYPE": "INS-NOBS", "DEGRADED_MSI_DATA_PERCENTAGE": 0, "FORMAT_CORRECTNESS": "PASSED", "GENERAL_QUALITY": "PASSED", "GENERATION_TIME": 1605700530000, "GEOMETRIC_QUALITY": "PASSED", "GRANULE_ID": "L2A_T33TYN_A028247_20201118T095400", "HIGH_PROBA_CLOUDS_PERCENTAGE": 0.248884, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B1": 281.681280146803, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B10": 281.572710595024, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B11": 281.081166730384, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B12": 281.107900743002, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B2": 282.19659426411, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B3": 281.893354312561, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B4": 281.719087546553, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B5": 281.69787113297, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B6": 281.699110683518, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B7": 281.697525241078, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B8": 282.017178016916, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B8A": 281.726600099165, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B9": 281.718430321399, "MEAN_INCIDENCE_ZENITH_ANGLE_B1": 5.67907654400748, "MEAN_INCIDENCE_ZENITH_ANGLE_B10": 5.41638149201038, "MEAN_INCIDENCE_ZENITH_ANGLE_B11": 5.51137814408818, "MEAN_INCIDENCE_ZENITH_ANGLE_B12": 5.63006844946331, "MEAN_INCIDENCE_ZENITH_ANGLE_B2": 5.31075209346664, "MEAN_INCIDENCE_ZENITH_ANGLE_B3": 5.36458098961822, "MEAN_INCIDENCE_ZENITH_ANGLE_B4": 5.43516937389945, "MEAN_INCIDENCE_ZENITH_ANGLE_B5": 5.48025769954631, "MEAN_INCIDENCE_ZENITH_ANGLE_B6": 5.52449969165148, "MEAN_INCIDENCE_ZENITH_ANGLE_B7": 5.57253782358173, "MEAN_INCIDENCE_ZENITH_ANGLE_B8": 5.33541219024446, "MEAN_INCIDENCE_ZENITH_ANGLE_B8A": 5.62367424140266, "MEAN_INCIDENCE_ZENITH_ANGLE_B9": 5.73796871755139, "MEAN_SOLAR_AZIMUTH_ANGLE": 171.248871380871, "MEAN_SOLAR_ZENITH_ANGLE": 67.1303819169954, "MEDIUM_PROBA_CLOUDS_PERCENTAGE": 0.169226, "MGRS_TILE": "33TYN", "NODATA_PIXEL_PERCENTAGE": 4e-05, "NOT_VEGETATED_PERCENTAGE": 31.151235, "PROCESSING_BASELINE": "02.14", "PRODUCT_ID": "S2A_MSIL2A_20201118T095311_N0214_R079_T33TYN_20201118T115530", "RADIATIVE_TRANSFER_ACCURACY": 0, "RADIOMETRIC_QUALITY": "PASSED", "REFLECTANCE_CONVERSION_CORRECTION": 1.02242117671629, "SATURATED_DEFECTIVE_PIXEL_PERCENTAGE": 0, "SENSING_ORBIT_DIRECTION": "DESCENDING", "SENSING_ORBIT_NUMBER": 79, "SENSOR_QUALITY": "PASSED", "SNOW_ICE_PERCENTAGE": 6e-05, "SOLAR_IRRADIANCE_B1": 1884.69, "SOLAR_IRRADIANCE_B10": 367.15, "SOLAR_IRRADIANCE_B11": 245.59, "SOLAR_IRRADIANCE_B12": 85.25, "SOLAR_IRRADIANCE_B2": 1959.66, "SOLAR_IRRADIANCE_B3": 1823.24, "SOLAR_IRRADIANCE_B4": 1512.06, "SOLAR_IRRADIANCE_B5": 1424.64, "SOLAR_IRRADIANCE_B6": 1287.61, "SOLAR_IRRADIANCE_B7": 1162.08, "SOLAR_IRRADIANCE_B8": 1041.63, "SOLAR_IRRADIANCE_B8A": 955.32, "SOLAR_IRRADIANCE_B9": 812.92, "SPACECRAFT_NAME": "Sentinel-2A", "THIN_CIRRUS_PERCENTAGE": 0.035823, "UNCLASSIFIED_PERCENTAGE": 12.55997, "VEGETATION_PERCENTAGE": 34.667516, "WATER_PERCENTAGE": 4.56567, "WATER_VAPOUR_RETRIEVAL_ACCURACY": 0, "system:asset_size": 1736727953, "system:band_names": ["B1", "B2", "B3", "B4", "B5", "B6", "B7", "B8", "B8A", "B9", "B11", "B12", "AOT", "WVP", "SCL", "TCI_R", "TCI_G", "TCI_B", "MSK_CLDPRB", "MSK_SNWPRB", "QA10", "QA20", "QA60"], "system:bands": {"B11": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [5490, 5490], "crs": "EPSG:32633", "crs_transform": [20, 0, 699960, 0, -20, 5300040]}, "TCI_B": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 255}, "dimensions": [10980, 10980], "crs": "EPSG:32633", "crs_transform": [10, 0, 699960, 0, -10, 5300040]}, "B12": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [5490, 5490], "crs": "EPSG:32633", "crs_transform": [20, 0, 699960, 0, -20, 5300040]}, "AOT": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [10980, 10980], "crs": "EPSG:32633", "crs_transform": [10, 0, 699960, 0, -10, 5300040]}, "QA10": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [10980, 10980], "crs": "EPSG:32633", "crs_transform": [10, 0, 699960, 0, -10, 5300040]}, "B8A": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [5490, 5490], "crs": "EPSG:32633", "crs_transform": [20, 0, 699960, 0, -20, 5300040]}, "QA20": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 4294967295}, "dimensions": [5490, 5490], "crs": "EPSG:32633", "crs_transform": [20, 0, 699960, 0, -20, 5300040]}, "MSK_CLDPRB": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 255}, "dimensions": [5490, 5490], "crs": "EPSG:32633", "crs_transform": [20, 0, 699960, 0, -20, 5300040]}, "B1": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [1830, 1830], "crs": "EPSG:32633", "crs_transform": [60, 0, 699960, 0, -60, 5300040]}, "QA60": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [1830, 1830], "crs": "EPSG:32633", "crs_transform": [60, 0, 699960, 0, -60, 5300040]}, "B2": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [10980, 10980], "crs": "EPSG:32633", "crs_transform": [10, 0, 699960, 0, -10, 5300040]}, "WVP": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [10980, 10980], "crs": "EPSG:32633", "crs_transform": [10, 0, 699960, 0, -10, 5300040]}, "B3": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [10980, 10980], "crs": "EPSG:32633", "crs_transform": [10, 0, 699960, 0, -10, 5300040]}, "B4": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [10980, 10980], "crs": "EPSG:32633", "crs_transform": [10, 0, 699960, 0, -10, 5300040]}, "TCI_R": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 255}, "dimensions": [10980, 10980], "crs": "EPSG:32633", "crs_transform": [10, 0, 699960, 0, -10, 5300040]}, "B5": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [5490, 5490], "crs": "EPSG:32633", "crs_transform": [20, 0, 699960, 0, -20, 5300040]}, "B6": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [5490, 5490], "crs": "EPSG:32633", "crs_transform": [20, 0, 699960, 0, -20, 5300040]}, "B7": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [5490, 5490], "crs": "EPSG:32633", "crs_transform": [20, 0, 699960, 0, -20, 5300040]}, "B8": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [10980, 10980], "crs": "EPSG:32633", "crs_transform": [10, 0, 699960, 0, -10, 5300040]}, "B9": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [1830, 1830], "crs": "EPSG:32633", "crs_transform": [60, 0, 699960, 0, -60, 5300040]}, "MSK_SNWPRB": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 255}, "dimensions": [5490, 5490], "crs": "EPSG:32633", "crs_transform": [20, 0, 699960, 0, -20, 5300040]}, "SCL": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 255}, "dimensions": [5490, 5490], "crs": "EPSG:32633", "crs_transform": [20, 0, 699960, 0, -20, 5300040]}, "TCI_G": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 255}, "dimensions": [10980, 10980], "crs": "EPSG:32633", "crs_transform": [10, 0, 699960, 0, -10, 5300040]}}, "system:footprint": {"type": "LinearRing", "coordinates": [[17.67141078171894, 47.82263288510892], [17.671405837212518, 47.82261820708107], [17.646411626097294, 47.32925082218025], [17.62207290404578, 46.83583851015872], [17.62212427253828, 46.83579584772863], [17.622167538661174, 46.8357491351654], [17.622188932533497, 46.8357458205179], [19.058775213444516, 46.79376963860806], [19.0588384118566, 46.793803980922334], [19.05890750034179, 46.79383271512413], [19.058912745577004, 46.79384734936178], [19.077589698914174, 47.04020634864976], [19.096514577054162, 47.28654137136117], [19.115691869935954, 47.53285992686103], [19.135125920080306, 47.77916186962477], [19.13507488572603, 47.77920526079178], [19.135032249348438, 47.77925250976796], [19.135010549350152, 47.77925610053018], [17.67154409192475, 47.82269745296916], [17.671480450885944, 47.82266239915652], [17.67141078171894, 47.82263288510892]]}, "system:id": "COPERNICUS/S2_SR/20201118T095311_20201118T095400_T33TYN", "system:index": "20201118T095311_20201118T095400_T33TYN", "system:time_end": 1605693452958, "system:time_start": 1605693452958, "system:version": 1605833526635574}
diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20210218T094029_20210218T094028_T33TYN/B1.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20210218T094029_20210218T094028_T33TYN/B1.tif
new file mode 100644
index 000000000..253b906f3
Binary files /dev/null and b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20210218T094029_20210218T094028_T33TYN/B1.tif differ
diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20210218T094029_20210218T094028_T33TYN/B11.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20210218T094029_20210218T094028_T33TYN/B11.tif
new file mode 100644
index 000000000..ef83ac19d
Binary files /dev/null and b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20210218T094029_20210218T094028_T33TYN/B11.tif differ
diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20210218T094029_20210218T094028_T33TYN/B12.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20210218T094029_20210218T094028_T33TYN/B12.tif
new file mode 100644
index 000000000..ccdd8fb01
Binary files /dev/null and b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20210218T094029_20210218T094028_T33TYN/B12.tif differ
diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20210218T094029_20210218T094028_T33TYN/B2.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20210218T094029_20210218T094028_T33TYN/B2.tif
new file mode 100644
index 000000000..17bf01d6a
Binary files /dev/null and b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20210218T094029_20210218T094028_T33TYN/B2.tif differ
diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20210218T094029_20210218T094028_T33TYN/B3.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20210218T094029_20210218T094028_T33TYN/B3.tif
new file mode 100644
index 000000000..225ce192c
Binary files /dev/null and b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20210218T094029_20210218T094028_T33TYN/B3.tif differ
diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20210218T094029_20210218T094028_T33TYN/B4.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20210218T094029_20210218T094028_T33TYN/B4.tif
new file mode 100644
index 000000000..3228fe9dc
Binary files /dev/null and b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20210218T094029_20210218T094028_T33TYN/B4.tif differ
diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20210218T094029_20210218T094028_T33TYN/B5.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20210218T094029_20210218T094028_T33TYN/B5.tif
new file mode 100644
index 000000000..dd1c11bba
Binary files /dev/null and b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20210218T094029_20210218T094028_T33TYN/B5.tif differ
diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20210218T094029_20210218T094028_T33TYN/B6.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20210218T094029_20210218T094028_T33TYN/B6.tif
new file mode 100644
index 000000000..f2f2a950b
Binary files /dev/null and b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20210218T094029_20210218T094028_T33TYN/B6.tif differ
diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20210218T094029_20210218T094028_T33TYN/B7.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20210218T094029_20210218T094028_T33TYN/B7.tif
new file mode 100644
index 000000000..d9146ba9d
Binary files /dev/null and b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20210218T094029_20210218T094028_T33TYN/B7.tif differ
diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20210218T094029_20210218T094028_T33TYN/B8.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20210218T094029_20210218T094028_T33TYN/B8.tif
new file mode 100644
index 000000000..41cdc9f90
Binary files /dev/null and b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20210218T094029_20210218T094028_T33TYN/B8.tif differ
diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20210218T094029_20210218T094028_T33TYN/B8A.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20210218T094029_20210218T094028_T33TYN/B8A.tif
new file mode 100644
index 000000000..55bb31b14
Binary files /dev/null and b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20210218T094029_20210218T094028_T33TYN/B8A.tif differ
diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20210218T094029_20210218T094028_T33TYN/B9.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20210218T094029_20210218T094028_T33TYN/B9.tif
new file mode 100644
index 000000000..e01001e63
Binary files /dev/null and b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20210218T094029_20210218T094028_T33TYN/B9.tif differ
diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20210218T094029_20210218T094028_T33TYN/metadata.json b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20210218T094029_20210218T094028_T33TYN/metadata.json
new file mode 100644
index 000000000..fa970f6ee
--- /dev/null
+++ b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20210218T094029_20210218T094028_T33TYN/metadata.json
@@ -0,0 +1 @@
+{"AOT_RETRIEVAL_ACCURACY": 0, "CLOUDY_PIXEL_PERCENTAGE": 0.901432, "CLOUD_COVERAGE_ASSESSMENT": 0.901432, "CLOUD_SHADOW_PERCENTAGE": 0.84547, "DARK_FEATURES_PERCENTAGE": 6.795183, "DATASTRIP_ID": "S2B_OPER_MSI_L2A_DS_VGS2_20210218T123009_S20210218T094028_N02.14", "DATATAKE_IDENTIFIER": "GS2B_20210218T094029_020654_N02.14", "DATATAKE_TYPE": "INS-NOBS", "DEGRADED_MSI_DATA_PERCENTAGE": 0, "FORMAT_CORRECTNESS": "PASSED", "GENERAL_QUALITY": "PASSED", "GENERATION_TIME": 1613651409000, "GEOMETRIC_QUALITY": "PASSED", "GRANULE_ID": "L2A_T33TYN_A020654_20210218T094028", "HIGH_PROBA_CLOUDS_PERCENTAGE": 0.16992, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B1": 106.191528469216, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B10": 105.765178013311, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B11": 105.899702597322, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B12": 106.125463551834, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B2": 105.346122751905, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B3": 105.53713441503, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B4": 105.695362490083, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B5": 105.795932240219, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B6": 105.881764502021, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B7": 105.977799431462, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B8": 105.437391144718, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B8A": 106.062903343031, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B9": 106.246892369833, "MEAN_INCIDENCE_ZENITH_ANGLE_B1": 8.60382616209157, "MEAN_INCIDENCE_ZENITH_ANGLE_B10": 8.4747035379831, "MEAN_INCIDENCE_ZENITH_ANGLE_B11": 8.51482581279656, "MEAN_INCIDENCE_ZENITH_ANGLE_B12": 8.5809124485055, "MEAN_INCIDENCE_ZENITH_ANGLE_B2": 8.42926736088918, "MEAN_INCIDENCE_ZENITH_ANGLE_B3": 8.44752826229191, "MEAN_INCIDENCE_ZENITH_ANGLE_B4": 8.48181484372435, "MEAN_INCIDENCE_ZENITH_ANGLE_B5": 8.50214317942103, "MEAN_INCIDENCE_ZENITH_ANGLE_B6": 8.52486382484645, "MEAN_INCIDENCE_ZENITH_ANGLE_B7": 8.55005538439596, "MEAN_INCIDENCE_ZENITH_ANGLE_B8": 8.4342391916749, "MEAN_INCIDENCE_ZENITH_ANGLE_B8A": 8.57753521260584, "MEAN_INCIDENCE_ZENITH_ANGLE_B9": 8.64032193814476, "MEAN_SOLAR_AZIMUTH_ANGLE": 159.489777929019, "MEAN_SOLAR_ZENITH_ANGLE": 61.0015939922247, "MEDIUM_PROBA_CLOUDS_PERCENTAGE": 0.169342, "MGRS_TILE": "33TYN", "NODATA_PIXEL_PERCENTAGE": 22.564258, "NOT_VEGETATED_PERCENTAGE": 63.721645, "PROCESSING_BASELINE": "02.14", "PRODUCT_ID": "S2B_MSIL2A_20210218T094029_N0214_R036_T33TYN_20210218T123009", "RADIATIVE_TRANSFER_ACCURACY": 0, "RADIOMETRIC_QUALITY": "PASSED", "REFLECTANCE_CONVERSION_CORRECTION": 1.02538502693364, "SATURATED_DEFECTIVE_PIXEL_PERCENTAGE": 0, "SENSING_ORBIT_DIRECTION": "DESCENDING", "SENSING_ORBIT_NUMBER": 36, "SENSOR_QUALITY": "PASSED", "SNOW_ICE_PERCENTAGE": 0.140262, "SOLAR_IRRADIANCE_B1": 1874.3, "SOLAR_IRRADIANCE_B10": 365.41, "SOLAR_IRRADIANCE_B11": 247.08, "SOLAR_IRRADIANCE_B12": 87.75, "SOLAR_IRRADIANCE_B2": 1959.75, "SOLAR_IRRADIANCE_B3": 1824.93, "SOLAR_IRRADIANCE_B4": 1512.79, "SOLAR_IRRADIANCE_B5": 1425.78, "SOLAR_IRRADIANCE_B6": 1291.13, "SOLAR_IRRADIANCE_B7": 1175.57, "SOLAR_IRRADIANCE_B8": 1041.28, "SOLAR_IRRADIANCE_B8A": 953.93, "SOLAR_IRRADIANCE_B9": 817.58, "SPACECRAFT_NAME": "Sentinel-2B", "THIN_CIRRUS_PERCENTAGE": 0.56217, "UNCLASSIFIED_PERCENTAGE": 8.106547, "VEGETATION_PERCENTAGE": 14.548491, "WATER_PERCENTAGE": 4.940968, "WATER_VAPOUR_RETRIEVAL_ACCURACY": 0, "system:asset_size": 1352573025, "system:band_names": ["B1", "B2", "B3", "B4", "B5", "B6", "B7", "B8", "B8A", "B9", "B11", "B12", "AOT", "WVP", "SCL", "TCI_R", "TCI_G", "TCI_B", "MSK_CLDPRB", "MSK_SNWPRB", "QA10", "QA20", "QA60"], "system:bands": {"B11": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [5490, 5490], "crs": "EPSG:32633", "crs_transform": [20, 0, 699960, 0, -20, 5300040]}, "TCI_B": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 255}, "dimensions": [10980, 10980], "crs": "EPSG:32633", "crs_transform": [10, 0, 699960, 0, -10, 5300040]}, "B12": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [5490, 5490], "crs": "EPSG:32633", "crs_transform": [20, 0, 699960, 0, -20, 5300040]}, "AOT": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [10980, 10980], "crs": "EPSG:32633", "crs_transform": [10, 0, 699960, 0, -10, 5300040]}, "QA10": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [10980, 10980], "crs": "EPSG:32633", "crs_transform": [10, 0, 699960, 0, -10, 5300040]}, "B8A": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [5490, 5490], "crs": "EPSG:32633", "crs_transform": [20, 0, 699960, 0, -20, 5300040]}, "QA20": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 4294967295}, "dimensions": [5490, 5490], "crs": "EPSG:32633", "crs_transform": [20, 0, 699960, 0, -20, 5300040]}, "MSK_CLDPRB": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 255}, "dimensions": [5490, 5490], "crs": "EPSG:32633", "crs_transform": [20, 0, 699960, 0, -20, 5300040]}, "B1": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [1830, 1830], "crs": "EPSG:32633", "crs_transform": [60, 0, 699960, 0, -60, 5300040]}, "QA60": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [1830, 1830], "crs": "EPSG:32633", "crs_transform": [60, 0, 699960, 0, -60, 5300040]}, "B2": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [10980, 10980], "crs": "EPSG:32633", "crs_transform": [10, 0, 699960, 0, -10, 5300040]}, "WVP": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [10980, 10980], "crs": "EPSG:32633", "crs_transform": [10, 0, 699960, 0, -10, 5300040]}, "B3": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [10980, 10980], "crs": "EPSG:32633", "crs_transform": [10, 0, 699960, 0, -10, 5300040]}, "B4": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [10980, 10980], "crs": "EPSG:32633", "crs_transform": [10, 0, 699960, 0, -10, 5300040]}, "TCI_R": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 255}, "dimensions": [10980, 10980], "crs": "EPSG:32633", "crs_transform": [10, 0, 699960, 0, -10, 5300040]}, "B5": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [5490, 5490], "crs": "EPSG:32633", "crs_transform": [20, 0, 699960, 0, -20, 5300040]}, "B6": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [5490, 5490], "crs": "EPSG:32633", "crs_transform": [20, 0, 699960, 0, -20, 5300040]}, "B7": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [5490, 5490], "crs": "EPSG:32633", "crs_transform": [20, 0, 699960, 0, -20, 5300040]}, "B8": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [10980, 10980], "crs": "EPSG:32633", "crs_transform": [10, 0, 699960, 0, -10, 5300040]}, "B9": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [1830, 1830], "crs": "EPSG:32633", "crs_transform": [60, 0, 699960, 0, -60, 5300040]}, "MSK_SNWPRB": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 255}, "dimensions": [5490, 5490], "crs": "EPSG:32633", "crs_transform": [20, 0, 699960, 0, -20, 5300040]}, "SCL": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 255}, "dimensions": [5490, 5490], "crs": "EPSG:32633", "crs_transform": [20, 0, 699960, 0, -20, 5300040]}, "TCI_G": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 255}, "dimensions": [10980, 10980], "crs": "EPSG:32633", "crs_transform": [10, 0, 699960, 0, -10, 5300040]}}, "system:footprint": {"type": "LinearRing", "coordinates": [[17.797345040594656, 46.83204619803086], [17.797362873719553, 46.83199454391213], [17.797885329067725, 46.83160945562662], [17.79795167729368, 46.83158136738409], [19.058775213444516, 46.79376963860806], [19.058911741293866, 46.793834509947615], [19.077589698914174, 47.04020634864976], [19.096514577054162, 47.28654137136117], [19.115691869935954, 47.53285992686103], [19.135125920080306, 47.77916186962477], [19.135029727720475, 47.779255442991314], [18.15182335365654, 47.81050538734816], [18.15174592944439, 47.810493350914705], [18.148367147337964, 47.808395609286016], [18.09547502744622, 47.673746790872656], [17.914556442526802, 47.17958035145654], [17.825145976296444, 46.92695559117315], [17.797373483050553, 46.832592829927364], [17.797345040594656, 46.83204619803086]]}, "system:id": "COPERNICUS/S2_SR/20210218T094029_20210218T094028_T33TYN", "system:index": "20210218T094029_20210218T094028_T33TYN", "system:time_end": 1613641650888, "system:time_start": 1613641650888, "system:version": 1614206891239703}
diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000099/20200614T181931_20200614T183418_T11SMS/B1.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000099/20200614T181931_20200614T183418_T11SMS/B1.tif
deleted file mode 100644
index 8b4b55754..000000000
Binary files a/tests/fixtures/data/SSL4EO-S12/s2a/0000099/20200614T181931_20200614T183418_T11SMS/B1.tif and /dev/null differ
diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000099/20200614T181931_20200614T183418_T11SMS/B11.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000099/20200614T181931_20200614T183418_T11SMS/B11.tif
deleted file mode 100644
index 656051ca3..000000000
Binary files a/tests/fixtures/data/SSL4EO-S12/s2a/0000099/20200614T181931_20200614T183418_T11SMS/B11.tif and /dev/null differ
diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000099/20200614T181931_20200614T183418_T11SMS/B12.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000099/20200614T181931_20200614T183418_T11SMS/B12.tif
deleted file mode 100644
index 472d8b414..000000000
Binary files a/tests/fixtures/data/SSL4EO-S12/s2a/0000099/20200614T181931_20200614T183418_T11SMS/B12.tif and /dev/null differ
diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000099/20200614T181931_20200614T183418_T11SMS/B2.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000099/20200614T181931_20200614T183418_T11SMS/B2.tif
deleted file mode 100644
index b0ef24579..000000000
Binary files a/tests/fixtures/data/SSL4EO-S12/s2a/0000099/20200614T181931_20200614T183418_T11SMS/B2.tif and /dev/null differ
diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000099/20200614T181931_20200614T183418_T11SMS/B3.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000099/20200614T181931_20200614T183418_T11SMS/B3.tif
deleted file mode 100644
index eb5c25c62..000000000
Binary files a/tests/fixtures/data/SSL4EO-S12/s2a/0000099/20200614T181931_20200614T183418_T11SMS/B3.tif and /dev/null differ
diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000099/20200614T181931_20200614T183418_T11SMS/B4.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000099/20200614T181931_20200614T183418_T11SMS/B4.tif
deleted file mode 100644
index d86fd7fd6..000000000
Binary files a/tests/fixtures/data/SSL4EO-S12/s2a/0000099/20200614T181931_20200614T183418_T11SMS/B4.tif and /dev/null differ
diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000099/20200614T181931_20200614T183418_T11SMS/B5.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000099/20200614T181931_20200614T183418_T11SMS/B5.tif
deleted file mode 100644
index 2f2514b78..000000000
Binary files a/tests/fixtures/data/SSL4EO-S12/s2a/0000099/20200614T181931_20200614T183418_T11SMS/B5.tif and /dev/null differ
diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000099/20200614T181931_20200614T183418_T11SMS/B6.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000099/20200614T181931_20200614T183418_T11SMS/B6.tif
deleted file mode 100644
index 68c6dc87c..000000000
Binary files a/tests/fixtures/data/SSL4EO-S12/s2a/0000099/20200614T181931_20200614T183418_T11SMS/B6.tif and /dev/null differ
diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000099/20200614T181931_20200614T183418_T11SMS/B7.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000099/20200614T181931_20200614T183418_T11SMS/B7.tif
deleted file mode 100644
index 4403ab65a..000000000
Binary files a/tests/fixtures/data/SSL4EO-S12/s2a/0000099/20200614T181931_20200614T183418_T11SMS/B7.tif and /dev/null differ
diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000099/20200614T181931_20200614T183418_T11SMS/B8.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000099/20200614T181931_20200614T183418_T11SMS/B8.tif
deleted file mode 100644
index 343b888f6..000000000
Binary files a/tests/fixtures/data/SSL4EO-S12/s2a/0000099/20200614T181931_20200614T183418_T11SMS/B8.tif and /dev/null differ
diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000099/20200614T181931_20200614T183418_T11SMS/B8A.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000099/20200614T181931_20200614T183418_T11SMS/B8A.tif
deleted file mode 100644
index f257d7612..000000000
Binary files a/tests/fixtures/data/SSL4EO-S12/s2a/0000099/20200614T181931_20200614T183418_T11SMS/B8A.tif and /dev/null differ
diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000099/20200614T181931_20200614T183418_T11SMS/B9.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000099/20200614T181931_20200614T183418_T11SMS/B9.tif
deleted file mode 100644
index e643fa0ec..000000000
Binary files a/tests/fixtures/data/SSL4EO-S12/s2a/0000099/20200614T181931_20200614T183418_T11SMS/B9.tif and /dev/null differ
diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000099/20200614T181931_20200614T183418_T11SMS/metadata.json b/tests/fixtures/data/SSL4EO-S12/s2a/0000099/20200614T181931_20200614T183418_T11SMS/metadata.json
deleted file mode 100644
index 56095b8f3..000000000
--- a/tests/fixtures/data/SSL4EO-S12/s2a/0000099/20200614T181931_20200614T183418_T11SMS/metadata.json
+++ /dev/null
@@ -1 +0,0 @@
-{"AOT_RETRIEVAL_ACCURACY": 0, "CLOUDY_PIXEL_PERCENTAGE": 1.380343, "CLOUD_COVERAGE_ASSESSMENT": 1.380343, "CLOUD_SHADOW_PERCENTAGE": 0.045924, "DARK_FEATURES_PERCENTAGE": 0.207592, "DATASTRIP_ID": "S2A_OPER_MSI_L2A_DS_EPAE_20200614T233609_S20200614T183418_N02.14", "DATATAKE_IDENTIFIER": "GS2A_20200614T181931_026007_N02.14", "DATATAKE_TYPE": "INS-NOBS", "DEGRADED_MSI_DATA_PERCENTAGE": 0, "FORMAT_CORRECTNESS": "PASSED", "GENERAL_QUALITY": "PASSED", "GENERATION_TIME": 1592177769000, "GEOMETRIC_QUALITY": "PASSED", "GRANULE_ID": "L2A_T11SMS_A026007_20200614T183418", "HIGH_PROBA_CLOUDS_PERCENTAGE": 1.025103, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B1": 103.669795569119, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B10": 103.06629594126, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B11": 103.34224573122, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B12": 103.61285978569, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B2": 102.600210808613, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B3": 102.835706432612, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B4": 103.047500410431, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B5": 103.169579176377, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B6": 103.279959829067, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B7": 103.399267251377, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B8": 102.718211170632, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B8A": 103.517410582706, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B9": 103.782954154122, "MEAN_INCIDENCE_ZENITH_ANGLE_B1": 10.241410125066, "MEAN_INCIDENCE_ZENITH_ANGLE_B10": 10.1279749581719, "MEAN_INCIDENCE_ZENITH_ANGLE_B11": 10.1660724760143, "MEAN_INCIDENCE_ZENITH_ANGLE_B12": 10.2124812156495, "MEAN_INCIDENCE_ZENITH_ANGLE_B2": 10.0913569044097, "MEAN_INCIDENCE_ZENITH_ANGLE_B3": 10.1094453559007, "MEAN_INCIDENCE_ZENITH_ANGLE_B4": 10.1331032363887, "MEAN_INCIDENCE_ZENITH_ANGLE_B5": 10.1488031307943, "MEAN_INCIDENCE_ZENITH_ANGLE_B6": 10.1664615257693, "MEAN_INCIDENCE_ZENITH_ANGLE_B7": 10.1861200861194, "MEAN_INCIDENCE_ZENITH_ANGLE_B8": 10.0993961975397, "MEAN_INCIDENCE_ZENITH_ANGLE_B8A": 10.217514876513, "MEAN_INCIDENCE_ZENITH_ANGLE_B9": 10.2674475409991, "MEAN_SOLAR_AZIMUTH_ANGLE": 115.516292013745, "MEAN_SOLAR_ZENITH_ANGLE": 19.0927215025527, "MEDIUM_PROBA_CLOUDS_PERCENTAGE": 0.35524, "MGRS_TILE": "11SMS", "NODATA_PIXEL_PERCENTAGE": 62.915778, "NOT_VEGETATED_PERCENTAGE": 49.45268, "PROCESSING_BASELINE": "02.14", "PRODUCT_ID": "S2A_MSIL2A_20200614T181931_N0214_R127_T11SMS_20200614T233609", "RADIATIVE_TRANSFER_ACCURACY": 0, "RADIOMETRIC_QUALITY": "PASSED", "REFLECTANCE_CONVERSION_CORRECTION": 0.969929694485884, "SATURATED_DEFECTIVE_PIXEL_PERCENTAGE": 0, "SENSING_ORBIT_DIRECTION": "DESCENDING", "SENSING_ORBIT_NUMBER": 127, "SENSOR_QUALITY": "PASSED", "SNOW_ICE_PERCENTAGE": 0.002174, "SOLAR_IRRADIANCE_B1": 1884.69, "SOLAR_IRRADIANCE_B10": 367.15, "SOLAR_IRRADIANCE_B11": 245.59, "SOLAR_IRRADIANCE_B12": 85.25, "SOLAR_IRRADIANCE_B2": 1959.66, "SOLAR_IRRADIANCE_B3": 1823.24, "SOLAR_IRRADIANCE_B4": 1512.06, "SOLAR_IRRADIANCE_B5": 1424.64, "SOLAR_IRRADIANCE_B6": 1287.61, "SOLAR_IRRADIANCE_B7": 1162.08, "SOLAR_IRRADIANCE_B8": 1041.63, "SOLAR_IRRADIANCE_B8A": 955.32, "SOLAR_IRRADIANCE_B9": 812.92, "SPACECRAFT_NAME": "Sentinel-2A", "THIN_CIRRUS_PERCENTAGE": 0, "UNCLASSIFIED_PERCENTAGE": 0.570473, "VEGETATION_PERCENTAGE": 23.253819, "WATER_PERCENTAGE": 25.086993, "WATER_VAPOUR_RETRIEVAL_ACCURACY": 0, "system:asset_size": 670484430, "system:band_names": ["B1", "B2", "B3", "B4", "B5", "B6", "B7", "B8", "B8A", "B9", "B11", "B12", "AOT", "WVP", "SCL", "TCI_R", "TCI_G", "TCI_B", "MSK_CLDPRB", "MSK_SNWPRB", "QA10", "QA20", "QA60"], "system:bands": {"B11": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [5490, 5490], "crs": "EPSG:32611", "crs_transform": [20, 0, 399960, 0, -20, 3700020]}, "TCI_B": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 255}, "dimensions": [10980, 10980], "crs": "EPSG:32611", "crs_transform": [10, 0, 399960, 0, -10, 3700020]}, "B12": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [5490, 5490], "crs": "EPSG:32611", "crs_transform": [20, 0, 399960, 0, -20, 3700020]}, "AOT": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [10980, 10980], "crs": "EPSG:32611", "crs_transform": [10, 0, 399960, 0, -10, 3700020]}, "QA10": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [10980, 10980], "crs": "EPSG:32611", "crs_transform": [10, 0, 399960, 0, -10, 3700020]}, "B8A": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [5490, 5490], "crs": "EPSG:32611", "crs_transform": [20, 0, 399960, 0, -20, 3700020]}, "QA20": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 4294967295}, "dimensions": [5490, 5490], "crs": "EPSG:32611", "crs_transform": [20, 0, 399960, 0, -20, 3700020]}, "MSK_CLDPRB": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 255}, "dimensions": [5490, 5490], "crs": "EPSG:32611", "crs_transform": [20, 0, 399960, 0, -20, 3700020]}, "B1": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [1830, 1830], "crs": "EPSG:32611", "crs_transform": [60, 0, 399960, 0, -60, 3700020]}, "QA60": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [1830, 1830], "crs": "EPSG:32611", "crs_transform": [60, 0, 399960, 0, -60, 3700020]}, "B2": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [10980, 10980], "crs": "EPSG:32611", "crs_transform": [10, 0, 399960, 0, -10, 3700020]}, "WVP": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [10980, 10980], "crs": "EPSG:32611", "crs_transform": [10, 0, 399960, 0, -10, 3700020]}, "B3": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [10980, 10980], "crs": "EPSG:32611", "crs_transform": [10, 0, 399960, 0, -10, 3700020]}, "B4": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [10980, 10980], "crs": "EPSG:32611", "crs_transform": [10, 0, 399960, 0, -10, 3700020]}, "TCI_R": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 255}, "dimensions": [10980, 10980], "crs": "EPSG:32611", "crs_transform": [10, 0, 399960, 0, -10, 3700020]}, "B5": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [5490, 5490], "crs": "EPSG:32611", "crs_transform": [20, 0, 399960, 0, -20, 3700020]}, "B6": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [5490, 5490], "crs": "EPSG:32611", "crs_transform": [20, 0, 399960, 0, -20, 3700020]}, "B7": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [5490, 5490], "crs": "EPSG:32611", "crs_transform": [20, 0, 399960, 0, -20, 3700020]}, "B8": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [10980, 10980], "crs": "EPSG:32611", "crs_transform": [10, 0, 399960, 0, -10, 3700020]}, "B9": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [1830, 1830], "crs": "EPSG:32611", "crs_transform": [60, 0, 399960, 0, -60, 3700020]}, "MSK_SNWPRB": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 255}, "dimensions": [5490, 5490], "crs": "EPSG:32611", "crs_transform": [20, 0, 399960, 0, -20, 3700020]}, "SCL": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 255}, "dimensions": [5490, 5490], "crs": "EPSG:32611", "crs_transform": [20, 0, 399960, 0, -20, 3700020]}, "TCI_G": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 255}, "dimensions": [10980, 10980], "crs": "EPSG:32611", "crs_transform": [10, 0, 399960, 0, -10, 3700020]}}, "system:footprint": {"type": "LinearRing", "coordinates": [[-116.89627078890739, 32.44924419582073], [-116.89626737417575, 32.44925898927161], [-116.8951045650362, 33.43953070750837], [-116.89514832775068, 33.439572205468394], [-116.89518786695115, 33.4396207947493], [-117.19903171329149, 33.43950534084795], [-117.19907137449762, 33.43947601948668], [-117.19912234182809, 33.4394638167782], [-117.19913391715312, 33.43944350739506], [-117.19977701964083, 33.43781570659065], [-117.20298056753498, 33.42698251736538], [-117.45058635492418, 32.49597355187842], [-117.45371990599305, 32.48405448940864], [-117.46060922124882, 32.45750750785478], [-117.4618614719186, 32.452630471418175], [-117.46248822151223, 32.44991226060029], [-117.46248078368934, 32.44846000384899], [-117.46243723209412, 32.4484187108336], [-117.46239789227033, 32.448370362912655], [-116.8963741294445, 32.44917659271196], [-116.89632497529956, 32.44921309854787], [-116.89627078890739, 32.44924419582073]]}, "system:id": "COPERNICUS/S2_SR/20200614T181931_20200614T183418_T11SMS", "system:index": "20200614T181931_20200614T183418_T11SMS", "system:time_end": 1592159723525, "system:time_start": 1592159723525, "system:version": 1592325576909226}
diff --git a/tests/test_datasets/test_collators.py b/tests/test_datasets/test_collators.py
index 31666f396..b35218bcb 100644
--- a/tests/test_datasets/test_collators.py
+++ b/tests/test_datasets/test_collators.py
@@ -37,8 +37,9 @@
# IMPORTS
# =====================================================================================================================
from collections import defaultdict
-from typing import Any, Dict, List, Union
+from typing import Any
+import pytest
import torch
from numpy.testing import assert_array_equal
from torch import Tensor
@@ -50,12 +51,15 @@
# =====================================================================================================================
# TESTS
# =====================================================================================================================
-def test_get_collator() -> None:
- collator_params_1 = {"module": "torchgeo.datasets.utils", "name": "stack_samples"}
- collator_params_2 = {"name": "stack_sample_pairs"}
-
- assert callable(mdt.get_collator(collator_params_1))
- assert callable(mdt.get_collator(collator_params_2))
+@pytest.mark.parametrize(
+ "target",
+ (
+ "torchgeo.datasets.utils.stack_samples",
+ "minerva.datasets.collators.stack_sample_pairs",
+ ),
+)
+def test_get_collator(target: str) -> None:
+ assert callable(mdt.get_collator(target))
def test_stack_sample_pairs() -> None:
@@ -67,13 +71,13 @@ def test_stack_sample_pairs() -> None:
mask_2 = torch.randint(0, 8, (52, 52)) # type: ignore[attr-defined]
bbox_2 = [BoundingBox(0, 1, 0, 1, 0, 1)]
- sample_1: Dict[str, Union[Tensor, List[Any]]] = {
+ sample_1: dict[str, Tensor | list[Any]] = {
"image": image_1,
"mask": mask_1,
"bbox": bbox_1,
}
- sample_2: Dict[str, Union[Tensor, List[Any]]] = {
+ sample_2: dict[str, Tensor | list[Any]] = {
"image": image_2,
"mask": mask_2,
"bbox": bbox_2,
diff --git a/tests/test_datasets/test_factory.py b/tests/test_datasets/test_factory.py
index a3ac47387..37d720346 100644
--- a/tests/test_datasets/test_factory.py
+++ b/tests/test_datasets/test_factory.py
@@ -39,7 +39,7 @@
# =====================================================================================================================
from copy import deepcopy
from pathlib import Path
-from typing import Any, Dict
+from typing import Any
import pandas as pd
import pytest
@@ -56,12 +56,14 @@
# =====================================================================================================================
# TESTS
# =====================================================================================================================
-def test_make_dataset(exp_dataset_params: Dict[str, Any], data_root: Path) -> None:
+def test_make_dataset(exp_dataset_params: dict[str, Any], data_root: Path) -> None:
dataset_params2 = {
"image": {
- "image_1": exp_dataset_params["image"],
- "image_2": exp_dataset_params["image"],
+ "subdatasets": {
+ "image_1": exp_dataset_params["image"],
+ "image_2": exp_dataset_params["image"],
+ },
},
"mask": exp_dataset_params["mask"],
}
@@ -80,8 +82,12 @@ def test_make_dataset(exp_dataset_params: Dict[str, Any], data_root: Path) -> No
assert isinstance(subdatasets_4[0], UnionDataset)
dataset_params3 = dataset_params2
- dataset_params3["image"]["image_1"]["transforms"] = {"AutoNorm": {"length": 12}}
- dataset_params3["image"]["image_2"]["transforms"] = {"AutoNorm": {"length": 12}}
+ dataset_params3["image"]["subdatasets"]["image_1"]["transforms"] = {
+ "AutoNorm": {"length": 12}
+ }
+ dataset_params3["image"]["subdatasets"]["image_2"]["transforms"] = {
+ "AutoNorm": {"length": 12}
+ }
dataset_params3["image"]["transforms"] = {"AutoNorm": {"length": 12}}
dataset_5, subdatasets_5 = mdt.make_dataset(data_root, dataset_params3, cache=False)
@@ -111,7 +117,7 @@ def test_make_dataset(exp_dataset_params: Dict[str, Any], data_root: Path) -> No
@pytest.mark.parametrize("sample_pairs", (False, True))
def test_caching_datasets(
- exp_dataset_params: Dict[str, Any],
+ exp_dataset_params: dict[str, Any],
data_root: Path,
cache_dir: Path,
sample_pairs: bool,
@@ -160,47 +166,38 @@ def test_caching_datasets(
[
(
{
- "module": "torchgeo.samplers",
- "name": "RandomBatchGeoSampler",
+ "_target_": "torchgeo.samplers.RandomBatchGeoSampler",
"roi": False,
- "params": {
- "size": 224,
- "length": 4096,
- },
+ "size": 224,
+ "length": 4096,
},
{},
),
(
{
- "module": "minerva.samplers",
- "name": "RandomPairGeoSampler",
+ "_target_": "minerva.samplers.RandomPairGeoSampler",
"roi": False,
- "params": {
- "size": 224,
- "length": 4096,
- },
+ "size": 224,
+ "length": 4096,
},
{"sample_pairs": True},
),
(
{
- "module": "torchgeo.samplers",
- "name": "RandomBatchGeoSampler",
+ "_target_": "torchgeo.samplers.RandomBatchGeoSampler",
"roi": False,
- "params": {
- "size": 224,
- "length": 4096,
- },
+ "size": 224,
+ "length": 4096,
},
{"world_size": 2},
),
],
)
def test_construct_dataloader(
- exp_dataset_params: Dict[str, Any],
+ exp_dataset_params: dict[str, Any],
data_root: Path,
- sampler_params: Dict[str, Any],
- kwargs: Dict[str, Any],
+ sampler_params: dict[str, Any],
+ kwargs: dict[str, Any],
) -> None:
batch_size = 256
@@ -236,9 +233,10 @@ def test_make_loaders(default_config: DictConfig) -> None:
old_params_2 = OmegaConf.to_object(deepcopy(default_config))
assert isinstance(old_params_2, dict)
dataset_params = old_params_2["tasks"]["fit-val"]["dataset_params"].copy()
- old_params_2["tasks"]["fit-val"]["dataset_params"] = {}
- old_params_2["tasks"]["fit-val"]["dataset_params"]["val-1"] = dataset_params
- old_params_2["tasks"]["fit-val"]["dataset_params"]["val-2"] = dataset_params
+ old_params_2["tasks"]["fit-val"]["dataset_params"] = {
+ "val-1": dataset_params,
+ "val-2": dataset_params,
+ }
loaders, n_batches, class_dist, params = mdt.make_loaders( # type: ignore[arg-type]
**old_params_2,
@@ -255,9 +253,8 @@ def test_make_loaders(default_config: DictConfig) -> None:
def test_get_manifest(
data_root: Path,
- exp_dataset_params: Dict[str, Any],
- exp_loader_params: Dict[str, Any],
- exp_sampler_params: Dict[str, Any],
+ exp_dataset_params: dict[str, Any],
+ exp_sampler_params: dict[str, Any],
) -> None:
manifest_path = Path("tests", "tmp", "cache", "Chesapeake7_Manifest.csv")
diff --git a/tests/test_datasets/test_paired.py b/tests/test_datasets/test_paired.py
index cf916f7cf..e52a38049 100644
--- a/tests/test_datasets/test_paired.py
+++ b/tests/test_datasets/test_paired.py
@@ -38,8 +38,6 @@
# =====================================================================================================================
# IMPORTS
# =====================================================================================================================
-from typing import Tuple
-
import matplotlib.pyplot as plt
import pytest
from rasterio.crs import CRS
@@ -94,7 +92,6 @@ def test_paired_geodatasets(img_root: Path) -> None:
assert isinstance(dataset.crs, CRS)
assert isinstance(getattr(dataset, "crs"), CRS)
assert isinstance(dataset.dataset, TstImgDataset)
- assert isinstance(dataset.__getattr__("dataset"), TstImgDataset)
with pytest.raises(AttributeError):
_ = dataset.roi
@@ -163,7 +160,7 @@ def test_paired_nongeodatasets(data_root: Path) -> None:
def test_paired_concat_datasets(
- data_root: Path, small_patch_size: Tuple[int, int]
+ data_root: Path, small_patch_size: tuple[int, int]
) -> None:
def dataset_test(_dataset) -> None:
for sub_dataset in _dataset.datasets:
diff --git a/tests/test_logging.py b/tests/test_logging.py
index f265ee440..eda99d2bb 100644
--- a/tests/test_logging.py
+++ b/tests/test_logging.py
@@ -41,7 +41,7 @@
# =====================================================================================================================
# IMPORTS
# =====================================================================================================================
-from typing import Any, Dict, List, Tuple, Union
+from typing import Any
import numpy as np
import torch
@@ -62,7 +62,7 @@
from minerva.logger.steplog import SupervisedStepLogger
from minerva.logger.tasklog import SSLTaskLogger, SupervisedTaskLogger
from minerva.loss import SegBarlowTwinsLoss
-from minerva.modelio import ssl_pair_tg, sup_tg
+from minerva.modelio import ssl_pair_torchgeo_io, supervised_torchgeo_io
from minerva.models import FCN16ResNet18, MinervaSiamese, SimCLR18, SimConv
from minerva.utils import utils
@@ -80,7 +80,7 @@ def test_SupervisedStepLogger(
std_n_batches: int,
std_n_classes: int,
std_batch_size: int,
- small_patch_size: Tuple[int, int],
+ small_patch_size: tuple[int, int],
default_device: torch.device,
train: bool,
model_type: str,
@@ -134,29 +134,29 @@ def test_SupervisedStepLogger(
record_float=True,
writer=writer,
model_type=model_type,
- step_logger_params={"params": {"n_classes": std_n_classes}},
+ step_logger_params={"n_classes": std_n_classes},
)
- correct_loss: Dict[str, List[float]] = {"x": [], "y": []}
- correct_acc: Dict[str, List[float]] = {"x": [], "y": []}
- correct_miou: Dict[str, List[float]] = {"x": [], "y": []}
+ correct_loss: dict[str, list[float]] = {"x": [], "y": []}
+ correct_acc: dict[str, list[float]] = {"x": [], "y": []}
+ correct_miou: dict[str, list[float]] = {"x": [], "y": []}
for epoch_no in range(n_epochs):
- data: List[Dict[str, Union[Tensor, List[Any]]]] = []
+ data: list[dict[str, Tensor | list[Any]]] = []
for i in range(std_n_batches):
images = torch.rand(size=(std_batch_size, 4, *small_patch_size))
masks = torch.randint( # type: ignore[attr-defined]
0, std_n_classes, (std_batch_size, *small_patch_size)
)
bboxes = [simple_bbox] * std_batch_size
- batch: Dict[str, Union[Tensor, List[Any]]] = {
+ batch: dict[str, Tensor | list[Any]] = {
"image": images,
"mask": masks,
"bbox": bboxes,
}
data.append(batch)
- logger.step(i, i, *sup_tg(batch, model, device=default_device, train=train)) # type: ignore[arg-type]
+ logger.step(i, i, *supervised_torchgeo_io(batch, model, device=default_device, train=train)) # type: ignore[arg-type] # noqa: E501
logs = logger.get_logs
assert logs["batch_num"] == std_n_batches - 1
@@ -184,7 +184,7 @@ def test_SupervisedStepLogger(
(std_n_batches, std_batch_size, *output_shape), dtype=np.uint8
)
for i in range(std_n_batches):
- mask: Union[Tensor, List[Any]] = data[i]["mask"]
+ mask: Tensor | list[Any] = data[i]["mask"]
assert isinstance(mask, Tensor)
y[i] = mask.cpu().numpy()
@@ -239,7 +239,7 @@ def test_SSLStepLogger(
simple_bbox: BoundingBox,
std_n_batches: int,
std_batch_size: int,
- small_patch_size: Tuple[int, int],
+ small_patch_size: tuple[int, int],
default_device: torch.device,
model_cls: MinervaSiamese,
model_type: str,
@@ -284,9 +284,9 @@ def test_SSLStepLogger(
sample_pairs=True,
)
- correct_loss: Dict[str, List[float]] = {"x": [], "y": []}
- correct_collapse_level: Dict[str, List[float]] = {"x": [], "y": []}
- correct_euc_dist: Dict[str, List[float]] = {"x": [], "y": []}
+ correct_loss: dict[str, list[float]] = {"x": [], "y": []}
+ correct_collapse_level: dict[str, list[float]] = {"x": [], "y": []}
+ correct_euc_dist: dict[str, list[float]] = {"x": [], "y": []}
for epoch_no in range(n_epochs):
for i in range(std_n_batches):
@@ -297,7 +297,7 @@ def test_SSLStepLogger(
logger.step(
i,
i,
- *ssl_pair_tg((batch, batch), model, device=default_device, train=train), # type: ignore[arg-type]
+ *ssl_pair_torchgeo_io((batch, batch), model, device=default_device, train=train), # type: ignore[arg-type] # noqa: E501
)
logs = logger.get_logs
diff --git a/tests/test_modelio.py b/tests/test_modelio.py
index e7f77d300..b1164e6b7 100644
--- a/tests/test_modelio.py
+++ b/tests/test_modelio.py
@@ -37,7 +37,7 @@
# IMPORTS
# =====================================================================================================================
import importlib
-from typing import Any, Dict, List, Tuple, Union
+from typing import Any
import torch
import torch.nn.modules as nn
@@ -53,20 +53,20 @@
from torch import Tensor
from torchgeo.datasets.utils import BoundingBox
-from minerva.modelio import autoencoder_io, ssl_pair_tg, sup_tg
+from minerva.modelio import autoencoder_io, ssl_pair_torchgeo_io, supervised_torchgeo_io
from minerva.models import FCN32ResNet18, SimCLR34
# =====================================================================================================================
# TESTS
# =====================================================================================================================
-def test_sup_tg(
+def test_supervised_torchgeo_io(
simple_bbox: BoundingBox,
random_rgbi_batch: Tensor,
random_mask_batch: Tensor,
std_batch_size: int,
std_n_classes: int,
- rgbi_input_size: Tuple[int, int, int],
+ rgbi_input_size: tuple[int, int, int],
default_device: torch.device,
) -> None:
criterion = nn.CrossEntropyLoss()
@@ -76,13 +76,13 @@ def test_sup_tg(
for train in (True, False):
bboxes = [simple_bbox] * std_batch_size
- batch: Dict[str, Union[Tensor, List[Any]]] = {
+ batch: dict[str, Tensor | list[Any]] = {
"image": random_rgbi_batch,
"mask": random_mask_batch,
"bbox": bboxes,
}
- results = sup_tg(batch, model, default_device, train)
+ results = supervised_torchgeo_io(batch, model, default_device, train)
assert isinstance(results[0], Tensor)
assert isinstance(results[1], Tensor)
@@ -95,10 +95,10 @@ def test_sup_tg(
assert results[3] == batch["bbox"]
-def test_ssl_pair_tg(
+def test_ssl_pair_torchgeo_io(
simple_bbox: BoundingBox,
std_batch_size: int,
- rgbi_input_size: Tuple[int, int, int],
+ rgbi_input_size: tuple[int, int, int],
default_device: torch.device,
) -> None:
criterion = NTXentLoss(0.5)
@@ -123,7 +123,7 @@ def test_ssl_pair_tg(
"bbox": bboxes_2,
}
- results = ssl_pair_tg((batch_1, batch_2), model, default_device, train)
+ results = ssl_pair_torchgeo_io((batch_1, batch_2), model, default_device, train)
assert isinstance(results[0], Tensor)
assert isinstance(results[1], Tensor)
@@ -138,7 +138,7 @@ def test_mask_autoencoder_io(
simple_bbox: BoundingBox,
std_batch_size: int,
std_n_classes: int,
- rgbi_input_size: Tuple[int, int, int],
+ rgbi_input_size: tuple[int, int, int],
default_device: torch.device,
) -> None:
criterion = nn.CrossEntropyLoss()
@@ -152,7 +152,7 @@ def test_mask_autoencoder_io(
images = torch.rand(size=(std_batch_size, *rgbi_input_size))
masks = torch.randint(0, 8, (std_batch_size, *rgbi_input_size[1:])) # type: ignore[attr-defined]
bboxes = [simple_bbox] * std_batch_size
- batch: Dict[str, Union[Tensor, List[Any]]] = {
+ batch: dict[str, Tensor | list[Any]] = {
"image": images,
"mask": masks,
"bbox": bboxes,
@@ -191,7 +191,7 @@ def test_image_autoencoder_io(
random_rgbi_batch: Tensor,
random_mask_batch: Tensor,
std_batch_size: int,
- rgbi_input_size: Tuple[int, int, int],
+ rgbi_input_size: tuple[int, int, int],
default_device: torch.device,
) -> None:
criterion = nn.CrossEntropyLoss()
@@ -203,7 +203,7 @@ def test_image_autoencoder_io(
for train in (True, False):
bboxes = [simple_bbox] * std_batch_size
- batch: Dict[str, Union[Tensor, List[Any]]] = {
+ batch: dict[str, Tensor | list[Any]] = {
"image": random_rgbi_batch,
"mask": random_mask_batch,
"bbox": bboxes,
diff --git a/tests/test_models/test_core.py b/tests/test_models/test_core.py
index 6a2dd5841..181e504ea 100644
--- a/tests/test_models/test_core.py
+++ b/tests/test_models/test_core.py
@@ -39,7 +39,6 @@
import importlib
import os
from platform import python_version
-from typing import Tuple
import internet_sabotage
import numpy as np
@@ -114,7 +113,7 @@ def test_minerva_model(x_entropy_loss, std_n_classes: int, std_n_batches: int) -
assert z.size() == (std_n_batches, std_n_classes)
-def test_minerva_backbone(rgbi_input_size: Tuple[int, int, int]) -> None:
+def test_minerva_backbone(rgbi_input_size: tuple[int, int, int]) -> None:
loss_func = NTXentLoss(0.3)
model = SimCLR18(loss_func, input_size=rgbi_input_size)
@@ -124,7 +123,7 @@ def test_minerva_backbone(rgbi_input_size: Tuple[int, int, int]) -> None:
def test_minerva_wrapper(
x_entropy_loss,
- small_patch_size: Tuple[int, int],
+ small_patch_size: tuple[int, int],
std_n_classes: int,
std_n_batches: int,
) -> None:
diff --git a/tests/test_models/test_fcn.py b/tests/test_models/test_fcn.py
index 15cb845de..2cf91cfee 100644
--- a/tests/test_models/test_fcn.py
+++ b/tests/test_models/test_fcn.py
@@ -36,8 +36,6 @@
# =====================================================================================================================
# IMPORTS
# =====================================================================================================================
-from typing import Tuple
-
import pytest
import torch
from torch import Tensor
@@ -86,7 +84,7 @@ def test_fcn(
random_mask_batch: Tensor,
std_batch_size: int,
std_n_classes: int,
- rgbi_input_size: Tuple[int, int, int],
+ rgbi_input_size: tuple[int, int, int],
) -> None:
model: MinervaModel = model_cls(x_entropy_loss, input_size=rgbi_input_size)
optimiser = torch.optim.SGD(model.parameters(), lr=1.0e-3)
diff --git a/tests/test_models/test_resnets.py b/tests/test_models/test_resnets.py
index bb4f0fdd6..d6a1f34ca 100644
--- a/tests/test_models/test_resnets.py
+++ b/tests/test_models/test_resnets.py
@@ -36,8 +36,6 @@
# =====================================================================================================================
# IMPORTS
# =====================================================================================================================
-from typing import Tuple
-
import pytest
import torch
from torch import LongTensor, Tensor
@@ -94,7 +92,7 @@ def test_resnets(
random_scene_classification_batch: LongTensor,
std_batch_size: int,
std_n_classes: int,
- rgbi_input_size: Tuple[int, int, int],
+ rgbi_input_size: tuple[int, int, int],
) -> None:
model: MinervaModel = model_cls(
x_entropy_loss, input_size=rgbi_input_size, zero_init_residual=zero_init
@@ -121,7 +119,7 @@ def test_replace_stride(
random_scene_classification_batch: LongTensor,
std_batch_size: int,
std_n_classes: int,
- rgbi_input_size: Tuple[int, int, int],
+ rgbi_input_size: tuple[int, int, int],
) -> None:
for model in (
ResNet50(x_entropy_loss, input_size=rgbi_input_size),
@@ -144,7 +142,7 @@ def test_replace_stride(
def test_resnet_encoder(
x_entropy_loss,
random_rgbi_batch: Tensor,
- rgbi_input_size: Tuple[int, int, int],
+ rgbi_input_size: tuple[int, int, int],
) -> None:
encoder = ResNet18(x_entropy_loss, input_size=rgbi_input_size, encoder=True)
optimiser = torch.optim.SGD(encoder.parameters(), lr=1.0e-3)
@@ -157,7 +155,7 @@ def test_resnet_encoder(
assert len(encoder(random_rgbi_batch)) == 5
-def test_preload_weights(rgbi_input_size: Tuple[int, int, int]) -> None:
+def test_preload_weights(rgbi_input_size: tuple[int, int, int]) -> None:
resnet = ResNet(BasicBlock, [2, 2, 2, 2])
new_resnet = _preload_weights(resnet, None, rgbi_input_size, encoder_on=False)
diff --git a/tests/test_models/test_unet.py b/tests/test_models/test_unet.py
index 18ee9df1a..0c4880b92 100644
--- a/tests/test_models/test_unet.py
+++ b/tests/test_models/test_unet.py
@@ -36,8 +36,6 @@
# =====================================================================================================================
# IMPORTS
# =====================================================================================================================
-from typing import Tuple
-
import pytest
import torch
from torch import Tensor
@@ -62,7 +60,7 @@ def unet_test(
y: Tensor,
batch_size: int,
n_classes: int,
- input_size: Tuple[int, int, int],
+ input_size: tuple[int, int, int],
) -> None:
optimiser = torch.optim.SGD(model.parameters(), lr=1.0e-3)
model.set_optimiser(optimiser)
@@ -97,7 +95,7 @@ def test_unetrs(
random_mask_batch: Tensor,
std_batch_size: int,
std_n_classes: int,
- rgbi_input_size: Tuple[int, int, int],
+ rgbi_input_size: tuple[int, int, int],
) -> None:
model: MinervaModel = model_cls(x_entropy_loss, rgbi_input_size)
@@ -117,7 +115,7 @@ def test_unet(
random_mask_batch: Tensor,
std_batch_size: int,
std_n_classes: int,
- rgbi_input_size: Tuple[int, int, int],
+ rgbi_input_size: tuple[int, int, int],
) -> None:
model = UNet(x_entropy_loss, input_size=rgbi_input_size)
unet_test(
diff --git a/tests/test_samplers.py b/tests/test_samplers.py
index e01d697ed..42cc08cd3 100644
--- a/tests/test_samplers.py
+++ b/tests/test_samplers.py
@@ -38,7 +38,7 @@
# =====================================================================================================================
from collections import defaultdict
from pathlib import Path
-from typing import Any, Dict
+from typing import Any
import pytest
from torch.utils.data import DataLoader
@@ -60,7 +60,7 @@ def test_randompairgeosampler(img_root: Path) -> None:
dataset = PairedGeoDataset(TstImgDataset, str(img_root), res=1.0)
sampler = RandomPairGeoSampler(dataset, size=32, length=32, max_r=52)
- loader: DataLoader[Dict[str, Any]] = DataLoader(
+ loader: DataLoader[dict[str, Any]] = DataLoader(
dataset, batch_size=8, sampler=sampler, collate_fn=stack_sample_pairs
)
@@ -78,7 +78,7 @@ def test_randompairbatchgeosampler(img_root: Path) -> None:
sampler = RandomPairBatchGeoSampler(
dataset, size=32, length=32, batch_size=8, max_r=52, tiles_per_batch=1
)
- loader: DataLoader[Dict[str, Any]] = DataLoader(
+ loader: DataLoader[dict[str, Any]] = DataLoader(
dataset, batch_sampler=sampler, collate_fn=stack_sample_pairs
)
diff --git a/tests/test_tasks/test_tasks_core.py b/tests/test_tasks/test_tasks_core.py
index f2924fc7f..353fb3b04 100644
--- a/tests/test_tasks/test_tasks_core.py
+++ b/tests/test_tasks/test_tasks_core.py
@@ -45,4 +45,4 @@
# =====================================================================================================================
def test_get_task():
with pytest.raises(TypeError):
- _ = get_task("MinervaTask")
+ _ = get_task("minerva.tasks.MinervaTask")
diff --git a/tests/test_trainer.py b/tests/test_trainer.py
index 7015b0a8f..e025728bc 100644
--- a/tests/test_trainer.py
+++ b/tests/test_trainer.py
@@ -38,7 +38,7 @@
import shutil
from copy import deepcopy
from pathlib import Path
-from typing import Any, Dict, Optional, Union
+from typing import Any, Optional
import hydra
import pytest
@@ -56,9 +56,7 @@
# TESTS
# =====================================================================================================================
@runner.distributed_run
-def run_trainer(
- gpu: int, wandb_run: Optional[Union[Run, RunDisabled]], cfg: DictConfig
-):
+def run_trainer(gpu: int, wandb_run: Optional[Run | RunDisabled], cfg: DictConfig):
params = deepcopy(cfg)
params["calc_norm"] = True
@@ -145,10 +143,17 @@ def test_trainer_3(default_config: DictConfig) -> None:
params1 = deepcopy(default_config)
trainer1 = Trainer(0, **params1)
- trainer1.save_model(fn=trainer1.get_model_cache_path())
+
+ pre_train_path = trainer1.get_model_cache_path()
+ trainer1.save_model(fn=pre_train_path, fmt="onnx")
params2 = deepcopy(default_config)
- OmegaConf.update(params2, "pre_train_name", params1["model_name"], force_add=True)
+ OmegaConf.update(
+ params2,
+ "pre_train_name",
+ str(pre_train_path.with_suffix(".onnx")),
+ force_add=True,
+ )
params2["fine_tune"] = True
params2["max_epochs"] = 2
params2["elim"] = False
@@ -177,8 +182,8 @@ def test_trainer_3(default_config: DictConfig) -> None:
def test_trainer_4(
inbuilt_cfg_root: Path,
cfg_name: str,
- cfg_args: Dict[str, Any],
- kwargs: Dict[str, Any],
+ cfg_args: dict[str, Any],
+ kwargs: dict[str, Any],
) -> None:
with hydra.initialize(version_base="1.3", config_path=str(inbuilt_cfg_root)):
diff --git a/tests/test_transforms.py b/tests/test_transforms.py
index 62d210c54..96a3c1f54 100644
--- a/tests/test_transforms.py
+++ b/tests/test_transforms.py
@@ -36,7 +36,7 @@
# =====================================================================================================================
# IMPORTS
# =====================================================================================================================
-from typing import Any, Dict, List, Optional
+from typing import Any, Optional
import pytest
import torch
@@ -77,7 +77,7 @@
],
)
def test_class_transform(
- example_matrix: Dict[int, int], input_mask: LongTensor, output: LongTensor
+ example_matrix: dict[int, int], input_mask: LongTensor, output: LongTensor
) -> None:
transform = ClassTransform(example_matrix)
@@ -276,7 +276,7 @@ def test_dublicator(
],
)
def test_tg_to_torch(
- transform, keys: Optional[List[str]], args: Any, in_img, expected
+ transform, keys: Optional[list[str]], args: Any, in_img, expected
) -> None:
transformation = (utils.tg_to_torch(transform, keys=keys))(args)
@@ -376,22 +376,21 @@ def test_init_auto_norm(default_image_dataset: RasterDataset, transforms) -> Non
and transforms is not None
):
with pytest.raises(TypeError):
- _ = init_auto_norm(default_image_dataset, params)
+ _ = init_auto_norm(default_image_dataset, **params)
else:
- dataset = init_auto_norm(default_image_dataset, params)
+ dataset = init_auto_norm(default_image_dataset, **params)
assert isinstance(dataset, RasterDataset)
assert isinstance(dataset.transforms.transforms[-1], AutoNorm) # type: ignore[union-attr]
def test_get_transform() -> None:
- name = "RandomResizedCrop"
- params = {"module": "torchvision.transforms", "size": 128}
- transform = get_transform(name, params)
+ params = {"_target_": "torchvision.transforms.RandomResizedCrop", "size": 128}
+ transform = get_transform(params)
assert callable(transform)
with pytest.raises(TypeError):
- _ = get_transform("DataFrame", {"module": "pandas"})
+ _ = get_transform({"_target_": "pandas.DataFrame"})
@pytest.mark.parametrize(
@@ -399,35 +398,50 @@ def test_get_transform() -> None:
[
(
{
- "CenterCrop": {"module": "torchvision.transforms", "size": 128},
- "RandomHorizontalFlip": {"module": "torchvision.transforms", "p": 0.7},
+ "crop": {"_target_": "torchvision.transforms.CenterCrop", "size": 128},
+ "flip": {
+ "_target_": "torchvision.transforms.RandomHorizontalFlip",
+ "p": 0.7,
+ },
},
"mask",
),
(
{
"RandomApply": {
- "CenterCrop": {"module": "torchvision.transforms", "size": 128},
+ "crop": {
+ "_target_": "torchvision.transforms.CenterCrop",
+ "size": 128,
+ },
"p": 0.3,
},
- "RandomHorizontalFlip": {"module": "torchvision.transforms", "p": 0.7},
+ "flip": {
+ "_target_": "torchvision.transforms.RandomHorizontalFlip",
+ "p": 0.7,
+ },
},
"image",
),
(
{
- "CenterCrop": {"module": "torchvision.transforms", "size": 128},
+ "crop": {"_target_": "torchvision.transforms.CenterCrop", "size": 128},
"RandomApply": {
- "CenterCrop": {"module": "torchvision.transforms", "size": 128},
+ "crop": {
+ "_target_": "torchvision.transforms.CenterCrop",
+ "size": 128,
+ },
"p": 0.3,
},
- "RandomHorizontalFlip": {"module": "torchvision.transforms", "p": 0.7},
+ "flip": {
+ "_target_": "torchvision.transforms.RandomHorizontalFlip",
+ "p": 0.7,
+ },
},
"image",
),
],
)
-def test_make_transformations(params: Dict[str, Any], key: str) -> None:
+def test_make_transformations(params: dict[str, Any], key: str) -> None:
if params:
transforms = make_transformations({key: params})
assert callable(transforms)
diff --git a/tests/test_utils/test_runner.py b/tests/test_utils/test_runner.py
index 0d07128e1..980f48eab 100644
--- a/tests/test_utils/test_runner.py
+++ b/tests/test_utils/test_runner.py
@@ -39,7 +39,7 @@
import os
import subprocess
import time
-from typing import Optional, Union
+from typing import Optional
import pytest
import requests
@@ -111,7 +111,7 @@ def test_config_env_vars(default_config: DictConfig) -> None:
@runner.distributed_run
def _run_func(
- gpu: int, wandb_run: Optional[Union[Run, RunDisabled]], cfg: DictConfig
+ gpu: int, wandb_run: Optional[Run | RunDisabled], cfg: DictConfig
) -> None:
time.sleep(0.5)
return
diff --git a/tests/test_utils/test_utils.py b/tests/test_utils/test_utils.py
index 34c6e6f1f..f88d9298e 100644
--- a/tests/test_utils/test_utils.py
+++ b/tests/test_utils/test_utils.py
@@ -45,7 +45,7 @@
from collections import defaultdict
from datetime import datetime
from pathlib import Path
-from typing import Any, Dict, List, Tuple, Union
+from typing import Any
import numpy as np
import pandas as pd
@@ -85,7 +85,7 @@ def test_is_notebook() -> None:
def test_return_updated_kwargs() -> None:
@utils.return_updated_kwargs
- def example_func(*args, **kwargs) -> Tuple[Any, Dict[str, Any]]:
+ def example_func(*args, **kwargs) -> tuple[Any, dict[str, Any]]:
_ = (
kwargs["update_1"] * kwargs["update_3"]
- kwargs["static_2"] / args[1] * args[0]
@@ -196,13 +196,13 @@ def test_ohe_labels() -> None:
assert_array_equal(correct_targets, targets)
-def test_empty_classes(exp_classes: Dict[int, str]) -> None:
+def test_empty_classes(exp_classes: dict[int, str]) -> None:
class_distribution = [(3, 321), (4, 112), (1, 671), (5, 456)]
assert utils.find_empty_classes(class_distribution, exp_classes) == [0, 2, 6, 7]
def test_eliminate_classes(
- exp_classes: Dict[int, str], exp_cmap_dict: Dict[int, str]
+ exp_classes: dict[int, str], exp_cmap_dict: dict[int, str]
) -> None:
empty = [0, 2, 7]
new_classes = {
@@ -246,12 +246,12 @@ def test_eliminate_classes(
],
)
def test_check_test_empty(
- exp_classes: Dict[int, str],
- in_labels: List[int],
- in_pred: List[int],
- out_labels: List[int],
- out_pred: List[int],
- out_classes: Dict[int, str],
+ exp_classes: dict[int, str],
+ in_labels: list[int],
+ in_pred: list[int],
+ out_labels: list[int],
+ out_pred: list[int],
+ out_classes: dict[int, str],
) -> None:
results = utils.check_test_empty(in_pred, in_labels, exp_classes)
@@ -260,7 +260,7 @@ def test_check_test_empty(
assert results[2] == out_classes
-def test_find_modes(exp_classes: Dict[int, str]) -> None:
+def test_find_modes(exp_classes: dict[int, str]) -> None:
labels = [1, 1, 3, 5, 1, 4, 1, 5, 3, 3]
class_dist = utils.find_modes(labels, plot=True)
@@ -370,12 +370,12 @@ def test_batch_flatten(x, exp_len: int) -> None:
],
)
def test_transform_coordinates(
- x: Union[List[float], float],
- y: Union[List[float], float],
+ x: list[float] | float,
+ y: list[float] | float,
src_crs: CRS,
dest_crs: CRS,
- exp_x: Union[List[float], float],
- exp_y: Union[List[float], float],
+ exp_x: list[float] | float,
+ exp_y: list[float] | float,
) -> None:
out_x, out_y = utils.transform_coordinates(x, y, src_crs, dest_crs)
@@ -451,9 +451,7 @@ def test_get_centre_loc() -> None:
(-77.844504, 166.707506, "McMurdo Station"), # McMurdo Station, Antartica.
],
)
-def test_lat_lon_to_loc(
- lat: Union[float, str], lon: Union[float, str], loc: str
-) -> None:
+def test_lat_lon_to_loc(lat: float | str, lon: float | str, loc: str) -> None:
try:
requests.head("http://www.google.com/", timeout=1.0)
except (requests.ConnectionError, requests.ReadTimeout):
@@ -546,7 +544,7 @@ def test_find_best_of() -> None:
assert scene == ["2018_06_21"]
-def test_modes_from_manifest(exp_classes: Dict[int, str]) -> None:
+def test_modes_from_manifest(exp_classes: dict[int, str]) -> None:
df = pd.DataFrame()
class_dist = [
@@ -608,7 +606,7 @@ def test_compute_roc_curves() -> None:
labels = [0, 3, 2, 1, 3, 2, 1, 0]
class_labels = [0, 1, 2, 3]
- fpr: Dict[Any, NDArray[Any, Any]] = {
+ fpr: dict[Any, NDArray[Any, Any]] = {
0: np.array([0.0, 0.0, 0.0, 0.5, 5.0 / 6.0, 1.0]),
1: np.array([0.0, 0.0, 1.0 / 6.0, 1.0 / 6.0, 1.0]),
2: np.array([0.0, 0.0, 1.0 / 6.0, 0.5, 0.5, 1.0]),
@@ -637,7 +635,7 @@ def test_compute_roc_curves() -> None:
"macro": np.array([0.0, 1.0 / 6.0, 0.5, 5.0 / 6.0, 1.0]),
}
- tpr: Dict[Any, NDArray[Any, Any]] = {
+ tpr: dict[Any, NDArray[Any, Any]] = {
0: np.array([0.0, 0.5, 1.0, 1.0, 1.0, 1.0]),
1: np.array([0.0, 0.5, 0.5, 1.0, 1.0]),
2: np.array([0.0, 0.5, 0.5, 0.5, 1.0, 1.0]),
diff --git a/tests/test_utils/test_visutils.py b/tests/test_utils/test_visutils.py
index e2820a3f0..8542a8616 100644
--- a/tests/test_utils/test_visutils.py
+++ b/tests/test_utils/test_visutils.py
@@ -40,7 +40,7 @@
import shutil
import tempfile
from pathlib import Path
-from typing import Any, Dict, Tuple
+from typing import Any
import matplotlib as mlp
import numpy as np
@@ -116,7 +116,7 @@ def test_get_mlp_cmap() -> None:
def test_discrete_heatmap(
- random_mask, exp_classes: Dict[int, str], exp_cmap_dict: Dict[int, str]
+ random_mask, exp_classes: dict[int, str], exp_cmap_dict: dict[int, str]
) -> None:
cmap = ListedColormap(exp_cmap_dict.values()) # type: ignore
visutils.discrete_heatmap(random_mask, list(exp_classes.values()), cmap_style=cmap)
@@ -151,8 +151,8 @@ def test_labelled_rgb_image(
random_mask,
random_image,
bounds_for_test_img,
- exp_classes: Dict[int, str],
- exp_cmap_dict: Dict[int, str],
+ exp_classes: dict[int, str],
+ exp_cmap_dict: dict[int, str],
) -> None:
path = tempfile.gettempdir()
name = "pretty_pic"
@@ -174,7 +174,7 @@ def test_labelled_rgb_image(
def test_make_gif(
- bounds_for_test_img, exp_classes: Dict[int, str], exp_cmap_dict: Dict[int, str]
+ bounds_for_test_img, exp_classes: dict[int, str], exp_cmap_dict: dict[int, str]
) -> None:
dates = ["2018-01-15", "2018-07-03", "2018-11-30"]
images = np.random.rand(3, 32, 32, 3)
@@ -209,7 +209,7 @@ def test_prediction_plot(
random_image,
random_mask,
bounds_for_test_img,
- exp_classes: Dict[int, str],
+ exp_classes: dict[int, str],
results_dir: Path,
) -> None:
pred = np.random.randint(0, 8, size=(32, 32))
@@ -229,9 +229,9 @@ def test_seg_plot(
results_root: Path,
data_root: Path,
default_dataset: GeoDataset,
- exp_dataset_params: Dict[str, Any],
- exp_classes: Dict[int, str],
- exp_cmap_dict: Dict[int, str],
+ exp_dataset_params: dict[str, Any],
+ exp_classes: dict[int, str],
+ exp_cmap_dict: dict[int, str],
cache_dir: Path,
monkeypatch,
) -> None:
@@ -271,7 +271,7 @@ def test_seg_plot(
def test_plot_subpopulations(
- exp_classes: Dict[int, str], exp_cmap_dict: Dict[int, str]
+ exp_classes: dict[int, str], exp_cmap_dict: dict[int, str]
) -> None:
class_dist = [(1, 25000), (0, 1300), (2, 100), (3, 2)]
@@ -305,7 +305,7 @@ def test_plot_history() -> None:
filename.unlink(missing_ok=True)
-def test_make_confusion_matrix(exp_classes: Dict[int, str]) -> None:
+def test_make_confusion_matrix(exp_classes: dict[int, str]) -> None:
batch_size = 2
patch_size = (32, 32)
@@ -346,11 +346,11 @@ def test_format_names() -> None:
def test_plot_results(
default_dataset: GeoDataset,
- exp_classes: Dict[int, str],
- exp_cmap_dict: Dict[int, str],
+ exp_classes: dict[int, str],
+ exp_cmap_dict: dict[int, str],
results_dir: Path,
- default_config: Dict[str, Any],
- small_patch_size: Tuple[int, int],
+ default_config: dict[str, Any],
+ small_patch_size: tuple[int, int],
std_batch_size: int,
std_n_classes: int,
) -> None:
@@ -420,7 +420,7 @@ def test_plot_results(
def test_plot_embeddings(
results_root: Path,
default_dataset: GeoDataset,
- exp_dataset_params: Dict[str, Any],
+ exp_dataset_params: dict[str, Any],
data_root: Path,
cache_dir: Path,
) -> None:
diff --git a/tox.ini b/tox.ini
index bfea17c61..6e9b2d9ec 100644
--- a/tox.ini
+++ b/tox.ini
@@ -1,6 +1,6 @@
[tox]
minversion = 3.8.0
-envlist = minerva-{312, 311, 310, 39}
+envlist = minerva-{312, 311, 310}
isolated_build = true
skip_missing_interpreters = true
@@ -9,7 +9,6 @@ python =
3.12: minerva-312
3.11: minerva-311
3.10: minerva-310
- 3.9: minerva-39
[testenv]
skip_install = true
@@ -33,7 +32,3 @@ deps = -r{toxinidir}/requirements/requirements_dev.txt
[testenv:minerva-310]
basepython = python3.10
deps = -r{toxinidir}/requirements/requirements_dev.txt
-
-[testenv:minerva-39]
-basepython = python3.9
-deps = -r{toxinidir}/requirements/requirements_dev.txt