diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 60d06fde0..8f5474d0f 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -15,7 +15,7 @@ jobs: fail-fast: false matrix: os: [ubuntu-latest, windows-latest, macos-latest] - python-version: ['3.9', '3.10', '3.11', '3.12'] + python-version: ['3.10', '3.11', '3.12'] exclude: - os: windows-latest python-version: '3.12' diff --git a/.gitignore b/.gitignore index a4d961a62..3f9bbe3c3 100644 --- a/.gitignore +++ b/.gitignore @@ -152,8 +152,7 @@ lightning_logs/ **/results/* **/cache/* **/outputs/* -tests/tmp/cache/* -tests/tmp/results/* +tests/tmp/* user_cfgs/* inbuilt_cfgs/config.yml @@ -162,8 +161,14 @@ inbuilt_cfgs/mf_config.yml .vscode/settings.json # Exclusions +!docs/ +docs/* +!docs/images/ +docs/images/* !docs/images/*.png + !tests/fixtures/data/**/*.png !tests/fixtures/data/**/*.txt !tests/fixtures/data/**/*.csv !tests/fixtures/data/**/*.tif +!tests/fixtures/data/**/*.json diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 772a54862..4917e48f1 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -18,7 +18,7 @@ repos: - id: requirements-txt-fixer - repo: https://github.com/psf/black - rev: 24.4.2 + rev: 24.8.0 hooks: - id: black diff --git a/README.md b/README.md index 4beccd94b..3b6c917e4 100644 --- a/README.md +++ b/README.md @@ -57,7 +57,7 @@ datasets with upcoming support for [torchvision](https://pytorch.org/vision/stab Required Python modules for `minerva` are stated in the `setup.cfg`. -`minerva` currently only supports `python` 3.9 -- 3.12. +`minerva` currently only supports `python` 3.10 -- 3.12.

(back to top)

diff --git a/minerva/datasets/collators.py b/minerva/datasets/collators.py index 7d00b88f3..2534e2b15 100644 --- a/minerva/datasets/collators.py +++ b/minerva/datasets/collators.py @@ -39,45 +39,36 @@ # ===================================================================================================================== # IMPORTS # ===================================================================================================================== -from typing import Any, Callable, Dict, Iterable, Optional, Tuple +from typing import Any, Callable, Iterable +from hydra.utils import get_method from torchgeo.datasets.utils import stack_samples -from minerva.utils import utils - # ===================================================================================================================== # METHODS # ===================================================================================================================== def get_collator( - collator_params: Optional[Dict[str, str]] = None + collator_target: str = "torchgeo.datasets.stack_samples", ) -> Callable[..., Any]: """Gets the function defined in parameters to collate samples together to form a batch. Args: - collator_params (dict[str, str]): Optional; Dictionary that must contain keys for - ``'module'`` and ``'name'`` of the collation function. Defaults to ``config['collator']``. + collator_target (str): Dot based import path for collator method. + Defaults to :meth:`torchgeo.datasets.stack_samples` Returns: - ~typing.Callable[..., ~typing.Any]: Collation function found from parameters given. + ~typing.Callable[..., ~typing.Any]: Collation function found from target path given. """ collator: Callable[..., Any] - if collator_params is not None: - module = collator_params.pop("module", "") - if module == "": - collator = globals()[collator_params["name"]] - else: - collator = utils.func_by_str(module, collator_params["name"]) - else: - collator = stack_samples - + collator = get_method(collator_target) assert callable(collator) return collator def stack_sample_pairs( - samples: Iterable[Tuple[Dict[Any, Any], Dict[Any, Any]]] -) -> Tuple[Dict[Any, Any], Dict[Any, Any]]: + samples: Iterable[tuple[dict[Any, Any], dict[Any, Any]]] +) -> tuple[dict[Any, Any], dict[Any, Any]]: """Takes a list of paired sample dicts and stacks them into a tuple of batches of sample dicts. Args: diff --git a/minerva/datasets/dfc.py b/minerva/datasets/dfc.py index bbc5bf277..49f8255f7 100644 --- a/minerva/datasets/dfc.py +++ b/minerva/datasets/dfc.py @@ -38,7 +38,7 @@ # ===================================================================================================================== from glob import glob from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Callable, Optional import numpy as np import pandas as pd @@ -84,7 +84,7 @@ class BaseSenS12MS(NonGeoDataset): S2_BANDS_MR = [5, 6, 7, 9, 12, 13] S2_BANDS_LR = [1, 10, 11] - splits: List[str] = [] + splits: list[str] = [] igbp = False @@ -133,13 +133,13 @@ def __init__( # Make sure parent dir exists. assert self.root.exists() - self.samples: List[Dict[str, str]] + self.samples: list[dict[str, str]] def load_sample( self, - sample: Dict[str, str], + sample: dict[str, str], index: int, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """Util function for reading data from single sample. Args: @@ -192,7 +192,7 @@ def get_ninputs(self) -> int: n_inputs += 2 return n_inputs - def get_display_channels(self) -> Tuple[List[int], int]: + def get_display_channels(self) -> tuple[list[int], int]: """Select channels for preview images. Returns: @@ -303,7 +303,7 @@ def load_lc(self, path: str) -> Tensor: return lc - def __getitem__(self, index: int) -> Dict[str, Any]: + def __getitem__(self, index: int) -> dict[str, Any]: """Get a single example from the dataset""" # Get and load sample from index file. @@ -393,11 +393,11 @@ def __init__( def plot( self, - sample: Dict[str, Tensor], + sample: dict[str, Tensor], show_titles: bool = True, suptitle: Optional[str] = None, - classes: Optional[Dict[int, str]] = None, - colours: Optional[Dict[int, str]] = None, + classes: Optional[dict[int, str]] = None, + colours: Optional[dict[int, str]] = None, ) -> Figure: """Plot a sample from the dataset. @@ -420,10 +420,15 @@ def plot( # Reorder image from BGR to RGB. from kornia.color import BgrToRgb + from minerva.transforms import AdjustGamma + bgr_to_rgb = BgrToRgb() + adjust_gamma = AdjustGamma(gamma=0.8) + + # Reorder channels from BGR to RGB and adjust the gamma ready for plotting. + image = adjust_gamma(bgr_to_rgb(sample["image"][:3])).permute(1, 2, 0).numpy() - image = bgr_to_rgb(sample["image"][:3]) - image = image.permute(1, 2, 0).numpy() + block_size_factor = 8 # Use inbuilt class colours and classes mappings. if classes is None or colours is None: @@ -454,7 +459,17 @@ def plot( # Plot the image. axs[0].imshow(image) - axs[0].axis("off") + + # Sets tick intervals to block size. + axs[0].set_xticks( + np.arange(0, image.shape[0] + 1, image.shape[0] // block_size_factor) + ) + axs[0].set_yticks( + np.arange(0, image.shape[1] + 1, image.shape[1] // block_size_factor) + ) + + # Add grid overlay. + axs[0].grid(which="both", color="#CCCCCC", linestyle=":") # Plot the ground truth mask and predicted mask. mask_plot: Axes @@ -462,22 +477,54 @@ def plot( mask_plot = axs[1].imshow( mask, cmap=cmap, vmin=vmin, vmax=vmax, interpolation="none" ) - axs[1].axis("off") + + # Sets tick intervals to block size. + axs[1].set_xticks( + np.arange(0, mask.shape[0] + 1, mask.shape[0] // block_size_factor) + ) + axs[1].set_yticks( + np.arange(0, mask.shape[1] + 1, mask.shape[1] // block_size_factor) + ) + + # Add grid overlay. + axs[1].grid(which="both", color="#CCCCCC", linestyle=":") + if showing_prediction: axs[2].imshow( pred, cmap=cmap, vmin=vmin, vmax=vmax, interpolation="none" ) - axs[2].axis("off") + + # Sets tick intervals to block size. + axs[2].set_xticks( + np.arange(0, pred.shape[0] + 1, pred.shape[0] // block_size_factor) + ) + axs[2].set_yticks( + np.arange(0, pred.shape[1] + 1, pred.shape[1] // block_size_factor) + ) + + # Add grid overlay. + axs[2].grid(which="both", color="#CCCCCC", linestyle=":") + elif showing_prediction: mask_plot = axs[1].imshow( pred, cmap=cmap, vmin=vmin, vmax=vmax, interpolation="none" ) - axs[1].axis("off") + + # Sets tick intervals to block size. + axs[1].set_xticks( + np.arange(0, pred.shape[0] + 1, pred.shape[0] // block_size_factor) + ) + axs[1].set_yticks( + np.arange(0, pred.shape[1] + 1, pred.shape[1] // block_size_factor) + ) + + # Add grid overlay. + axs[1].grid(which="both", color="#CCCCCC", linestyle=":") if showing_mask or showing_prediction: # Plots colour bar onto figure. clb = fig.colorbar( - mask_plot, + mask_plot, # type: ignore[arg-type] ax=axs, location="top", ticks=np.arange(0, len(colours)), diff --git a/minerva/datasets/factory.py b/minerva/datasets/factory.py index eaaedf027..c3e4b3adb 100644 --- a/minerva/datasets/factory.py +++ b/minerva/datasets/factory.py @@ -49,10 +49,10 @@ import re from copy import deepcopy from datetime import timedelta -from inspect import signature from pathlib import Path -from typing import Any, Callable, Dict, Iterable, List, Literal, Optional, Tuple, Union +from typing import Any, Iterable, Optional +import hydra import numpy as np import pandas as pd import torch @@ -60,17 +60,15 @@ from omegaconf import OmegaConf from pandas import DataFrame from rasterio.crs import CRS -from torch.utils.data import DataLoader, Sampler +from torch.utils.data import DataLoader from torchgeo.datasets import GeoDataset, NonGeoDataset, RasterDataset -from torchgeo.samplers import BatchGeoSampler, GeoSampler from torchgeo.samplers.utils import _to_tuple -from minerva.samplers import DistributedSamplerWrapper +from minerva.samplers import DistributedSamplerWrapper, get_sampler from minerva.transforms import MinervaCompose, init_auto_norm, make_transformations from minerva.utils import universal_path, utils from .collators import get_collator, stack_sample_pairs -from .paired import PairedGeoDataset, PairedNonGeoDataset from .utils import ( MinervaConcatDataset, cache_dataset, @@ -78,7 +76,6 @@ intersect_datasets, load_all_samples, load_dataset_from_cache, - make_bounding_box, masks_or_labels, unionise_datasets, ) @@ -88,92 +85,67 @@ # METHODS # ===================================================================================================================== def create_subdataset( - dataset_class: Union[Callable[..., GeoDataset], Callable[..., NonGeoDataset]], - paths: Union[str, Iterable[str]], - subdataset_params: Dict[Literal["params"], Dict[str, Any]], + paths: str | Iterable[str], + subdataset_params: dict[str, Any], transformations: Optional[Any], sample_pairs: bool = False, -) -> Union[GeoDataset, NonGeoDataset]: +) -> GeoDataset | NonGeoDataset: """Creates a sub-dataset based on the parameters supplied. Args: - dataset_class (Callable[..., ~typing.datasets.GeoDataset]): Constructor for the sub-dataset. paths (str | ~typing.Iterable[str]): Paths to where the data for the dataset is located. subdataset_params (dict[Literal[params], dict[str, ~typing.Any]]): Parameters for the sub-dataset. transformations (~typing.Any): Transformations to apply to this sub-dataset. sample_pairs (bool): Will configure the dataset for paired sampling. Defaults to False. Returns: - ~torchgeo.datasets.GeoDataset: Subdataset requested. + ~torchgeo.datasets.GeoDataset | ~torchgeo.datasets.NonGeoDataset: Subdataset requested. """ - copy_params = deepcopy(subdataset_params) + params = deepcopy(subdataset_params) - if "crs" in copy_params["params"]: - copy_params["params"]["crs"] = CRS.from_epsg(copy_params["params"]["crs"]) + crs = None + if "crs" in params: + crs = CRS.from_epsg(params["crs"]) - if sample_pairs: - if "paths" in signature(dataset_class).parameters: - return PairedGeoDataset( - dataset_class, # type: ignore[arg-type] - paths=paths, - transforms=transformations, - **copy_params["params"], - ) + subdataset: GeoDataset | NonGeoDataset + if "paths" in params: + if sample_pairs: + params["dataset"] = params["_target_"] + params["_target_"] = "minerva.datasets.paired.PairedGeoDataset" + + subdataset = hydra.utils.instantiate( + params, paths=paths, transforms=transformations, crs=crs + ) + + elif "root" in params: + if isinstance(paths, list): + paths = paths[0] + assert isinstance(paths, str) + + if sample_pairs: + params["dataset"] = params["_target_"] + params["_target_"] = "minerva.datasets.paired.PairedNonGeoDataset" + + subdataset = hydra.utils.instantiate( + params, root=paths, transforms=transformations + ) - elif "season_transform" in signature(dataset_class).parameters: - if isinstance(paths, list): - paths = paths[0] - assert isinstance(paths, str) - del copy_params["params"]["season_transform"] - return PairedNonGeoDataset( - dataset_class, # type: ignore[arg-type] - root=paths, - transforms=transformations, - season=True, - season_transform="pair", - **copy_params["params"], - ) - elif "root" in signature(dataset_class).parameters: - if isinstance(paths, list): - paths = paths[0] - assert isinstance(paths, str) - return PairedNonGeoDataset( - dataset_class, # type: ignore[arg-type] - root=paths, - transforms=transformations, - **copy_params["params"], - ) - else: - raise TypeError else: - if "paths" in signature(dataset_class).parameters: - return dataset_class( - paths=paths, - transforms=transformations, - **copy_params["params"], - ) - elif "root" in signature(dataset_class).parameters: - if isinstance(paths, list): - paths = paths[0] - assert isinstance(paths, str) - return dataset_class( - root=paths, - transforms=transformations, - **copy_params["params"], - ) - else: - raise TypeError + raise TypeError + + return subdataset def get_subdataset( - data_directory: Union[Iterable[str], str, Path], - dataset_params: Dict[str, Any], + data_directory: Iterable[str] | str | Path, + dataset_params: dict[str, Any], key: str, transformations: Optional[Any], sample_pairs: bool = False, cache: bool = True, - cache_dir: Union[str, Path] = "", -) -> Union[GeoDataset, NonGeoDataset]: + cache_dir: str | Path = "", + auto_norm: bool = False, +) -> GeoDataset | NonGeoDataset: """Get a subdataset based on the parameters specified. If ``cache==True``, this will attempt to load a cached version of the dataset instance. @@ -196,17 +168,13 @@ def get_subdataset( # Get the params for this sub-dataset. sub_dataset_params = dataset_params[key] - # Get the constructor for the class of dataset defined in params. - _sub_dataset: Callable[..., GeoDataset] = utils.func_by_str( - module_path=sub_dataset_params["module"], func=sub_dataset_params["name"] - ) - # Construct the path to the sub-dataset's files. sub_dataset_paths = utils.compile_dataset_paths( - universal_path(data_directory), sub_dataset_params["paths"] + universal_path(data_directory), + sub_dataset_params.get("paths", sub_dataset_params.get("root")), ) - sub_dataset: Optional[Union[GeoDataset, NonGeoDataset]] + sub_dataset: Optional[GeoDataset | NonGeoDataset] if cache or sub_dataset_params.get("cache_dataset"): this_hash = utils.make_hash(sub_dataset_params) @@ -230,7 +198,6 @@ def get_subdataset( if rank == 0: print(f"\nCreating dataset on {rank}...") sub_dataset = create_subdataset( - _sub_dataset, sub_dataset_paths, sub_dataset_params, transformations, @@ -255,7 +222,6 @@ def get_subdataset( else: print("\nCreating dataset...") sub_dataset = create_subdataset( - _sub_dataset, sub_dataset_paths, sub_dataset_params, transformations, @@ -267,7 +233,6 @@ def get_subdataset( else: sub_dataset = create_subdataset( - _sub_dataset, sub_dataset_paths, sub_dataset_params, transformations, @@ -275,16 +240,31 @@ def get_subdataset( ) assert sub_dataset is not None + + if auto_norm: + if isinstance(sub_dataset, RasterDataset): + init_auto_norm( + sub_dataset, + length=sub_dataset_params.get("length"), + roi=sub_dataset_params.get("roi"), + ) + else: + raise TypeError( # pragma: no cover + "AutoNorm only supports normalisation of data " + + f"from RasterDatasets, not {type(sub_dataset)}!" + ) + return sub_dataset def make_dataset( - data_directory: Union[Iterable[str], str, Path], - dataset_params: Dict[Any, Any], + data_directory: Iterable[str] | str | Path, + dataset_params: dict[Any, Any], sample_pairs: bool = False, + change_detection: bool = False, cache: bool = True, - cache_dir: Union[str, Path] = "", -) -> Tuple[Any, List[Any]]: + cache_dir: str | Path = "", +) -> tuple[Any, list[Any]]: """Constructs a dataset object from ``n`` sub-datasets given by the parameters supplied. Args: @@ -293,6 +273,8 @@ def make_dataset( dataset_params (dict[~typing.Any, ~typing.Any]): Dictionary of parameters defining each sub-datasets to be used. sample_pairs (bool): Optional; ``True`` if paired sampling. This will ensure paired samples are handled correctly in the datasets. + change_detection (bool): Flag for a change detection dataset which has + ``"image1"`` and ``"image2"`` keys rather than ``"image"``. cache (bool): Cache the dataset or load from cache if pre-existing. Defaults to True. cache_dir (str | ~pathlib.Path): Path to the directory to save the cached dataset (if ``cache==True``). Defaults to CWD. @@ -303,9 +285,7 @@ def make_dataset( """ # --+ MAKE SUB-DATASETS +=========================================================================================+ # List to hold all the sub-datasets defined by dataset_params to be intersected together into a single dataset. - sub_datasets: Union[ - List[GeoDataset], List[Union[NonGeoDataset, MinervaConcatDataset]] - ] = [] + sub_datasets: list[GeoDataset] | list[NonGeoDataset | MinervaConcatDataset] = [] if OmegaConf.is_config(dataset_params): dataset_params = OmegaConf.to_object(dataset_params) # type: ignore[assignment] @@ -339,74 +319,75 @@ def make_dataset( add_target_transforms = type_dataset_params["transforms"] continue - type_subdatasets: Union[List[GeoDataset], List[NonGeoDataset]] = [] + type_subdatasets: list[GeoDataset] | list[NonGeoDataset] = [] multi_datasets_exist = False - auto_norm = None + auto_norm = False master_transforms: Optional[Any] = None - for area_key in type_dataset_params.keys(): + + for sub_type_key in type_dataset_params.keys(): # If any of these keys are present, this must be a parameter set for a singular dataset at this level. - if area_key in ("module", "name", "params", "paths"): + if sub_type_key in ("_target_", "paths", "root"): multi_datasets_exist = False continue # If there are transforms specified, make them. These could cover a single dataset or many. - elif area_key == "transforms": - if isinstance(type_dataset_params[area_key], dict): - transform_params = type_dataset_params[area_key] - auto_norm = transform_params.get("AutoNorm") + elif sub_type_key == "transforms": + if isinstance(type_dataset_params[sub_type_key], dict): + transform_params = type_dataset_params[sub_type_key] + auto_norm = transform_params.get("AutoNorm", False) # If transforms aren't specified for a particular modality of the sample, # assume they're for the same type as the dataset. if ( not ("image", "mask", "label") - in type_dataset_params[area_key].keys() + in type_dataset_params[sub_type_key].keys() ): - transform_params = {type_key: type_dataset_params[area_key]} + transform_params = {type_key: type_dataset_params[sub_type_key]} else: transform_params = False - master_transforms = make_transformations(transform_params) + master_transforms = make_transformations( + transform_params, change_detection=change_detection + ) # Assuming that these keys are names of datasets. - else: + elif sub_type_key == "subdatasets": multi_datasets_exist = True - if isinstance(type_dataset_params[area_key].get("transforms"), dict): - transform_params = type_dataset_params[area_key]["transforms"] - auto_norm = transform_params.get("AutoNorm") + if isinstance( + type_dataset_params[sub_type_key].get("transforms"), dict + ): + transform_params = type_dataset_params[sub_type_key]["transforms"] + auto_norm = transform_params.get("AutoNorm", False) else: transform_params = False - transformations = make_transformations({type_key: transform_params}) + transformations = make_transformations( + {type_key: transform_params}, change_detection=change_detection + ) # Send the params for this area key back through this function to make the sub-dataset. - sub_dataset = get_subdataset( - data_directory, - type_dataset_params, - area_key, - transformations, - sample_pairs=sample_pairs, - cache=cache, - cache_dir=cache_dir, - ) + for area_key in type_dataset_params[sub_type_key]: + sub_dataset = get_subdataset( + data_directory, + type_dataset_params[sub_type_key], + area_key, + transformations, + sample_pairs=sample_pairs, + cache=cache, + cache_dir=cache_dir, + auto_norm=auto_norm, + ) - # Performs an auto-normalisation initialisation which finds the mean and std of the dataset - # to make a transform, then adds the transform to the dataset's existing transforms. - if auto_norm: - if isinstance(sub_dataset, RasterDataset): - init_auto_norm(sub_dataset, auto_norm) - else: - raise TypeError( # pragma: no cover - "AutoNorm only supports normalisation of data " - + f"from RasterDatasets, not {type(sub_dataset)}!" - ) + # Reset back to False. + auto_norm = False - # Reset back to None. - auto_norm = None + type_subdatasets.append(sub_dataset) # type: ignore[arg-type] - type_subdatasets.append(sub_dataset) # type: ignore[arg-type] + else: + continue # Unionise all the sub-datsets of this modality together. if multi_datasets_exist: @@ -425,20 +406,11 @@ def make_dataset( sample_pairs=sample_pairs, cache=cache, cache_dir=cache_dir, + auto_norm=auto_norm, ) - # Performs an auto-normalisation initialisation which finds the mean and std of the dataset - # to make a transform, then adds the transform to the dataset's existing transforms. - if auto_norm: - if isinstance(sub_dataset, RasterDataset): - init_auto_norm(sub_dataset, auto_norm) - - # Reset back to None. - auto_norm = None - else: - raise TypeError( # pragma: no cover - f"AutoNorm only supports normalisation of data from RasterDatasets, not {type(sub_dataset)}!" - ) + # Reset back to False. + auto_norm = False sub_datasets.append(sub_dataset) # type: ignore[arg-type] @@ -450,14 +422,16 @@ def make_dataset( if add_target_transforms is not None: target_key = masks_or_labels(dataset_params) - target_transforms = make_transformations({target_key: add_target_transforms}) + target_transforms = make_transformations( + {target_key: add_target_transforms}, change_detection=change_detection + ) if hasattr(dataset, "transforms"): if isinstance(dataset.transforms, MinervaCompose): assert target_transforms is not None - dataset.transforms += target_transforms + dataset.transforms += target_transforms # type: ignore[union-attr] else: - dataset.transforms = target_transforms + dataset.transforms = target_transforms # type: ignore[union-attr] else: raise TypeError( f"dataset of type {type(dataset)} has no ``transforms`` atttribute!" @@ -465,14 +439,15 @@ def make_dataset( if add_multi_modal_transforms is not None: multi_modal_transforms = make_transformations( - {"both": add_multi_modal_transforms} + {"both": add_multi_modal_transforms}, + change_detection=change_detection, ) if hasattr(dataset, "transforms"): if isinstance(dataset.transforms, MinervaCompose): assert multi_modal_transforms is not None - dataset.transforms += multi_modal_transforms + dataset.transforms += multi_modal_transforms # type: ignore[union-attr] else: - dataset.transforms = multi_modal_transforms + dataset.transforms = multi_modal_transforms # type: ignore[union-attr] else: raise TypeError( f"dataset of type {type(dataset)} has no ``transforms`` atttribute!" @@ -482,17 +457,18 @@ def make_dataset( def construct_dataloader( - data_directory: Union[Iterable[str], str, Path], - dataset_params: Dict[str, Any], - sampler_params: Dict[str, Any], - dataloader_params: Dict[str, Any], + data_directory: Iterable[str] | str | Path, + dataset_params: dict[str, Any], + sampler_params: dict[str, Any], + dataloader_params: dict[str, Any], batch_size: int, - collator_params: Optional[Dict[str, Any]] = None, + collator_target: str = "torchgeo.datasets.stack_samples", rank: int = 0, world_size: int = 1, sample_pairs: bool = False, + change_detection: bool = False, cache: bool = True, - cache_dir: Union[Path, str] = "", + cache_dir: Path | str = "", ) -> DataLoader[Iterable[Any]]: """Constructs a :class:`~torch.utils.data.DataLoader` object from the parameters provided for the datasets, sampler, collator and transforms. @@ -505,12 +481,14 @@ def construct_dataloader( to sample from the dataset. dataloader_params (dict[str, ~typing.Any]): Dictionary of parameters for the DataLoader itself. batch_size (int): Number of samples per (global) batch. - collator_params (dict[str, ~typing.Any]): Optional; Dictionary of parameters defining the function to collate + collator_target (str): Import target path for collator function to collate and stack samples from the sampler. rank (int): Optional; The rank of this process for distributed computing. world_size (int): Optional; The total number of processes within a distributed run. sample_pairs (bool): Optional; True if paired sampling. This will wrap the collation function for paired samples. + change_detection (bool): Flag for if using a change detection dataset which has + ``"image1"`` and ``"image2"`` keys rather than ``"image"``. Returns: ~torch.utils.data.DataLoader: Object to handle the returning of batched samples from the dataset. @@ -519,42 +497,29 @@ def construct_dataloader( data_directory, dataset_params, sample_pairs=sample_pairs, + change_detection=change_detection, cache=cache, cache_dir=cache_dir, ) # --+ MAKE SAMPLERS +=============================================================================================+ - _sampler: Callable[..., Union[BatchGeoSampler, GeoSampler]] = utils.func_by_str( - module_path=sampler_params["module"], func=sampler_params["name"] - ) + per_device_batch_size = None - batch_sampler = True if re.search(r"Batch", sampler_params["name"]) else False + batch_sampler = True if re.search(r"Batch", sampler_params["_target_"]) else False if batch_sampler: - sampler_params["params"]["batch_size"] = batch_size - + per_device_batch_size = sampler_params.get("batch_size", batch_size) if dist.is_available() and dist.is_initialized(): # type: ignore[attr-defined] - assert ( - sampler_params["params"]["batch_size"] % world_size == 0 - ) # pragma: no cover + assert per_device_batch_size % world_size == 0 # pragma: no cover per_device_batch_size = ( - sampler_params["params"]["batch_size"] // world_size + per_device_batch_size // world_size ) # pragma: no cover - sampler_params["params"][ - "batch_size" - ] = per_device_batch_size # pragma: no cover - - sampler: Sampler[Any] - if "roi" in signature(_sampler).parameters: - sampler = _sampler( - subdatasets[0], - roi=make_bounding_box(sampler_params["roi"]), - **sampler_params["params"], - ) - else: - sampler = _sampler(subdatasets[0], **sampler_params["params"]) + + sampler = get_sampler( + sampler_params, subdatasets[0], batch_size=per_device_batch_size + ) # --+ MAKE DATALOADERS +==========================================================================================+ - collator = get_collator(collator_params) + collator = get_collator(collator_target) # Add batch size from top-level parameters to the dataloader parameters. dataloader_params["batch_size"] = batch_size @@ -588,11 +553,11 @@ def construct_dataloader( def _add_class_transform( - class_matrix: Dict[int, int], dataset_params: Dict[str, Any], target_key: str -) -> Dict[str, Any]: + class_matrix: dict[int, int], dataset_params: dict[str, Any], target_key: str +) -> dict[str, Any]: class_transform = { "ClassTransform": { - "module": "minerva.transforms", + "_target_": "minerva.transforms.ClassTransform", "transform": class_matrix, } } @@ -615,12 +580,13 @@ def _make_loader( dataset_params, sampler_params, dataloader_params, - collator_params, + collator_target, class_matrix, batch_size, model_type, elim, sample_pairs, + change_detection, cache, ): target_key = None @@ -640,10 +606,11 @@ def _make_loader( sampler_params, dataloader_params, batch_size, - collator_params=collator_params, + collator_target=collator_target, rank=rank, world_size=world_size, sample_pairs=sample_pairs, + change_detection=change_detection, cache=cache, cache_dir=cache_dir, ) @@ -651,9 +618,9 @@ def _make_loader( # Calculates number of batches. assert hasattr(loaders.dataset, "__len__") n_batches = int( - sampler_params["params"].get( + sampler_params.get( "length", - sampler_params["params"].get("num_samples", len(loaders.dataset)), + sampler_params.get("num_samples", len(loaders.dataset)), ) / batch_size ) @@ -667,11 +634,11 @@ def make_loaders( p_dist: bool = False, task_name: Optional[str] = None, **params, -) -> Tuple[ - Union[Dict[str, DataLoader[Iterable[Any]]], DataLoader[Iterable[Any]]], - Union[Dict[str, int], int], - List[Tuple[int, int]], - Dict[Any, Any], +) -> tuple[ + dict[str, DataLoader[Iterable[Any]]] | DataLoader[Iterable[Any]], + dict[str, int] | int, + list[tuple[int, int]], + dict[Any, Any], ]: """Constructs train, validation and test datasets and places into :class:`~torch.utils.data.DataLoader` objects. @@ -696,9 +663,7 @@ def make_loaders( sampler_params (dict[str, ~typing.Any]): Parameters to construct the samplers for each mode of model fitting. transform_params (dict[str, ~typing.Any]): Parameters to construct the transforms for each dataset. See documentation for the structure of these. - collator (dict[str, ~typing.Any]): Defines the collator to use that will collate samples together into batches. - Contains the ``module`` key to define the import path and the ``name`` key - for name of the collation function. + collator (str): Optional; Defines the collator to use that will collate samples together into batches. sample_pairs (bool): Activates paired sampling for Siamese models. Only used for ``train`` datasets. Returns: @@ -717,14 +682,14 @@ def make_loaders( cache_dir = params["cache_dir"] # Gets out the parameters for the DataLoaders from params. - dataloader_params: Dict[Any, Any] = deepcopy( + dataloader_params: dict[Any, Any] = deepcopy( utils.fallback_params("loader_params", task_params, params) ) if OmegaConf.is_config(dataloader_params): dataloader_params = OmegaConf.to_object(dataloader_params) # type: ignore[assignment] - dataset_params: Dict[str, Any] = utils.fallback_params( + dataset_params: dict[str, Any] = utils.fallback_params( "dataset_params", task_params, params ) @@ -734,7 +699,7 @@ def make_loaders( batch_size: int = utils.fallback_params("batch_size", task_params, params) model_type = utils.fallback_params("model_type", task_params, params) - class_dist: List[Tuple[int, int]] = [(0, 0)] + class_dist: list[tuple[int, int]] = [(0, 0)] classes = utils.fallback_params("classes", data_config, params, None) cmap_dict = utils.fallback_params("colours", data_config, params, None) @@ -745,28 +710,26 @@ def make_loaders( if n_classes: classes = {i: f"class {i}" for i in range(n_classes)} - new_classes: Dict[int, str] = {} - new_colours: Dict[int, str] = {} - class_matrix: Dict[int, int] = {} + new_classes: dict[int, str] = {} + new_colours: dict[int, str] = {} + class_matrix: dict[int, int] = {} + + sample_pairs = utils.fallback_params("sample_pairs", task_params, params, False) - sample_pairs: Union[bool, Any] = utils.fallback_params( - "sample_pairs", task_params, params, False + change_detection = utils.fallback_params( + "change_detection", task_params, params, False ) - if not isinstance(sample_pairs, bool): # pragma: no cover - sample_pairs = False elim = utils.fallback_params("elim", task_params, params, False) cache = utils.fallback_params("cache_dataset", task_params, params, True) - n_batches: Union[Dict[str, int], int] - loaders: Union[Dict[str, DataLoader[Iterable[Any]]], DataLoader[Iterable[Any]]] + n_batches: dict[str, int] | int + loaders: dict[str, DataLoader[Iterable[Any]]] | DataLoader[Iterable[Any]] - collator_params = deepcopy(utils.fallback_params("collator", task_params, params)) - if OmegaConf.is_config(collator_params): - collator_params = OmegaConf.to_object(collator_params) + collator_target = utils.fallback_params("collator", task_params, params, None) if "sampler" in dataset_params.keys(): - sampler_params: Dict[str, Any] = dataset_params["sampler"] + sampler_params: dict[str, Any] = dataset_params["sampler"] if not utils.check_substrings_in_string(model_type, "siamese"): new_classes, class_matrix, new_colours, class_dist = get_data_specs( @@ -778,7 +741,8 @@ def make_loaders( dataset_params, sampler_params, dataloader_params, - collator_params, + collator_target, + change_detection=change_detection, elim=elim, ) @@ -791,12 +755,13 @@ def make_loaders( dataset_params, sampler_params, dataloader_params, - collator_params, + collator_target, class_matrix, batch_size, model_type, elim=elim, sample_pairs=sample_pairs, + change_detection=change_detection, cache=cache, ) @@ -806,7 +771,7 @@ def make_loaders( loaders = {} for mode in dataset_params.keys(): - mode_sampler_params: Dict[str, Any] = dataset_params[mode]["sampler"] + mode_sampler_params: dict[str, Any] = dataset_params[mode]["sampler"] if ( not utils.check_substrings_in_string(model_type, "siamese") @@ -821,7 +786,8 @@ def make_loaders( dataset_params[mode], mode_sampler_params, dataloader_params, - collator_params, + collator_target, + change_detection=change_detection, elim=elim, ) @@ -835,12 +801,13 @@ def make_loaders( dataset_params[mode], mode_sampler_params, dataloader_params, - collator_params, + collator_target, class_matrix, batch_size, model_type, elim=elim, sample_pairs=sample_pairs if mode == "train" else False, + change_detection=change_detection, cache=cache, ) @@ -883,15 +850,16 @@ def make_loaders( def get_data_specs( - manifest_name: Union[str, Path], - classes: Dict[int, str], - cmap_dict: Dict[int, str], - cache_dir: Optional[Union[str, Path]] = None, - data_dir: Optional[Union[str, Path]] = None, - dataset_params: Optional[Dict[str, Any]] = None, - sampler_params: Optional[Dict[str, Any]] = None, - dataloader_params: Optional[Dict[str, Any]] = None, - collator_params: Optional[Dict[str, Any]] = None, + manifest_name: str | Path, + classes: dict[int, str], + cmap_dict: dict[int, str], + cache_dir: Optional[str | Path] = None, + data_dir: Optional[str | Path] = None, + dataset_params: Optional[dict[str, Any]] = None, + sampler_params: Optional[dict[str, Any]] = None, + dataloader_params: Optional[dict[str, Any]] = None, + collator_target: str = "torchgeo.datasets.stack_samples", + change_detection: bool = False, elim: bool = True, ): # Load manifest from cache for this dataset. @@ -901,7 +869,8 @@ def get_data_specs( dataset_params, sampler_params, dataloader_params, - collator_params=collator_params, + collator_target=collator_target, + change_detection=change_detection, ) class_dist = utils.modes_from_manifest(manifest, classes) @@ -923,12 +892,13 @@ def get_data_specs( def get_manifest( - manifest_path: Union[str, Path], - data_dir: Optional[Union[str, Path]] = None, - dataset_params: Optional[Dict[str, Any]] = None, - sampler_params: Optional[Dict[str, Any]] = None, - loader_params: Optional[Dict[str, Any]] = None, - collator_params: Optional[Dict[str, Any]] = None, + manifest_path: str | Path, + data_dir: Optional[str | Path] = None, + dataset_params: Optional[dict[str, Any]] = None, + sampler_params: Optional[dict[str, Any]] = None, + loader_params: Optional[dict[str, Any]] = None, + collator_target: str = "torchgeo.datasets.stack_samples", + change_detection: bool = False, ) -> DataFrame: """Attempts to return the :class:`~pandas.DataFrame` located at ``manifest_path``. @@ -948,6 +918,8 @@ def get_manifest( manifest_path (str | ~pathlib.Path): Path (including filename and extension) to the manifest saved as a ``csv``. task_name (str): Optional; Name of the task to which the dataset to create a manifest of belongs to. + change_detection (bool): Flag for if using a change detection dataset which has + ``"image1"`` and ``"image2"`` keys rather than ``"image"``. Returns: ~pandas.DataFrame: Manifest either loaded from ``manifest_path`` or created from parameters in :data:`CONFIG`. @@ -969,7 +941,8 @@ def get_manifest( dataset_params, sampler_params, loader_params, - collator_params=collator_params, + collator_target=collator_target, + change_detection=change_detection, ) print(f"MANIFEST TO FILE -----> {manifest_path}") @@ -983,11 +956,12 @@ def get_manifest( def make_manifest( - data_dir: Union[str, Path], - dataset_params: Dict[str, Any], - sampler_params: Dict[str, Any], - loader_params: Dict[str, Any], - collator_params: Optional[Dict[str, Any]] = None, + data_dir: str | Path, + dataset_params: dict[str, Any], + sampler_params: dict[str, Any], + loader_params: dict[str, Any], + collator_target: str = "torchgeo.datasets.stack_samples", + change_detection: bool = False, ) -> DataFrame: """Constructs a manifest of the dataset detailing each sample therein. @@ -996,12 +970,14 @@ def make_manifest( Args: mf_config (dict[~typing.Any, ~typing.Any]): Config to use to construct the manifest with. task_name (str): Optional; Name of the task to which the dataset to create a manifest of belongs to. + change_detection (bool): Flag for if using a change detection dataset which has + ``"image1"`` and ``"image2"`` keys rather than ``"image"``. Returns: ~pandas.DataFrame: The completed manifest as a :class:`~pandas.DataFrame`. """ - def delete_transforms(params: Dict[str, Any]) -> None: + def delete_transforms(params: dict[str, Any]) -> None: assert target_key is not None if params[target_key] is None: return @@ -1025,24 +1001,22 @@ def delete_transforms(params: Dict[str, Any]) -> None: if OmegaConf.is_config(_dataset_params): _dataset_params = OmegaConf.to_object(_dataset_params) # type: ignore[assignment] - if _sampler_params["name"] in ( + if _sampler_params["_target_"].rpartition(".")[2] in ( "RandomGeoSampler", "RandomPairGeoSampler", "RandomBatchGeoSampler", "RandomPairBatchGeoSampler", ): - _sampler_params["module"] = "torchgeo.samplers" - _sampler_params["name"] = "GridGeoSampler" - _sampler_params["params"]["stride"] = [ - 0.9 * x for x in _to_tuple(_sampler_params["params"]["size"]) + _sampler_params["_target_"] = "torchgeo.samplers.GridGeoSampler" + _sampler_params["stride"] = [ + 0.9 * x for x in _to_tuple(_sampler_params["size"]) ] - if _sampler_params["name"] == "RandomSampler": - _sampler_params["name"] = "SequentialSampler" - _sampler_params["params"] = {} + if _sampler_params["_target_"].rpartition(".")[2] == "RandomSampler": + _sampler_params = {"_target_": "torch.utils.data.sampler.SequentialSampler"} - if "length" in _sampler_params["params"]: - del _sampler_params["params"]["length"] + if "length" in _sampler_params: + del _sampler_params["length"] # Ensure there are no errant `ClassTransform` transforms in the parameters from previous runs. # A `ClassTransform` can only be defined with a correct manifest so we cannot use an old one to @@ -1063,7 +1037,8 @@ def delete_transforms(params: Dict[str, Any]) -> None: _sampler_params, loader_params, batch_size=1, # To prevent issues with stacking different sized patches, set batch size to 1. - collator_params=collator_params, + collator_target=collator_target, + change_detection=change_detection, cache=False, ) diff --git a/minerva/datasets/multispectral.py b/minerva/datasets/multispectral.py index 9ece53eb2..de2749845 100644 --- a/minerva/datasets/multispectral.py +++ b/minerva/datasets/multispectral.py @@ -56,14 +56,14 @@ class MultiSpectralDataset(VisionDataset, MinervaNonGeoDataset): """Generic dataset class for multi-spectral images that works within :mod:`torchgeo`""" - all_bands: List[str] = [] - rgb_bands: List[str] = [] + all_bands: tuple[str, ...] = () + rgb_bands: tuple[str, ...] = () def __init__( self, root: str, transforms: Optional[Callable[..., Any]] = None, - bands: Optional[List[str]] = None, + bands: Optional[tuple[str, ...]] = None, as_type=np.float32, ) -> None: super().__init__(root, transform=transforms, target_transform=None) diff --git a/minerva/datasets/paired.py b/minerva/datasets/paired.py index 278069561..5b2b4e53f 100644 --- a/minerva/datasets/paired.py +++ b/minerva/datasets/paired.py @@ -44,8 +44,9 @@ # ===================================================================================================================== import random from inspect import signature -from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union, overload +from typing import Any, Callable, Optional, Sequence, Union, overload +import hydra import matplotlib.pyplot as plt import torch from matplotlib.figure import Figure @@ -82,7 +83,7 @@ class PairedGeoDataset(RasterDataset): def __new__( # type: ignore[misc] cls, - dataset: Union[Callable[..., GeoDataset], GeoDataset], + dataset: Callable[..., GeoDataset] | GeoDataset, *args, **kwargs, ) -> Union["PairedGeoDataset", "PairedUnionDataset"]: @@ -106,9 +107,12 @@ def __init__( self, dataset: GeoDataset, *args, **kwargs ) -> None: ... # pragma: no cover + @overload + def __init__(self, dataset: str, *args, **kwargs) -> None: ... # pragma: no cover + def __init__( self, - dataset: Union[Callable[..., GeoDataset], GeoDataset], + dataset: Callable[..., GeoDataset] | GeoDataset | str, *args, **kwargs, ) -> None: @@ -116,6 +120,9 @@ def __init__( self._args = args self._kwargs = kwargs + if isinstance(dataset, str): + dataset = hydra.utils.get_method(dataset) + if isinstance(dataset, GeoDataset): self.dataset = dataset self._res = dataset.res @@ -141,8 +148,8 @@ def __init__( ) def __getitem__( # type: ignore[override] - self, queries: Tuple[BoundingBox, BoundingBox] - ) -> Tuple[Dict[str, Any], ...]: + self, queries: tuple[BoundingBox, BoundingBox] + ) -> tuple[dict[str, Any], ...]: return self.dataset.__getitem__(queries[0]), self.dataset.__getitem__( queries[1] ) @@ -170,11 +177,11 @@ def __and__(self, other: "PairedGeoDataset") -> IntersectionDataset: # type: ig self, other, collate_fn=utils.pair_collate(concat_samples) ) - def __or__(self, other: "PairedGeoDataset") -> "PairedUnionDataset": # type: ignore[override] - """Take the union of two :class:`PairedGeoDataset`. + def __or__(self, other: GeoDataset) -> "PairedUnionDataset": # type: ignore[override] + """Take the union of two :class:~`torchgeo.datasets.GeoDataset`. Args: - other (PairedGeoDataset): Another dataset. + other (~torchgeo.datasets.GeoDataset): Another dataset. Returns: PairedUnionDataset: A single dataset. @@ -183,20 +190,42 @@ def __or__(self, other: "PairedGeoDataset") -> "PairedUnionDataset": # type: ig """ return PairedUnionDataset(self, other) - def __getattr__(self, item): - if item in self.dataset.__dict__: - return getattr(self.dataset, item) # pragma: no cover - elif item in self.__dict__: - return getattr(self, item) - else: - raise AttributeError + def __getattr__(self, name: str) -> Any: + """ + Called only if the attribute 'name' is not found by usual means. + Checks if 'name' exists in the dataset attribute. + """ + # Instead of calling __getattr__ directly, access the dataset attribute directly + dataset = super().__getattribute__("dataset") + if hasattr(dataset, name): + return getattr(dataset, name) + # If not found in dataset, raise an AttributeError + raise AttributeError(f"{name} cannot be found in self or dataset") + + def __getattribute__(self, name: str) -> Any: + """ + Overrides default attribute access method to prevent recursion. + Checks 'self' first and uses __getattr__ for fallback. + """ + try: + # First, try to get the attribute from the current instance + return super().__getattribute__(name) + except AttributeError: + # If not found in self, __getattr__ will check in the dataset + dataset = super().__getattribute__("dataset") + if hasattr(dataset, name): + return getattr(dataset, name) + # Raise AttributeError if not found in dataset either + raise AttributeError( + f"'{self.__class__.__name__}' object has no attribute '{name}'" + ) def __repr__(self) -> Any: return self.dataset.__repr__() @staticmethod def plot( - sample: Dict[str, Any], + sample: dict[str, Any], show_titles: bool = True, suptitle: Optional[str] = None, ) -> Figure: @@ -235,7 +264,7 @@ def plot( def plot_random_sample( self, - size: Union[Tuple[int, int], int], + size: tuple[int, int] | int, res: float, show_titles: bool = True, suptitle: Optional[str] = None, @@ -277,6 +306,13 @@ def __init__( ] = merge_samples, transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None, ) -> None: + + # Extract the actual dataset out of the paired dataset (otherwise we'll be pairing the datasets twice!) + if isinstance(dataset1, PairedGeoDataset): + dataset1 = dataset1.dataset + if isinstance(dataset2, PairedGeoDataset): + dataset2 = dataset2.dataset + super().__init__(dataset1, dataset2, collate_fn, transforms) new_datasets = [] @@ -291,8 +327,8 @@ def __init__( self.datasets = new_datasets def __getitem__( # type: ignore[override] - self, query: Tuple[BoundingBox, BoundingBox] - ) -> Tuple[Dict[str, Any], Dict[str, Any]]: + self, query: tuple[BoundingBox, BoundingBox] + ) -> tuple[dict[str, Any], dict[str, Any]]: """Retrieve image and metadata indexed by query. Uses :meth:`torchgeo.datasets.UnionDataset.__getitem__` to send each query of the pair off to get a @@ -320,6 +356,16 @@ def __or__(self, other: "PairedGeoDataset") -> "PairedUnionDataset": # type: ig """ return PairedUnionDataset(self, other) + def __getattr__(self, name: str) -> Any: + if name in self.__dict__: + return getattr(self, name) + elif name in self.datasets[0].__dict__: + return getattr(self.datasets[0], name) # pragma: no cover + elif name in self.datasets[1].__dict__: + return getattr(self.datasets[1], name) # pragma: no cover + else: + raise AttributeError + class PairedNonGeoDataset(NonGeoDataset): """Custom dataset to act as a wrapper to other datasets to handle paired sampling. @@ -355,7 +401,7 @@ def __getnewargs__(self): def __init__( self, dataset: Callable[..., NonGeoDataset], - size: Union[Tuple[int, int], int], + size: tuple[int, int] | int, max_r: int, season: bool = False, *args, @@ -366,7 +412,18 @@ def __init__( def __init__( self, dataset: NonGeoDataset, - size: Union[Tuple[int, int], int], + size: tuple[int, int] | int, + max_r: int, + season: bool = False, + *args, + **kwargs, + ) -> None: ... # pragma: no cover + + @overload + def __init__( + self, + dataset: str, + size: tuple[int, int] | int, max_r: int, season: bool = False, *args, @@ -375,8 +432,8 @@ def __init__( def __init__( self, - dataset: Union[Callable[..., NonGeoDataset], NonGeoDataset], - size: Union[Tuple[int, int], int], + dataset: Callable[..., NonGeoDataset] | NonGeoDataset | str, + size: tuple[int, int] | int, max_r: int, season: bool = False, *args, @@ -396,6 +453,9 @@ def __init__( if isinstance(dataset, PairedNonGeoDataset): raise ValueError("Cannot pair an already paired dataset!") + if isinstance(dataset, str): + dataset = hydra.utils.get_method(dataset) + if isinstance(dataset, NonGeoDataset): self.dataset = dataset @@ -420,7 +480,7 @@ def __init__( self.make_geo_pair = SamplePair(self.size, self.max_r, season=season) - def __getitem__(self, index: int) -> Tuple[Dict[str, Any], ...]: # type: ignore[override] + def __getitem__(self, index: int) -> tuple[dict[str, Any], ...]: # type: ignore[override] patch = self.dataset[index] image_a, image_b = self.make_geo_pair(patch["image"]) @@ -450,9 +510,39 @@ def __len__(self) -> int: def __repr__(self) -> Any: return self.dataset.__repr__() + def __getattr__(self, name: str) -> Any: + """ + Called only if the attribute 'name' is not found by usual means. + Checks if 'name' exists in the dataset attribute. + """ + # Instead of calling __getattr__ directly, access the dataset attribute directly + dataset = super().__getattribute__("dataset") + if hasattr(dataset, name): + return getattr(dataset, name) + # If not found in dataset, raise an AttributeError + raise AttributeError(f"{name} cannot be found in self or dataset") + + def __getattribute__(self, name: str) -> Any: + """ + Overrides default attribute access method to prevent recursion. + Checks 'self' first and uses __getattr__ for fallback. + """ + try: + # First, try to get the attribute from the current instance + return super().__getattribute__(name) + except AttributeError: + # If not found in self, __getattr__ will check in the dataset + dataset = super().__getattribute__("dataset") + if hasattr(dataset, name): + return getattr(dataset, name) + # Raise AttributeError if not found in dataset either + raise AttributeError( + f"'{self.__class__.__name__}' object has no attribute '{name}'" + ) + @staticmethod def plot( - sample: Dict[str, Any], + sample: dict[str, Any], show_titles: bool = True, suptitle: Optional[str] = None, ) -> Figure: @@ -546,7 +636,7 @@ def __init__( super().__init__(datasets) - def __getitem__(self, index: int) -> Tuple[Dict[str, Any], Dict[str, Any]]: + def __getitem__(self, index: int) -> tuple[dict[str, Any], dict[str, Any]]: """Retrieve image and metadata indexed by query. Uses :meth:`torch.utils.data.ConcatDataset.__getitem__` to get the pair of samples from @@ -604,7 +694,7 @@ def __init__(self, size: int = 64, max_r: int = 20, season: bool = False) -> Non # Transform to cut samples out at the desired output size. self.random_crop = RandomCrop(self.size) - def __call__(self, x: Tensor) -> Tuple[Tensor, Tensor]: + def __call__(self, x: Tensor) -> tuple[Tensor, Tensor]: if self.season: max_width, h, w = self._find_max_width(x[0]) @@ -631,7 +721,7 @@ def __call__(self, x: Tensor) -> Tuple[Tensor, Tensor]: # Now cut out 2 random samples from within that sampling area and return. return self.random_crop(sampling_area), self.random_crop(sampling_area) - def _find_max_width(self, x: Tensor) -> Tuple[int, int, int]: + def _find_max_width(self, x: Tensor) -> tuple[int, int, int]: max_width = self.max_width w = x.shape[-1] @@ -647,7 +737,7 @@ def _find_max_width(self, x: Tensor) -> Tuple[int, int, int]: return max_width, h, w @staticmethod - def _get_random_crop_params(img: Tensor, max_width: int) -> Tuple[int, int]: + def _get_random_crop_params(img: Tensor, max_width: int) -> tuple[int, int]: i = torch.randint(0, img.shape[-1] - max_width + 1, size=(1,)).item() j = torch.randint(0, img.shape[-2] - max_width + 1, size=(1,)).item() diff --git a/minerva/datasets/ssl4eos12.py b/minerva/datasets/ssl4eos12.py index 74849618c..3e6a447be 100644 --- a/minerva/datasets/ssl4eos12.py +++ b/minerva/datasets/ssl4eos12.py @@ -32,7 +32,7 @@ # ===================================================================================================================== import os import pickle -from typing import Dict, List, Optional, Tuple, Union +from typing import Optional import cv2 import lmdb @@ -69,7 +69,7 @@ class GeoSSL4EOS12Sentinel2(Sentinel2): filename_glob = "{}.*" filename_regex = r"""(?PB[^[0-1]?[0-9]|B[^[1]?[0-9][\dA])\..*$""" date_format = "" - all_bands = [ + all_bands: tuple[str, ...] = ( "B1", "B2", "B3", @@ -82,12 +82,12 @@ class GeoSSL4EOS12Sentinel2(Sentinel2): "B9", "B11", "B12", - ] - rgb_bands = ["B4", "B3", "B2"] + ) + rgb_bands: tuple[str, str, str] = ("B4", "B3", "B2") class NonGeoSSL4EOS12Sentinel2(MultiSpectralDataset): - all_bands = [ + all_bands: tuple[str, ...] = ( "B1", "B2", "B3", @@ -100,8 +100,8 @@ class NonGeoSSL4EOS12Sentinel2(MultiSpectralDataset): "B9", "B11", "B12", - ] - rgb_bands = ["B4", "B3", "B2"] + ) + rgb_bands: tuple[str, str, str] = ("B4", "B3", "B2") class MinervaSSL4EO(VisionDataset, MinervaNonGeoDataset): @@ -110,7 +110,7 @@ class MinervaSSL4EO(VisionDataset, MinervaNonGeoDataset): Source: https://github.com/zhu-xlab/SSL4EO-S12/tree/main/src/benchmark/pretrain_ssl/datasets/SSL4EO """ - ALL_BANDS_S2_L2A = [ + ALL_BANDS_S2_L2A: tuple[str, ...] = ( "B1", "B2", "B3", @@ -123,8 +123,8 @@ class MinervaSSL4EO(VisionDataset, MinervaNonGeoDataset): "B9", "B11", "B12", - ] - ALL_BANDS_S2_L1C = [ + ) + ALL_BANDS_S2_L1C: tuple[str, ...] = ( "B1", "B2", "B3", @@ -138,9 +138,9 @@ class MinervaSSL4EO(VisionDataset, MinervaNonGeoDataset): "B10", "B11", "B12", - ] - RGB_BANDS = ["B4", "B3", "B2"] - ALL_BANDS_S1_GRD = ["VV", "VH"] + ) + RGB_BANDS: tuple[str, str, str] = ("B4", "B3", "B2") + ALL_BANDS_S1_GRD: tuple[str, str] = ("VV", "VH") # Band statistics: mean & std # Calculated from 50k data @@ -215,7 +215,7 @@ def __init__( lmdb_file: Optional[str] = None, normalize: bool = False, mode: str = "s2a", - bands: Optional[List[str]] = None, + bands: Optional[tuple[str, ...]] = None, dtype: str = "uint8", is_slurm_job=False, transforms=None, @@ -263,7 +263,7 @@ def _init_db(self): with self.env.begin(write=False) as txn: # type: ignore[unreachable] self.length = txn.stat()["entries"] - def __getitem__(self, index: int) -> Dict[str, Union[Tuple[Tensor, ...], Tensor]]: + def __getitem__(self, index: int) -> dict[str, tuple[Tensor, ...] | Tensor]: if self.lmdb_file: if self.is_slurm_job: # Delay loading LMDB data until after initialization @@ -363,7 +363,9 @@ def __getitem__(self, index: int) -> Dict[str, Union[Tuple[Tensor, ...], Tensor] return {"image": img_4s} - def get_array(self, patch_id: str, mode: str, bands: Optional[List[str]] = None): + def get_array( + self, patch_id: str, mode: str, bands: Optional[tuple[str, ...]] = None + ): data_root_patch = os.path.join(self.root, mode, patch_id) patch_seasons = os.listdir(data_root_patch) seasons = [] diff --git a/minerva/datasets/utils.py b/minerva/datasets/utils.py index a2f4c86e0..c64f5e843 100644 --- a/minerva/datasets/utils.py +++ b/minerva/datasets/utils.py @@ -46,18 +46,7 @@ # ===================================================================================================================== import pickle from pathlib import Path -from typing import ( - Any, - Callable, - Dict, - Iterable, - List, - Literal, - Optional, - Sequence, - Tuple, - Union, -) +from typing import Any, Callable, Iterable, Literal, Optional, Sequence, Union import numpy as np from nptyping import NDArray @@ -113,6 +102,16 @@ def __or__( """ return MinervaConcatDataset([self, other]) + def __getattr__(self, name: str) -> Any: + if name in self.__dict__: + return getattr(self, name) + elif name in self.datasets[0].__dict__: + return getattr(self.datasets[0], name) # pragma: no cover + elif name in self.datasets[1].__dict__: + return getattr(self.datasets[1], name) # pragma: no cover + else: + raise AttributeError + # ===================================================================================================================== # METHODS @@ -129,7 +128,7 @@ def intersect_datasets(datasets: Sequence[GeoDataset]) -> IntersectionDataset: ~torchgeo.datasets.IntersectionDataset: Final dataset object representing an intersection of all the parsed datasets. """ - master_dataset: Union[GeoDataset, IntersectionDataset] = datasets[0] + master_dataset: GeoDataset | IntersectionDataset = datasets[0] for i in range(len(datasets) - 1): master_dataset = master_dataset & datasets[i + 1] @@ -140,7 +139,7 @@ def intersect_datasets(datasets: Sequence[GeoDataset]) -> IntersectionDataset: def unionise_datasets( datasets: Sequence[GeoDataset], - transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None, + transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None, ) -> UnionDataset: """Unionises a list of :class:`~torchgeo.datasets.GeoDataset` together to return a single dataset object. @@ -156,7 +155,7 @@ def unionise_datasets( Returns: ~torchgeo.datasets.UnionDataset: Final dataset object representing an union of all the parsed datasets. """ - master_dataset: Union[GeoDataset, UnionDataset] = datasets[0] + master_dataset: GeoDataset | UnionDataset = datasets[0] for i in range(len(datasets) - 1): master_dataset = master_dataset | datasets[i + 1] @@ -168,7 +167,7 @@ def unionise_datasets( def concatenate_datasets( datasets: Sequence[NonGeoDataset], - transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None, + transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None, ) -> MinervaConcatDataset: """Unionises a list of :class:`~torchgeo.datasets.GeoDataset` together to return a single dataset object. @@ -185,20 +184,18 @@ def concatenate_datasets( ~minerva.datasets.MinervaConcatDataset: Final dataset object representing an union of all the parsed datasets. """ assert isinstance(datasets[0], MinervaNonGeoDataset) - master_dataset: Union[MinervaNonGeoDataset, MinervaConcatDataset] = datasets[0] + master_dataset: MinervaNonGeoDataset | MinervaConcatDataset = datasets[0] for i in range(len(datasets) - 1): master_dataset = master_dataset | datasets[i + 1] # type: ignore[operator] if hasattr(master_dataset, "transforms"): - master_dataset.transforms = transforms + master_dataset.transforms = transforms # type: ignore[union-attr] assert isinstance(master_dataset, MinervaConcatDataset) return master_dataset -def make_bounding_box( - roi: Union[Sequence[float], bool] = False -) -> Optional[BoundingBox]: +def make_bounding_box(roi: Sequence[float] | bool = False) -> Optional[BoundingBox]: """Construct a :class:`~torchgeo.datasets.utils.BoundingBox` object from the corners of the box. ``False`` for no :class:`~torchgeo.datasets.utils.BoundingBox`. @@ -236,7 +233,7 @@ def load_all_samples( ~numpy.ndarray: 2D array of the class modes within every sample defined by the parsed :class:`~torch.utils.data.DataLoader`. """ - sample_modes: List[List[Tuple[int, int]]] = [] + sample_modes: list[list[tuple[int, int]]] = [] for sample in tqdm(dataloader): modes = utils.find_modes(sample[target_key]) sample_modes.append(modes) @@ -245,8 +242,8 @@ def load_all_samples( def get_random_sample( - dataset: GeoDataset, size: Union[Tuple[int, int], int], res: float -) -> Dict[str, Any]: + dataset: GeoDataset, size: tuple[int, int] | int, res: float +) -> dict[str, Any]: """Gets a random sample from the provided dataset of size ``size`` and at ``res`` resolution. Args: @@ -262,7 +259,7 @@ def get_random_sample( def load_dataset_from_cache( cached_dataset_path: Path, -) -> Union[NonGeoDataset, GeoDataset]: +) -> NonGeoDataset | GeoDataset: """Load a pickled dataset object in from a cache. Args: @@ -280,7 +277,7 @@ def load_dataset_from_cache( def cache_dataset( - dataset: Union[GeoDataset, NonGeoDataset], cached_dataset_path: Path + dataset: GeoDataset | NonGeoDataset, cached_dataset_path: Path ) -> None: """Pickle and cache a dataset object. @@ -295,7 +292,7 @@ def cache_dataset( pickle.dump(dataset, fp) -def masks_or_labels(dataset_params: Dict[str, Any]) -> str: +def masks_or_labels(dataset_params: dict[str, Any]) -> str: for key in dataset_params.keys(): if key not in ( "sampler", diff --git a/minerva/inbuilt_cfgs/example_3rd_party.yaml b/minerva/inbuilt_cfgs/example_3rd_party.yaml index 4db007b28..1dcd35933 100644 --- a/minerva/inbuilt_cfgs/example_3rd_party.yaml +++ b/minerva/inbuilt_cfgs/example_3rd_party.yaml @@ -41,12 +41,11 @@ optim_func: SGD # Name of the optimiser function. # ---+ Model Parameters +------------------------------------------------------ model_params: - module: lightly.models - params: - input_size: ${input_size} - n_classes: ${n_classes} - num_classes: ${n_classes} - name: resnet-9 + _target_: lightly.models.resnet.ResNetGenerator + input_size: ${input_size} + n_classes: ${n_classes} + num_classes: ${n_classes} + name: resnet-9 # ---+ Optimiser Parameters +-------------------------------------------------- optimiser: @@ -68,24 +67,16 @@ wandb_log: true # Activates wandb logging. project: pytest # Define the project name for wandb. wandb_dir: /test/tmp/wandb # Directory to store wandb logs locally. -# ---+ Minerva Inbuilt Logging Functions +------------------------------------- -# task_logger: SupervisedTaskLogger -# step_logger: SupervisedGeoStepLogger -# model_io: sup_tg - record_int: true # Store integer results in memory. record_float: true # Store floating point results too. Beware memory overload! # ---+ Collator +-------------------------------------------------------------- -collator: - module: torchgeo.datasets - name: stack_samples +collator: torchgeo.datasets.stack_samples # === TASKS =================================================================== tasks: fit-train: - name: StandardEpoch - module: minerva.tasks + _target_: minerva.tasks.StandardEpoch train: true record_float: true @@ -95,45 +86,36 @@ tasks: # ---+ Dataset Parameters +-------------------------------------------- dataset_params: sampler: - module: torchgeo.samplers - name: RandomGeoSampler + _target_: torchgeo.samplers.RandomGeoSampler roi: false - params: - size: ${patch_size} - length: 120 + size: ${patch_size} + length: 120 image: transforms: ToRGB: - module: minerva.transforms - - images_1: - module: minerva.datasets.__testing - name: TstImgDataset - paths: NAIP - params: + _target_: minerva.transforms.ToRGB + subdatasets: + images_1: + _target_: minerva.datasets.__testing.TstImgDataset + paths: NAIP res: 1.0 - image2: - module: minerva.datasets.__testing - name: TstImgDataset - paths: NAIP - params: + image2: + _target_: minerva.datasets.__testing.TstImgDataset + paths: NAIP res: 1.0 mask: transforms: SingleLabel: - module: minerva.transforms - module: minerva.datasets.__testing - name: TstMaskDataset + _target_: minerva.transforms.SingleLabel + _target_: minerva.datasets.__testing.TstMaskDataset paths: Chesapeake7 - params: - res: 1.0 + res: 1.0 fit-val: - name: StandardEpoch - module: minerva.tasks + _target_: minerva.tasks.StandardEpoch train: false record_float: true @@ -143,36 +125,29 @@ tasks: # ---+ Dataset Parameters +-------------------------------------------- dataset_params: sampler: - module: torchgeo.samplers - name: RandomGeoSampler + _target_: torchgeo.samplers.RandomGeoSampler roi: false - params: - size: ${patch_size} - length: 32 + size: ${patch_size} + length: 32 image: transforms: ToRGB: - module: minerva.transforms - module: minerva.datasets.__testing - name: TstImgDataset + _target_: minerva.transforms.ToRGB + _target_: minerva.datasets.__testing.TstImgDataset paths: NAIP - params: - res: 1.0 + res: 1.0 mask: transforms: SingleLabel: - module: minerva.transforms - module: minerva.datasets.__testing - name: TstMaskDataset + _target_: minerva.transforms.SingleLabel + _target_: minerva.datasets.__testing.TstMaskDataset paths: Chesapeake7 - params: - res: 1.0 + res: 1.0 test-test: - name: StandardEpoch - module: minerva.tasks + _target_: minerva.tasks.StandardEpoch train: false record_float: true @@ -182,32 +157,26 @@ tasks: # ---+ Dataset Parameters +-------------------------------------------- dataset_params: sampler: - module: torchgeo.samplers - name: RandomGeoSampler + _target_: torchgeo.samplers.RandomGeoSampler roi: false - params: - size: ${patch_size} - length: 32 + size: ${patch_size} + length: 32 image: transforms: ToRGB: - module: minerva.transforms - module: minerva.datasets.__testing - name: TstImgDataset + _target_: minerva.transforms.ToRGB + _target_: minerva.datasets.__testing.TstImgDataset paths: NAIP - params: - res: 1.0 + res: 1.0 mask: transforms: SingleLabel: - module: minerva.transforms - module: minerva.datasets.__testing - name: TstMaskDataset + _target_: minerva.transforms.SingleLabel + _target_: minerva.datasets.__testing.TstMaskDataset paths: Chesapeake7 - params: - res: 1.0 + res: 1.0 # === PLOTTING OPTIONS ======================================================== plots: diff --git a/minerva/inbuilt_cfgs/example_CNN_config.yaml b/minerva/inbuilt_cfgs/example_CNN_config.yaml index 3936e48e4..2a5daf2dc 100644 --- a/minerva/inbuilt_cfgs/example_CNN_config.yaml +++ b/minerva/inbuilt_cfgs/example_CNN_config.yaml @@ -24,7 +24,7 @@ model_type: scene-classifier batch_size: 8 # Number of samples in each batch. input_size: [4, 32, 32] # patch_size plus leading channel dim. patch_size: '${to_patch_size: ${input_size}}' # 2D tuple or float. -n_classes: &n_classes 8 # Number of classes in dataset. +n_classes: 8 # Number of classes in dataset. # ---+ Experiment Execution +-------------------------------------------------- max_epochs: 3 # Maximum number of training epochs. @@ -41,10 +41,9 @@ optim_func: SGD # Name of the optimiser function. # ---+ Model Parameters +------------------------------------------------------ model_params: - module: - params: - input_size: ${input_size} - n_classes: *n_classes + _target_: minerva.models.CNN + input_size: ${input_size} + n_classes: ${n_classes} # ---+ Optimiser Parameters +-------------------------------------------------- optimiser: @@ -67,22 +66,19 @@ wandb_dir: /test/tmp/wandb # Directory to store wandb logs locally. # === MODEL IO & LOGGING ====================================================== # ---+ Minerva Inbuilt Logging Functions +------------------------------------- -task_logger: SupervisedTaskLogger -model_io: sup_tg +task_logger: minerva.logger.tasklog.SupervisedTaskLogger +model_io: minerva.modelio.supervised_torchgeo_io record_int: true # Store integer results in memory. record_float: false # Store floating point results too. Beware memory overload! # ---+ Collator +-------------------------------------------------------------- -collator: - module: torchgeo.datasets - name: stack_samples +collator: torchgeo.datasets.stack_samples # === TASKS =================================================================== tasks: fit-train: - name: StandardEpoch - module: minerva.tasks + _target_: minerva.tasks.StandardEpoch train: true record_float: true @@ -91,37 +87,30 @@ tasks: dataset_params: sampler: - module: torchgeo.samplers - name: RandomGeoSampler + _target_: torchgeo.samplers.RandomGeoSampler roi: false - params: - size: ${patch_size} - length: 320 + size: ${patch_size} + length: 320 image: transforms: Normalise: - module: minerva.transforms + _target_: minerva.transforms.Normalise norm_value: 255 - module: minerva.datasets.__testing - name: TstImgDataset + _target_: minerva.datasets.__testing.TstImgDataset paths: NAIP - params: - res: 1.0 + res: 1.0 mask: transforms: SingleLabel: - module: minerva.transforms + _target_: minerva.transforms.SingleLabel mode: modal - module: minerva.datasets.__testing - name: TstMaskDataset + _target_: minerva.datasets.__testing.TstMaskDataset paths: Chesapeake7 - params: - res: 1.0 + res: 1.0 fit-val: - name: StandardEpoch - module: minerva.tasks + _target_: minerva.tasks.StandardEpoch train: false record_float: true @@ -130,38 +119,31 @@ tasks: dataset_params: sampler: - module: torchgeo.samplers - name: RandomGeoSampler + _target_: torchgeo.samplers.RandomGeoSampler roi: false - params: - size: ${patch_size} - length: 80 + size: ${patch_size} + length: 80 image: transforms: Normalise: - module: minerva.transforms + _target_: minerva.transforms.Normalise norm_value: 255 - module: minerva.datasets.__testing - name: TstImgDataset + _target_: minerva.datasets.__testing.TstImgDataset paths: NAIP - params: - res: 1.0 + res: 1.0 mask: transforms: SingleLabel: - module: minerva.transforms + _target_: minerva.transforms.SingleLabel mode: modal - module: minerva.datasets.__testing - name: TstMaskDataset + _target_: minerva.datasets.__testing.TstMaskDataset paths: Chesapeake7 - params: - res: 1.0 + res: 1.0 test-test: - name: StandardEpoch - module: minerva.tasks + _target_: minerva.tasks.StandardEpoch train: false record_float: true @@ -170,34 +152,28 @@ tasks: dataset_params: sampler: - module: torchgeo.samplers - name: RandomGeoSampler + _target_: torchgeo.samplers.RandomGeoSampler roi: false - params: - size: ${patch_size} - length: 160 + size: ${patch_size} + length: 160 image: transforms: Normalise: - module: minerva.transforms + _target_: minerva.transforms.Normalise norm_value: 255 - module: minerva.datasets.__testing - name: TstImgDataset + _target_: minerva.datasets.__testing.TstImgDataset paths: NAIP - params: - res: 1.0 + res: 1.0 mask: transforms: SingleLabel: - module: minerva.transforms + _target_: minerva.transforms.SingleLabel mode: modal - module: minerva.datasets.__testing - name: TstMaskDataset + _target_: minerva.datasets.__testing.TstMaskDataset paths: Chesapeake7 - params: - res: 1.0 + res: 1.0 # === PLOTTING OPTIONS ======================================================== plots: diff --git a/minerva/inbuilt_cfgs/example_ChangeDetector.yaml b/minerva/inbuilt_cfgs/example_ChangeDetector.yaml index 244b3a7e5..e060b3c97 100644 --- a/minerva/inbuilt_cfgs/example_ChangeDetector.yaml +++ b/minerva/inbuilt_cfgs/example_ChangeDetector.yaml @@ -34,6 +34,7 @@ pre_train: false # Activate pre-training mode. fine_tune: false # Activate fine-tuning mode. torch_compile: false # Wrap model in `torch.compile`. sample_pairs: false +change_detection: true # ---+ Loss and Optimisers +--------------------------------------------------- loss_func: CrossEntropyLoss # Name of the loss function to use. @@ -42,25 +43,24 @@ optim_func: Adam # Name of the optimiser function. # ---+ Model Parameters +------------------------------------------------------ model_params: - module: - params: + _target_: minerva.models.ChangeDetector + input_size: ${input_size} + n_classes: ${n_classes} + encoder_on: true + filter_dim: -1 + fc_dim: 512 + freeze_backbone: false + backbone_args: + module: minerva.models + name: MinervaPSP input_size: ${input_size} - n_classes: ${n_classes} - encoder_on: true - filter_dim: -1 - fc_dim: 512 - freeze_backbone: false - backbone_args: - module: minerva.models - name: MinervaPSP - input_size: ${input_size} - n_classes: 1 - encoder_name: resnet18 - encoder_weights: - psp_out_channels: 512 - segmentation_on: false - classification_on: false - encoder: false + n_classes: 1 + encoder_name: resnet18 + encoder_weights: + psp_out_channels: 512 + segmentation_on: false + classification_on: false + encoder: false # ---+ Optimiser Parameters +-------------------------------------------------- optimiser: @@ -83,20 +83,17 @@ wandb_dir: /test/tmp/wandb # Directory to store wandb logs locally. # === MODEL IO & LOGGING ====================================================== # ---+ Minerva Inbuilt Logging Functions +------------------------------------- - +model_io: minerva.modelio.change_detection_io record_int: true # Store integer results in memory. record_float: false # Store floating point results too. Beware memory overload! # ---+ Collator +-------------------------------------------------------------- -collator: - module: torchgeo.datasets - name: stack_samples +collator: torchgeo.datasets.stack_samples # === TASKS =================================================================== tasks: fit-train: - name: StandardEpoch - module: minerva.tasks + _target_: minerva.tasks.StandardEpoch train: true record_float: true sample_pairs: false @@ -106,51 +103,35 @@ tasks: # ---+ Dataset Parameters +-------------------------------------------- dataset_params: sampler: - module: torch.utils.data - name: RandomSampler - params: - num_samples: 32 + _target_: torch.utils.data.RandomSampler + num_samples: 32 transforms: RandomCrop: - module: kornia.augmentation + _target_: kornia.augmentation.RandomCrop size: ${patch_size} keepdim: true image: transforms: Normalise: - module: minerva.transforms + _target_: minerva.transforms.Normalise norm_value: 4095 - SelectChannels: - module: minerva.transforms - channels: - - 1 - - 2 - - 3 - - 7 - - 14 - - 15 - - 16 - - 20 - - module: torchgeo.datasets - name: OSCD - paths: OSCD - params: - split: train - bands: all - download: true + + _target_: torchgeo.datasets.OSCD + root: OSCD + split: train + bands: [B02, B03, B04, B08] + download: true mask: transforms: SingleLabel: - module: minerva.transforms + _target_: minerva.transforms.SingleLabel mode: centre test-test: - name: StandardEpoch - module: minerva.tasks + _target_: minerva.tasks.StandardEpoch train: false record_float: true sample_pairs: false @@ -160,46 +141,31 @@ tasks: # ---+ Dataset Parameters +-------------------------------------------- dataset_params: sampler: - module: torch.utils.data - name: RandomSampler - params: - num_samples: 16 + _target_: torch.utils.data.RandomSampler + num_samples: 16 transforms: RandomCrop: - module: kornia.augmentation + _target_: kornia.augmentation.RandomCrop size: ${patch_size} keepdim: true image: transforms: Normalise: - module: minerva.transforms + _target_: minerva.transforms.Normalise norm_value: 4095 - SelectChannels: - module: minerva.transforms - channels: - - 1 - - 2 - - 3 - - 7 - - 14 - - 15 - - 16 - - 20 - - module: torchgeo.datasets - name: OSCD - paths: OSCD - params: - split: test - bands: all - download: true + + _target_: torchgeo.datasets.OSCD + root: OSCD + split: test + bands: [B02, B03, B04, B08] + download: true mask: transforms: SingleLabel: - module: minerva.transforms + _target_: minerva.transforms.SingleLabel mode: centre # === PLOTTING OPTIONS ======================================================== diff --git a/minerva/inbuilt_cfgs/example_GSConvNet-II.yaml b/minerva/inbuilt_cfgs/example_GSConvNet-II.yaml index d04bd52b1..9d2467358 100644 --- a/minerva/inbuilt_cfgs/example_GSConvNet-II.yaml +++ b/minerva/inbuilt_cfgs/example_GSConvNet-II.yaml @@ -49,10 +49,9 @@ val_freq: 1 # Validation epoch every ``val_freq`` training epochs. # ---+ Model Parameters +------------------------------------------------------ model_params: - module: - params: - input_size: ${input_size} - encoder_weights: imagenet + _target_: minerva.models.SimConv + input_size: ${input_size} + encoder_weights: imagenet # ---+ Optimiser Parameters +-------------------------------------------------- optimiser: @@ -92,160 +91,140 @@ record_int: true # Store integer results in memory. record_float: false # Store floating point results too. Beware memory overload! # ---+ Collator +-------------------------------------------------------------- -collator: - module: torchgeo.datasets - name: stack_samples +collator: torchgeo.datasets.stack_samples # === TASKS =================================================================== tasks: fit-train: - name: StandardEpoch - module: minerva.tasks + _target_: minerva.tasks.StandardEpoch train: true record_float: true sample_pairs: true # ---+ Minerva Inbuilt Logging Functions +------------------------- - task_logger: SSLTaskLogger - model_io: ssl_pair_tg + task_logger: minerva.logger.tasklog.SSLTaskLogger + model_io: minerva.modelio.ssl_pair_torchgeo_io # ---+ Dataset Parameters +---------------------------------------- dataset_params: sampler: - module: torch.utils.data - name: RandomSampler - params: - num_samples: 32 + _target_: torch.utils.data.RandomSampler + num_samples: 32 image: transforms: - Normalise: - module: minerva.transforms - norm_value: 10000 + normalise: + _target_: minerva.transforms.Normalise + norm_value: 255 RandomApply: p: 0.25 - DetachedColorJitter: - module: minerva.transforms + jitter: + _target_: minerva.transforms.DetachedColorJitter brightness: 0.2 contrast: 0.1 saturation: 0.1 hue: 0.15 - RandomResizedCrop: - module: kornia.augmentation + resize_crop: + _target_: kornia.augmentation.RandomResizedCrop p: 0.2 size: ${patch_size} cropping_mode: resample keepdim: true - RandomHorizontalFlip: - module: kornia.augmentation + horizontal_flip: + _target_: kornia.augmentation.RandomHorizontalFlip p: 0.2 keepdim: true - RandomGaussianBlur: - module: kornia.augmentation + gaussian_blur: + _target_: kornia.augmentation.RandomGaussianBlur kernel_size: 9 p: 0.2 sigma: [0.01, 0.2] keepdim: true - RandomGaussianNoise: - module: kornia.augmentation + gaussian_noise: + _target_: kornia.augmentation.RandomGaussianNoise p: 0.2 std: 0.05 keepdim: true - RandomErasing: - module: kornia.augmentation + random_erasing: + _target_: kornia.augmentation.RandomErasing p: 0.2 keepdim: true - module: minerva.datasets - name: NonGeoSSL4EOS12Sentinel2 - paths: SSL4EO-S12 - params: - bands: [B2, B3, B4, B8] - size: ${patch_size} - max_r: *max_r + _target_: minerva.datasets.MinervaSSL4EO + root: SSL4EO-S12 + mode: s2a + bands: [B2, B3, B4, B8] + size: ${patch_size} + max_r: *max_r + season: true + season_transform: pair fit-val: - name: WeightedKNN - module: minerva.tasks + _target_: minerva.tasks.WeightedKNN train: false record_float: true sample_pairs: false n_classes: 11 # ---+ Minerva Inbuilt Logging Functions +------------------------- - task_logger: SSLTaskLogger + task_logger: minerva.logger.tasklog.SSLTaskLogger step_logger: - name: KNNStepLogger - model_io: sup_tg + _target_: minerva.logger.steplog.KNNStepLogger + model_io: minerva.modelio.supervised_torchgeo_io # ---+ Dataset Parameters +---------------------------------------- dataset_params: features: sampler: - module: torch.utils.data - name: RandomSampler - params: - num_samples: 16 + _target_: torch.utils.data.RandomSampler + num_samples: 16 image: - module: minerva.datasets - name: DFC2020 - paths: DFC/DFC2020 - params: - split: val - use_s2hr: true - labels: true + _target_: minerva.datasets.DFC2020 + root: DFC/DFC2020 + split: val + use_s2hr: true + labels: true test: sampler: - module: torch.utils.data - name: RandomSampler - params: - num_samples: 16 + _target_: torch.utils.data.RandomSampler + num_samples: 16 image: - module: minerva.datasets - name: DFC2020 - paths: DFC/DFC2020 - params: - split: test - use_s2hr: true - labels: true + _target_: minerva.datasets.DFC2020 + root: DFC/DFC2020 + split: test + use_s2hr: true + labels: true test-test: - name: StandardEpoch - module: minerva.tasks + _target_: minerva.tasks.StandardEpoch train: false record_float: true sample_pairs: false # ---+ Minerva Inbuilt Logging Functions +------------------------- - task_logger: SupervisedTaskLogger - model_io: sup_tg + task_logger: minerva.logger.tasklog.SupervisedTaskLogger + model_io: minerva.modelio.supervised_torchgeo_io # ---+ Dataset Parameters +---------------------------------------- dataset_params: sampler: - module: torchgeo.samplers - name: RandomGeoSampler + _target_: torchgeo.samplers.RandomGeoSampler roi: false - params: - size: ${patch_size} - length: 16 + size: ${patch_size} + length: 16 image: - module: minerva.datasets.__testing - name: TstImgDataset + _target_: minerva.datasets.__testing.TstImgDataset paths: NAIP - params: - res: 1.0 + res: 1.0 mask: - module: minerva.datasets.__testing - name: TstMaskDataset + _target_: minerva.datasets.__testing.TstMaskDataset paths: Chesapeake7 - params: - res: 1.0 + res: 1.0 # === PLOTTING OPTIONS ======================================================== plots: diff --git a/minerva/inbuilt_cfgs/example_GeoCLR_config.yaml b/minerva/inbuilt_cfgs/example_GeoCLR_config.yaml index f474b26fe..3cabddad4 100644 --- a/minerva/inbuilt_cfgs/example_GeoCLR_config.yaml +++ b/minerva/inbuilt_cfgs/example_GeoCLR_config.yaml @@ -47,10 +47,9 @@ val_freq: 2 # Validation epoch every ``val_freq`` training epochs. # ---+ Model Parameters +------------------------------------------------------ model_params: - module: - params: - input_size: ${input_size} - # any other params... + _target_: minerva.models.SimCLR18 + input_size: ${input_size} + # any other params... # ---+ Optimiser Parameters +-------------------------------------------------- optimiser: @@ -77,15 +76,12 @@ record_int: true # Store integer results in memory. record_float: true # Store floating point results too. Beware memory overload! # ---+ Collator +-------------------------------------------------------------- -collator: - module: torchgeo.datasets - name: stack_samples +collator: torchgeo.datasets.stack_samples # === TASKS =================================================================== tasks: fit-train: - name: StandardEpoch - module: minerva.tasks + _target_: minerva.tasks.StandardEpoch train: true record_float: true sample_pairs: true @@ -93,77 +89,70 @@ tasks: imagery_config: '${oc.create:${cfg_load: minerva/inbuilt_cfgs/dataset/NAIP.yaml}}' # yamllint disable-line rule:line-length # ---+ Minerva Inbuilt Logging Functions +------------------------- - task_logger: SSLTaskLogger + task_logger: minerva.logger.tasklog.SSLTaskLogger step_logger: - name: SSLStepLogger - model_io: ssl_pair_tg + _target_: minerva.logger.steplog.SSLStepLogger + model_io: minerva.modelio.ssl_pair_torchgeo_io # ---+ Dataset Parameters +---------------------------------------- dataset_params: sampler: - module: minerva.samplers - name: RandomPairGeoSampler + _target_: minerva.samplers.RandomPairGeoSampler roi: false - params: - size: ${patch_size} - length: 120 - max_r: *max_r + size: ${patch_size} + length: 120 + max_r: *max_r image: transforms: - Normalise: - module: minerva.transforms + normalise: + _target_: minerva.transforms.Normalise norm_value: 255 RandomApply: p: 0.25 - DetachedColorJitter: - module: minerva.transforms + jitter: + _target_: minerva.transforms.DetachedColorJitter brightness: 0.2 contrast: 0.1 saturation: 0.1 hue: 0.15 - RandomResizedCrop: - module: kornia.augmentation + resize_crop: + _target_: kornia.augmentation.RandomResizedCrop p: 0.2 size: ${patch_size} cropping_mode: resample keepdim: true - RandomHorizontalFlip: - module: kornia.augmentation + horizontal_flip: + _target_: kornia.augmentation.RandomHorizontalFlip p: 0.2 keepdim: true - RandomGaussianBlur: - module: kornia.augmentation + gaussian_blur: + _target_: kornia.augmentation.RandomGaussianBlur kernel_size: 9 p: 0.2 sigma: [0.01, 0.2] keepdim: true - RandomGaussianNoise: - module: kornia.augmentation + gaussian_noise: + _target_: kornia.augmentation.RandomGaussianNoise p: 0.2 std: 0.05 keepdim: true - RandomErasing: - module: kornia.augmentation + random_erasing: + _target_: kornia.augmentation.RandomErasing p: 0.2 keepdim: true - - image1: - module: minerva.datasets.__testing - name: TstImgDataset - paths: NAIP - params: + subdatasets: + image1: + _target_: minerva.datasets.__testing.TstImgDataset + paths: NAIP res: 1.0 - image2: - module: minerva.datasets.__testing - name: TstImgDataset - paths: NAIP - params: + image2: + _target_: minerva.datasets.__testing.TstImgDataset + paths: NAIP res: 1.0 fit-val: - name: WeightedKNN - module: minerva.tasks + _target_: minerva.tasks.WeightedKNN train: false sample_pairs: false n_classes: 8 @@ -172,62 +161,49 @@ tasks: data_config: '${oc.create:${cfg_load: minerva/inbuilt_cfgs/dataset/Chesapeake7.yaml}}' # yamllint disable-line rule:line-length # ---+ Minerva Inbuilt Logging Functions +------------------------- - task_logger: SSLTaskLogger + task_logger: minerva.logger.tasklog.SSLTaskLogger step_logger: - name: KNNStepLogger - model_io: ssl_pair_tg + _target_: minerva.logger.steplog.KNNStepLogger + model_io: minerva.modelio.ssl_pair_torchgeo_io # ---+ Dataset Parameters +---------------------------------------- dataset_params: features: sampler: - module: torchgeo.samplers - name: RandomGeoSampler + _target_: torchgeo.samplers.RandomGeoSampler roi: false - params: - size: ${patch_size} - length: 32 + size: ${patch_size} + length: 32 image: - module: minerva.datasets.__testing - name: TstImgDataset + _target_: minerva.datasets.__testing.TstImgDataset paths: NAIP - params: - res: 1.0 + res: 1.0 mask: - module: minerva.datasets.__testing - name: TstMaskDataset + _target_: minerva.datasets.__testing.TstMaskDataset paths: Chesapeake7 - params: - res: 1.0 + res: 1.0 test: sampler: - module: torchgeo.samplers - name: RandomGeoSampler + _target_: torchgeo.samplers.RandomGeoSampler roi: false - params: - size: ${patch_size} - length: 32 + size: ${patch_size} + length: 32 image: - module: minerva.datasets.__testing - name: TstImgDataset + _target_: minerva.datasets.__testing.TstImgDataset paths: NAIP - params: - res: 1.0 + res: 1.0 mask: - module: minerva.datasets.__testing - name: TstMaskDataset + _target_: minerva.datasets.__testing.TstMaskDataset paths: Chesapeake7 - params: - res: 1.0 + res: 1.0 test-test: - name: StandardEpoch - module: minerva.tasks + _target_: minerva.tasks.StandardEpoch train: false record_float: true sample_pairs: false @@ -237,32 +213,26 @@ tasks: data_config: '${oc.create:${cfg_load: minerva/inbuilt_cfgs/dataset/Chesapeake7.yaml}}' # yamllint disable-line rule:line-length # ---+ Minerva Inbuilt Logging Functions +------------------------- - task_logger: SupervisedTaskLogger - model_io: sup_tg + task_logger: minerva.logger.tasklog.SupervisedTaskLogger + model_io: minerva.modelio.supervised_torchgeo_io # ---+ Dataset Parameters +---------------------------------------- dataset_params: sampler: - module: torchgeo.samplers - name: RandomGeoSampler + _target_: torchgeo.samplers.RandomGeoSampler roi: false - params: - size: ${patch_size} - length: 32 + size: ${patch_size} + length: 32 image: - module: minerva.datasets.__testing - name: TstImgDataset + _target_: minerva.datasets.__testing.TstImgDataset paths: NAIP - params: - res: 1.0 + res: 1.0 mask: - module: minerva.datasets.__testing - name: TstMaskDataset + _target_: minerva.datasets.__testing.TstMaskDataset paths: Chesapeake7 - params: - res: 1.0 + res: 1.0 # === PLOTTING OPTIONS ======================================================== plots: diff --git a/minerva/inbuilt_cfgs/example_GeoSimConvNet.yaml b/minerva/inbuilt_cfgs/example_GeoSimConvNet.yaml index b9536e839..8305e6bab 100644 --- a/minerva/inbuilt_cfgs/example_GeoSimConvNet.yaml +++ b/minerva/inbuilt_cfgs/example_GeoSimConvNet.yaml @@ -49,9 +49,8 @@ val_freq: 1 # Validation epoch every ``val_freq`` training epochs. # ---+ Model Parameters +------------------------------------------------------ model_params: - module: - params: - input_size: ${input_size} + _target_: minerva.models.SimConv + input_size: ${input_size} # ---+ Optimiser Parameters +-------------------------------------------------- optimiser: @@ -59,12 +58,11 @@ optimiser: lr: ${lr} # ---+ Scheduler Parameters +-------------------------------------------------- -scheduler_params: - name: LinearLR - params: - start_factor: *lr - end_factor: 5.0E-4 - total_iters: 2 +scheduler: + _target_: torch.optim.lr_scheduler.LinearLR + start_factor: ${lr} + end_factor: 5.0E-4 + total_iters: 2 # ---+ Loss Function Parameters +---------------------------------------------- loss_params: @@ -84,81 +82,73 @@ record_int: true # Store integer results in memory. record_float: false # Store floating point results too. Beware memory overload! # ---+ Collator +-------------------------------------------------------------- -collator: - module: torchgeo.datasets - name: stack_samples +collator: torchgeo.datasets.stack_samples # === TASKS =================================================================== tasks: fit-train: - name: StandardEpoch - module: minerva.tasks + _target_: minerva.tasks.StandardEpoch train: true record_float: true sample_pairs: true # ---+ Minerva Inbuilt Logging Functions +------------------------- - task_logger: SSLTaskLogger - model_io: ssl_pair_tg + task_logger: minerva.logger.tasklog.SSLTaskLogger + model_io: minerva.modelio.ssl_pair_torchgeo_io # ---+ Dataset Parameters +---------------------------------------- dataset_params: sampler: - module: torch.utils.data - name: RandomSampler - params: - num_samples: 32 + _target_: torch.utils.data.RandomSampler + num_samples: 32 image: transforms: - Normalise: - module: minerva.transforms - norm_value: 10000 + normalise: + _target_: minerva.transforms.Normalise + norm_value: 255 RandomApply: p: 0.25 - DetachedColorJitter: - module: minerva.transforms + jitter: + _target_: minerva.transforms.DetachedColorJitter brightness: 0.2 contrast: 0.1 saturation: 0.1 hue: 0.15 - RandomResizedCrop: - module: kornia.augmentation + resize_crop: + _target_: kornia.augmentation.RandomResizedCrop p: 0.2 size: ${patch_size} cropping_mode: resample keepdim: true - RandomHorizontalFlip: - module: kornia.augmentation + horizontal_flip: + _target_: kornia.augmentation.RandomHorizontalFlip p: 0.2 keepdim: true - RandomGaussianBlur: - module: kornia.augmentation + gaussian_blur: + _target_: kornia.augmentation.RandomGaussianBlur kernel_size: 9 p: 0.2 sigma: [0.01, 0.2] keepdim: true - RandomGaussianNoise: - module: kornia.augmentation + gaussian_noise: + _target_: kornia.augmentation.RandomGaussianNoise p: 0.2 std: 0.05 keepdim: true - RandomErasing: - module: kornia.augmentation + random_erasing: + _target_: kornia.augmentation.RandomErasing p: 0.2 keepdim: true - module: minerva.datasets - name: NonGeoSSL4EOS12Sentinel2 - paths: SSL4EO-S12 - params: - bands: [B2, B3, B4, B8] - size: ${patch_size} - max_r: *max_r + _target_: minerva.datasets.NonGeoSSL4EOS12Sentinel2 + root: SSL4EO-S12 + bands: [B2, B3, B4, B8] + size: ${patch_size} + max_r: *max_r fit-val: - name: WeightedKNN - module: minerva.tasks + _target_: minerva.tasks.WeightedKNN train: false record_float: true sample_pairs: false @@ -167,62 +157,49 @@ tasks: data_config: '${oc.create:${cfg_load: minerva/inbuilt_cfgs/dataset/Chesapeake7.yaml}}' # yamllint disable-line rule:line-length # ---+ Minerva Inbuilt Logging Functions +------------------------- - task_logger: SSLTaskLogger + task_logger: minerva.logger.tasklog.SSLTaskLogger step_logger: - name: KNNStepLogger - model_io: ssl_pair_tg + _target_: minerva.logger.steplog.KNNStepLogger + model_io: minerva.modelio.ssl_pair_torchgeo_io # ---+ Dataset Parameters +---------------------------------------- dataset_params: features: sampler: - module: torchgeo.samplers - name: RandomGeoSampler + _target_: torchgeo.samplers.RandomGeoSampler roi: false - params: - size: ${patch_size} - length: 16 + size: ${patch_size} + length: 16 image: - module: minerva.datasets.__testing - name: TstImgDataset + _target_: minerva.datasets.__testing.TstImgDataset paths: NAIP - params: - res: 1.0 + res: 1.0 mask: - module: minerva.datasets.__testing - name: TstMaskDataset + _target_: minerva.datasets.__testing.TstMaskDataset paths: Chesapeake7 - params: - res: 1.0 + res: 1.0 test: sampler: - module: torchgeo.samplers - name: RandomGeoSampler + _target_: torchgeo.samplers.RandomGeoSampler roi: false - params: - size: ${patch_size} - length: 16 + size: ${patch_size} + length: 16 image: - module: minerva.datasets.__testing - name: TstImgDataset + _target_: minerva.datasets.__testing.TstImgDataset paths: NAIP - params: - res: 1.0 + res: 1.0 mask: - module: minerva.datasets.__testing - name: TstMaskDataset + _target_: minerva.datasets.__testing.TstMaskDataset paths: Chesapeake7 - params: - res: 1.0 + res: 1.0 test-test: - name: StandardEpoch - module: minerva.tasks + _target_: minerva.tasks.StandardEpoch train: false record_float: true sample_pairs: false @@ -231,32 +208,26 @@ tasks: data_config: '${oc.create:${cfg_load: minerva/inbuilt_cfgs/dataset/Chesapeake7.yaml}}' # yamllint disable-line rule:line-length # ---+ Minerva Inbuilt Logging Functions +------------------------- - task_logger: SSLTaskLogger - model_io: ssl_pair_tg + task_logger: minerva.logger.tasklog.SSLTaskLogger + model_io: minerva.modelio.ssl_pair_torchgeo_io # ---+ Dataset Parameters +---------------------------------------- dataset_params: sampler: - module: torchgeo.samplers - name: RandomGeoSampler + _target_: torchgeo.samplers.RandomGeoSampler roi: false - params: - size: ${patch_size} - length: 16 + size: ${patch_size} + length: 16 image: - module: minerva.datasets.__testing - name: TstImgDataset + _target_: minerva.datasets.__testing.TstImgDataset paths: NAIP - params: - res: 1.0 + res: 1.0 mask: - module: minerva.datasets.__testing - name: TstMaskDataset + _target_: minerva.datasets.__testing.TstMaskDataset paths: Chesapeake7 - params: - res: 1.0 + res: 1.0 # === PLOTTING OPTIONS ======================================================== plots: diff --git a/minerva/inbuilt_cfgs/example_MultiLabel.yaml b/minerva/inbuilt_cfgs/example_MultiLabel.yaml index ed33d12fb..b0e120c20 100644 --- a/minerva/inbuilt_cfgs/example_MultiLabel.yaml +++ b/minerva/inbuilt_cfgs/example_MultiLabel.yaml @@ -24,7 +24,7 @@ model_type: multilabel-scene-classifier batch_size: 2 # Number of samples in each batch. input_size: [4, 120, 120] # patch_size plus leading channel dim. patch_size: '${to_patch_size: ${input_size}}' # 2D tuple or float. -n_classes: &n_classes 19 # Number of classes in dataset. +n_classes: 19 # Number of classes in dataset. # ---+ Experiment Execution +-------------------------------------------------- max_epochs: 2 # Maximum number of training epochs. @@ -41,26 +41,25 @@ optim_func: SGD # Name of the optimiser function. # ---+ Model Parameters +------------------------------------------------------ model_params: - module: - params: + _target_: minerva.models.FlexiSceneClassifier + input_size: ${input_size} + n_classes: ${n_classes} + encoder_on: true + filter_dim: -1 + fc_dim: 512 + freeze_backbone: false + clamp_outputs: true + backbone_args: + module: minerva.models + name: MinervaPSP input_size: ${input_size} - n_classes: *n_classes - encoder_on: true - filter_dim: -1 - fc_dim: 512 - freeze_backbone: false - clamp_outputs: true - backbone_args: - module: minerva.models - name: MinervaPSP - input_size: ${input_size} - n_classes: 1 - encoder_name: resnet18 - encoder_weights: - psp_out_channels: 512 - segmentation_on: false - classification_on: false - encoder: false + n_classes: 1 + encoder_name: resnet18 + encoder_weights: + psp_out_channels: 512 + segmentation_on: false + classification_on: false + encoder: false # ---+ Optimiser Parameters +-------------------------------------------------- optimiser: @@ -88,15 +87,12 @@ record_int: true # Store integer results in memory. record_float: false # Store floating point results too. Beware memory overload! # ---+ Collator +-------------------------------------------------------------- -collator: - module: torchgeo.datasets - name: stack_samples +collator: torchgeo.datasets.stack_samples # === TASKS =================================================================== tasks: fit-train: - name: StandardEpoch - module: minerva.tasks + _target_: minerva.tasks.StandardEpoch train: true record_float: true cache_dataset: false @@ -106,71 +102,66 @@ tasks: # ---+ Dataset Parameters +-------------------------------------------- dataset_params: sampler: - module: torch.utils.data - name: RandomSampler - params: - num_samples: 32 + _target_: torch.utils.data.RandomSampler + num_samples: 32 image: transforms: SelectChannels: - module: minerva.transforms + _target_: minerva.transforms.SelectChannels channels: - 1 - 2 - 3 - 7 Normalise: - module: minerva.transforms + _target_: minerva.transforms.Normalise norm_value: 4095 RandomApply: p: 0.25 DetachedColorJitter: - module: minerva.transforms + _target_: minerva.transforms.DetachedColorJitter brightness: 0.2 contrast: 0.1 saturation: 0.1 hue: 0.15 - RandomResizedCrop: - module: kornia.augmentation + resize_crop: + _target_: kornia.augmentation.RandomResizedCrop p: 0.2 size: ${patch_size} cropping_mode: resample keepdim: true - RandomHorizontalFlip: - module: kornia.augmentation + horizontal_flip: + _target_: kornia.augmentation.RandomHorizontalFlip p: 0.2 keepdim: true - RandomGaussianBlur: - module: kornia.augmentation + gaussian_blur: + _target_: kornia.augmentation.RandomGaussianBlur kernel_size: 9 p: 0.2 sigma: [0.01, 0.2] keepdim: true - RandomGaussianNoise: - module: kornia.augmentation + gaussian_noise: + _target_: kornia.augmentation.RandomGaussianNoise p: 0.2 std: 0.05 keepdim: true - RandomErasing: - module: kornia.augmentation + random_erasing: + _target_: kornia.augmentation.RandomErasing p: 0.2 keepdim: true - module: torchgeo.datasets - name: BigEarthNet - paths: BigEarthNet-mini - params: - split: train - bands: s2 - download: true - num_classes: *n_classes + _target_: torchgeo.datasets.BigEarthNet + root: BigEarthNet-mini + split: train + bands: s2 + download: true + num_classes: ${n_classes} label: test-test: - name: StandardEpoch - module: minerva.tasks + _target_: minerva.tasks.StandardEpoch train: false record_float: true cache_dataset: false @@ -180,32 +171,28 @@ tasks: # ---+ Dataset Parameters +-------------------------------------------- dataset_params: sampler: - module: torch.utils.data - name: RandomSampler - params: - num_samples: 16 + _target_: torch.utils.data.RandomSampler + num_samples: 16 image: transforms: SelectChannels: - module: minerva.transforms + _target_: minerva.transforms.SelectChannels channels: - 1 - 2 - 3 - 7 Normalise: - module: minerva.transforms + _target_: minerva.transforms.Normalise norm_value: 4095 - module: torchgeo.datasets - name: BigEarthNet - paths: BigEarthNet-mini - params: - split: test - bands: s2 - download: true - num_classes: *n_classes + _target_: torchgeo.datasets.BigEarthNet + root: BigEarthNet-mini + split: test + bands: s2 + download: true + num_classes: ${n_classes} label: diff --git a/minerva/inbuilt_cfgs/example_PSP.yaml b/minerva/inbuilt_cfgs/example_PSP.yaml index a8483852a..1359f066d 100644 --- a/minerva/inbuilt_cfgs/example_PSP.yaml +++ b/minerva/inbuilt_cfgs/example_PSP.yaml @@ -41,16 +41,15 @@ optim_func: SGD # Name of the optimiser function. # ---+ Model Parameters +------------------------------------------------------ model_params: - module: - params: - input_size: ${input_size} - n_classes: ${n_classes} - encoder: true - segmentation_on: true - classification_on: true - upsampling: 8 - aux_params: - classes: ${n_classes} + _target_: minerva.models.MinervaPSP + input_size: ${input_size} + n_classes: ${n_classes} + encoder: true + segmentation_on: true + classification_on: true + upsampling: 8 + aux_params: + classes: ${n_classes} # ---+ Optimiser Parameters +-------------------------------------------------- optimiser: @@ -78,15 +77,12 @@ record_int: true # Store integer results in memory. record_float: false # Store floating point results too. Beware memory overload! # ---+ Collator +-------------------------------------------------------------- -collator: - module: torchgeo.datasets - name: stack_samples +collator: torchgeo.datasets.stack_samples # === TASKS =================================================================== tasks: fit-train: - name: StandardEpoch - module: minerva.tasks + _target_: minerva.tasks.StandardEpoch train: true record_float: true @@ -95,68 +91,64 @@ tasks: # ---+ Dataset Parameters +-------------------------------------------- dataset_params: sampler: - module: torch.utils.data - name: RandomSampler - params: - num_samples: 32 + _target_: torch.utils.data.RandomSampler + num_samples: 32 transforms: RandomResizedCrop: - module: kornia.augmentation + _target_: kornia.augmentation.RandomResizedCrop p: 1.0 size: ${patch_size} cropping_mode: resample keepdim: true RandomHorizontalFlip: - module: kornia.augmentation + _target_: kornia.augmentation.RandomHorizontalFlip p: 0.2 keepdim: true image: transforms: - Normalise: - module: minerva.transforms + normalise: + _target_: minerva.transforms.Normalise norm_value: 4095 RandomApply: p: 0.25 - DetachedColorJitter: - module: minerva.transforms + jitter: + _target_: minerva.transforms.DetachedColorJitter brightness: 0.2 contrast: 0.1 saturation: 0.1 hue: 0.15 - RandomGaussianBlur: - module: kornia.augmentation + gaussian_blur: + _target_: kornia.augmentation.RandomGaussianBlur kernel_size: 9 p: 0.2 sigma: [0.01, 0.2] keepdim: true - RandomGaussianNoise: - module: kornia.augmentation + gaussian_noise: + _target_: kornia.augmentation.RandomGaussianNoise p: 0.2 std: 0.05 keepdim: true - RandomErasing: - module: kornia.augmentation + random_erasing: + _target_: kornia.augmentation.RandomErasing p: 0.2 keepdim: true - module: minerva.datasets - name: DFC2020 - paths: DFC/DFC2020 - params: - split: test - use_s2hr: true - labels: true + _target_: minerva.datasets.DFC2020 + root: DFC/DFC2020 + split: test + use_s2hr: true + labels: true mask: transforms: MaskResize: - module: minerva.transforms + _target_: minerva.transforms.MaskResize size: 64 + interpolation: NEAREST test-test: - name: StandardEpoch - module: minerva.tasks + _target_: minerva.tasks.StandardEpoch train: false record_float: true @@ -165,30 +157,27 @@ tasks: # ---+ Dataset Parameters +-------------------------------------------- dataset_params: sampler: - module: torch.utils.data - name: RandomSampler - params: - num_samples: 16 + _target_: torch.utils.data.RandomSampler + num_samples: 16 image: transforms: Normalise: - module: minerva.transforms + _target_: minerva.transforms.Normalise norm_value: 4095 - module: minerva.datasets - name: DFC2020 - paths: DFC/DFC2020 - params: - split: test - use_s2hr: true - labels: true + _target_: minerva.datasets.DFC2020 + root: DFC/DFC2020 + split: test + use_s2hr: true + labels: true mask: transforms: MaskResize: - module: minerva.transforms + _target_: minerva.transforms.MaskResize size: 64 + interpolation: NEAREST # === PLOTTING OPTIONS ======================================================== plots: diff --git a/minerva/inbuilt_cfgs/example_SceneClassifier.yaml b/minerva/inbuilt_cfgs/example_SceneClassifier.yaml index 5f35abcdb..716355437 100644 --- a/minerva/inbuilt_cfgs/example_SceneClassifier.yaml +++ b/minerva/inbuilt_cfgs/example_SceneClassifier.yaml @@ -24,7 +24,7 @@ model_type: scene-classifier batch_size: 2 # Number of samples in each batch. input_size: [4, 64, 64] # patch_size plus leading channel dim. patch_size: '${to_patch_size: ${input_size}}' # 2D tuple or float. -n_classes: &n_classes 10 # Number of classes in dataset. +n_classes: 10 # Number of classes in dataset. # ---+ Experiment Execution +-------------------------------------------------- max_epochs: 2 # Maximum number of training epochs. @@ -41,25 +41,24 @@ optim_func: SGD # Name of the optimiser function. # ---+ Model Parameters +------------------------------------------------------ model_params: - module: - params: + _target_: minerva.models.FlexiSceneClassifier + input_size: ${input_size} + n_classes: ${n_classes} + encoder_on: true + filter_dim: -1 + fc_dim: 512 + freeze_backbone: false + backbone_args: + module: minerva.models + name: MinervaPSP input_size: ${input_size} - n_classes: *n_classes - encoder_on: true - filter_dim: -1 - fc_dim: 512 - freeze_backbone: false - backbone_args: - module: minerva.models - name: MinervaPSP - input_size: ${input_size} - n_classes: 1 - encoder_name: resnet18 - encoder_weights: - psp_out_channels: 512 - segmentation_on: false - classification_on: false - encoder: false + n_classes: 1 + encoder_name: resnet18 + encoder_weights: + psp_out_channels: 512 + segmentation_on: false + classification_on: false + encoder: false # ---+ Optimiser Parameters +-------------------------------------------------- optimiser: @@ -87,15 +86,12 @@ record_int: true # Store integer results in memory. record_float: false # Store floating point results too. Beware memory overload! # ---+ Collator +-------------------------------------------------------------- -collator: - module: torchgeo.datasets - name: stack_samples +collator: torchgeo.datasets.stack_samples # === TASKS =================================================================== tasks: fit-train: - name: StandardEpoch - module: minerva.tasks + _target_: minerva.tasks.StandardEpoch train: true record_float: true @@ -104,67 +100,62 @@ tasks: # ---+ Dataset Parameters +-------------------------------------------- dataset_params: sampler: - module: torch.utils.data - name: RandomSampler - params: - num_samples: 32 + _target_: torch.utils.data.RandomSampler + num_samples: 32 image: transforms: - Normalise: - module: minerva.transforms - norm_value: 4095 + normalise: + _target_: minerva.transforms.Normalise + norm_value: 255 RandomApply: p: 0.25 - DetachedColorJitter: - module: minerva.transforms + jitter: + _target_: minerva.transforms.DetachedColorJitter brightness: 0.2 contrast: 0.1 saturation: 0.1 hue: 0.15 - RandomResizedCrop: - module: kornia.augmentation + resize_crop: + _target_: kornia.augmentation.RandomResizedCrop p: 0.2 size: ${patch_size} cropping_mode: resample keepdim: true - RandomHorizontalFlip: - module: kornia.augmentation + horizontal_flip: + _target_: kornia.augmentation.RandomHorizontalFlip p: 0.2 keepdim: true - RandomGaussianBlur: - module: kornia.augmentation + gaussian_blur: + _target_: kornia.augmentation.RandomGaussianBlur kernel_size: 9 p: 0.2 sigma: [0.01, 0.2] keepdim: true - RandomGaussianNoise: - module: kornia.augmentation + gaussian_noise: + _target_: kornia.augmentation.RandomGaussianNoise p: 0.2 std: 0.05 keepdim: true - RandomErasing: - module: kornia.augmentation + random_erasing: + _target_: kornia.augmentation.RandomErasing p: 0.2 keepdim: true - module: torchgeo.datasets - name: EuroSAT100 - paths: EuroSAT100 - params: - split: train - bands: - - B02 - - B03 - - B04 - - B08 - download: true + _target_: torchgeo.datasets.EuroSAT100 + root: EuroSAT100 + split: train + bands: + - B02 + - B03 + - B04 + - B08 + download: true label: test-test: - name: StandardEpoch - module: minerva.tasks + _target_: minerva.tasks.StandardEpoch train: false record_float: true @@ -173,28 +164,24 @@ tasks: # ---+ Dataset Parameters +-------------------------------------------- dataset_params: sampler: - module: torch.utils.data - name: RandomSampler - params: - num_samples: 16 + _target_: torch.utils.data.RandomSampler + num_samples: 16 image: transforms: Normalise: - module: minerva.transforms + _target_: minerva.transforms.Normalise norm_value: 4095 - module: torchgeo.datasets - name: EuroSAT100 - paths: EuroSAT100 - params: - split: test - bands: - - B02 - - B03 - - B04 - - B08 - download: true + _target_: torchgeo.datasets.EuroSAT100 + root: EuroSAT100 + split: test + bands: + - B02 + - B03 + - B04 + - B08 + download: true label: diff --git a/minerva/inbuilt_cfgs/example_UNetR_config.yaml b/minerva/inbuilt_cfgs/example_UNetR_config.yaml index b720f76ba..338488294 100644 --- a/minerva/inbuilt_cfgs/example_UNetR_config.yaml +++ b/minerva/inbuilt_cfgs/example_UNetR_config.yaml @@ -24,10 +24,10 @@ model_type: segmentation batch_size: 8 # Number of samples in each batch. input_size: [4, 224, 224] # patch_size plus leading channel dim. patch_size: '${to_patch_size: ${input_size}}' # 2D tuple or float. -n_classes: &n_classes 8 # Number of classes in dataset. +n_classes: 8 # Number of classes in dataset. # ---+ Experiment Execution +-------------------------------------------------- -max_epochs: 5 # Maximum number of training epochs. +max_epochs: 2 # Maximum number of training epochs. elim: true # Eliminates empty classes from schema. balance: true # Balances dataset classes. pre_train: false # Activate pre-training mode. @@ -41,12 +41,11 @@ optim_func: SGD # Name of the optimiser function. # ---+ Model Parameters +------------------------------------------------------ model_params: - module: - params: - input_size: ${input_size} - n_classes: *n_classes - backbone_kwargs: - torch_weights: true + _target_: minerva.models.UNetR18 + input_size: ${input_size} + n_classes: ${n_classes} + backbone_kwargs: + torch_weights: true # ---+ Optimiser Parameters +-------------------------------------------------- optimiser: @@ -67,25 +66,16 @@ wandb_log: true # Activates wandb logging. project: pytest # Define the project name for wandb. wandb_dir: /test/tmp/wandb # Directory to store wandb logs locally. -# === MODEL IO & LOGGING ====================================================== -# ---+ Minerva Inbuilt Logging Functions +------------------------------------- -# task_logger: SupervisedTaskLogger -# step_logger: SupervisedGeoStepLogger -# model_io: sup_tg - record_int: true # Store integer results in memory. record_float: false # Store floating point results too. Beware memory overload! # ---+ Collator +-------------------------------------------------------------- -collator: - module: torchgeo.datasets - name: stack_samples +collator: torchgeo.datasets.stack_samples # === TASKS =================================================================== tasks: fit-train: - name: StandardEpoch - module: minerva.tasks + _target_: minerva.tasks.StandardEpoch train: true record_float: true @@ -95,33 +85,26 @@ tasks: # ---+ Dataset Parameters +---------------------------------------- dataset_params: sampler: - module: torchgeo.samplers - name: RandomGeoSampler + _target_: torchgeo.samplers.RandomGeoSampler roi: false - params: - size: ${patch_size} - length: 320 + size: ${patch_size} + length: 320 image: transforms: Normalise: - module: minerva.transforms + _target_: minerva.transforms.Normalise norm_value: 255 - module: minerva.datasets.__testing - name: TstImgDataset + _target_: minerva.datasets.__testing.TstImgDataset paths: NAIP - params: - res: 1.0 + res: 1.0 mask: - module: minerva.datasets.__testing - name: TstMaskDataset + _target_: minerva.datasets.__testing.TstMaskDataset paths: Chesapeake7 - params: - res: 1.0 + res: 1.0 fit-val: - name: StandardEpoch - module: minerva.tasks + _target_: minerva.tasks.StandardEpoch train: true record_float: true @@ -131,33 +114,26 @@ tasks: # ---+ Dataset Parameters +---------------------------------------- dataset_params: sampler: - module: torchgeo.samplers - name: RandomGeoSampler + _target_: torchgeo.samplers.RandomGeoSampler roi: false - params: - size: ${patch_size} - length: 80 + size: ${patch_size} + length: 80 image: transforms: Normalise: - module: minerva.transforms + _target_: minerva.transforms.Normalise norm_value: 255 - module: minerva.datasets.__testing - name: TstImgDataset + _target_: minerva.datasets.__testing.TstImgDataset paths: NAIP - params: - res: 1.0 + res: 1.0 mask: - module: minerva.datasets.__testing - name: TstMaskDataset + _target_: minerva.datasets.__testing.TstMaskDataset paths: Chesapeake7 - params: - res: 1.0 + res: 1.0 test-test: - name: StandardEpoch - module: minerva.tasks + _target_: minerva.tasks.StandardEpoch train: true record_float: true @@ -167,28 +143,22 @@ tasks: # ---+ Dataset Parameters +---------------------------------------- dataset_params: sampler: - module: torchgeo.samplers - name: RandomGeoSampler + _target_: torchgeo.samplers.RandomGeoSampler roi: false - params: - size: ${patch_size} - length: 160 + size: ${patch_size} + length: 160 image: transforms: Normalise: - module: minerva.transforms + _target_: minerva.transforms.Normalise norm_value: 255 - module: minerva.datasets.__testing - name: TstImgDataset + _target_: minerva.datasets.__testing.TstImgDataset paths: NAIP - params: - res: 1.0 + res: 1.0 mask: - module: minerva.datasets.__testing - name: TstMaskDataset + _target_: minerva.datasets.__testing.TstMaskDataset paths: Chesapeake7 - params: - res: 1.0 + res: 1.0 # === PLOTTING OPTIONS ======================================================== plots: diff --git a/minerva/inbuilt_cfgs/example_autoencoder_config.yaml b/minerva/inbuilt_cfgs/example_autoencoder_config.yaml index cb162c699..78e45157d 100644 --- a/minerva/inbuilt_cfgs/example_autoencoder_config.yaml +++ b/minerva/inbuilt_cfgs/example_autoencoder_config.yaml @@ -41,12 +41,11 @@ optim_func: SGD # Name of the optimiser function. # ---+ Model Parameters +------------------------------------------------------ model_params: - module: - params: - input_size: ${input_size} - n_classes: ${n_classes} - backbone_kwargs: - torch_weights: false + _target_: minerva.models.UNetR18 + input_size: ${input_size} + n_classes: ${n_classes} + backbone_kwargs: + torch_weights: false # ---+ Optimiser Parameters +-------------------------------------------------- optimiser: @@ -69,24 +68,21 @@ wandb_dir: /test/tmp/wandb # Directory to store wandb logs locally. # === MODEL IO & LOGGING ====================================================== # ---+ Minerva Inbuilt Logging Functions +------------------------------------- -task_logger: SupervisedTaskLogger +task_logger: minerva.logger.tasklog.SupervisedTaskLogger step_logger: - name: SupervisedStepLogger -model_io: autoencoder_io + _target_: minerva.logger.steplog.SupervisedStepLogger +model_io: minerva.modelio.autoencoder_io record_int: true # Store integer results in memory. record_float: false # Store floating point results too. Beware memory overload! # ---+ Collator +-------------------------------------------------------------- -collator: - module: torchgeo.datasets - name: stack_samples +collator: torchgeo.datasets.stack_samples # === TASKS =================================================================== tasks: fit-train: - name: StandardEpoch - module: minerva.tasks + _target_: minerva.tasks.StandardEpoch train: true record_float: true autoencoder_data_key: mask @@ -97,34 +93,27 @@ tasks: # ---+ Dataset Parameters +-------------------------------------------- dataset_params: sampler: - module: torchgeo.samplers - name: RandomGeoSampler + _target_: torchgeo.samplers.RandomGeoSampler roi: false - params: - size: ${patch_size} - length: 36 + size: ${patch_size} + length: 36 image: transforms: Normalise: - module: minerva.transforms + _target_: minerva.transforms.Normalise norm_value: 255 - module: minerva.datasets.__testing - name: TstImgDataset + _target_: minerva.datasets.__testing.TstImgDataset paths: NAIP - params: - res: 1.0 + res: 1.0 mask: - module: minerva.datasets.__testing - name: TstMaskDataset + _target_: minerva.datasets.__testing.TstMaskDataset paths: Chesapeake7 - params: - res: 1.0 + res: 1.0 fit-val: - name: StandardEpoch - module: minerva.tasks + _target_: minerva.tasks.StandardEpoch train: false record_float: true autoencoder_data_key: mask @@ -135,34 +124,27 @@ tasks: # ---+ Dataset Parameters +-------------------------------------------- dataset_params: sampler: - module: torchgeo.samplers - name: RandomGeoSampler + _target_: torchgeo.samplers.RandomGeoSampler roi: false - params: - size: ${patch_size} - length: 18 + size: ${patch_size} + length: 18 image: transforms: Normalise: - module: minerva.transforms + _target_: minerva.transforms.Normalise norm_value: 255 - module: minerva.datasets.__testing - name: TstImgDataset + _target_: minerva.datasets.__testing.TstImgDataset paths: NAIP - params: - res: 1.0 + res: 1.0 mask: - module: minerva.datasets.__testing - name: TstMaskDataset + _target_: minerva.datasets.__testing.TstMaskDataset paths: Chesapeake7 - params: - res: 1.0 + res: 1.0 test-test: - name: StandardEpoch - module: minerva.tasks + _target_: minerva.tasks.StandardEpoch train: false record_float: true autoencoder_data_key: mask @@ -173,30 +155,24 @@ tasks: # ---+ Dataset Parameters +-------------------------------------------- dataset_params: sampler: - module: torchgeo.samplers - name: RandomGeoSampler + _target_: torchgeo.samplers.RandomGeoSampler roi: false - params: - size: ${patch_size} - length: 18 + size: ${patch_size} + length: 18 image: transforms: Normalise: - module: minerva.transforms + _target_: minerva.transforms.Normalise norm_value: 255 - module: minerva.datasets.__testing - name: TstImgDataset + _target_: minerva.datasets.__testing.TstImgDataset paths: NAIP - params: - res: 1.0 + res: 1.0 mask: - module: minerva.datasets.__testing - name: TstMaskDataset + _target_: minerva.datasets.__testing.TstMaskDataset paths: Chesapeake7 - params: - res: 1.0 + res: 1.0 # === PLOTTING OPTIONS ======================================================== plots: diff --git a/minerva/inbuilt_cfgs/example_config.yaml b/minerva/inbuilt_cfgs/example_config.yaml index 65991055e..b1370f014 100644 --- a/minerva/inbuilt_cfgs/example_config.yaml +++ b/minerva/inbuilt_cfgs/example_config.yaml @@ -14,7 +14,7 @@ cache_dir: tests/tmp/cache # === HYPERPARAMETERS ========================================================= # ---+ Model Specification +--------------------------------------------------- -# Name of model. Substring before hyphen is model class. +# Name of model. This no longer used for model class (see model_params). model_name: FCN32ResNet18-test # Type of model. Can be mlp, scene classifier, segmentation, ssl or siamese. @@ -40,11 +40,10 @@ optim_func: SGD # Name of the optimiser function. # ---+ Model Parameters +------------------------------------------------------ model_params: - module: - params: - input_size: ${input_size} - n_classes: *n_classes - # any other params... + _target_: minerva.models.FCN32ResNet18 + input_size: ${input_size} + n_classes: ${n_classes} + # any other params... # ---+ Optimiser Parameters +-------------------------------------------------- optimiser: @@ -52,12 +51,11 @@ optimiser: lr: ${lr} # ---+ Scheduler Parameters +-------------------------------------------------- -scheduler_params: - name: LinearLR - params: - start_factor: 1.0 - end_factor: 0.5 - total_iters: 5 +scheduler: + _target_: torch.optim.lr_scheduler.LinearLR + start_factor: 1.0 + end_factor: 0.5 + total_iters: 5 # ---+ Loss Function Parameters +---------------------------------------------- loss_params: @@ -75,15 +73,12 @@ project: pytest # Define the project name for wandb. wandb_dir: /test/tmp/wandb # Directory to store wandb logs locally. # ---+ Collator +-------------------------------------------------------------- -collator: - module: torchgeo.datasets - name: stack_samples +collator: torchgeo.datasets.stack_samples # === TASKS =================================================================== tasks: fit-train: - name: StandardEpoch - module: minerva.tasks + _target_: minerva.tasks.StandardEpoch train: true record_float: true @@ -93,40 +88,32 @@ tasks: # ---+ Dataset Parameters +---------------------------------------- dataset_params: sampler: - module: torchgeo.samplers - name: RandomGeoSampler + _target_: torchgeo.samplers.RandomGeoSampler roi: false - params: - size: ${patch_size} - length: 120 + size: ${patch_size} + length: 120 image: transforms: false - images_1: - module: minerva.datasets.__testing - name: TstImgDataset - paths: NAIP - params: + subdatasets: + images_1: + _target_: minerva.datasets.__testing.TstImgDataset + paths: NAIP res: 1.0 - image2: - module: minerva.datasets.__testing - name: TstImgDataset - paths: NAIP - params: + image2: + _target_: minerva.datasets.__testing.TstImgDataset + paths: NAIP res: 1.0 mask: transforms: false - module: minerva.datasets.__testing - name: TstMaskDataset + _target_: minerva.datasets.__testing.TstMaskDataset paths: Chesapeake7 - params: - res: 1.0 + res: 1.0 fit-val: - name: StandardEpoch - module: minerva.tasks + _target_: minerva.tasks.StandardEpoch train: false record_float: true @@ -134,74 +121,59 @@ tasks: data_config: '${oc.create:${cfg_load: minerva/inbuilt_cfgs/dataset/Chesapeake7.yaml}}' # yamllint disable-line rule:line-length # ---+ Minerva Inbuilt Logging Functions +------------------------- - # logger: STGLogger - # metrics: SPMetrics - # model_io: sup_tg + task_logger: minerva.logger.tasklog.SupervisedTaskLogger + model_io: minerva.modelio.supervised_torchgeo_io # ---+ Dataset Parameters +---------------------------------------- dataset_params: sampler: - module: torchgeo.samplers - name: RandomGeoSampler + _target_: torchgeo.samplers.RandomGeoSampler roi: false - params: - size: ${patch_size} - length: 32 + size: ${patch_size} + length: 32 image: transforms: false - module: minerva.datasets.__testing - name: TstImgDataset + _target_: minerva.datasets.__testing.TstImgDataset paths: NAIP - params: - res: 1.0 + res: 1.0 mask: transforms: false - module: minerva.datasets.__testing - name: TstMaskDataset + _target_: minerva.datasets.__testing.TstMaskDataset paths: Chesapeake7 - params: - res: 1.0 + res: 1.0 test-test: - name: StandardEpoch - module: minerva.tasks + _target_: minerva.tasks.StandardEpoch record_float: true imagery_config: '${oc.create:${cfg_load: minerva/inbuilt_cfgs/dataset/NAIP.yaml}}' # yamllint disable-line rule:line-length data_config: '${oc.create:${cfg_load: minerva/inbuilt_cfgs/dataset/Chesapeake7.yaml}}' # yamllint disable-line rule:line-length # ---+ Minerva Inbuilt Logging Functions +------------------------- - # logger: STGLogger - # metrics: SPMetrics - # model_io: sup_tg + task_logger: minerva.logger.tasklog.SupervisedTaskLogger + model_io: minerva.modelio.supervised_torchgeo_io # ---+ Dataset Parameters +---------------------------------------- dataset_params: sampler: - module: torchgeo.samplers - name: RandomGeoSampler + _target_: torchgeo.samplers.RandomGeoSampler roi: false - params: - size: ${patch_size} - length: 32 + size: ${patch_size} + length: 32 image: transforms: false - module: minerva.datasets.__testing - name: TstImgDataset + _target_: minerva.datasets.__testing.TstImgDataset paths: NAIP - params: - res: 1.0 + res: 1.0 mask: transforms: false - module: minerva.datasets.__testing - name: TstMaskDataset + _target_: minerva.datasets.__testing.TstMaskDataset paths: Chesapeake7 - params: - res: 1.0 + res: 1.0 # === PLOTTING OPTIONS ======================================================== plots: diff --git a/minerva/logger/steplog.py b/minerva/logger/steplog.py index 20cb0b13e..0b3bc9bde 100644 --- a/minerva/logger/steplog.py +++ b/minerva/logger/steplog.py @@ -38,7 +38,6 @@ "SupervisedStepLogger", "SSLStepLogger", "KNNStepLogger", - "get_logger", ] # ===================================================================================================================== @@ -47,16 +46,7 @@ import abc import math from abc import ABC -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Dict, - Optional, - SupportsFloat, - Tuple, - Union, -) +from typing import TYPE_CHECKING, Any, Callable, Optional, SupportsFloat import mlflow import numpy as np @@ -76,7 +66,7 @@ from wandb.sdk.wandb_run import Run from minerva.utils import utils -from minerva.utils.utils import check_substrings_in_string, func_by_str +from minerva.utils.utils import check_substrings_in_string # ===================================================================================================================== # GLOBALS @@ -133,10 +123,10 @@ def __init__( task_name: str, n_batches: int, batch_size: int, - output_size: Tuple[int, ...], + output_size: tuple[int, ...], record_int: bool = True, record_float: bool = False, - writer: Optional[Union[SummaryWriter, Run]] = None, + writer: Optional[SummaryWriter | Run] = None, model_type: str = "", **kwargs, ) -> None: @@ -157,8 +147,8 @@ def __init__( self.model_type = model_type - self.logs: Dict[str, Any] = {} - self.results: Dict[str, Any] = {} + self.logs: dict[str, Any] = {} + self.results: dict[str, Any] = {} def __call__( self, @@ -186,7 +176,7 @@ def log( loss: Tensor, z: Optional[Tensor] = None, y: Optional[Tensor] = None, - index: Optional[Union[int, BoundingBox]] = None, + index: Optional[int | BoundingBox] = None, *args, **kwargs, ) -> None: @@ -243,7 +233,7 @@ def write_metric( mlflow.log_metric(key, value) # pragma: no cover @property - def get_logs(self) -> Dict[str, Any]: + def get_logs(self) -> dict[str, Any]: """Gets the logs dictionary. Returns: @@ -252,7 +242,7 @@ def get_logs(self) -> Dict[str, Any]: return self.logs @property - def get_results(self) -> Dict[str, Any]: + def get_results(self) -> dict[str, Any]: """Gets the results dictionary. Returns: @@ -309,10 +299,10 @@ def __init__( task_name: str, n_batches: int, batch_size: int, - output_size: Tuple[int, int], + output_size: tuple[int, int], record_int: bool = True, record_float: bool = False, - writer: Optional[Union[SummaryWriter, Run]] = None, + writer: Optional[SummaryWriter | Run] = None, model_type: str = "", n_classes: Optional[int] = None, **kwargs, @@ -330,13 +320,13 @@ def __init__( if n_classes is None: raise ValueError("`n_classes` must be specified for this type of logger!") - self.logs: Dict[str, Any] = { + self.logs: dict[str, Any] = { "batch_num": 0, "total_loss": 0.0, "total_correct": 0.0, } - self.results: Dict[str, Any] = { + self.results: dict[str, Any] = { "y": None, "z": None, "probs": None, @@ -354,7 +344,7 @@ def __init__( # Allocate memory for the integer values to be recorded. if self.record_int: - int_log_shape: Tuple[int, ...] + int_log_shape: tuple[int, ...] if check_substrings_in_string(self.model_type, "scene-classifier"): if check_substrings_in_string(self.model_type, "multilabel"): int_log_shape = (self.n_batches, self.batch_size, n_classes) @@ -375,7 +365,7 @@ def __init__( # Allocate memory for the floating point values to be recorded. if self.record_float: - float_log_shape: Tuple[int, ...] + float_log_shape: tuple[int, ...] if check_substrings_in_string(self.model_type, "scene-classifier"): float_log_shape = (self.n_batches, self.batch_size, n_classes) else: @@ -416,7 +406,7 @@ def log( loss: Tensor, z: Optional[Tensor] = None, y: Optional[Tensor] = None, - index: Optional[Union[int, BoundingBox]] = None, + index: Optional[int | BoundingBox] = None, *args, **kwargs, ) -> None: @@ -543,7 +533,7 @@ def __init__( batch_size: int, record_int: bool = True, record_float: bool = False, - writer: Optional[Union[SummaryWriter, Run]] = None, + writer: Optional[SummaryWriter | Run] = None, model_type: str = "", **kwargs, ) -> None: @@ -558,14 +548,14 @@ def __init__( **kwargs, ) - self.logs: Dict[str, Any] = { + self.logs: dict[str, Any] = { "batch_num": 0, "total_loss": 0.0, "total_correct": 0.0, "total_top5": 0.0, } - self.results: Dict[str, Any] = { + self.results: dict[str, Any] = { "y": None, "z": None, "probs": None, @@ -580,7 +570,7 @@ def log( loss: Tensor, z: Optional[Tensor] = None, y: Optional[Tensor] = None, - index: Optional[Union[int, BoundingBox]] = None, + index: Optional[int | BoundingBox] = None, *args, **kwargs, ) -> None: @@ -663,10 +653,10 @@ def __init__( task_name: str, n_batches: int, batch_size: int, - output_size: Tuple[int, int], + output_size: tuple[int, int], record_int: bool = True, record_float: bool = False, - writer: Optional[Union[SummaryWriter, Run]] = None, + writer: Optional[SummaryWriter | Run] = None, model_type: str = "", **kwargs, ) -> None: @@ -682,7 +672,7 @@ def __init__( **kwargs, ) - self.logs: Dict[str, Any] = { + self.logs: dict[str, Any] = { "batch_num": 0, "total_loss": 0.0, "avg_loss": 0.0, @@ -707,7 +697,7 @@ def log( loss: Tensor, z: Optional[Tensor] = None, y: Optional[Tensor] = None, - index: Optional[Union[int, BoundingBox]] = None, + index: Optional[int | BoundingBox] = None, *args, **kwargs, ) -> None: @@ -798,19 +788,3 @@ def log( # Writes the loss to the writer. self.write_metric("loss", ls, step_num=global_step_num) - - -# ===================================================================================================================== -# METHODS -# ===================================================================================================================== -def get_logger(name) -> Callable[..., Any]: - """Gets the constructor for a step logger to log the results from each step of model fitting during an epoch. - - Returns: - ~typing.Callable[..., ~typing.Any]: The constructor of :class:`~logging.step.log.MinervaStepLogger` - to be intialised within the epoch. - - .. versionadded:: 0.27 - """ - logger: Callable[..., Any] = func_by_str("minerva.logger.steplog", name) - return logger diff --git a/minerva/logger/tasklog.py b/minerva/logger/tasklog.py index 43f4c42dc..f8ad92735 100644 --- a/minerva/logger/tasklog.py +++ b/minerva/logger/tasklog.py @@ -42,13 +42,14 @@ # ===================================================================================================================== import abc from abc import ABC -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Optional if TYPE_CHECKING: # pragma: no cover from torch.utils.tensorboard.writer import SummaryWriter else: # pragma: no cover SummaryWriter = None +import hydra import numpy as np from torch import Tensor from torchgeo.datasets.utils import BoundingBox @@ -56,7 +57,7 @@ from minerva.utils.utils import check_substrings_in_string -from .steplog import MinervaStepLogger, get_logger +from .steplog import MinervaStepLogger # ===================================================================================================================== @@ -85,8 +86,8 @@ class MinervaTaskLogger(ABC): __metaclass__ = abc.ABCMeta - metric_types: List[str] = [] - special_metric_types: List[str] = [] + metric_types: list[str] = [] + special_metric_types: list[str] = [] logger_cls: str def __init__( @@ -94,11 +95,11 @@ def __init__( task_name: str, n_batches: int, batch_size: int, - output_size: Tuple[int, ...], - step_logger_params: Optional[Dict[str, Any]] = None, + output_size: tuple[int, ...], + step_logger_params: Optional[dict[str, Any]] = None, record_int: bool = True, record_float: bool = False, - writer: Optional[Union[SummaryWriter, Run]] = None, + writer: Optional[SummaryWriter | Run] = None, **params, ) -> None: super(MinervaTaskLogger, self).__init__() @@ -119,15 +120,14 @@ def __init__( self.writer = writer - if isinstance(step_logger_params, dict): - self._logger = get_logger(step_logger_params.get("name", self.logger_cls)) - if "params" not in step_logger_params: - step_logger_params["params"] = {} - + if not isinstance(step_logger_params, dict): + step_logger_params = {"_target_": self.logger_cls} + elif "_target_" not in step_logger_params: + step_logger_params["_target_"] = self.logger_cls else: - step_logger_params = {"params": {}} + pass - step_logger_params["params"]["n_classes"] = self.n_classes + step_logger_params["n_classes"] = self.n_classes self.step_logger_params = step_logger_params @@ -137,7 +137,7 @@ def __init__( self.metric_types += self.special_metric_types # Creates a dict to hold the loss and accuracy results from training, validation and testing. - self.metrics: Dict[str, Any] = {} + self.metrics: dict[str, Any] = {} for metric in self.metric_types: self.metrics[f"{self.task_name}_{metric}"] = {"x": [], "y": []} @@ -147,7 +147,8 @@ def _make_logger(self) -> None: .. note:: Will overwrite ``self.logger`` with new logger. """ - self.step_logger: MinervaStepLogger = self._logger( + self.step_logger: MinervaStepLogger = hydra.utils.instantiate( + self.step_logger_params, task_name=self.task_name, n_batches=self.n_batches, batch_size=self.batch_size, @@ -156,7 +157,6 @@ def _make_logger(self) -> None: record_float=self.record_float, writer=self.writer, model_type=self.model_type, - **self.step_logger_params.get("params", {}), ) def refresh_step_logger(self) -> None: @@ -207,7 +207,7 @@ def calc_metrics(self, epoch_no: int) -> None: self.log_epoch_number(epoch_no) @abc.abstractmethod - def _calc_metrics(self, logs: Dict[str, Any]) -> None: + def _calc_metrics(self, logs: dict[str, Any]) -> None: """Updates metrics with epoch results. Must be defined before use. @@ -226,7 +226,7 @@ def log_epoch_number(self, epoch_no: int) -> None: self.metrics[metric]["x"].append(epoch_no) @property - def get_metrics(self) -> Dict[str, Any]: + def get_metrics(self) -> dict[str, Any]: """Get the ``metrics`` dictionary. Returns: @@ -235,7 +235,7 @@ def get_metrics(self) -> Dict[str, Any]: return self.metrics @property - def get_logs(self) -> Dict[str, Any]: + def get_logs(self) -> dict[str, Any]: """Get the logs of each step from the latest epoch of the task. Returns: @@ -246,7 +246,7 @@ def get_logs(self) -> Dict[str, Any]: return self.step_logger.get_logs @property - def get_results(self) -> Dict[str, Any]: + def get_results(self) -> dict[str, Any]: """Get the results of each step from the latest epoch of the task. Returns: @@ -266,8 +266,8 @@ def log_null(self) -> None: self.metrics[metric]["y"].append(np.NAN) def get_sub_metrics( - self, pattern: Tuple[str, ...] = ("train", "val") - ) -> Dict[str, Any]: + self, pattern: tuple[str, ...] = ("train", "val") + ) -> dict[str, Any]: """Gets a subset of the metrics dictionary with keys containing strings in the pattern. Useful for getting the train and validation metrics for plotting for example. @@ -315,19 +315,19 @@ class SupervisedTaskLogger(MinervaTaskLogger): .. versionadded:: 0.27 """ - metric_types: List[str] = ["loss", "acc", "miou"] - logger_cls = "SupervisedStepLogger" + metric_types: list[str] = ["loss", "acc", "miou"] + logger_cls = "minerva.logger.steplog.SupervisedStepLogger" def __init__( self, task_name: str, n_batches: int, batch_size: int, - output_size: Tuple[int, ...], - step_logger_params: Optional[Dict[str, Any]] = None, + output_size: tuple[int, ...], + step_logger_params: Optional[dict[str, Any]] = None, record_int: bool = True, record_float: bool = False, - writer: Optional[Union[SummaryWriter, Run]] = None, + writer: Optional[SummaryWriter | Run] = None, model_type: str = "segmentation", **params, ) -> None: @@ -344,7 +344,7 @@ def __init__( **params, ) - def _calc_metrics(self, logs: Dict[str, Any]) -> None: + def _calc_metrics(self, logs: dict[str, Any]) -> None: """Updates metrics with epoch results. Args: @@ -415,30 +415,34 @@ class SSLTaskLogger(MinervaTaskLogger): metric_types = ["loss", "acc", "top5_acc"] special_metric_types = ["collapse_level", "euc_dist"] - logger_cls = "SSLStepLogger" + logger_cls = "minerva.logger.steplog.SSLStepLogger" def __init__( self, task_name: str, n_batches: int, batch_size: int, - output_size: Tuple[int, ...], - step_logger_params: Optional[Dict[str, Any]] = None, + output_size: tuple[int, ...], + step_logger_params: Optional[dict[str, Any]] = None, record_int: bool = True, record_float: bool = False, - writer: Optional[Union[SummaryWriter, Run]] = None, + writer: Optional[SummaryWriter | Run] = None, model_type: str = "segmentation", sample_pairs: bool = False, **params, ) -> None: if not step_logger_params: - step_logger_params = {} - if "params" not in step_logger_params: - step_logger_params["params"] = {} + step_logger_params = {"_target_": self.logger_cls} - step_logger_params["params"]["sample_pairs"] = sample_pairs - step_logger_params["params"]["collapse_level"] = sample_pairs - step_logger_params["params"]["euclidean"] = sample_pairs + step_logger_params["sample_pairs"] = step_logger_params.get( + "sample_pairs", sample_pairs + ) + step_logger_params["collapse_level"] = step_logger_params.get( + "collapse_level", sample_pairs + ) + step_logger_params["euclidean"] = step_logger_params.get( + "euclidean", sample_pairs + ) super(SSLTaskLogger, self).__init__( task_name, @@ -465,7 +469,7 @@ def __init__( if not getattr(self.step_logger, "euclidean", False): del self.metrics[f"{self.task_name}_euc_dist"] - def _calc_metrics(self, logs: Dict[str, Any]) -> None: + def _calc_metrics(self, logs: dict[str, Any]) -> None: """Updates metrics with epoch results. Args: diff --git a/minerva/loss.py b/minerva/loss.py index 111d28b55..1c7d13162 100644 --- a/minerva/loss.py +++ b/minerva/loss.py @@ -33,11 +33,12 @@ __copyright__ = "Copyright (C) 2024 Harry Baker" __all__ = ["SegBarlowTwinsLoss", "AuxCELoss"] +import importlib + # ===================================================================================================================== # IMPORTS # ===================================================================================================================== from typing import Optional -import importlib import torch from torch import Tensor @@ -91,7 +92,12 @@ class AuxCELoss(Module): Source: https://github.com/xitongpu/PSPNet/blob/main/src/model/cell.py """ - def __init__(self, weight: Optional[Tensor] = None, ignore_index: int = 255, alpha: float = 0.4) -> None: + def __init__( + self, + weight: Optional[Tensor] = None, + ignore_index: int = 255, + alpha: float = 0.4, + ) -> None: super().__init__() self.loss = CrossEntropyLoss(weight=weight, ignore_index=ignore_index) self.alpha = alpha diff --git a/minerva/modelio.py b/minerva/modelio.py index 2fe6ac934..e539ec3e5 100644 --- a/minerva/modelio.py +++ b/minerva/modelio.py @@ -32,14 +32,16 @@ __license__ = "MIT License" __copyright__ = "Copyright (C) 2024 Harry Baker" __all__ = [ - "sup_tg", - "ssl_pair_tg", + "supervised_torchgeo_io", + "change_detection_io", + "autoencoder_io", + "ssl_pair_torchgeo_io", ] # ===================================================================================================================== # IMPORTS # ===================================================================================================================== -from typing import Any, Dict, Optional, Sequence, Tuple, Union +from typing import Any, Optional, Sequence import numpy as np import torch @@ -48,29 +50,33 @@ from torchgeo.datasets.utils import BoundingBox from minerva.models import MinervaModel -from minerva.utils.utils import check_substrings_in_string, mask_to_ohe +from minerva.utils.utils import ( + check_substrings_in_string, + get_sample_index, + mask_to_ohe, +) # ===================================================================================================================== # METHODS # ===================================================================================================================== -def sup_tg( - batch: Dict[Any, Any], +def supervised_torchgeo_io( + batch: dict[Any, Any], model: MinervaModel, device: torch.device, # type: ignore[name-defined] train: bool, **kwargs, -) -> Tuple[ +) -> tuple[ Tensor, - Union[Tensor, Tuple[Tensor, ...]], + Tensor | tuple[Tensor, ...], Tensor, - Optional[Union[Sequence[str], Sequence[BoundingBox]]], + Optional[Sequence[str] | Sequence[int] | Sequence[BoundingBox]], ]: """Provides IO functionality for a supervised model using :mod:`torchgeo` datasets. Args: batch (dict[~typing.Any, ~typing.Any]): Batch of data in a :class:`dict`. - Must have ``"image"``, ``"mask"``/ ``"label"`` and ``"bbox"`` / ``"id"`` keys. + Must have ``"image"``, ``"mask"``/ ``"label"`` and ``"bbox"`` / ``"bounds"`` / ``"id"`` keys. model (MinervaModel): Model being fitted. device (~torch.device): `torch` device object to send data to (e.g. CUDA device). train (bool): True to run a step of the model in training mode. False for eval mode. @@ -94,65 +100,140 @@ def sup_tg( multilabel = True if check_substrings_in_string(model_type, "multilabel") else False # Extracts the x and y batches from the dict. - images: Tensor = batch["image"] - targets: Tensor = batch[target_key] + x: Tensor = batch["image"] + y: Tensor = batch[target_key] # Check that none of the data is NaN or infinity. if kwargs.get("validate_variables", False): - assert not images.isnan().any() - assert not images.isinf().any() - assert not targets.isnan().any() - assert not targets.isinf().any() + assert not x.isnan().any() + assert not x.isinf().any() + assert not y.isnan().any() + assert not y.isinf().any() # Re-arranges the x and y batches. - x_batch: Tensor = images.to(float_dtype) # type: ignore[attr-defined] - y_batch: Tensor + x = x.to(float_dtype) # type: ignore[attr-defined] # Squeeze out axis 1 if only 1 element wide. if target_key == "mask": - targets = targets.squeeze() + y = y.squeeze() - if isinstance(targets, Tensor): - targets = targets.detach().cpu() - y_batch = torch.tensor(targets, dtype=torch.float if multilabel else torch.long) # type: ignore[attr-defined] + if isinstance(y, Tensor): + y = y.detach().cpu() + y = y.to(dtype=torch.float if multilabel else torch.long) # type: ignore[attr-defined] # Transfer to GPU. - x: Tensor = x_batch.to(device) - y: Tensor = y_batch.to(device) + x = x.to(device) + y = y.to(device) # Runs a step of the epoch. loss, z = model.step(x, y, train=train) - # Get the indices of the batch. Either bounding boxes or filenames. - index: Optional[Union[Sequence[str], Sequence[BoundingBox]]] - if "bbox" in batch: - index = batch["bbox"] - elif "id" in batch: - index = batch["id"] - else: - index = None + # Get the indices of the batch. Either bounding boxes, filenames or index number. + index: Optional[Sequence[str] | Sequence[int] | Sequence[BoundingBox]] = ( + get_sample_index(batch) + ) + + return loss, z, y, index + + +def change_detection_io( + batch: dict[Any, Any], + model: MinervaModel, + device: torch.device, # type: ignore[name-defined] + train: bool, + **kwargs, +) -> tuple[ + Tensor, + Tensor | tuple[Tensor, ...], + Tensor, + Optional[Sequence[str] | Sequence[int] | Sequence[BoundingBox]], +]: + """Provides IO functionality for a change_detection model. + + Args: + batch (dict[~typing.Any, ~typing.Any]): Batch of data in a :class:`dict`. + Must have ``"image1"``, ``"image2"``, ``"mask"``/ ``"label"`` and ``"bbox"`` / ``"bounds"`` / ``"id"`` keys. + model (MinervaModel): Model being fitted. + device (~torch.device): `torch` device object to send data to (e.g. CUDA device). + train (bool): True to run a step of the model in training mode. False for eval mode. + + Kwargs: + mix_precision (bool): Use mixed-precision. Will set the floating tensors to 16-bit + rather than the default 32-bit. + target_key (str): Should be either ``"mask"`` or ``"label"``. + + Returns: + tuple[ + ~torch.Tensor, ~torch.Tensor, ~torch.Tensor, ~typing.Sequence[~torchgeo.datasets.utils.BoundingBox] | None + ]: + The ``loss``, the model output ``z``, the ground truth ``y`` supplied and the bounding boxes + of the input images supplied. + """ + float_dtype = _determine_float_dtype(device, kwargs.get("mix_precision", False)) + target_key = kwargs.get("target_key", "mask") + + model_type = kwargs.get("model_type", "") + multilabel = True if check_substrings_in_string(model_type, "multilabel") else False + + # Extracts the x and y batches from the dict. + x1: Tensor = batch["image1"] + x2: Tensor = batch["image2"] + y: Tensor = batch[target_key] + + # Check that none of the data is NaN or infinity. + if kwargs.get("validate_variables", False): + assert not x1.isnan().any() + assert not x1.isinf().any() + assert not x2.isnan().any() + assert not x2.isinf().any() + assert not y.isnan().any() + assert not y.isinf().any() + + # Re-arranges the x and y batches. + x1 = x1.to(float_dtype) # type: ignore[attr-defined] + x2 = x2.to(float_dtype) # type: ignore[attr-defined] + + # Squeeze out axis 1 if only 1 element wide. + if target_key == "mask": + y = y.squeeze() + + if isinstance(y, Tensor): + y = y.detach().cpu() + y = y.to(dtype=torch.float if multilabel else torch.long) # type: ignore[attr-defined] + + # Transfer to GPU. + x = torch.stack([x1, x2]).to(device) + y = y.to(device) + + # Runs a step of the epoch. + loss, z = model.step(x, y, train=train) + + # Get the indices of the batch. Either bounding boxes, filenames or index number. + index: Optional[Sequence[str] | Sequence[int] | Sequence[BoundingBox]] = ( + get_sample_index(batch) + ) return loss, z, y, index def autoencoder_io( - batch: Dict[Any, Any], + batch: dict[Any, Any], model: MinervaModel, device: torch.device, # type: ignore[name-defined] train: bool, **kwargs, -) -> Tuple[ +) -> tuple[ Tensor, - Union[Tensor, Tuple[Tensor, ...]], + Tensor | tuple[Tensor, ...], Tensor, - Optional[Union[Sequence[str], Sequence[BoundingBox]]], + Optional[Sequence[str] | Sequence[int] | Sequence[BoundingBox]], ]: """Provides IO functionality for an autoencoder using :mod:`torchgeo` datasets by only using the same data for input and ground truth. Args: batch (dict[~typing.Any, ~typing.Any]): Batch of data in a :class:`dict`. - Must have ``"image"``, ``"mask"``/ ``"label"`` and ``"bbox"`` / ``"id"`` keys. + Must have ``"image"``, ``"mask"``/ ``"label"`` and ``"bbox"`` / ``"bounds"`` / ``"id"`` keys. model (MinervaModel): Model being fitted. device (~torch.device): `torch` device object to send data to (e.g. CUDA device). train (bool): True to run a step of the model in training mode. False for eval mode. @@ -205,15 +286,13 @@ def autoencoder_io( ) output_masks: LongTensor = masks - if isinstance(input_masks, Tensor): - input_masks = input_masks.detach().cpu().numpy() - - if isinstance(output_masks, Tensor): - output_masks = output_masks.detach().cpu().numpy() - # Transfer to GPU and cast to correct dtypes. - x = torch.tensor(input_masks, dtype=float_dtype, device=device) - y = torch.tensor(output_masks, dtype=torch.long, device=device) + x = torch.tensor( + input_masks.detach().cpu().numpy(), dtype=float_dtype, device=device + ) + y = torch.tensor( + output_masks.detach().cpu().numpy(), dtype=torch.long, device=device + ) elif key == "image": # Extract the images from the batch, set to float, transfer to GPU and make x and y. @@ -228,35 +307,31 @@ def autoencoder_io( # Runs a step of the epoch. loss, z = model.step(x, y, train=train) - # Get the indices of the batch. Either bounding boxes or filenames. - index: Optional[Union[Sequence[str], Sequence[BoundingBox]]] - if "bbox" in batch: - index = batch["bbox"] - elif "id" in batch: - index = batch["id"] - else: - index = None + # Get the indices of the batch. Either bounding boxes, filenames or index number. + index: Optional[Sequence[str] | Sequence[int] | Sequence[BoundingBox]] = ( + get_sample_index(batch) + ) return loss, z, y, index -def ssl_pair_tg( - batch: Tuple[Dict[str, Any], Dict[str, Any]], +def ssl_pair_torchgeo_io( + batch: tuple[dict[str, Any], dict[str, Any]], model: MinervaModel, device: torch.device, # type: ignore[name-defined] train: bool, **kwargs, -) -> Tuple[ +) -> tuple[ Tensor, - Union[Tensor, Tuple[Tensor, ...]], + Tensor | tuple[Tensor, ...], None, - Optional[Union[Sequence[BoundingBox], Sequence[int]]], + Optional[Sequence[BoundingBox] | Sequence[int]], ]: """Provides IO functionality for a self-supervised Siamese model using :mod:`torchgeo` datasets. Args: batch (tuple[dict[str, ~typing.Any], dict[str, ~typing.Any]]): Pair of batches of data in :class:`dict` (s). - Must have ``"image"`` and ``"bbox"`` keys. + Must have ``"image"`` and ``"bbox"`` / ``"bounds"`` / ``"id"`` keys. model (MinervaModel): Model being fitted. device (~torch.device): :mod:`torch` device object to send data to (e.g. ``CUDA`` device). train (bool): True to run a step of the model in training mode. False for eval mode. @@ -273,13 +348,9 @@ def ssl_pair_tg( """ float_dtype = _determine_float_dtype(device, kwargs.get("mix_precision", False)) - # Extracts the x_i batch from the dict. - x_i_batch: Tensor = batch[0]["image"] - x_j_batch: Tensor = batch[1]["image"] - - # Ensures images are floats. - x_i_batch = x_i_batch.to(float_dtype) # type: ignore[attr-defined] - x_j_batch = x_j_batch.to(float_dtype) # type: ignore[attr-defined] + # Extracts both batches from the dict and ensures images are floats. + x_i_batch: Tensor = batch[0]["image"].to(float_dtype) # type: ignore[attr-defined] + x_j_batch: Tensor = batch[1]["image"].to(float_dtype) # type: ignore[attr-defined] if kwargs.get("validate_variables", False): try: @@ -287,11 +358,8 @@ def ssl_pair_tg( except AssertionError: print("WARNING: Batches are the same!") - # Stacks each side of the pair batches together. - x_batch = torch.stack([x_i_batch, x_j_batch]) - - # Transfer to GPU. - x = x_batch.to(device, non_blocking=True) + # Stacks each side of the pair batches together and transfer to GPU. + x = torch.stack([x_i_batch, x_j_batch]).to(device, non_blocking=True) # Check that none of the data is NaN or infinity. if kwargs.get("validate_variables", False): @@ -301,12 +369,12 @@ def ssl_pair_tg( # Runs a step of the epoch. loss, z = model.step(x, train=train) - if "bbox" in batch[0].keys(): - return loss, z, None, batch[0]["bbox"] + batch[1]["bbox"] - elif "id" in batch[0].keys(): - return loss, z, None, batch[0]["id"] + batch[1]["id"] - else: + index_0, index_1 = get_sample_index(batch[0]), get_sample_index(batch[1]) + + if index_0 is None or index_1 is None: return loss, z, None, None + else: + return loss, z, None, index_0 + index_1 def _determine_float_dtype(device: torch.device, mix_precision: bool) -> torch.dtype: diff --git a/minerva/models/__depreciated.py b/minerva/models/__depreciated.py index 2df888790..544a817d5 100644 --- a/minerva/models/__depreciated.py +++ b/minerva/models/__depreciated.py @@ -37,7 +37,7 @@ # IMPORTS # ===================================================================================================================== from collections import OrderedDict -from typing import Any, List, Optional, Sequence, Tuple, Union +from typing import Any, Optional, Sequence import numpy as np import torch.nn.modules as nn @@ -80,7 +80,7 @@ def __init__( criterion: Optional[Any] = None, input_size: int = 288, n_classes: int = 8, - hidden_sizes: Union[Tuple[int, ...], List[int], int] = (256, 144), + hidden_sizes: tuple[int, ...] | list[int] | int = (256, 144), ) -> None: super(MLP, self).__init__( criterion=criterion, input_size=(input_size,), n_classes=n_classes @@ -161,14 +161,14 @@ class CNN(MinervaModel): def __init__( self, criterion, - input_size: Tuple[int, int, int] = (4, 256, 256), + input_size: tuple[int, int, int] = (4, 256, 256), n_classes: int = 8, - features: Union[Tuple[int, ...], List[int]] = (2, 1, 1), - fc_sizes: Union[Tuple[int, ...], List[int]] = (128, 64), - conv_kernel_size: Union[int, Tuple[int, ...]] = 3, - conv_stride: Union[int, Tuple[int, ...]] = 1, - max_kernel_size: Union[int, Tuple[int, ...]] = 2, - max_stride: Union[int, Tuple[int, ...]] = 2, + features: tuple[int, ...] | list[int] = (2, 1, 1), + fc_sizes: tuple[int, ...] | list[int] = (128, 64), + conv_kernel_size: int | tuple[int, ...] = 3, + conv_stride: int | tuple[int, ...] = 1, + max_kernel_size: int | tuple[int, ...] = 2, + max_stride: int | tuple[int, ...] = 2, conv_do: bool = True, fc_do: bool = True, p_conv_do: float = 0.1, diff --git a/minerva/models/change_detector.py b/minerva/models/change_detector.py index 5ece466a5..72d816786 100644 --- a/minerva/models/change_detector.py +++ b/minerva/models/change_detector.py @@ -38,7 +38,7 @@ # IMPORTS # ===================================================================================================================== from pathlib import Path -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Optional import torch from torch import Tensor @@ -57,7 +57,7 @@ class ChangeDetector(MinervaModel): def __init__( self, criterion: Optional[Module] = None, - input_size: Optional[Tuple[int]] = None, + input_size: Optional[tuple[int]] = None, n_classes: int = 1, scaler: Optional[GradScaler] = None, fc_dim: int = 512, @@ -65,8 +65,8 @@ def __init__( encoder_on: bool = False, filter_dim: int = 0, freeze_backbone: bool = False, - backbone_weight_path: Optional[Union[str, Path]] = None, - backbone_args: Dict[str, Any] = {}, + backbone_weight_path: Optional[str | Path] = None, + backbone_args: dict[str, Any] = {}, clamp_outputs: bool = False, ) -> None: super().__init__(criterion, input_size, n_classes, scaler) @@ -128,8 +128,7 @@ def forward(self, x: Tensor) -> Tensor: ~torch.Tensor: Likelihoods the network places on the input ``x`` being of each class. """ - x = torch.squeeze(x) - x_0, x_1 = torch.chunk(x, 2, dim=1) + x_0, x_1 = x[0], x[1] f_0 = self.backbone(x_0) f_1 = self.backbone(x_1) @@ -140,7 +139,6 @@ def forward(self, x: Tensor) -> Tensor: f = torch.cat((f_0, f_1), 1) - # f = f.view(f.size(0), -1) z: Tensor = self.classification_head(f) assert isinstance(z, Tensor) diff --git a/minerva/models/classifiers.py b/minerva/models/classifiers.py index bb3472798..c95eb42fa 100644 --- a/minerva/models/classifiers.py +++ b/minerva/models/classifiers.py @@ -42,7 +42,7 @@ # IMPORTS # ===================================================================================================================== from pathlib import Path -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Optional import torch from torch import Tensor @@ -61,7 +61,7 @@ class FlexiSceneClassifier(MinervaBackbone): def __init__( self, criterion: Optional[Module] = None, - input_size: Optional[Tuple[int]] = None, + input_size: Optional[tuple[int]] = None, n_classes: int = 1, scaler: Optional[GradScaler] = None, fc_dim: int = 512, @@ -69,8 +69,8 @@ def __init__( encoder_on: bool = False, filter_dim: int = 0, freeze_backbone: bool = False, - backbone_weight_path: Optional[Union[str, Path]] = None, - backbone_args: Dict[str, Any] = {}, + backbone_weight_path: Optional[str | Path] = None, + backbone_args: dict[str, Any] = {}, clamp_outputs: bool = False, ) -> None: super().__init__(criterion, input_size, n_classes, scaler) diff --git a/minerva/models/core.py b/minerva/models/core.py index f06513979..493ea94c4 100644 --- a/minerva/models/core.py +++ b/minerva/models/core.py @@ -58,18 +58,7 @@ import warnings from abc import ABC from pathlib import Path -from typing import ( - Any, - Callable, - Iterable, - List, - Optional, - Sequence, - Tuple, - Type, - Union, - overload, -) +from typing import Any, Callable, Iterable, Optional, Sequence, Type, overload import numpy as np import torch @@ -81,8 +70,8 @@ from torch.nn.modules import Module from torch.nn.parallel import DataParallel from torch.nn.parallel import DistributedDataParallel as DDP -from torch.optim import Optimizer from torch.optim.lr_scheduler import LRScheduler +from torch.optim.optimizer import Optimizer from torchvision.models._api import WeightsEnum from minerva.utils.utils import func_by_str @@ -101,7 +90,7 @@ class FilterOutputs(Module): indexes (int | list[int]): Index(es) of the inputs to pass forward. """ - def __init__(self, indexes: Union[int, List[int]]) -> None: + def __init__(self, indexes: int | list[int]) -> None: self.indexes = indexes def forward(self, inputs: Tensor) -> Tensor: @@ -134,7 +123,7 @@ class MinervaModel(Module, ABC): def __init__( self, criterion: Optional[Module] = None, - input_size: Optional[Tuple[int, ...]] = None, + input_size: Optional[tuple[int, ...]] = None, n_classes: Optional[int] = None, scaler: Optional[GradScaler] = None, ) -> None: @@ -148,7 +137,7 @@ def __init__( self.scaler = scaler # Output shape initialised as None. Should be set by calling determine_output_dim. - self.output_shape: Optional[Tuple[int, ...]] = None + self.output_shape: Optional[tuple[int, ...]] = None # Optimiser initialised as None as the model parameters created by its init is required to init a # torch optimiser. The optimiser MUST be set by calling set_optimiser before the model can be trained. @@ -206,19 +195,19 @@ def update_n_classes(self, n_classes: int) -> None: @overload def step( self, x: Tensor, y: Tensor, train: bool = False - ) -> Tuple[Tensor, Union[Tensor, Tuple[Tensor, ...]]]: ... # pragma: no cover + ) -> tuple[Tensor, Tensor | tuple[Tensor, ...]]: ... # pragma: no cover @overload def step( self, x: Tensor, *, train: bool = False - ) -> Tuple[Tensor, Union[Tensor, Tuple[Tensor, ...]]]: ... # pragma: no cover + ) -> tuple[Tensor, Tensor | tuple[Tensor, ...]]: ... # pragma: no cover def step( self, x: Tensor, y: Optional[Tensor] = None, train: bool = False, - ) -> Tuple[Tensor, Union[Tensor, Tuple[Tensor, ...]]]: + ) -> tuple[Tensor, Tensor | tuple[Tensor, ...]]: """Generic step of model fitting using a batch of data. Raises: @@ -248,7 +237,7 @@ def step( if train: self.optimiser.zero_grad() - z: Union[Tensor, Tuple[Tensor, ...]] + z: Tensor | tuple[Tensor, ...] loss: Tensor mix_precision: bool = True if self.scaler else False @@ -300,9 +289,9 @@ class MinervaWrapper(MinervaModel): def __init__( self, - model: Union[Module, Callable[..., Module]], + model: Module | Callable[..., Module], criterion: Optional[Module] = None, - input_size: Optional[Tuple[int, ...]] = None, + input_size: Optional[tuple[int, ...]] = None, n_classes: Optional[int] = None, scaler: Optional[GradScaler] = None, *args, @@ -386,7 +375,7 @@ class MinervaDataParallel(Module): # pragma: no cover def __init__( self, model: Module, - paralleliser: Union[Type[DataParallel], Type[DDP]], # type: ignore[type-arg] + paralleliser: Type[DataParallel] | Type[DDP], # type: ignore[type-arg] *args, **kwargs, ) -> None: @@ -396,7 +385,7 @@ def __init__( self.output_shape = model.output_shape self.n_classes = model.n_classes - def forward(self, *inputs: Tuple[Tensor, ...]) -> Tuple[Tensor, ...]: + def forward(self, *inputs: tuple[Tensor, ...]) -> tuple[Tensor, ...]: """Ensures a forward call to the model goes to the actual wrapped model. Args: @@ -410,7 +399,7 @@ def forward(self, *inputs: Tuple[Tensor, ...]) -> Tuple[Tensor, ...]: assert isinstance(z, tuple) and list(map(type, z)) == [Tensor] * len(z) return z - def __call__(self, *inputs) -> Tuple[Tensor, ...]: + def __call__(self, *inputs) -> tuple[Tensor, ...]: return self.forward(*inputs) def __getattr__(self, name): @@ -515,10 +504,10 @@ def get_torch_weights(weights_name: str) -> Optional[WeightsEnum]: def get_output_shape( model: Module, - image_dim: Union[Sequence[int], int], + image_dim: Sequence[int] | int, sample_pairs: bool = False, change_detection: bool = False, -) -> Tuple[int, ...]: +) -> tuple[int, ...]: """Gets the output shape of a model. Args: @@ -530,7 +519,7 @@ def get_output_shape( Returns: tuple[int, ...]: The shape of the output data from the model. """ - _image_dim: Union[Sequence[int], int] = image_dim + _image_dim: Sequence[int] | int = image_dim try: assert not isinstance(image_dim, int) if len(image_dim) == 1: @@ -542,12 +531,9 @@ def get_output_shape( if not hasattr(_image_dim, "__len__"): assert isinstance(_image_dim, int) random_input = torch.rand([4, _image_dim]) - elif sample_pairs: + elif sample_pairs or change_detection: assert isinstance(_image_dim, Iterable) random_input = torch.rand([2, 4, *_image_dim]) - elif change_detection: - assert isinstance(_image_dim, Iterable) - random_input = torch.rand([4, 2 * _image_dim[0], *_image_dim[1:]]) else: assert isinstance(_image_dim, Iterable) random_input = torch.rand([4, *_image_dim]) @@ -637,7 +623,7 @@ def is_minerva_subtype(model: Module, subtype: type) -> bool: def extract_wrapped_model( - model: Union[MinervaModel, MinervaDataParallel, OptimizedModule] + model: MinervaModel | MinervaDataParallel | OptimizedModule, ) -> MinervaModel: """ Extracts the actual model object from within :class:`MinervaDataParallel` or diff --git a/minerva/models/fcn.py b/minerva/models/fcn.py index d80bb3962..9c32910b3 100644 --- a/minerva/models/fcn.py +++ b/minerva/models/fcn.py @@ -52,7 +52,7 @@ # ===================================================================================================================== # IMPORTS # ===================================================================================================================== -from typing import Any, Dict, Literal, Optional, Sequence, Tuple +from typing import Any, Literal, Optional, Sequence import torch import torch.nn.modules as nn @@ -100,12 +100,12 @@ class FCN(MinervaBackbone): def __init__( self, criterion: Any, - input_size: Tuple[int, ...] = (4, 256, 256), + input_size: tuple[int, ...] = (4, 256, 256), n_classes: int = 8, scaler: Optional[GradScaler] = None, backbone_weight_path: Optional[str] = None, freeze_backbone: bool = False, - backbone_kwargs: Dict[str, Any] = {}, + backbone_kwargs: dict[str, Any] = {}, ) -> None: super(FCN, self).__init__( criterion=criterion, @@ -294,7 +294,7 @@ def __init__( f"Variant {self.variant} does not match known types" ) - def forward(self, x: Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]) -> Tensor: + def forward(self, x: tuple[Tensor, Tensor, Tensor, Tensor, Tensor]) -> Tensor: """Performs a forward pass of the decoder. Depending on DCN variant, will take multiple inputs throughout pass from the encoder. diff --git a/minerva/models/psp.py b/minerva/models/psp.py index e0c1a2276..d76984f0e 100644 --- a/minerva/models/psp.py +++ b/minerva/models/psp.py @@ -38,7 +38,7 @@ # ===================================================================================================================== # IMPORTS # ===================================================================================================================== -from typing import Any, Callable, Dict, Optional, Tuple, Union +from typing import Any, Callable, Optional import segmentation_models_pytorch as smp import torch @@ -106,9 +106,9 @@ def __init__( psp_dropout: float = 0.2, in_channels: Optional[int] = None, n_classes: int = 1, - activation: Optional[Union[str, Callable[..., Any]]] = None, + activation: Optional[str | Callable[..., Any]] = None, upsampling: int = 8, - aux_params: Optional[Dict[str, Any]] = None, + aux_params: Optional[dict[str, Any]] = None, backbone_weight_path: Optional[str] = None, freeze_backbone: bool = False, encoder: bool = True, @@ -150,7 +150,7 @@ def __init__( def make_segmentation_head( self, n_classes: int, - activation: Optional[Union[str, Callable[..., Any]]] = None, + activation: Optional[str | Callable[..., Any]] = None, upsampling: int = 8, ) -> None: self.segmentation_head = SegmentationHead( @@ -164,7 +164,7 @@ def make_segmentation_head( self.encoder_mode = True self.segmentation_on = True - def make_classification_head(self, aux_params: Dict[str, Any]) -> None: + def make_classification_head(self, aux_params: dict[str, Any]) -> None: # Makes the classification head. self.classification_head = ClassificationHead( in_channels=self.encoder.out_channels[-1], **aux_params @@ -198,7 +198,7 @@ def freeze_backbone(self, freeze: bool = True) -> None: """ self.encoder.requires_grad_(False if freeze else True) - def forward(self, x: Tensor) -> Union[Tensor, Tuple[Tensor, ...]]: + def forward(self, x: Tensor) -> Tensor | tuple[Tensor, ...]: f = self.encoder(x) if not self.encoder_mode: @@ -223,7 +223,7 @@ class MinervaPSP(MinervaWrapper): def __init__( self, criterion: Optional[Module] = None, - input_size: Optional[Tuple[int, ...]] = None, + input_size: Optional[tuple[int, ...]] = None, n_classes: int = 1, scaler: Optional[GradScaler] = None, encoder_name: str = "resnet34", @@ -232,9 +232,9 @@ def __init__( psp_out_channels: int = 512, psp_use_batchnorm: bool = True, psp_dropout: float = 0.2, - activation: Optional[Union[str, Callable[..., Any]]] = None, + activation: Optional[str | Callable[..., Any]] = None, upsampling: int = 8, - aux_params: Optional[Dict[str, Any]] = None, + aux_params: Optional[dict[str, Any]] = None, backbone_weight_path: Optional[str] = None, freeze_backbone: bool = False, encoder: bool = False, @@ -315,7 +315,7 @@ def forward(self, x): class PSPUNetDecoder(Module): def __init__( self, - encoder_channels: Tuple[int, int, int, int, int, int], + encoder_channels: tuple[int, int, int, int, int, int], n_classes: int, use_batchnorm: bool = True, dropout: float = 0.2, @@ -467,7 +467,7 @@ class MinervaPSPUNet(MinervaWrapper): def __init__( self, criterion: Optional[Module] = None, - input_size: Optional[Tuple[int, ...]] = None, + input_size: Optional[tuple[int, ...]] = None, n_classes: int = 1, scaler: Optional[GradScaler] = None, encoder_name: str = "resnet34", @@ -476,9 +476,9 @@ def __init__( psp_out_channels: int = 512, psp_use_batchnorm: bool = True, psp_dropout: float = 0.2, - activation: Optional[Union[str, Callable[..., Any]]] = None, + activation: Optional[str | Callable[..., Any]] = None, upsampling: int = 8, - aux_params: Optional[Dict[str, Any]] = None, + aux_params: Optional[dict[str, Any]] = None, backbone_weight_path: Optional[str] = None, freeze_backbone: bool = False, encoder: bool = False, diff --git a/minerva/models/resnet.py b/minerva/models/resnet.py index d9e437c0a..36c67cc49 100644 --- a/minerva/models/resnet.py +++ b/minerva/models/resnet.py @@ -46,7 +46,7 @@ # IMPORTS # ===================================================================================================================== import abc -from typing import Any, Callable, List, Optional, Tuple, Type, Union +from typing import Any, Callable, Optional, Type import torch import torch.nn.modules as nn @@ -122,14 +122,14 @@ class ResNet(MinervaModel): def __init__( self, - block: Type[Union[BasicBlock, Bottleneck]], - layers: Union[List[int], Tuple[int, int, int, int]], + block: Type[BasicBlock | Bottleneck], + layers: list[int] | tuple[int, int, int, int], in_channels: int = 3, n_classes: int = 8, zero_init_residual: bool = False, groups: int = 1, width_per_group: int = 64, - replace_stride_with_dilation: Optional[Tuple[bool, bool, bool]] = None, + replace_stride_with_dilation: Optional[tuple[bool, bool, bool]] = None, norm_layer: Optional[Callable[..., Module]] = None, encoder: bool = False, ) -> None: @@ -224,7 +224,7 @@ def __init__( def _make_layer( self, - block: Type[Union[BasicBlock, Bottleneck]], + block: Type[BasicBlock | Bottleneck], planes: int, blocks: int, stride: int = 1, @@ -296,7 +296,7 @@ def _make_layer( def _forward_impl( self, x: Tensor - ) -> Union[Tensor, Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]]: + ) -> Tensor | tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: x = self.conv1(x) x = self.bn1(x) x = self.relu(x) @@ -320,7 +320,7 @@ def _forward_impl( def forward( self, x: Tensor - ) -> Union[Tensor, Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]]: + ) -> Tensor | tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: """Performs a forward pass of the :class:`ResNet`. Can be called directly as a method (e.g. ``model.forward``) or when data is parsed @@ -390,19 +390,19 @@ class ResNetX(MinervaModel): """ __metaclass__ = abc.ABCMeta - block_type: Union[Type[BasicBlock], Type[Bottleneck]] = BasicBlock - layer_struct: List[int] = [2, 2, 2, 2] + block_type: Type[BasicBlock] | Type[Bottleneck] = BasicBlock + layer_struct: list[int] = [2, 2, 2, 2] weights_name = "ResNet18_Weights.IMAGENET1K_V1" def __init__( self, criterion: Optional[Any] = None, - input_size: Tuple[int, int, int] = (4, 256, 256), + input_size: tuple[int, int, int] = (4, 256, 256), n_classes: int = 8, zero_init_residual: bool = False, groups: int = 1, width_per_group: int = 64, - replace_stride_with_dilation: Optional[Tuple[bool, bool, bool]] = None, + replace_stride_with_dilation: Optional[tuple[bool, bool, bool]] = None, norm_layer: Optional[Callable[..., Module]] = None, encoder: bool = False, torch_weights: bool = False, @@ -431,7 +431,7 @@ def __init__( def forward( self, x: Tensor - ) -> Union[Tensor, Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]]: + ) -> Tensor | tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: """Performs a forward pass of the :class:`ResNet`. Can be called directly as a method (e.g. :func:`model.forward`) or when data is parsed @@ -445,9 +445,7 @@ def forward( initialised as an encoder, returns a tuple of outputs from each ``layer`` 1-4. Else, returns :class:`~torch.Tensor` of the likelihoods the network places on the input ``x`` being of each class. """ - z: Union[Tensor, Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]] = self.network( - x - ) + z: Tensor | tuple[Tensor, Tensor, Tensor, Tensor, Tensor] = self.network(x) if isinstance(z, Tensor): return z elif isinstance(z, tuple): @@ -460,7 +458,7 @@ class ResNet18(ResNetX): by stripping classification layers away. """ - layer_struct: List[int] = [2, 2, 2, 2] + layer_struct: list[int] = [2, 2, 2, 2] weights_name = "ResNet18_Weights.IMAGENET1K_V1" @@ -469,7 +467,7 @@ class ResNet34(ResNetX): by stripping classification layers away. """ - layer_struct: List[int] = [3, 4, 6, 3] + layer_struct: list[int] = [3, 4, 6, 3] weights_name = "ResNet34_Weights.IMAGENET1K_V1" @@ -508,8 +506,8 @@ class ResNet152(ResNetX): # ===================================================================================================================== def _preload_weights( resnet: ResNet, - weights: Optional[Union[WeightsEnum, Any]], - input_shape: Tuple[int, int, int], + weights: Optional[WeightsEnum | Any], + input_shape: tuple[int, int, int], encoder_on: bool, ) -> ResNet: # pragma: no cover if not weights: diff --git a/minerva/models/siamese.py b/minerva/models/siamese.py index 62d6564cf..7e8b5fd9b 100644 --- a/minerva/models/siamese.py +++ b/minerva/models/siamese.py @@ -53,7 +53,7 @@ # IMPORTS # ===================================================================================================================== import abc -from typing import Any, Dict, Optional, Sequence, Tuple +from typing import Any, Optional, Sequence import numpy as np import torch @@ -86,7 +86,7 @@ def __init__(self, *args, **kwargs) -> None: self.backbone: MinervaModel self.proj_head: Module - def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: + def forward(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: """Performs a forward pass of the network by using the forward methods of the backbone and feeding its output into the projection heads. @@ -106,7 +106,7 @@ def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: """ return self.forward_pair(x) - def forward_pair(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: + def forward_pair(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: """Performs a forward pass of the network by using the forward methods of the backbone and feeding its output into the projection heads. @@ -131,7 +131,7 @@ def forward_pair(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tenso return g, g_a, g_b, f_a, f_b @abc.abstractmethod - def forward_single(self, x: Tensor) -> Tuple[Tensor, Tensor]: + def forward_single(self, x: Tensor) -> tuple[Tensor, Tensor]: """Performs a forward pass of a single head of the network by using the forward methods of the backbone and feeding its output into the projection heads. @@ -171,10 +171,10 @@ class SimCLR(MinervaSiamese): def __init__( self, criterion: Any, - input_size: Tuple[int, int, int] = (4, 256, 256), + input_size: tuple[int, int, int] = (4, 256, 256), feature_dim: int = 128, scaler: Optional[GradScaler] = None, - backbone_kwargs: Dict[str, Any] = {}, + backbone_kwargs: dict[str, Any] = {}, ) -> None: super(SimCLR, self).__init__( criterion=criterion, input_size=input_size, scaler=scaler @@ -196,7 +196,7 @@ def __init__( nn.Linear(512, feature_dim, bias=False), ) - def forward_single(self, x: Tensor) -> Tuple[Tensor, Tensor]: + def forward_single(self, x: Tensor) -> tuple[Tensor, Tensor]: """Performs a forward pass of a single head of the network by using the forward methods of the :attr:`~SimCLR.backbone` and feeding its output into the :attr:`~SimCLR.proj_head`. @@ -214,7 +214,7 @@ def forward_single(self, x: Tensor) -> Tuple[Tensor, Tensor]: return g, f - def step(self, x: Tensor, *args, train: bool = False) -> Tuple[Tensor, Tensor]: + def step(self, x: Tensor, *args, train: bool = False) -> tuple[Tensor, Tensor]: """Overwrites :class:`~models.core.MinervaModel` to account for paired logits. Raises: @@ -318,11 +318,11 @@ class SimSiam(MinervaSiamese): def __init__( self, criterion: Any, - input_size: Tuple[int, int, int] = (4, 256, 256), + input_size: tuple[int, int, int] = (4, 256, 256), feature_dim: int = 128, pred_dim: int = 512, scaler: Optional[GradScaler] = None, - backbone_kwargs: Dict[str, Any] = {}, + backbone_kwargs: dict[str, Any] = {}, ) -> None: super(SimSiam, self).__init__( criterion=criterion, input_size=input_size, scaler=scaler @@ -359,7 +359,7 @@ def __init__( nn.Linear(pred_dim, feature_dim), ) # output layer - def forward_single(self, x: Tensor) -> Tuple[Tensor, Tensor]: + def forward_single(self, x: Tensor) -> tuple[Tensor, Tensor]: """Performs a forward pass of a single head of :class:`SimSiam` by using the forward methods of the backbone and feeding its output into the :attr:`~SimSiam.proj_head`. @@ -376,7 +376,7 @@ def forward_single(self, x: Tensor) -> Tuple[Tensor, Tensor]: return p, z.detach() - def step(self, x: Tensor, *args, train: bool = False) -> Tuple[Tensor, Tensor]: + def step(self, x: Tensor, *args, train: bool = False) -> tuple[Tensor, Tensor]: """Overwrites :class:`~models.core.MinervaModel` to account for paired logits. Raises: @@ -480,12 +480,12 @@ class SimConv(MinervaSiamese): def __init__( self, criterion: Any, - input_size: Tuple[int, int, int] = (4, 256, 256), + input_size: tuple[int, int, int] = (4, 256, 256), feature_dim: int = 2048, projection_dim: int = 512, scaler: Optional[GradScaler] = None, encoder_weights: Optional[str] = None, - backbone_kwargs: Dict[str, Any] = {}, + backbone_kwargs: dict[str, Any] = {}, ) -> None: super(SimConv, self).__init__( criterion=criterion, input_size=input_size, scaler=scaler @@ -518,7 +518,7 @@ def __init__( nn.UpsamplingBilinear2d(scale_factor=4), ) - def forward_single(self, x: Tensor) -> Tuple[Tensor, Tensor]: + def forward_single(self, x: Tensor) -> tuple[Tensor, Tensor]: """Performs a forward pass of a single head of the network by using the forward methods of the :attr:`~SimCLR.backbone` and feeding its output into the :attr:`~SimCLR.proj_head`. @@ -536,7 +536,7 @@ def forward_single(self, x: Tensor) -> Tuple[Tensor, Tensor]: return g, f - def step(self, x: Tensor, *args, train: bool = False) -> Tuple[Tensor, Tensor]: + def step(self, x: Tensor, *args, train: bool = False) -> tuple[Tensor, Tensor]: """Overwrites :class:`~models.core.MinervaModel` to account for paired logits. Raises: diff --git a/minerva/models/unet.py b/minerva/models/unet.py index 200c514ed..b1fb0f641 100644 --- a/minerva/models/unet.py +++ b/minerva/models/unet.py @@ -51,7 +51,7 @@ # IMPORTS # ===================================================================================================================== import abc -from typing import Any, Dict, Optional, Sequence, Tuple +from typing import Any, Optional, Sequence import torch import torch.nn.functional as F @@ -267,7 +267,7 @@ class UNet(MinervaModel): def __init__( self, criterion: Any, - input_size: Tuple[int, ...] = (4, 256, 256), + input_size: tuple[int, ...] = (4, 256, 256), n_classes: int = 8, bilinear: bool = False, scaler: Optional[GradScaler] = None, @@ -356,13 +356,13 @@ class UNetR(MinervaModel): def __init__( self, criterion: Any, - input_size: Tuple[int, ...] = (4, 256, 256), + input_size: tuple[int, ...] = (4, 256, 256), n_classes: int = 8, bilinear: bool = False, scaler: Optional[GradScaler] = None, backbone_weight_path: Optional[str] = None, freeze_backbone: bool = False, - backbone_kwargs: Dict[str, Any] = {}, + backbone_kwargs: dict[str, Any] = {}, ) -> None: super(UNetR, self).__init__( criterion=criterion, diff --git a/minerva/optimisers.py b/minerva/optimisers.py index 90bf4cc7b..5a5a10f9f 100644 --- a/minerva/optimisers.py +++ b/minerva/optimisers.py @@ -31,7 +31,7 @@ # ===================================================================================================================== # IMPORTS # ===================================================================================================================== -from typing import Any, Callable, Dict, Iterable, Optional, Union +from typing import Any, Callable, Iterable, Optional import torch from torch.optim.optimizer import Optimizer @@ -67,7 +67,7 @@ class LARS(Optimizer): def __init__( self, - params: Union[Iterable[Any], Dict[Any, Any]], + params: Iterable[Any] | dict[Any, Any], lr: float, momentum: float = 0.9, weight_decay: float = 0.0005, diff --git a/minerva/pytorchtools.py b/minerva/pytorchtools.py index 538f8b669..1cb61c848 100644 --- a/minerva/pytorchtools.py +++ b/minerva/pytorchtools.py @@ -35,7 +35,7 @@ # IMPORTS # ===================================================================================================================== from pathlib import Path -from typing import Callable, Optional, Union +from typing import Callable, Optional import numpy as np import torch @@ -81,7 +81,7 @@ def __init__( patience: int = 7, verbose: bool = False, delta: float = 0.0, - path: Union[str, Path] = "checkpoint.pt", + path: str | Path = "checkpoint.pt", trace_func: Callable[..., None] = print, external_save: bool = False, ): @@ -92,7 +92,7 @@ def __init__( self.early_stop: bool = False self.val_loss_min: float = np.Inf self.delta: float = delta - self.path: Union[str, Path] = path + self.path: str | Path = path self.trace_func: Callable[..., None] = trace_func self.external_save: bool = external_save self.save_model: bool = False diff --git a/minerva/samplers.py b/minerva/samplers.py index 33bce7b3e..9731961e3 100644 --- a/minerva/samplers.py +++ b/minerva/samplers.py @@ -41,22 +41,26 @@ "DistributedSamplerWrapper", "get_greater_bbox", "get_pair_bboxes", + "get_sampler", ] # ===================================================================================================================== # IMPORTS # ===================================================================================================================== import random +import re from operator import itemgetter -from typing import Any, Iterator, List, Optional, Sequence, Tuple, Union +from typing import Any, Iterator, Optional, Sequence +import hydra import torch from torch.utils.data import Dataset, DistributedSampler, Sampler -from torchgeo.datasets import GeoDataset +from torchgeo.datasets import GeoDataset, NonGeoDataset from torchgeo.datasets.utils import BoundingBox from torchgeo.samplers import BatchGeoSampler, RandomGeoSampler, Units from torchgeo.samplers.utils import _to_tuple, get_random_bounding_box +from minerva.datasets.utils import make_bounding_box from minerva.utils import utils @@ -90,7 +94,7 @@ class RandomPairGeoSampler(RandomGeoSampler): def __init__( self, dataset: GeoDataset, - size: Union[Tuple[float, float], float], + size: tuple[float, float] | float, length: int, roi: Optional[BoundingBox] = None, units: Units = Units.PIXELS, @@ -99,7 +103,7 @@ def __init__( super().__init__(dataset, size, length, roi, units) self.max_r = max_r - def __iter__(self) -> Iterator[Tuple[BoundingBox, BoundingBox]]: # type: ignore[override] + def __iter__(self) -> Iterator[tuple[BoundingBox, BoundingBox]]: # type: ignore[override] """Return a pair of :class:`~torchgeo.datasets.utils.BoundingBox` indices of a dataset that are geospatially close. @@ -155,7 +159,7 @@ class RandomPairBatchGeoSampler(BatchGeoSampler): def __init__( self, dataset: GeoDataset, - size: Union[Tuple[float, float], float], + size: tuple[float, float] | float, batch_size: int, length: int, roi: Optional[BoundingBox] = None, @@ -176,7 +180,7 @@ def __init__( else: raise ValueError(f"{tiles_per_batch=} is not a multiple of {batch_size=}") - def __iter__(self) -> Iterator[List[Tuple[BoundingBox, BoundingBox]]]: # type: ignore[override] + def __iter__(self) -> Iterator[list[tuple[BoundingBox, BoundingBox]]]: # type: ignore[override] """Return the indices of a dataset. Returns: @@ -208,7 +212,9 @@ def __len__(self) -> int: def get_greater_bbox( - bbox: BoundingBox, r: float, size: Union[float, int, Sequence[float]] + bbox: BoundingBox, + r: float, + size: float | int | Sequence[float], ) -> BoundingBox: """Return a bounding box at ``r`` distance around the first box. @@ -246,10 +252,10 @@ def get_greater_bbox( def get_pair_bboxes( bounds: BoundingBox, - size: Union[Tuple[float, float], float], + size: tuple[float, float] | float, res: float, max_r: float, -) -> Tuple[BoundingBox, BoundingBox]: +) -> tuple[BoundingBox, BoundingBox]: """Samples a pair of bounding boxes geo-spatially close to each other. Args: @@ -345,7 +351,7 @@ class DatasetFromSampler(Dataset): # type: ignore[type-arg] def __init__(self, sampler: Sampler[Any]): """Initialisation for :class:`DatasetFromSampler`.""" self.sampler = sampler - self.sampler_list: Optional[List[Sampler[Any]]] = None + self.sampler_list: Optional[list[Sampler[Any]]] = None def __getitem__(self, index: int) -> Any: """Gets element of the dataset. @@ -366,3 +372,43 @@ def __len__(self) -> int: int: Length of the dataset """ return len(self.sampler) # type: ignore[arg-type] + + +# ===================================================================================================================== +# METHODS +# ===================================================================================================================== +def get_sampler( + params: dict[str, Any], + dataset: GeoDataset | NonGeoDataset, + batch_size: Optional[int] = None, +) -> Sampler[Any]: + """Use :meth:`hydra.utils.instantiate` to get the sampler using config parameters. + + Args: + params (dict[str, ~typing.Any]): Sampler parameters. Must include the ``_target_`` key pointing to + the sampler class. + dataset (~torchgeo.datasets.GeoDataset, ~torchgeo.datasets.NonGeoDataset]): Dataset to sample. + batch_size (int): Optional; Batch size to sample if using a batch sampler. + Use if you need to overwrite the config parameters due to distributed computing as the batch size + needs to modified to split the batch across devices. Defaults to ``None``. + + Returns: + ~torch.utils.data.Sampler: Sampler requested by config parameters. + """ + + batch_sampler = True if re.search(r"Batch", params["_target_"]) else False + if batch_sampler and batch_size is not None: + params["batch_size"] = batch_size + + if "roi" in params: + sampler = hydra.utils.instantiate( + params, dataset=dataset, roi=make_bounding_box(params["roi"]) + ) + else: + if "torchgeo" in params["_target_"]: + sampler = hydra.utils.instantiate(params, dataset=dataset) + else: + sampler = hydra.utils.instantiate(params, data_source=dataset) + + assert isinstance(sampler, Sampler) + return sampler diff --git a/minerva/scheduler.py b/minerva/scheduler.py index ff8b33684..c49f0fe9a 100644 --- a/minerva/scheduler.py +++ b/minerva/scheduler.py @@ -37,7 +37,6 @@ # IMPORTS # ===================================================================================================================== import warnings -from typing import List import numpy as np from torch.optim.lr_scheduler import LRScheduler @@ -81,7 +80,7 @@ def __init__( self.n_periods = n_periods super().__init__(optimizer, last_epoch, verbose) - def get_lr(self) -> List[float]: # type: ignore[override] + def get_lr(self) -> list[float]: # type: ignore[override] if not hasattr(self, "_get_lr_called_within_step"): warnings.warn( "To get the last learning rate computed by the scheduler, " diff --git a/minerva/tasks/core.py b/minerva/tasks/core.py index 67faac873..fcd6d84d2 100644 --- a/minerva/tasks/core.py +++ b/minerva/tasks/core.py @@ -41,7 +41,7 @@ import abc from abc import ABC from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Sequence, Union +from typing import TYPE_CHECKING, Any, Callable, Optional, Sequence if TYPE_CHECKING: # pragma: no cover from torch.utils.tensorboard.writer import SummaryWriter @@ -66,7 +66,7 @@ wrap_model, ) from minerva.utils import utils, visutils -from minerva.utils.utils import fallback_params, func_by_str +from minerva.utils.utils import fallback_params # ===================================================================================================================== @@ -158,19 +158,19 @@ class MinervaTask(ABC): """ logger_cls: str = "SupervisedTaskLogger" - model_io_name: str = "sup_tg" + model_io_name: str = "minerva.modelio.supervised_torchgeo_io" def __init__( self, name: str, - model: Union[MinervaModel, MinervaDataParallel, OptimizedModule], + model: MinervaModel | MinervaDataParallel | OptimizedModule, device: torch.device, exp_fn: Path, gpu: int = 0, rank: int = 0, world_size: int = 1, - writer: Optional[Union[SummaryWriter, Run]] = None, - backbone_weight_path: Optional[Union[str, Path]] = None, + writer: Optional[SummaryWriter | Run] = None, + backbone_weight_path: Optional[str | Path] = None, record_int: bool = True, record_float: bool = False, train: bool = False, @@ -181,9 +181,12 @@ def __init__( self.model = model # Gets the datasets, number of batches, class distribution and the modfied parameters for the experiment. - loaders, n_batches, class_dist, task_params, = make_loaders( - rank, world_size, task_name=name, **global_params - ) + ( + loaders, + n_batches, + class_dist, + task_params, + ) = make_loaders(rank, world_size, task_name=name, **global_params) # If there are multiple modes and therefore number of batches, just take the value of the first one. if isinstance(n_batches, dict): @@ -226,6 +229,9 @@ def __init__( self.sample_pairs = fallback_params( "sample_pairs", self.params, self.global_params ) + self.change_detection = fallback_params( + "change_detection", self.params, self.global_params + ) self.n_classes = fallback_params("n_classes", self.params, self.global_params) @@ -259,7 +265,7 @@ def __init__( self.model.to(self.device) # To eliminate classes, we're going to have to do a fair bit of rebuilding of the model... - if self.elim: + if self.elim and self.train: # Update the stored number of classes within the model and # then rebuild the classification layers that are dependent on the number of classes. self.model.update_n_classes(self.n_classes) @@ -319,18 +325,24 @@ def make_criterion(self) -> Any: _weights = Tensor(weights) # Use hydra to instantiate the loss function with the weights and return. - return hydra.utils.instantiate(fallback_params("loss_params", self.params, self.global_params), weight=_weights) + return hydra.utils.instantiate( + fallback_params("loss_params", self.params, self.global_params), + weight=_weights, + ) else: # Use hydra to instantiate the loss function based of the config, without weights. - return hydra.utils.instantiate(fallback_params("loss_params", self.params, self.global_params)) + return hydra.utils.instantiate( + fallback_params("loss_params", self.params, self.global_params) + ) def make_optimiser(self) -> None: """Creates a :mod:`torch` optimiser based on config parameters and sets optimiser.""" # Constructs and sets the optimiser for the model based on supplied config parameters. optimiser = hydra.utils.instantiate( - fallback_params("optimiser", self.params, self.global_params), params=self.model.parameters() + fallback_params("optimiser", self.params, self.global_params), + params=self.model.parameters(), ) self.model.set_optimiser(optimiser) @@ -350,8 +362,7 @@ def make_logger(self) -> MinervaTaskLogger: """ # Gets constructor of the metric logger from name in the config. - _logger_cls = func_by_str( - "minerva.logger.tasklog", + _logger_cls = hydra.utils.get_class( utils.fallback_params( "task_logger", self.params, self.global_params, self.logger_cls ), @@ -382,8 +393,7 @@ def get_io_func(self) -> Callable[..., Any]: Returns: ~typing.Callable[..., ~typing.Any]: Model IO function requested from parameters. """ - io_func: Callable[..., Any] = func_by_str( - "minerva.modelio", + io_func: Callable[..., Any] = hydra.utils.get_method( utils.fallback_params( "model_io", self.params, self.global_params, self.model_io_name ), @@ -394,7 +404,7 @@ def get_io_func(self) -> Callable[..., Any]: def step(self) -> None: # pragma: no cover raise NotImplementedError - def _generic_step(self, epoch_no: int) -> Optional[Dict[str, Any]]: + def _generic_step(self, epoch_no: int) -> Optional[dict[str, Any]]: self.local_step_num = 0 self.step() @@ -421,11 +431,11 @@ def __call__(self, epoch_no: int) -> Any: return self._generic_step(epoch_no) @property - def get_logs(self) -> Dict[str, Any]: + def get_logs(self) -> dict[str, Any]: return self.logger.get_logs @property - def get_metrics(self) -> Dict[str, Any]: + def get_metrics(self) -> dict[str, Any]: return self.logger.get_metrics def log_null(self, epoch_no: int) -> None: @@ -445,8 +455,8 @@ def print_epoch_results(self, epoch_no: int) -> None: def plot( self, - results: Dict[str, Any], - metrics: Optional[Dict[str, Any]] = None, + results: dict[str, Any], + metrics: Optional[dict[str, Any]] = None, save: bool = True, show: bool = False, ) -> None: @@ -491,7 +501,9 @@ def plot( colours=utils.fallback_params("colours", self.params, self.global_params), save=save, show=show, - model_name=utils.fallback_params("model_name", self.params, self.global_params), + model_name=utils.fallback_params( + "model_name", self.params, self.global_params + ), timestamp=self.global_params["timestamp"], results_dir=self.task_dir, task_cfg=self.params, @@ -545,7 +557,7 @@ def __repr__(self) -> str: # ===================================================================================================================== # METHODS # ===================================================================================================================== -def get_task(task_name: str, task_module: str = "minerva.tasks", *args, **params) -> MinervaTask: +def get_task(task_name: str, *args, **params) -> MinervaTask: """Get the requested :class:`MinervaTask` by name. Args: @@ -555,7 +567,7 @@ def get_task(task_name: str, task_module: str = "minerva.tasks", *args, **params Returns: MinervaTask: Constructed :class:`MinervaTask` object. """ - _task = func_by_str(task_module, task_name) + _task = hydra.utils.get_class(task_name) task = _task(*args, **params) assert isinstance(task, MinervaTask) diff --git a/minerva/tasks/epoch.py b/minerva/tasks/epoch.py index 19363a8b3..98e64fb33 100644 --- a/minerva/tasks/epoch.py +++ b/minerva/tasks/epoch.py @@ -60,7 +60,7 @@ class StandardEpoch(MinervaTask): .. versionadded:: 0.27 """ - logger_cls = "SupervisedTaskLogger" + logger_cls = "minerva.logger.tasklog.SupervisedTaskLogger" def step(self) -> None: # Initialises a progress bar for the epoch. @@ -72,7 +72,7 @@ def step(self) -> None: self.model.eval() # Ensure gradients will not be calculated if this is not a training task. - with torch.no_grad() if not self.train else nullcontext(): + with torch.no_grad() if not self.train else nullcontext(): # type: ignore[attr-defined] # Core of the epoch. for batch in self.loaders: @@ -85,7 +85,7 @@ def step(self) -> None: ) if self.local_step_num % self.log_rate == 0: - if dist.is_available() and dist.is_initialized(): # type: ignore[attr-defined] # pragma: no cover + if dist.is_available() and dist.is_initialized(): # type: ignore[attr-defined] # pragma: no cover # noqa: E501 loss = results[0].data.clone() dist.all_reduce(loss.div_(dist.get_world_size())) # type: ignore[attr-defined] results = (loss, *results[1:]) diff --git a/minerva/tasks/knn.py b/minerva/tasks/knn.py index 984faab7f..9c650a139 100644 --- a/minerva/tasks/knn.py +++ b/minerva/tasks/knn.py @@ -41,7 +41,7 @@ # ===================================================================================================================== # IMPORTS # ===================================================================================================================== -from typing import TYPE_CHECKING, Optional, Tuple, Union +from typing import TYPE_CHECKING, Optional import torch import torch.distributed as dist @@ -137,19 +137,19 @@ class WeightedKNN(MinervaTask): .. versionadded:: 0.27 """ - logger_cls = "SSLTaskLogger" + logger_cls = "minerva.logger.tasklog.SSLTaskLogger" def __init__( self, name: str, - model: Union[MinervaModel, MinervaDataParallel], + model: MinervaModel | MinervaDataParallel, device: torch.device, exp_fn: Path, gpu: int = 0, rank: int = 0, world_size: int = 1, - writer: Optional[Union[SummaryWriter, Run]] = None, - backbone_weight_path: Optional[Union[str, Path]] = None, + writer: Optional[SummaryWriter | Run] = None, + backbone_weight_path: Optional[str | Path] = None, record_int: bool = True, record_float: bool = False, k: int = 5, @@ -174,10 +174,12 @@ def __init__( self.temp = temp self.k = k - def generate_feature_bank(self) -> Tuple[Tensor, Tensor]: + def generate_feature_bank(self) -> tuple[Tensor, Tensor]: feature_list = [] target_list = [] + assert isinstance(self.loaders, dict) + for batch in tqdm(self.loaders["features"]): val_data: Tensor = batch["image"].to(self.device, non_blocking=True) val_target: Tensor = batch["mask"].to(self.device, non_blocking=True) @@ -231,6 +233,8 @@ def step(self) -> None: # Generate feature bank and target bank. feature_bank, feature_labels = self.generate_feature_bank() + assert isinstance(self.loaders, dict) + # Loop test data to predict the label by weighted KNN search. for batch in tqdm(self.loaders["test"]): test_data: Tensor = batch["image"].to(self.device, non_blocking=True) diff --git a/minerva/tasks/tsne.py b/minerva/tasks/tsne.py index 38bd37f26..4db2bd922 100644 --- a/minerva/tasks/tsne.py +++ b/minerva/tasks/tsne.py @@ -38,7 +38,7 @@ # IMPORTS # ===================================================================================================================== from pathlib import Path -from typing import Optional, Union +from typing import Any, Optional import torch from torch import Tensor @@ -46,6 +46,7 @@ from wandb.sdk.wandb_run import Run from minerva.models import MinervaDataParallel, MinervaModel +from minerva.utils.utils import get_sample_index from minerva.utils.visutils import plot_embedding from .core import MinervaTask @@ -64,14 +65,14 @@ class TSNEVis(MinervaTask): def __init__( self, name: str, - model: Union[MinervaModel, MinervaDataParallel], + model: MinervaModel | MinervaDataParallel, device: torch.device, exp_fn: Path, gpu: int = 0, rank: int = 0, world_size: int = 1, - writer: Union[SummaryWriter, Run, None] = None, - backbone_weight_path: Optional[Union[str, Path]] = None, + writer: Optional[SummaryWriter | Run] = None, + backbone_weight_path: Optional[str | Path] = None, record_int: bool = True, record_float: bool = False, **params, @@ -79,7 +80,7 @@ def __init__( backbone = model.get_backbone() # type: ignore[assignment, operator] # Set dummy optimiser. It won't be used as this is a test. - backbone.set_optimiser(torch.optim.SGD(backbone.parameters(), lr=1.0e-3)) + backbone.set_optimiser(torch.optim.SGD(backbone.parameters(), lr=1.0e-3)) # type: ignore[attr-defined] super().__init__( name, @@ -103,7 +104,8 @@ def step(self) -> None: Passes these embeddings to :mod:`visutils` to train a TSNE algorithm and then visual the cluster. """ # Get a batch of data. - data = next(iter(self.loaders)) + assert isinstance(self.loaders, torch.utils.data.DataLoader) + data: dict[str, Any] = next(iter(self.loaders)) # Make sure the model is in evaluation mode. self.model.eval() @@ -118,7 +120,7 @@ def step(self) -> None: plot_embedding( embeddings.detach().cpu(), - data["bbox"], + get_sample_index(data), # type: ignore[arg-type] self.global_params["data_root"], self.params["dataset_params"], show=True, diff --git a/minerva/trainer.py b/minerva/trainer.py index 0f35b99fa..88a9cf470 100644 --- a/minerva/trainer.py +++ b/minerva/trainer.py @@ -39,11 +39,12 @@ # IMPORTS # ===================================================================================================================== import os +import re import warnings from copy import deepcopy from pathlib import Path from platform import python_version -from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, Optional import hydra import packaging @@ -66,7 +67,6 @@ MinervaDataParallel, MinervaModel, MinervaOnnxModel, - MinervaWrapper, extract_wrapped_model, wrap_model, ) @@ -237,7 +237,7 @@ def __init__( rank: int = 0, world_size: int = 1, verbose: bool = True, - wandb_run: Optional[Union[Run, RunDisabled]] = None, + wandb_run: Optional[Run | RunDisabled] = None, **params, ) -> None: assert not isinstance(wandb_run, RunDisabled) @@ -264,7 +264,7 @@ def __init__( utils.print_config(params) # type: ignore[arg-type] # Now that we have pretty printed the config, it is easier to handle as a dict. - self.params: Dict[str, Any] = OmegaConf.to_object(params) # type: ignore[assignment] + self.params: dict[str, Any] = OmegaConf.to_object(params) # type: ignore[assignment] assert isinstance(self.params, dict) # Set variables for checkpointing the experiment or loading from a previous checkpoint. @@ -324,7 +324,7 @@ def __init__( # Makes a directory for this experiment. utils.mkexpdir(self.params["exp_name"]) - self.writer: Optional[Union[SummaryWriter, Run]] = None + self.writer: Optional[SummaryWriter | Run] = None if self.params.get("wandb_log", False): # Sets the `wandb` run object (or None). self.writer = wandb_run @@ -337,9 +337,10 @@ def __init__( else: # pragma: no cover self.writer = None - self.model: Union[ - MinervaModel, MinervaDataParallel, MinervaBackbone, OptimizedModule - ] + self.model: ( + MinervaModel | MinervaDataParallel | MinervaBackbone | OptimizedModule + ) + if Path(self.params.get("pre_train_name", "none")).suffix == ".onnx": # Loads model from `onnx` format. self.model = self.load_onnx_model() @@ -351,7 +352,7 @@ def __init__( self.model = self.make_model() # Determines the output shape of the model. - sample_pairs: Union[bool, Any] = self.sample_pairs + sample_pairs: bool | Any = self.sample_pairs if not isinstance(sample_pairs, bool): sample_pairs = False self.params["sample_pairs"] = False @@ -409,7 +410,9 @@ def __init__( else: pass - self.checkpoint_path = self.exp_fn / (self.params["exp_name"] + "-checkpoint.pt") + self.checkpoint_path = self.exp_fn / ( + self.params["exp_name"] + "-checkpoint.pt" + ) self.backbone_path = self.exp_fn / (self.params["exp_name"] + "-backbone.pt") self.print("Checkpoint will be saved to " + str(self.checkpoint_path)) @@ -455,22 +458,19 @@ def _setup_writer(self) -> None: if isinstance(self.writer, Run): self.writer.watch(self.model) - def get_input_size(self) -> Tuple[int, ...]: + def get_input_size(self) -> tuple[int, ...]: """Determines the input size of the model. Returns: tuple[int, ...]: :class:`tuple` describing the input shape of the model. """ - input_shape: Optional[Tuple[int, ...]] = self.model.input_size # type: ignore + input_shape: Optional[tuple[int, ...]] = self.model.input_size # type: ignore assert input_shape is not None - input_size: Tuple[int, ...] = (self.batch_size, *input_shape) + input_size: tuple[int, ...] = (self.batch_size, *input_shape) - if self.sample_pairs: + if self.sample_pairs or self.change_detection: input_size = (2, *input_size) - if self.change_detection: - input_size = (input_size[0], 2 * input_size[1], *input_size[2:]) - return input_size def get_model_cache_path(self) -> Path: @@ -497,43 +497,39 @@ def make_model(self) -> MinervaModel: Returns: MinervaModel: Initialised model. """ - model_params: Dict[str, Any] = deepcopy(self.params["model_params"]) + model_params: dict[str, Any] = deepcopy(self.params["model_params"]) if OmegaConf.is_config(model_params): model_params = OmegaConf.to_object(model_params) # type: ignore[assignment] - module = model_params.pop("module", "minerva.models") - if not module: - module = "minerva.models" - is_minerva = True if module == "minerva.models" else False - - # Gets the model requested by config parameters. - _model = utils.func_by_str(module, self.params["model_name"].split("-")[0]) + is_minerva = True if re.search(r"minerva", model_params["_target_"]) else False if self.fine_tune: # Add the path to the pre-trained weights to the model params. model_params["backbone_weight_path"] = f"{self.get_weights_path()}.pt" - params = model_params.get("params", {}) - if "n_classes" in params.keys(): + if "n_classes" in model_params.keys(): # Updates the number of classes in case it has been altered by class balancing. - params["n_classes"] = self.params["n_classes"] + model_params["n_classes"] = self.params["n_classes"] - if "num_classes" in params.keys(): + if "num_classes" in model_params.keys(): # Updates the number of classes in case it has been altered by class balancing. - params["num_classes"] = self.params["n_classes"] + model_params["num_classes"] = self.params["n_classes"] if self.params.get("mix_precision", False): - params["scaler"] = torch.cuda.amp.grad_scaler.GradScaler() + model_params["scaler"] = torch.cuda.amp.grad_scaler.GradScaler() # Initialise model. model: MinervaModel if is_minerva: - model = _model(self.make_criterion(), **params) + model = hydra.utils.instantiate( + model_params, criterion=self.make_criterion() + ) else: - model = MinervaWrapper( - _model, - self.make_criterion(), - **params, + model_params["model"] = hydra.utils.get_method(model_params["_target_"]) + model_params["_target_"] = "minerva.models.MinervaWrapper" + model = hydra.utils.instantiate( + model_params, + criterion=self.make_criterion(), ) if self.params.get("reload", False): @@ -562,7 +558,10 @@ def load_onnx_model(self) -> MinervaModel: package="onnx2torch", ) - model_params = self.params["model_params"].get("params", {}) + model_params = deepcopy(self.params["model_params"]) + + if "_target_" in model_params: + del model_params["_target_"] onnx_model = convert(onnx_load(f"{self.get_weights_path()}.onnx")) model = MinervaOnnxModel(onnx_model, self.make_criterion(), **model_params) @@ -620,11 +619,10 @@ def fit(self) -> None: } ) - tasks: Dict[str, MinervaTask] = {} + tasks: dict[str, MinervaTask] = {} for mode in fit_params.keys(): tasks[mode] = get_task( - fit_params[mode]["name"], - fit_params[mode].get("module", "minerva.tasks"), + fit_params[mode]["_target_"], mode, self.model, self.device, @@ -637,7 +635,9 @@ def fit(self) -> None: **self.params, ) - if tasks[mode].params.get("elim", False): + # Update the number of classes from the task + # if class elimination and training is active in the task. + if tasks[mode].elim and tasks[mode].train: self.params["n_classes"] = tasks[mode].n_classes while self.epoch_no < self.max_epochs: @@ -649,28 +649,34 @@ def fit(self) -> None: # Conduct training or validation epoch. for mode in tasks.keys(): # Only run a validation epoch at set frequency of epochs. Goes to next epoch if not. - if ( - utils.check_substrings_in_string(mode, "val") - and (self.epoch_no) % self.val_freq != 0 - ): - tasks[mode].log_null(self.epoch_no - 1) - break + if utils.check_substrings_in_string(mode, "val"): + tasks[mode].model = self.model + + if self.epoch_no % self.val_freq != 0: + tasks[mode].log_null(self.epoch_no - 1) + break if tasks[mode].train: self.model.train() else: self.model.eval() - results: Optional[Dict[str, Any]] - - results = tasks[mode](self.epoch_no - 1) + results: Optional[dict[str, Any]] = tasks[mode](self.epoch_no - 1) # Print epoch results. if self.gpu == 0: tasks[mode].print_epoch_results(self.epoch_no - 1) - if not self.stopper and self.checkpoint_experiment: + if ( + not self.stopper + and self.checkpoint_experiment + and utils.check_substrings_in_string(mode, "train") + ): self.save_checkpoint() + # Update Trainer's copy of the model from the training task. + if utils.check_substrings_in_string(mode, "train"): + self.model = tasks[mode].model + # Sends validation loss to the stopper and updates early stop bool. if ( utils.check_substrings_in_string(mode, "val") @@ -692,7 +698,7 @@ def fit(self) -> None: self.print("\nEarly stopping triggered") # Create a subset of metrics for plotting model history. - fit_metrics: Dict[str, Any] = {} + fit_metrics: dict[str, Any] = {} for _mode in tasks.keys(): fit_metrics = {**fit_metrics, **tasks[_mode].get_metrics} @@ -742,8 +748,7 @@ def test(self, save: bool = True, show: bool = False) -> None: self.params["plot_last_epoch"] = True for task_name in test_params.keys(): task = get_task( - test_params[task_name]["name"], - test_params[task_name].get("module", "minerva.tasks"), + test_params[task_name]["_target_"], task_name, self.model, self.device, @@ -872,6 +877,11 @@ def load_checkpoint(self) -> None: if self.model.scheduler is not None and "scheduler_state_dict" in checkpoint: self.model.scheduler.load_state_dict(checkpoint["scheduler_state_dict"]) + # Calculate the output dimensions of the model. + self.model.determine_output_dim( + sample_pairs=self.sample_pairs, change_detection=self.change_detection + ) + # Transfer to GPU. self.model.to(self.device) @@ -960,7 +970,7 @@ def close(self) -> None: # Saves model state dict to PyTorch file. self.save_model_weights() - def save_model_weights(self, fn: Optional[Union[str, Path]] = None) -> None: + def save_model_weights(self, fn: Optional[str | Path] = None) -> None: """Saves model state dict to :mod:`torch` file. Args: @@ -973,9 +983,7 @@ def save_model_weights(self, fn: Optional[Union[str, Path]] = None) -> None: torch.save(model.state_dict(), f"{fn}.pt") - def save_model( - self, fn: Optional[Union[Path, str]] = None, fmt: str = "pt" - ) -> None: + def save_model(self, fn: Optional[Path | str] = None, fmt: str = "pt") -> None: """Saves the model object itself to :mod:`torch` file. Args: diff --git a/minerva/transforms.py b/minerva/transforms.py index 7768d091e..bbb857581 100644 --- a/minerva/transforms.py +++ b/minerva/transforms.py @@ -55,20 +55,9 @@ import re from copy import deepcopy from pathlib import Path -from typing import ( - Any, - Callable, - Dict, - List, - Literal, - Optional, - Sequence, - Tuple, - Union, - cast, - overload, -) +from typing import Any, Callable, Literal, Optional, Sequence, cast, overload +import hydra import numpy as np import rasterio import torch @@ -79,18 +68,14 @@ from torchvision.transforms import ( ColorJitter, ConvertImageDtype, + InterpolationMode, Normalize, RandomApply, Resize, ) from torchvision.transforms.v2 import functional as ft -from minerva.utils.utils import ( - find_tensor_mode, - func_by_str, - get_centre_pixel_value, - mask_transform, -) +from minerva.utils.utils import find_tensor_mode, get_centre_pixel_value, mask_transform # ===================================================================================================================== @@ -106,7 +91,7 @@ class ClassTransform: transform (dict[int, int]): Mapping from one labelling schema to another. """ - def __init__(self, transform: Dict[int, int]) -> None: + def __init__(self, transform: dict[int, int]) -> None: self.transform = transform def __call__(self, mask: LongTensor) -> LongTensor: @@ -134,14 +119,14 @@ class PairCreate: def __init__(self) -> None: pass - def __call__(self, sample: Any) -> Tuple[Any, Any]: + def __call__(self, sample: Any) -> tuple[Any, Any]: return self.forward(sample) def __repr__(self) -> str: return f"{self.__class__.__name__}()" @staticmethod - def forward(sample: Any) -> Tuple[Any, Any]: + def forward(sample: Any) -> tuple[Any, Any]: """Takes a sample and returns it and a copy as a :class:`tuple` pair. Args: @@ -220,7 +205,7 @@ def __init__( super().__init__(mean, std, inplace) - def _calc_mean_std(self) -> Tuple[List[float], List[float]]: + def _calc_mean_std(self) -> tuple[list[float], list[float]]: per_img_means = [] per_img_stds = [] for query in self.sampler: @@ -236,7 +221,7 @@ def _calc_mean_std(self) -> Tuple[List[float], List[float]]: return per_band_mean, per_band_std - def _get_tile_mean_std(self, query: BoundingBox) -> Tuple[List[float], List[float]]: + def _get_tile_mean_std(self, query: BoundingBox) -> tuple[list[float], list[float]]: hits = self.dataset.index.intersection(tuple(query), objects=True) filepaths = cast(list[str], [hit.object for hit in hits]) @@ -245,8 +230,8 @@ def _get_tile_mean_std(self, query: BoundingBox) -> Tuple[List[float], List[floa f"query: {query} not found in index with bounds: {self.dataset.bounds}" ) - means: List[float] - stds: List[float] + means: list[float] + stds: list[float] if self.dataset.separate_files: filename_regex = re.compile(self.dataset.filename_regex, re.VERBOSE) @@ -277,9 +262,9 @@ def _get_tile_mean_std(self, query: BoundingBox) -> Tuple[List[float], List[floa def _get_image_mean_std( self, - filepaths: List[str], + filepaths: list[str], band_indexes: Optional[Sequence[int]] = None, - ) -> Tuple[List[float], List[float]]: + ) -> tuple[list[float], list[float]]: stats = [self._get_meta_mean_std(fp, band_indexes) for fp in filepaths] means = list(np.mean([stat[0] for stat in stats], axis=0)) @@ -289,7 +274,7 @@ def _get_image_mean_std( def _get_meta_mean_std( self, filepath, band_indexes: Optional[Sequence[int]] = None - ) -> Tuple[List[float], List[float]]: + ) -> tuple[list[float], list[float]]: # Open the Tiff file and get the statistics from the meta (min, max, mean, std). means = [] stds = [] @@ -374,7 +359,7 @@ class ToRGB: """ - def __init__(self, channels: Optional[Tuple[int, int, int]] = None) -> None: + def __init__(self, channels: Optional[tuple[int, int, int]] = None) -> None: self.channels = channels def __call__(self, img: Tensor) -> Tensor: @@ -426,7 +411,7 @@ class SelectChannels: channels (list[int]): Channel indices to keep. """ - def __init__(self, channels: List[int]) -> None: + def __init__(self, channels: list[int]) -> None: self.channels = channels def __call__(self, img: Tensor) -> Tensor: @@ -501,8 +486,8 @@ class MinervaCompose: Args: transforms (~typing.Sequence[~typing.Callable[..., ~typing.Any]] | ~typing.Callable[..., ~typing.Any]): List of transforms to compose. - key (str): Optional; For use with :mod:`torchgeo` samples and must be assigned a value if using. - The key of the data type in the sample dict to transform. + change_detection (bool): Flag for if transforming a change detection dataset which has + ``"image1"`` and ``"image2"`` keys rather than ``"image"``. Example: >>> transforms.MinervaCompose([ @@ -514,21 +499,26 @@ class MinervaCompose: def __init__( self, - transforms: Union[ - List[Callable[..., Any]], - Callable[..., Any], - Dict[str, Union[List[Callable[..., Any]], Callable[..., Any]]], - ], + transforms: ( + list[Callable[..., Any]] + | Callable[..., Any] + | dict[str, list[Callable[..., Any]] | Callable[..., Any]] + ), + change_detection: bool = False, ) -> None: - self.transforms: Union[ - List[Callable[..., Any]], Dict[str, List[Callable[..., Any]]] - ] + self.transforms: list[Callable[..., Any]] | dict[str, list[Callable[..., Any]]] + + self.change_detection = change_detection if isinstance(transforms, Sequence): self.transforms = list(transforms) elif callable(transforms): self.transforms = [transforms] elif isinstance(transforms, dict): + if self.change_detection and "image" in transforms: + transforms["image1"] = transforms["image"] + transforms["image2"] = transforms["image"] + del transforms["image"] self.transforms = transforms # type: ignore[assignment] assert isinstance(self.transforms, dict) for key in transforms.keys(): @@ -546,12 +536,10 @@ def __call__(self, sample: Tensor) -> Tensor: ... # pragma: no cover @overload def __call__( - self, sample: Dict[str, Any] - ) -> Dict[str, Any]: ... # pragma: no cover + self, sample: dict[str, Any] + ) -> dict[str, Any]: ... # pragma: no cover - def __call__( - self, sample: Union[Tensor, Dict[str, Any]] - ) -> Union[Tensor, Dict[str, Any]]: + def __call__(self, sample: Tensor | dict[str, Any]) -> Tensor | dict[str, Any]: if isinstance(sample, Tensor): assert not isinstance(self.transforms, dict) return self._transform_input(sample, self.transforms) @@ -563,11 +551,24 @@ def __call__( # Assumes the keys must be "image" and "mask" # We need to apply these first before applying any seperate transforms for modalities. if "both" in self.transforms: + if self.change_detection: + # Transform images1 with new random states (if applicable). + sample["image1"] = self._transform_input( + sample["image1"], self.transforms["both"] + ) - # Transform images with new random states (if applicable). - sample["image"] = self._transform_input( - sample["image"], self.transforms["both"] - ) + # Transform images1 with new random states (if applicable). + sample["image2"] = self._transform_input( + sample["image2"], + self.transforms["both"], + reapply=True, + ) + + else: + # Transform images with new random states (if applicable). + sample["image"] = self._transform_input( + sample["image"], self.transforms["both"] + ) # We'll have to convert the masks to float for these transforms to work # so need to store the current dtype to cast back to after. @@ -596,7 +597,7 @@ def __call__( @staticmethod def _transform_input( - img: Tensor, transforms: List[Callable[..., Any]], reapply: bool = False + img: Tensor, transforms: list[Callable[..., Any]], reapply: bool = False ) -> Tensor: if isinstance(transforms, Sequence): for t in transforms: @@ -615,17 +616,17 @@ def _transform_input( def _add( self, - new_transform: Union[ - "MinervaCompose", - Sequence[Callable[..., Any]], - Callable[..., Any], - Dict[str, Union[Sequence[Callable[..., Any]], Callable[..., Any]]], - ], - ) -> Union[Dict[str, List[Callable[..., Any]]], List[Callable[..., Any]]]: + new_transform: ( + "MinervaCompose" + | Sequence[Callable[..., Any]] + | Callable[..., Any] + | dict[str, Sequence[Callable[..., Any]] | Callable[..., Any]] + ), + ) -> dict[str, list[Callable[..., Any]]] | list[Callable[..., Any]]: def add_transforms( - _new_transform: Union[Sequence[Callable[..., Any]], Callable[..., Any]], - old_transform: List[Callable[..., Any]], - ) -> List[Callable[..., Any]]: + _new_transform: Sequence[Callable[..., Any]] | Callable[..., Any], + old_transform: list[Callable[..., Any]], + ) -> list[Callable[..., Any]]: if isinstance(_new_transform, Sequence): old_transform.extend(_new_transform) return old_transform @@ -665,12 +666,12 @@ def add_transforms( def __add__( self, - new_transform: Union[ - "MinervaCompose", - Sequence[Callable[..., Any]], - Callable[..., Any], - Dict[str, Union[Sequence[Callable[..., Any]], Callable[..., Any]]], - ], + new_transform: ( + "MinervaCompose" + | Sequence[Callable[..., Any]] + | Callable[..., Any] + | dict[str, Sequence[Callable[..., Any]] | Callable[..., Any]] + ), ) -> "MinervaCompose": new_compose = deepcopy(self) new_compose.transforms = self._add(new_transform) @@ -678,12 +679,12 @@ def __add__( def __iadd__( self, - new_transform: Union[ - "MinervaCompose", - Sequence[Callable[..., Any]], - Callable[..., Any], - Dict[str, Union[Sequence[Callable[..., Any]], Callable[..., Any]]], - ], + new_transform: ( + "MinervaCompose" + | Sequence[Callable[..., Any]] + | Callable[..., Any] + | dict[str, Sequence[Callable[..., Any]] | Callable[..., Any]] + ), ) -> "MinervaCompose": self.transforms = self._add(new_transform) return self @@ -738,13 +739,13 @@ def __init__(self, from_key: str, to_key: str) -> None: self.from_key = from_key self.to_key = to_key - def __call__(self, sample: Dict[str, Any]) -> Dict[str, Any]: + def __call__(self, sample: dict[str, Any]) -> dict[str, Any]: return self.forward(sample) def __repr__(self) -> str: return f"{self.__class__.__name__}({self.from_key} -> {self.to_key})" - def forward(self, sample: Dict[str, Any]) -> Dict[str, Any]: + def forward(self, sample: dict[str, Any]) -> dict[str, Any]: """Sets the ``to_key`` of ``sample`` to the ``from_key`` and returns. Args: @@ -766,12 +767,14 @@ class SeasonTransform: season (str): How to handle what seasons to return: * ``pair``: Randomly pick 2 seasons to return that will form a pair. * ``random``: Randomly pick a single season to return. + + .. versionadded:: 0.28 """ def __init__(self, season: str = "random") -> None: self.season = season - def __call__(self, x: Tensor) -> Union[Tuple[Tensor, Tensor], Tensor]: + def __call__(self, x: Tensor) -> tuple[Tensor, Tensor] | Tensor: if self.season == "pair": season1 = np.random.choice([0, 1, 2, 3]) @@ -799,6 +802,8 @@ class ConvertDtypeFromStr(ConvertImageDtype): Args: dtype (str): A tensor type as :class:`str`. + + .. versionadded:: 0.28 """ def __init__(self, dtype: str) -> None: @@ -806,7 +811,20 @@ def __init__(self, dtype: str) -> None: class MaskResize(Resize): - """Wrapper of :class:`torchvision.transforms.Resize` for use with masks that have no channel dimension.""" + """Wrapper of :class:`torchvision.transforms.Resize` for use with masks that have no channel dimension. + + .. versionadded:: 0.28 + """ + + def __init__( + self, + size, + interpolation: str = "NEAREST", + max_size=None, + antialias: bool = True, + ) -> None: + interpolation_mode = getattr(InterpolationMode, interpolation) + super().__init__(size, interpolation_mode, max_size, antialias) def forward(self, img: Tensor) -> Tensor: """ @@ -818,7 +836,7 @@ def forward(self, img: Tensor) -> Tensor: """ org_shape = img.shape - tmp_shape: Tuple[int, int, int, int] + tmp_shape: tuple[int, int, int, int] if len(org_shape) == 4: # Mask already has shape [N,C,H,W] so no need to modify shape for Resize. @@ -836,35 +854,61 @@ def forward(self, img: Tensor) -> Tensor: return torch.squeeze(super().forward(torch.reshape(img, tmp_shape))) +class AdjustGamma: + """Callable version of :meth:`torchvision.transforms.functional.adjust_gamma` + + Args: + gamma (float): Optional; Gamma factor. + gain (float): Optional; Gain scalar. + + .. versionadded:: 0.28 + """ + + def __init__(self, gamma: float = 1.0, gain: float = 1.0) -> None: + self.gamma = gamma + self.gain = gain + + def forward(self, img: Tensor) -> Tensor: + img = ft.adjust_gamma(img, gamma=self.gamma, gain=self.gain) + assert isinstance(img, Tensor) + return img + + def __call__(self, img: Tensor) -> Tensor: + return self.forward(img) + + # ===================================================================================================================== # METHODS # ===================================================================================================================== -def _construct_random_transforms(random_params: Dict[str, Any]) -> Any: +def _construct_random_transforms(random_params: dict[str, Any]) -> Any: p = random_params.pop("p", 0.5) random_transforms = [] for ran_name in random_params: - random_transforms.append(get_transform(ran_name, random_params[ran_name])) + random_transforms.append(get_transform(random_params[ran_name])) return RandomApply(random_transforms, p=p) def init_auto_norm( - dataset: RasterDataset, params: Dict[str, Any] = {} + dataset: RasterDataset, + length: int = 128, + roi: Optional[BoundingBox] = None, + inplace=False, ) -> RasterDataset: """Uses :class:~`minerva.transforms.AutoNorm` to automatically find the mean and standard deviation of `dataset` to create a normalisation transform that is then added to the existing transforms of `dataset`. Args: dataset (RasterDataset): Dataset to find and apply the normalisation conditions to. - params (Dict[str, Any]): Parameters for :class:~`minerva.transforms.AutoNorm`. + params (dict[str, Any]): Parameters for :class:~`minerva.transforms.AutoNorm`. Returns: RasterDataset: `dataset` with an additional :class:~`minerva.transforms.AutoNorm` transform added to it's :attr:~`torchgeo.datasets.RasterDataset.transforms` attribute. """ # Creates the AutoNorm transform by sampling `dataset` for its mean and standard deviation stats. - auto_norm = AutoNorm(dataset, **params) + auto_norm = AutoNorm(dataset, length=length, roi=roi, inplace=inplace) if dataset.transforms is None: dataset.transforms = MinervaCompose(auto_norm) @@ -887,35 +931,24 @@ def init_auto_norm( return dataset -def get_transform(name: str, transform_params: Dict[str, Any]) -> Callable[..., Any]: +def get_transform(transform_params: dict[str, Any]) -> Callable[..., Any]: """Creates a transform object based on config parameters. Args: - name (str): Name of transform object to import e.g :class:`~torchvision.transforms.RandomResizedCrop`. transform_params (dict[str, ~typing.Any]): Arguements to construct transform with. - Should also include ``"module"`` key defining the import path to the transform object. + Should also include ``"_target_"`` key defining the import path to the transform object. Returns: Initialised transform object specified by config parameters. - .. note:: - If ``transform_params`` contains no ``"module"`` key, it defaults to ``torchvision.transforms``. - Example: - >>> name = "RandomResizedCrop" - >>> params = {"module": "torchvision.transforms", "size": 128} - >>> transform = get_transform(name, params) + >>> params = {"_target": "torchvision.transforms.RandomResizedCrop", "size": 128} + >>> transform = get_transform(params) Raises: TypeError: If created transform object is itself not :class:`~typing.Callable`. """ - params = transform_params.copy() - module = params.pop("module", "torchvision.transforms") - - # Gets the transform requested by config parameters. - _transform: Callable[..., Any] = func_by_str(module, name) - - transform: Callable[..., Any] = _transform(**params) + transform: Callable[..., Any] = hydra.utils.instantiate(transform_params) if callable(transform): return transform else: @@ -923,7 +956,8 @@ def get_transform(name: str, transform_params: Dict[str, Any]) -> Callable[..., def make_transformations( - transform_params: Union[Dict[str, Any], Literal[False]] + transform_params: dict[str, Any] | Literal[False], + change_detection: bool = False, ) -> Optional[MinervaCompose]: """Constructs a transform or series of transforms based on parameters provided. @@ -931,11 +965,13 @@ def make_transformations( transform_params (dict[str, ~typing.Any] | ~typing.Literal[False]): Parameters defining transforms desired. The name of each transform should be the key, while the kwargs for the transform should be the value of that key as a dict. + change_detection (bool): Flag for if transforming a change detection dataset which has + ``"image1"`` and ``"image2"`` keys rather than ``"image"``. Example: >>> transform_params = { - >>> "CenterCrop": {"module": "torchvision.transforms", "size": 128}, - >>> "RandomHorizontalFlip": {"module": "torchvision.transforms", "p": 0.7} + >>> "crop": {"_target_": "torchvision.transforms.CenterCrop", "size": 128}, + >>> "flip": {"_target_": "torchvision.transforms.RandomHorizontalFlip", "p": 0.7} >>> } >>> transforms = make_transformations(transform_params) @@ -945,7 +981,7 @@ def make_transformations( If multiple transforms are defined, a Compose object of Transform objects is returned. """ - def construct(type_transform_params: Dict[str, Any]) -> List[Callable[..., Any]]: + def construct(type_transform_params: dict[str, Any]) -> list[Callable[..., Any]]: type_transformations = [] # Get each transform. @@ -959,9 +995,7 @@ def construct(type_transform_params: Dict[str, Any]) -> List[Callable[..., Any]] continue else: - type_transformations.append( - get_transform(_name, type_transform_params[_name]) - ) + type_transformations.append(get_transform(type_transform_params[_name])) return type_transformations @@ -971,7 +1005,7 @@ def construct(type_transform_params: Dict[str, Any]) -> List[Callable[..., Any]] if all(transform_params.values()) is None: return None - transformations: Dict[str, Any] = {} + transformations: dict[str, Any] = {} for name in transform_params: if name in ("image", "mask", "label", "both"): @@ -980,6 +1014,8 @@ def construct(type_transform_params: Dict[str, Any]) -> List[Callable[..., Any]] else: transformations[name] = construct(transform_params[name]) else: - return MinervaCompose(construct(transform_params)) + return MinervaCompose( + construct(transform_params), change_detection=change_detection + ) - return MinervaCompose(transformations) + return MinervaCompose(transformations, change_detection=change_detection) diff --git a/minerva/utils/runner.py b/minerva/utils/runner.py index edba8f1ba..7f74b18e6 100644 --- a/minerva/utils/runner.py +++ b/minerva/utils/runner.py @@ -54,7 +54,7 @@ import signal import subprocess from pathlib import Path -from typing import Any, Callable, Optional, Tuple, Union +from typing import Any, Callable, Optional import requests import torch @@ -62,7 +62,7 @@ import torch.multiprocessing as mp import wandb import yaml -from omegaconf import DictConfig, OmegaConf, ListConfig +from omegaconf import DictConfig, ListConfig, OmegaConf from wandb.sdk.lib import RunDisabled from wandb.sdk.wandb_run import Run @@ -120,14 +120,14 @@ def _config_load_resolver(path: str): return cfg -def _construct_patch_size(input_size: Tuple[int, int, int]) -> ListConfig: +def _construct_patch_size(input_size: tuple[int, int, int]) -> ListConfig: return ListConfig(input_size[-2:]) def setup_wandb_run( gpu: int, cfg: DictConfig, -) -> Tuple[Optional[Union[Run, RunDisabled]], DictConfig]: +) -> tuple[Optional[Run | RunDisabled], DictConfig]: """Sets up a :mod:`wandb` logger for either every process, the master process or not if not logging. Note: @@ -148,7 +148,7 @@ def setup_wandb_run( ~wandb.sdk.wandb_run.Run | ~wandb.sdk.lib.RunDisabled | None: The :mod:`wandb` run object for this process or ``None`` if ``log_all=False`` and ``rank!=0``. """ - run: Optional[Union[Run, RunDisabled]] = None + run: Optional[Run | RunDisabled] = None if cfg.get("wandb_log", False) or cfg.get("project", None): try: if cfg.get("log_all", False) and cfg.world_size > 1: @@ -290,7 +290,7 @@ def config_args(cfg: DictConfig) -> DictConfig: def _run_preamble( gpu: int, - run: Callable[[int, Optional[Union[Run, RunDisabled]], DictConfig], Any], + run: Callable[[int, Optional[Run | RunDisabled], DictConfig], Any], cfg: DictConfig, ) -> None: # pragma: no cover # Calculates the global rank of this process. @@ -301,7 +301,7 @@ def _run_preamble( if cfg.world_size > 1: dist.init_process_group( # type: ignore[attr-defined] - backend="nccl", + backend=cfg.get("dist_backend", "gloo"), init_method=cfg.dist_url, world_size=cfg.world_size, rank=cfg.rank, @@ -316,7 +316,7 @@ def _run_preamble( def distributed_run( - run: Callable[[int, Optional[Union[Run, RunDisabled]], DictConfig], Any] + run: Callable[[int, Optional[Run | RunDisabled], DictConfig], Any] ) -> Callable[..., Any]: """Runs the supplied function and arguments with distributed computing according to arguments. @@ -334,7 +334,9 @@ def distributed_run( OmegaConf.register_new_resolver("cfg_load", _config_load_resolver, replace=True) OmegaConf.register_new_resolver("eval", eval, replace=True) - OmegaConf.register_new_resolver("to_patch_size", _construct_patch_size, replace=True) + OmegaConf.register_new_resolver( + "to_patch_size", _construct_patch_size, replace=True + ) @functools.wraps(run) def inner_decorator(cfg: DictConfig): @@ -360,7 +362,7 @@ def inner_decorator(cfg: DictConfig): def run_trainer( - gpu: int, wandb_run: Optional[Union[Run, RunDisabled]], cfg: DictConfig + gpu: int, wandb_run: Optional[Run | RunDisabled], cfg: DictConfig ) -> None: trainer = Trainer( diff --git a/minerva/utils/utils.py b/minerva/utils/utils.py index b8c563867..8e701b1ad 100644 --- a/minerva/utils/utils.py +++ b/minerva/utils/utils.py @@ -107,7 +107,6 @@ # ---+ Inbuilt +------------------------------------------------------------------------------------------------------- import cmath import functools -import glob import hashlib import importlib import inspect @@ -126,7 +125,7 @@ from types import ModuleType from typing import Any, Callable from typing import Counter as CounterType -from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Union, overload +from typing import Iterable, Optional, Sequence, overload # ---+ 3rd Party +----------------------------------------------------------------------------------------------------- import numpy as np @@ -172,8 +171,8 @@ # DECORATORS # ===================================================================================================================== def return_updated_kwargs( - func: Callable[..., Tuple[Any, ...]] -) -> Callable[..., Tuple[Any, ...]]: + func: Callable[..., tuple[Any, ...]] +) -> Callable[..., tuple[Any, ...]]: """Decorator that allows the `kwargs` supplied to the wrapped function to be returned with updated values. Assumes that the wrapped function returns a :class:`dict` in the last position of the @@ -212,7 +211,7 @@ def pair_collate(func: Callable[[Any], Any]) -> Callable[[Any], Any]: """ @functools.wraps(func) - def wrapper(samples: Iterable[Tuple[Any, Any]]) -> Tuple[Any, Any]: + def wrapper(samples: Iterable[tuple[Any, Any]]) -> tuple[Any, Any]: a, b = tuple(zip(*samples)) return func(a), func(b) @@ -227,7 +226,7 @@ class Wrapper: def __init__(self, *args, **kwargs) -> None: self.wrap = cls(*args, **kwargs) - def __call__(self, pair: Tuple[Any, Any]) -> Tuple[Any, Any]: + def __call__(self, pair: tuple[Any, Any]) -> tuple[Any, Any]: a, b = pair return self.wrap.__call__(a), self.wrap.__call__(b) @@ -258,26 +257,26 @@ class Wrapper: def __init__(self, *args, **kwargs) -> None: self.wrap: Callable[ [ - Union[Dict[str, Any], Tensor], + dict[str, Any] | Tensor, ], - Dict[str, Any], + dict[str, Any], ] = cls(*args, **kwargs) self.keys = keys @overload def __call__( - self, batch: Dict[str, Any] - ) -> Dict[str, Any]: ... # pragma: no cover + self, batch: dict[str, Any] + ) -> dict[str, Any]: ... # pragma: no cover @overload - def __call__(self, batch: Tensor) -> Dict[str, Any]: ... # pragma: no cover + def __call__(self, batch: Tensor) -> dict[str, Any]: ... # pragma: no cover - def __call__(self, batch: Union[Dict[str, Any], Tensor]) -> Dict[str, Any]: + def __call__(self, batch: dict[str, Any] | Tensor) -> dict[str, Any]: if isinstance(batch, Tensor): return self.wrap(batch) elif isinstance(batch, dict) and isinstance(self.keys, Sequence): - aug_batch: Dict[str, Any] = {} + aug_batch: dict[str, Any] = {} for key in self.keys: aug_batch[key] = self.wrap(batch.pop(key)) @@ -311,7 +310,7 @@ class Wrapper: def __init__(self, *args, **kwargs) -> None: self.wrap = cls(*args, **kwargs) - def __getitem__(self, queries: Any = None) -> Tuple[Any, Any]: + def __getitem__(self, queries: Any = None) -> tuple[Any, Any]: return self.wrap[queries[0]], self.wrap[queries[1]] def __getattr__(self, item): @@ -399,7 +398,7 @@ def _optional_import( def _optional_import( module: str, *, name: Optional[str] = None, package: Optional[str] = None -) -> Union[ModuleType, Callable[..., Any]]: +) -> ModuleType | Callable[..., Any]: try: _module: ModuleType = importlib.import_module(module) return _module if name is None else getattr(_module, name) @@ -462,7 +461,7 @@ def is_notebook() -> bool: return True -def get_cuda_device(device_sig: Union[int, str] = "cuda:0") -> _device: +def get_cuda_device(device_sig: int | str = "cuda:0") -> _device: """Finds and returns the ``CUDA`` device, if one is available. Else, returns CPU as device. Assumes there is at most only one ``CUDA`` device. @@ -478,7 +477,7 @@ def get_cuda_device(device_sig: Union[int, str] = "cuda:0") -> _device: return device -def exist_delete_check(fn: Union[str, Path]) -> None: +def exist_delete_check(fn: str | Path) -> None: """Checks if given file exists then deletes if true. Args: @@ -491,7 +490,7 @@ def exist_delete_check(fn: Union[str, Path]) -> None: Path(fn).unlink(missing_ok=True) -def mkexpdir(name: str, results_dir: Union[Path, str] = "results") -> None: +def mkexpdir(name: str, results_dir: Path | str = "results") -> None: """Makes a new directory below the results directory with name provided. If directory already exists, no action is taken. @@ -521,7 +520,7 @@ def set_seeds(seed: int) -> None: torch.cuda.random.manual_seed_all(seed) -def check_dict_key(dictionary: Dict[Any, Any], key: Any) -> bool: +def check_dict_key(dictionary: dict[Any, Any], key: Any) -> bool: """Checks if a key exists in a dictionary and if it is ``None`` or ``False``. Args: @@ -583,7 +582,7 @@ def transform_coordinates( y: Sequence[float], src_crs: CRS, new_crs: CRS = WGS84, -) -> Tuple[Sequence[float], Sequence[float]]: ... # pragma: no cover +) -> tuple[Sequence[float], Sequence[float]]: ... # pragma: no cover @overload @@ -592,7 +591,7 @@ def transform_coordinates( y: float, src_crs: CRS, new_crs: CRS = WGS84, -) -> Tuple[Sequence[float], Sequence[float]]: ... # pragma: no cover +) -> tuple[Sequence[float], Sequence[float]]: ... # pragma: no cover @overload @@ -601,21 +600,21 @@ def transform_coordinates( y: Sequence[float], src_crs: CRS, new_crs: CRS = WGS84, -) -> Tuple[Sequence[float], Sequence[float]]: ... # pragma: no cover +) -> tuple[Sequence[float], Sequence[float]]: ... # pragma: no cover @overload def transform_coordinates( x: float, y: float, src_crs: CRS, new_crs: CRS = WGS84 -) -> Tuple[float, float]: ... # pragma: no cover +) -> tuple[float, float]: ... # pragma: no cover def transform_coordinates( - x: Union[Sequence[float], float], - y: Union[Sequence[float], float], + x: Sequence[float] | float, + y: Sequence[float] | float, src_crs: CRS, new_crs: CRS = WGS84, -) -> Union[Tuple[Sequence[float], Sequence[float]], Tuple[float, float]]: +) -> tuple[Sequence[float], Sequence[float]] | tuple[float, float]: """Transforms co-ordinates from one :class:`~rasterio.crs.CRS` to another. Args: @@ -641,7 +640,7 @@ def transform_coordinates( y = check_len(y, x) # Transform co-ordinates from source to new CRS and returns a tuple of (x, y) - co_ordinates: Tuple[Sequence[float], Sequence[float]] = rt.warp.transform( # type: ignore + co_ordinates: tuple[Sequence[float], Sequence[float]] = rt.warp.transform( # type: ignore src_crs=src_crs, dst_crs=new_crs, xs=x, ys=y ) @@ -715,9 +714,9 @@ def deg_to_dms(deg: float, axis: str = "lat") -> str: def dec2deg( - dec_co: Union[Sequence[float], NDArray[Shape["*"], Float]], # noqa: F722 + dec_co: Sequence[float] | NDArray[Shape["*"], Float], # noqa: F722 axis: str = "lat", -) -> List[str]: +) -> list[str]: """Wrapper for :func:`deg_to_dms`. Args: @@ -727,14 +726,14 @@ def dec2deg( Returns: list[str]: List of formatted strings in degrees, minutes and seconds. """ - deg_co: List[str] = [] + deg_co: list[str] = [] for co in dec_co: deg_co.append(deg_to_dms(co, axis=axis)) return deg_co -def get_centre_loc(bounds: BoundingBox) -> Tuple[float, float]: +def get_centre_loc(bounds: BoundingBox) -> tuple[float, float]: """Gets the centre co-ordinates of the parsed bounding box. Args: @@ -775,7 +774,7 @@ def get_centre_pixel_value(x: Tensor) -> Any: raise ValueError() -def lat_lon_to_loc(lat: Union[str, float], lon: Union[str, float]) -> str: +def lat_lon_to_loc(lat: str | float, lon: str | float) -> str: """Takes a latitude - longitude co-ordinate and returns a string of the semantic location. Args: @@ -804,7 +803,7 @@ def lat_lon_to_loc(lat: Union[str, float], lon: Union[str, float]) -> str: location = query.raw["properties"] # type: ignore # Attempts to add possible fields to address of the location. Not all will be present for every query. - locs: List[str] = [] + locs: list[str] = [] try: locs.append(location["city"]) except KeyError: @@ -889,8 +888,8 @@ def mask_to_ohe(mask: LongTensor, n_classes: int) -> LongTensor: def class_weighting( - class_dist: List[Tuple[int, int]], normalise: bool = False -) -> Dict[int, float]: + class_dist: list[tuple[int, int]], normalise: bool = False +) -> dict[int, float]: """Constructs weights for each class defined by the distribution provided. Note: @@ -912,7 +911,7 @@ def class_weighting( n_samples += mode[1] # Constructs class weights. Each weight is 1 / number of samples for that class. - class_weights: Dict[int, float] = {} + class_weights: dict[int, float] = {} if normalise: for mode in class_dist: class_weights[mode[0]] = n_samples / mode[1] @@ -924,8 +923,8 @@ def class_weighting( def find_empty_classes( - class_dist: List[Tuple[int, int]], class_names: Dict[int, str] -) -> List[int]: + class_dist: list[tuple[int, int]], class_names: dict[int, str] +) -> list[int]: """Finds which classes defined by config files are not present in the dataset. Args: @@ -936,7 +935,7 @@ def find_empty_classes( Returns: list[int]: List of classes not found in ``class_dist`` and are thus empty/ not present in dataset. """ - empty: List[int] = [] + empty: list[int] = [] # Checks which classes are not present in class_dist for label in class_names.keys(): @@ -948,10 +947,10 @@ def find_empty_classes( def eliminate_classes( - empty_classes: Union[List[int], Tuple[int, ...], NDArray[Any, Int]], - old_classes: Dict[int, str], - old_cmap: Optional[Dict[int, str]] = None, -) -> Tuple[Dict[int, str], Dict[int, int], Optional[Dict[int, str]]]: + empty_classes: list[int] | tuple[int, ...] | NDArray[Any, Int], + old_classes: dict[int, str], + old_cmap: Optional[dict[int, str]] = None, +) -> tuple[dict[int, str], dict[int, int], Optional[dict[int, str]]]: """Eliminates empty classes from the class text label and class colour dictionaries and re-normalise. This should ensure that the remaining list of classes is still a linearly spaced list of numbers. @@ -1020,7 +1019,7 @@ def eliminate_classes( return reordered_classes, conversion, reordered_colours -def class_transform(label: int, matrix: Dict[int, int]) -> int: +def class_transform(label: int, matrix: dict[int, int]) -> int: """Transforms labels from one schema to another mapped by a supplied dictionary. Args: @@ -1035,20 +1034,20 @@ def class_transform(label: int, matrix: Dict[int, int]) -> int: @overload def mask_transform( # type: ignore[overload-overlap] - array: NDArray[Any, Int], matrix: Dict[int, int] + array: NDArray[Any, Int], matrix: dict[int, int] ) -> NDArray[Any, Int]: ... # pragma: no cover @overload def mask_transform( - array: LongTensor, matrix: Dict[int, int] + array: LongTensor, matrix: dict[int, int] ) -> LongTensor: ... # pragma: no cover def mask_transform( - array: Union[NDArray[Any, Int], LongTensor], - matrix: Dict[int, int], -) -> Union[NDArray[Any, Int], LongTensor]: + array: NDArray[Any, Int] | LongTensor, + matrix: dict[int, int], +) -> NDArray[Any, Int] | LongTensor: """Transforms all labels of an N-dimensional array from one schema to another mapped by a supplied dictionary. Args: @@ -1065,11 +1064,11 @@ def mask_transform( def check_test_empty( - pred: Union[Sequence[int], NDArray[Any, Int]], - labels: Union[Sequence[int], NDArray[Any, Int]], - class_labels: Optional[Dict[int, str]] = None, + pred: Sequence[int] | NDArray[Any, Int], + labels: Sequence[int] | NDArray[Any, Int], + class_labels: Optional[dict[int, str]] = None, p_dist: bool = True, -) -> Tuple[NDArray[Any, Int], NDArray[Any, Int], Dict[int, str]]: +) -> tuple[NDArray[Any, Int], NDArray[Any, Int], dict[int, str]]: """Checks if any of the classes in the dataset were not present in both the predictions and ground truth labels. Returns corrected and re-ordered predictions, labels and class labels. @@ -1120,8 +1119,8 @@ def check_test_empty( def class_dist_transform( - class_dist: List[Tuple[int, int]], matrix: Dict[int, int] -) -> List[Tuple[int, int]]: + class_dist: list[tuple[int, int]], matrix: dict[int, int] +) -> list[tuple[int, int]]: """Transforms the class distribution from an old schema to a new one. Args: @@ -1132,14 +1131,14 @@ def class_dist_transform( Returns: list[tuple[int, int]]: Class distribution updated to new labels. """ - new_class_dist: List[Tuple[int, int]] = [] + new_class_dist: list[tuple[int, int]] = [] for mode in class_dist: new_class_dist.append((class_transform(mode[0], matrix), mode[1])) return new_class_dist -def class_frac(patch: pd.Series) -> Dict[Any, Any]: +def class_frac(patch: pd.Series) -> dict[Any, Any]: """Computes the fractional sizes of the classes of the given :term:`patch` and returns a :class:`dict` of the results. @@ -1150,7 +1149,7 @@ def class_frac(patch: pd.Series) -> Dict[Any, Any]: Mapping: Dictionary-like object with keys as class numbers and associated values of fractional size of class plus a key-value pair for the :term:`patch` ID. """ - new_columns: Dict[Any, Any] = dict(patch.to_dict()) + new_columns: dict[Any, Any] = dict(patch.to_dict()) counts = 0 for mode in patch["MODES"]: counts += mode[1] @@ -1173,7 +1172,7 @@ def cloud_cover(scene: NDArray[Any, Any]) -> Any: return np.sum(scene) / scene.size -def threshold_scene_select(df: DataFrame, thres: float = 0.3) -> List[str]: +def threshold_scene_select(df: DataFrame, thres: float = 0.3) -> list[str]: """Selects all scenes in a :term:`patch` with a cloud cover less than the threshold provided. Args: @@ -1191,9 +1190,9 @@ def threshold_scene_select(df: DataFrame, thres: float = 0.3) -> List[str]: def find_best_of( patch_id: str, manifest: DataFrame, - selector: Callable[[DataFrame], List[str]] = threshold_scene_select, + selector: Callable[[DataFrame], list[str]] = threshold_scene_select, **kwargs, -) -> List[str]: +) -> list[str]: """Finds the scenes sorted by cloud cover using selector function supplied. Args: @@ -1234,9 +1233,9 @@ def timestamp_now(fmt: str = "%d-%m-%Y_%H%M") -> str: def find_modes( labels: Iterable[int], plot: bool = False, - classes: Optional[Dict[int, str]] = None, - cmap_dict: Optional[Dict[int, str]] = None, -) -> List[Tuple[int, int]]: + classes: Optional[dict[int, str]] = None, + cmap_dict: Optional[dict[int, str]] = None, +) -> list[tuple[int, int]]: """Finds the modal distribution of the classes within the labels provided. Can plot the results as a pie chart if ``plot=True``. @@ -1249,7 +1248,7 @@ def find_modes( list[tuple[int, int]]: Modal distribution of classes in input in order of most common class. """ # Finds the distribution of the classes within the data - class_dist: List[Tuple[int, int]] = Counter( + class_dist: list[tuple[int, int]] = Counter( np.array(labels).flatten() ).most_common() @@ -1264,10 +1263,10 @@ def find_modes( def modes_from_manifest( manifest: DataFrame, - classes: Dict[int, str], + classes: dict[int, str], plot: bool = False, - cmap_dict: Optional[Dict[int, str]] = None, -) -> List[Tuple[int, int]]: + cmap_dict: Optional[dict[int, str]] = None, +) -> list[tuple[int, int]]: """Uses the dataset manifest to calculate the fractional size of the classes. Args: @@ -1295,7 +1294,7 @@ def count_samples(cls): class_counter[classification] = count except KeyError: continue - class_dist: List[Tuple[int, int]] = class_counter.most_common() + class_dist: list[tuple[int, int]] = class_counter.most_common() if plot: # Plots a pie chart of the distribution of the classes within the given list of patches @@ -1325,7 +1324,7 @@ def func_by_str(module_path: str, func: str) -> Callable[..., Any]: return func -def check_len(param: Any, comparator: Any) -> Union[Any, Sequence[Any]]: +def check_len(param: Any, comparator: Any) -> Any | Sequence[Any]: """Checks the length of one object against a comparator object. Args: @@ -1383,8 +1382,8 @@ def calc_grad(model: Module) -> Optional[float]: def print_class_dist( - class_dist: List[Tuple[int, int]], - class_labels: Optional[Dict[int, str]] = None, + class_dist: list[tuple[int, int]], + class_labels: Optional[dict[int, str]] = None, ) -> None: """Prints the supplied ``class_dist`` in a pretty table format using :mod:`tabulate`. @@ -1435,7 +1434,7 @@ def calc_frac(count: float, total: float) -> str: def batch_flatten( - x: Union[NDArray[Any, Any], ArrayLike] + x: NDArray[Any, Any] | ArrayLike ) -> NDArray[Shape["*"], Any]: # noqa: F722 """Flattens the supplied array with :func:`numpy`. @@ -1455,9 +1454,9 @@ def batch_flatten( def make_classification_report( - pred: Union[Sequence[int], NDArray[Any, Int]], - labels: Union[Sequence[int], NDArray[Any, Int]], - class_labels: Optional[Dict[int, str]] = None, + pred: Sequence[int] | NDArray[Any, Int], + labels: Sequence[int] | NDArray[Any, Int], + class_labels: Optional[dict[int, str]] = None, print_cr: bool = True, p_dist: bool = False, ) -> DataFrame: @@ -1564,9 +1563,9 @@ def calc_contrastive_acc(z: Tensor) -> Tensor: def run_tensorboard( exp_name: str, - path: Union[str, List[str], Tuple[str, ...], Path] = "", + path: str | list[str] | tuple[str, ...] | Path = "", env_name: str = "env", - host_num: Union[str, int] = 6006, + host_num: str | int = 6006, _testing: bool = False, ) -> Optional[int]: """Runs the :mod:`TensorBoard` logs and hosts on a local webpage. @@ -1627,11 +1626,11 @@ def run_tensorboard( def compute_roc_curves( probs: NDArray[Any, Float], - labels: Union[Sequence[int], NDArray[Any, Int]], - class_labels: List[int], + labels: Sequence[int] | NDArray[Any, Int], + class_labels: list[int], micro: bool = True, macro: bool = True, -) -> Tuple[Dict[Any, float], Dict[Any, float], Dict[Any, float]]: +) -> tuple[dict[Any, float], dict[Any, float], dict[Any, float]]: """Computes the false-positive rate, true-positive rate and AUCs for each class using a one-vs-all approach. The micro and macro averages are for each of these variables is also computed. @@ -1660,13 +1659,13 @@ def compute_roc_curves( # Dicts to hold the false-positive rate, true-positive rate and Area Under Curves # of each class and micro, macro averages. - fpr: Dict[Any, Any] = {} - tpr: Dict[Any, Any] = {} - roc_auc: Dict[Any, Any] = {} + fpr: dict[Any, Any] = {} + tpr: dict[Any, Any] = {} + roc_auc: dict[Any, Any] = {} # Holds a list of the classes that were in the targets supplied to the model. # Avoids warnings about empty targets from sklearn! - populated_classes: List[int] = [] + populated_classes: list[int] = [] print("Computing class ROC curves") @@ -1844,8 +1843,8 @@ def calc_norm_euc_dist(a: Tensor, b: Tensor) -> Tensor: def fallback_params( key: str, - params_a: Dict[str, Any], - params_b: Dict[str, Any], + params_a: dict[str, Any], + params_b: dict[str, Any], fallback: Optional[Any] = None, ) -> Any: """Search for a value associated with ``key`` from @@ -1868,9 +1867,9 @@ def fallback_params( def compile_dataset_paths( - data_dir: Union[Path, str], - in_paths: Union[List[Union[Path, str]], Union[Path, str]], -) -> List[str]: + data_dir: Path | str, + in_paths: list[Path | str] | Path | str, +) -> list[str]: """Ensures that a list of paths is returned with the data directory prepended, even if a single string is supplied Args: @@ -1885,15 +1884,15 @@ def compile_dataset_paths( else: out_paths = [universal_path(data_dir) / in_paths] - compiled_paths = [] + # Check if each path exists. If not, make the path. for path in out_paths: - compiled_paths.extend(glob.glob(str(path), recursive=True)) + path.mkdir(parents=True, exist_ok=True) - # For each path, get the absolute path, convert to string and return. - return [str(Path(path).absolute()) for path in compiled_paths] + # For each path, get the absolute path, make the path if it does not exist then convert to string and return. + return [str(Path(path).absolute()) for path in out_paths] -def make_hash(obj: Dict[Any, Any]) -> str: +def make_hash(obj: dict[Any, Any]) -> str: """Make a deterministic MD5 hash of a serialisable object using JSON. Source: https://death.andgravity.com/stable-hashing @@ -1958,3 +1957,61 @@ def closest_factors(n): best_pair = (best_pair[1], best_pair[0]) return best_pair + + +def get_sample_index(sample: dict[str, Any]) -> Optional[Any]: + """Get the index for a sample with unkown index key. + + Will try: + * ``bbox`` (:mod:`torchgeo` < 0.6.0) for :class:`~torchgeo.datasets.GeoDataset` + * ``bounds`` (:mod:`torchgeo` >= 0.6.0) for :class:`~torchgeo.datasets.GeoDataset` + * ``id`` for :class:`~torchgeo.datasets.NonGeoDataset` + + Args: + sample (dict[str, ~typing.Any]): Sample dictionary to find index in. + + Returns: + None | ~typing.Any: Sample index or ``None`` if not found. + + .. versionadded:: 0.28 + """ + if "bbox" in sample: + index = sample["bbox"] + elif "bounds" in sample: + index = sample["bounds"] + elif "id" in sample: + index = sample["id"] + else: + index = None + + return index + + +def compare_models(model_1: Module, model_2: Module) -> None: + """Compare two models weight-by-weight. + + Source: https://discuss.pytorch.org/t/check-if-models-have-same-weights/4351/5 + + Args: + model_1 (torch.nn.Module): First model. + model_2 (torch.nn.Module): Second model. + + Raises: + AssertionError: If models do not match exactly. + + .. versionadded:: 0.28 + """ + models_differ = 0 + for key_item_1, key_item_2 in zip( + model_1.state_dict().items(), model_2.state_dict().items() + ): + if torch.equal(key_item_1[1], key_item_2[1]): + pass + else: + models_differ += 1 + if key_item_1[0] == key_item_2[0]: + print("Mismtach found at", key_item_1[0]) + else: + raise AssertionError + if models_differ == 0: + print("Models match perfectly! :)") diff --git a/minerva/utils/visutils.py b/minerva/utils/visutils.py index 8201c7fcd..dd36c6153 100644 --- a/minerva/utils/visutils.py +++ b/minerva/utils/visutils.py @@ -64,7 +64,7 @@ import os import random from pathlib import Path -from typing import Any, Dict, List, Optional, Sequence, Tuple, Union +from typing import Any, Optional, Sequence import imageio import matplotlib as mlp @@ -126,7 +126,7 @@ def de_interlace(x: Sequence[Any], f: int) -> NDArray[Any, Any]: Returns: ~numpy.ndarray[~typing.Any]: De-interlaced array. Each source array is now sequentially connected. """ - new_x: List[NDArray[Any, Any]] = [] + new_x: list[NDArray[Any, Any]] = [] for i in range(f): x_i = [] for j in np.arange(start=i, stop=len(x), step=f): @@ -137,12 +137,12 @@ def de_interlace(x: Sequence[Any], f: int) -> NDArray[Any, Any]: def dec_extent_to_deg( - shape: Tuple[int, int], + shape: tuple[int, int], bounds: BoundingBox, src_crs: CRS, new_crs: CRS = WGS84, spacing: int = 32, -) -> Tuple[Tuple[int, int, int, int], NDArray[Any, Float], NDArray[Any, Float]]: +) -> tuple[tuple[int, int, int, int], NDArray[Any, Float], NDArray[Any, Float]]: """Gets the extent of the image with ``shape`` and with ``bounds`` in latitude, longitude of system ``new_crs``. Args: @@ -194,7 +194,7 @@ def dec_extent_to_deg( def get_mlp_cmap( - cmap_style: Optional[Union[Colormap, str]] = None, n_classes: Optional[int] = None + cmap_style: Optional[Colormap | str] = None, n_classes: Optional[int] = None ) -> Optional[Colormap]: """Creates a cmap from query @@ -225,8 +225,8 @@ def get_mlp_cmap( def discrete_heatmap( data: NDArray[Shape["*, *"], Int], # noqa: F722 - classes: Union[List[str], Tuple[str, ...]], - cmap_style: Optional[Union[str, ListedColormap]] = None, + classes: list[str] | tuple[str, ...], + cmap_style: Optional[str | ListedColormap] = None, block_size: int = 32, ) -> None: """Plots a heatmap with a discrete colour bar. Designed for Radiant Earth MLHub 256x256 SENTINEL images. @@ -329,16 +329,16 @@ def labelled_rgb_image( mask: NDArray[Shape["*, *"], Int], # noqa: F722 bounds: BoundingBox, src_crs: CRS, - path: Union[str, Path], + path: str | Path, name: str, - classes: Union[List[str], Tuple[str, ...]], - cmap_style: Optional[Union[str, ListedColormap]] = None, + classes: list[str] | tuple[str, ...], + cmap_style: Optional[str | ListedColormap] = None, new_crs: Optional[CRS] = WGS84, block_size: int = 32, alpha: float = 0.5, show: bool = True, save: bool = True, - figdim: Tuple[Union[int, float], Union[int, float]] = (8.02, 10.32), + figdim: tuple[int | float, int | float] = (8.02, 10.32), ) -> Path: """Produces a layered image of an RGB image, and it's associated label mask heat map alpha blended on top. @@ -363,7 +363,7 @@ def labelled_rgb_image( str: Path to figure save location. """ # Checks that the mask and image shapes will align. - mask_shape: Tuple[int, int] = mask.shape # type: ignore[assignment] + mask_shape: tuple[int, int] = mask.shape # type: ignore[assignment] assert mask_shape == image.shape[:2] assert new_crs is not None @@ -475,14 +475,14 @@ def make_gif( masks: NDArray[Shape["*, *, *"], Any], # noqa: F722 bounds: BoundingBox, src_crs: CRS, - classes: Union[List[str], Tuple[str, ...]], + classes: list[str] | tuple[str, ...], gif_name: str, - path: Union[str, Path], - cmap_style: Optional[Union[str, ListedColormap]] = None, + path: str | Path, + cmap_style: Optional[str | ListedColormap] = None, fps: float = 1.0, new_crs: Optional[CRS] = WGS84, alpha: float = 0.5, - figdim: Tuple[Union[int, float], Union[int, float]] = (8.02, 10.32), + figdim: tuple[int | float, int | float] = (8.02, 10.32), ) -> None: """Wrapper to :func:`labelled_rgb_image` to make a GIF for a patch out of scenes. @@ -548,19 +548,19 @@ def make_gif( def prediction_plot( - sample: Dict[str, Any], + sample: dict[str, Any], sample_id: str, - classes: Dict[int, str], + classes: dict[int, str], src_crs: CRS, new_crs: CRS = WGS84, path: str = "", - cmap_style: Optional[Union[str, ListedColormap]] = None, + cmap_style: Optional[str | ListedColormap] = None, exp_id: Optional[str] = None, - fig_dim: Optional[Tuple[Union[int, float], Union[int, float]]] = None, + fig_dim: Optional[tuple[int | float, int | float]] = None, block_size: int = 32, show: bool = True, save: bool = True, - fn_prefix: Optional[Union[str, Path]] = None, + fn_prefix: Optional[str | Path] = None, ) -> None: """ Produces a figure containing subplots of the predicted label mask, the ground truth label mask @@ -704,21 +704,21 @@ def prediction_plot( def seg_plot( - z: Union[List[int], NDArray[Any, Any]], - y: Union[List[int], NDArray[Any, Any]], - ids: List[str], - index: Union[Sequence[Any], NDArray[Any, Any]], - data_dir: Union[Path, str], - dataset_params: Dict[str, Any], - classes: Dict[int, str], - colours: Dict[int, str], - fn_prefix: Optional[Union[str, Path]], + z: list[int] | NDArray[Any, Any], + y: list[int] | NDArray[Any, Any], + ids: list[str], + index: Sequence[Any] | NDArray[Any, Any], + data_dir: Path | str, + dataset_params: dict[str, Any], + classes: dict[int, str], + colours: dict[int, str], + fn_prefix: Optional[str | Path], frac: float = 0.05, - fig_dim: Optional[Tuple[Union[int, float], Union[int, float]]] = (9.3, 10.5), + fig_dim: Optional[tuple[int | float, int | float]] = (9.3, 10.5), model_name: str = "", path: str = "", max_pixel_value: int = 255, - cache_dir: Optional[Union[str, Path]] = None, + cache_dir: Optional[str | Path] = None, ) -> None: """Custom function for pre-processing the outputs from image segmentation testing for data visualisation. @@ -840,10 +840,10 @@ def seg_plot( def plot_subpopulations( - class_dist: List[Tuple[int, int]], - class_names: Optional[Dict[int, str]] = None, - cmap_dict: Optional[Dict[int, str]] = None, - filename: Optional[Union[str, Path]] = None, + class_dist: list[tuple[int, int]], + class_names: Optional[dict[int, str]] = None, + cmap_dict: Optional[dict[int, str]] = None, + filename: Optional[str | Path] = None, save: bool = True, show: bool = False, ) -> None: @@ -867,7 +867,7 @@ def plot_subpopulations( counts = [] # List to hold colours of classes in the correct order. - colours: Optional[List[str]] = [] + colours: Optional[list[str]] = [] if class_names is None: class_numbers = [x[0] for x in class_dist] @@ -920,8 +920,8 @@ def plot_subpopulations( def plot_history( - metrics: Dict[str, Any], - filename: Optional[Union[str, Path]] = None, + metrics: dict[str, Any], + filename: Optional[str | Path] = None, save: bool = True, show: bool = False, ) -> None: @@ -972,12 +972,12 @@ def plot_history( def make_confusion_matrix( - pred: Union[List[int], NDArray[Any, Int]], - labels: Union[List[int], NDArray[Any, Int]], - classes: Dict[int, str], - filename: Optional[Union[str, Path]] = None, + pred: list[int] | NDArray[Any, Int], + labels: list[int] | NDArray[Any, Int], + classes: dict[int, str], + filename: Optional[str | Path] = None, cmap_style: str = "Blues", - figsize: Tuple[int, int] = (2, 2), + figsize: tuple[int, int] = (2, 2), show: bool = True, save: bool = False, ) -> None: @@ -1029,12 +1029,12 @@ def make_confusion_matrix( def make_multilabel_confusion_matrix( - preds: Union[List[int], NDArray[Any, Int]], - labels: Union[List[int], NDArray[Any, Int]], - classes: Dict[int, str], - filename: Optional[Union[str, Path]] = None, + preds: list[int] | NDArray[Any, Int], + labels: list[int] | NDArray[Any, Int], + classes: dict[int, str], + filename: Optional[str | Path] = None, cmap_style: str = "Blues", - figsize: Tuple[int, int] = (2, 2), + figsize: tuple[int, int] = (2, 2), show: bool = True, save: bool = False, ) -> None: @@ -1105,12 +1105,12 @@ def make_multilabel_confusion_matrix( def make_roc_curves( probs: ArrayLike, - labels: Union[Sequence[int], NDArray[Any, Int]], - class_names: Dict[int, str], - colours: Dict[int, str], + labels: Sequence[int] | NDArray[Any, Int], + class_names: dict[int, str], + colours: dict[int, str], micro: bool = True, macro: bool = True, - filename: Optional[Union[str, Path]] = None, + filename: Optional[str | Path] = None, show: bool = False, save: bool = True, ) -> None: @@ -1206,15 +1206,15 @@ def make_roc_curves( def plot_embedding( embeddings: Any, - index: Union[Sequence[BoundingBox], Sequence[int]], - data_dir: Union[Path, str], - dataset_params: Dict[str, Any], + index: Sequence[BoundingBox] | Sequence[int], + data_dir: Path | str, + dataset_params: dict[str, Any], title: Optional[str] = None, show: bool = False, save: bool = True, - filename: Optional[Union[Path, str]] = None, + filename: Optional[Path | str] = None, max_pixel_value: int = 255, - cache_dir: Optional[Union[Path, str]] = None, + cache_dir: Optional[Path | str] = None, ) -> None: """Using TSNE Clustering, visualises the embeddings from a model. @@ -1316,8 +1316,8 @@ def plot_embedding( def format_plot_names( - model_name: str, timestamp: str, path: Union[Sequence[str], str, Path] -) -> Dict[str, Path]: + model_name: str, timestamp: str, path: Sequence[str] | str | Path +) -> dict[str, Path]: """Creates unique filenames of plots in a standardised format. Args: @@ -1357,23 +1357,23 @@ def standard_format(plot_type: str, *sub_dir) -> str: def plot_results( - plots: Dict[str, bool], - z: Optional[Union[List[int], NDArray[Any, Int]]] = None, - y: Optional[Union[List[int], NDArray[Any, Int]]] = None, - metrics: Optional[Dict[str, Any]] = None, - ids: Optional[List[str]] = None, + plots: dict[str, bool], + z: Optional[list[int] | NDArray[Any, Int]] = None, + y: Optional[list[int] | NDArray[Any, Int]] = None, + metrics: Optional[dict[str, Any]] = None, + ids: Optional[list[str]] = None, index: Optional[NDArray[Any, Any]] = None, - probs: Optional[Union[List[float], NDArray[Any, Float]]] = None, + probs: Optional[list[float] | NDArray[Any, Float]] = None, embeddings: Optional[NDArray[Any, Any]] = None, - class_names: Optional[Dict[int, str]] = None, - colours: Optional[Dict[int, str]] = None, + class_names: Optional[dict[int, str]] = None, + colours: Optional[dict[int, str]] = None, save: bool = True, show: bool = False, model_name: Optional[str] = None, timestamp: Optional[str] = None, - results_dir: Optional[Union[Sequence[str], str, Path]] = None, - task_cfg: Optional[Dict[str, Any]] = None, - global_cfg: Optional[Dict[str, Any]] = None, + results_dir: Optional[Sequence[str] | str | Path] = None, + task_cfg: Optional[dict[str, Any]] = None, + global_cfg: Optional[dict[str, Any]] = None, ) -> None: """Orchestrates the creation of various plots from the results of a model fitting. diff --git a/pyproject.toml b/pyproject.toml index 7c84e03cb..b3fb51d41 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,7 +16,6 @@ classifiers = [ "Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.10", - "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3 :: Only", "License :: OSI Approved :: MIT License", "Development Status :: 4 - Beta", @@ -38,6 +37,7 @@ dependencies = [ "numba>=0.57.0; python_version>='3.11'", "numpy", "overload", + "opencv-python", "pandas", "psutil", "pyyaml", diff --git a/requirements/requirements.txt b/requirements/requirements.txt index d4a7d2bbb..e1de41ede 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -1,50 +1,50 @@ -argcomplete==3.4.0 +argcomplete==3.5.0 fastapi>=0.101.0 fiona>=1.9.1 geopy==2.4.1 GitPython>=3.1.35 hydra-core==1.3.2 -imagecodecs==2024.1.1 -imageio==2.34.2 +imagecodecs==2024.6.1 +imageio==2.35.1 inputimeout==1.0.4 -ipykernel==6.29.4 -kornia==0.7.2 -lightly==1.5.7 -lmdb==1.4.1 -matplotlib==3.9.0 -mlflow-skinny==2.14.0 +ipykernel==6.29.5 +kornia==0.7.3 +lightly==1.5.12 +lmdb==1.5.1 +matplotlib==3.9.2 +mlflow-skinny==2.16.0 nptyping==2.5.0 numba>=0.57.0 # not directly required but pinned to ensure Python 3.11 compatibility. numpy==1.26.4 -onnx==1.16.1 -onnx2torch==1.5.14 -opencv-python==4.9.0.80 +onnx==1.16.2 +onnx2torch==1.5.15 +opencv-python-headless==4.10.0.84 overload==1.1 pandas==2.2.2 -patch-ng==1.17.4 +patch-ng==1.18.0 protobuf>=3.19.5 # not directly required, pinned by Snyk to avoid a vulnerability. psutil==6.0.0 -pyyaml==6.0.1 +pyyaml==6.0.2 rasterio>=1.3.6 requests==2.32.3 -scikit-learn==1.5.0 -segmentation-models-pytorch==0.3.3 -setuptools==70.1.0 +scikit-learn==1.5.1 +segmentation-models-pytorch==0.3.4 +setuptools==74.0.0 starlette==0.37.2 # not directly required, pinned by Dependabot to avoid a vulnerability. tabulate==0.9.0 -tensorflow==2.16.2 +tensorflow==2.17.0 tifffile==2024.1.30 -timm==0.9.2 -torch==2.3.1 +timm==0.9.7 +torch==2.4.0 torcheval==0.0.7 -torchgeo==0.5.2 +torchgeo==0.6.0 torchinfo==1.8.0 -torchvision==0.18.1 +torchvision==0.19.0 tornado>=6.3.3 -tqdm==4.66.4 -types-PyYAML==6.0.12.20240311 -types-requests==2.32.0.20240602 +tqdm==4.66.5 +types-PyYAML==6.0.12.20240808 +types-requests==2.32.0.20240712 types-tabulate==0.9.0.20240106 -wandb==0.17.4 +wandb==0.17.8 Werkzeug>=2.2.3 # Patches a potential security vulnerability. wheel>=0.38.0 # not directly required, pinned by Snyk to avoid a vulnerability. diff --git a/requirements/requirements_dev.txt b/requirements/requirements_dev.txt index 65a266767..e9e00a78f 100644 --- a/requirements/requirements_dev.txt +++ b/requirements/requirements_dev.txt @@ -1,63 +1,63 @@ -argcomplete==3.4.0 +argcomplete==3.5.0 certifi>=2022.12.7 # not directly required, pinned by Snyk to avoid a vulnerability. fastapi>=0.101.0 fiona>=1.9.1 -flake8==7.1.0 +flake8==7.1.1 geopy==2.4.1 GitPython>=3.1.35 hydra-core==1.3.2 -imagecodecs==2024.1.1 -imageio==2.34.2 +imagecodecs==2024.6.1 +imageio==2.35.1 inputimeout==1.0.4 internet-sabotage3==0.1.6 -ipykernel==6.29.4 -kornia==0.7.2 -lightly==1.5.7 -lmdb==1.4.1 -matplotlib==3.9.0 -mlflow-skinny==2.14.0 -mypy==1.10.1 -myst_parser==3.0.1 +ipykernel==6.29.5 +kornia==0.7.3 +lightly==1.5.12 +lmdb==1.5.1 +matplotlib==3.9.2 +mlflow-skinny==2.16.0 +mypy==1.11.2 +myst_parser==4.0.0 nptyping==2.5.0 numba>=0.57.0 # not directly required but pinned to ensure Python 3.11 compatibility. numpy==1.26.4 -onnx==1.16.1 -onnx2torch==1.5.14 -opencv-python==4.9.0.80 +onnx==1.16.2 +onnx2torch==1.5.15 +opencv-python-headless==4.10.0.84 overload==1.1 pandas==2.2.2 -patch-ng==1.17.4 -pre-commit==3.7.1 +patch-ng==1.18.0 +pre-commit==3.8.0 protobuf>=3.19.5 # not directly required, pinned by Snyk to avoid a vulnerability. psutil==6.0.0 pygments>=2.7.4 # not directly required, pinned by Snyk to avoid a vulnerability. pytest==7.4.4 pytest-cov==5.0.0 pytest-lazy-fixture==0.6.3 -pyyaml==6.0.1 +pyyaml==6.0.2 rasterio>=1.3.6 requests==2.32.3 -scikit-learn==1.5.0 -segmentation-models-pytorch==0.3.3 -setuptools==70.1.0 -sphinx==7.3.7 +scikit-learn==1.5.1 +segmentation-models-pytorch==0.3.4 +setuptools==72.1.0 +sphinx==7.4.4 sphinx-rtd-theme==2.0.0 starlette>=0.25.0 # not directly required, pinned by Dependabot to avoid a vulnerability. tabulate==0.9.0 -tensorflow==2.16.2 +tensorflow==2.17.0 tifffile==2024.1.30 -timm==0.9.2 -torch==2.3.1 +timm==0.9.7 +torch==2.4.0 torcheval==0.0.7 -torchgeo==0.5.2 +torchgeo==0.6.0 torchinfo==1.8.0 -torchvision==0.18.1 +torchvision==0.19.0 tornado>=6.3.3 -tox==4.15.1 -tqdm==4.66.4 -types-PyYAML==6.0.12.20240311 -types-requests==2.32.0.20240602 +tox==4.18.0 +tqdm==4.66.5 +types-PyYAML==6.0.12.20240808 +types-requests==2.32.0.20240712 types-tabulate==0.9.0.20240106 -wandb==0.17.4 +wandb==0.17.8 Werkzeug>=2.2.3 # Patches a potential security vulnerability. wheel>=0.38.0 # not directly required, pinned by Snyk to avoid a vulnerability. diff --git a/scripts/MinervaExp.py b/scripts/MinervaExp.py index fc40d93f5..f8ff132a4 100644 --- a/scripts/MinervaExp.py +++ b/scripts/MinervaExp.py @@ -39,7 +39,7 @@ # ===================================================================================================================== # IMPORTS # ===================================================================================================================== -from typing import Optional, Union +from typing import Optional import hydra from omegaconf import DictConfig @@ -58,9 +58,7 @@ config_name=DEFAULT_CONFIG_NAME, ) @runner.distributed_run -def main( - gpu: int, wandb_run: Optional[Union[Run, RunDisabled]], cfg: DictConfig -) -> None: +def main(gpu: int, wandb_run: Optional[Run | RunDisabled], cfg: DictConfig) -> None: # Due to the nature of multiprocessing and its interaction with hydra, wandb and SLURM, # the actual code excuted in the job is contained in `run_trainer` in `runner`. diff --git a/scripts/MinervaPipe.py b/scripts/MinervaPipe.py index 5a4ced905..6a95f6b91 100644 --- a/scripts/MinervaPipe.py +++ b/scripts/MinervaPipe.py @@ -36,7 +36,7 @@ import shlex import subprocess import sys -from typing import Any, Dict +from typing import Any import yaml @@ -46,7 +46,7 @@ # ===================================================================================================================== def main(config_path: str): with open(config_path) as f: - config: Dict[str, Any] = yaml.safe_load(f) + config: dict[str, Any] = yaml.safe_load(f) for key in config.keys(): print( diff --git a/scripts/RunTensorBoard.py b/scripts/RunTensorBoard.py index 589b1a7d9..3559429ad 100644 --- a/scripts/RunTensorBoard.py +++ b/scripts/RunTensorBoard.py @@ -33,7 +33,7 @@ # IMPORTS # ===================================================================================================================== import argparse -from typing import List, Optional, Union +from typing import Optional from minerva.utils import utils @@ -42,7 +42,7 @@ # MAIN # ===================================================================================================================== def main( - path: Optional[Union[str, List[str]]] = None, + path: Optional[str | list[str]] = None, env_name: str = "env2", exp_name: Optional[str] = None, host_num: int = 6006, diff --git a/scripts/TorchWeightDownloader.py b/scripts/TorchWeightDownloader.py index 2d076b7b8..b0f6f2f5c 100644 --- a/scripts/TorchWeightDownloader.py +++ b/scripts/TorchWeightDownloader.py @@ -21,7 +21,7 @@ """Loads :mod:`torch` weights from Torch Hub into cache. Attributes: - resnets (List[str]): List of tags for ``pytorch`` resnet weights to download. + resnets (list[str]): List of tags for ``pytorch`` resnet weights to download. """ # ===================================================================================================================== diff --git a/setup.cfg b/setup.cfg index a48cea02c..ea926089e 100644 --- a/setup.cfg +++ b/setup.cfg @@ -13,7 +13,6 @@ classifiers = Programming Language :: Python :: 3.12 Programming Language :: Python :: 3.11 Programming Language :: Python :: 3.10 - Programming Language :: Python :: 3.9 Programming Language :: Python :: 3 :: Only License :: OSI Approved :: MIT License Development Status :: 4 - Beta @@ -34,6 +33,7 @@ install_requires = psutil geopy overload + opencv-python nptyping lightly argcomplete diff --git a/tests/conftest.py b/tests/conftest.py index 393afd48a..cbfa31a45 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -41,7 +41,7 @@ import os import shutil from pathlib import Path -from typing import Any, Dict, Generator, Tuple +from typing import Any, Generator import hydra import numpy as np @@ -199,7 +199,7 @@ def std_batch_size() -> int: @pytest.fixture -def std_n_classes(exp_classes: Dict[int, str]) -> int: +def std_n_classes(exp_classes: dict[int, str]) -> int: return len(exp_classes) @@ -214,12 +214,12 @@ def x_entropy_loss() -> nn.CrossEntropyLoss: @pytest.fixture -def small_patch_size() -> Tuple[int, int]: +def small_patch_size() -> tuple[int, int]: return (32, 32) @pytest.fixture -def rgbi_input_size() -> Tuple[int, int, int]: +def rgbi_input_size() -> tuple[int, int, int]: return (4, 32, 32) @@ -230,7 +230,7 @@ def exp_mlp(x_entropy_loss: nn.CrossEntropyLoss) -> MLP: @pytest.fixture def exp_cnn( - x_entropy_loss: nn.CrossEntropyLoss, rgbi_input_size: Tuple[int, int, int] + x_entropy_loss: nn.CrossEntropyLoss, rgbi_input_size: tuple[int, int, int] ) -> CNN: return CNN(x_entropy_loss, rgbi_input_size) @@ -238,7 +238,7 @@ def exp_cnn( @pytest.fixture def exp_fcn( x_entropy_loss: nn.CrossEntropyLoss, - rgbi_input_size: Tuple[int, int, int], + rgbi_input_size: tuple[int, int, int], std_n_classes: int, ) -> FCN32ResNet18: return FCN32ResNet18(x_entropy_loss, rgbi_input_size, std_n_classes) @@ -246,14 +246,14 @@ def exp_fcn( @pytest.fixture def exp_simconv( - rgbi_input_size: Tuple[int, int, int], + rgbi_input_size: tuple[int, int, int], ) -> SimConv: return SimConv(SegBarlowTwinsLoss(), rgbi_input_size, feature_dim=128) @pytest.fixture def random_mask( - small_patch_size: Tuple[int, int], std_n_classes: int + small_patch_size: tuple[int, int], std_n_classes: int ) -> NDArray[Shape["32, 32"], Int]: mask = np.random.randint(0, std_n_classes - 1, size=small_patch_size) assert isinstance(mask, np.ndarray) @@ -262,20 +262,20 @@ def random_mask( @pytest.fixture def random_image( - small_patch_size: Tuple[int, int] + small_patch_size: tuple[int, int] ) -> NDArray[Shape["32, 32, 3"], Float]: return np.random.rand(*small_patch_size, 3) @pytest.fixture def random_rgbi_image( - small_patch_size: Tuple[int, int] + small_patch_size: tuple[int, int] ) -> NDArray[Shape["32, 32, 4"], Float]: return np.random.rand(*small_patch_size, 4) @pytest.fixture -def random_rgbi_tensor(rgbi_input_size: Tuple[int, int, int]) -> Tensor: +def random_rgbi_tensor(rgbi_input_size: tuple[int, int, int]) -> Tensor: return torch.rand(rgbi_input_size) @@ -305,27 +305,27 @@ def flipped_rgb_img() -> Tensor: @pytest.fixture -def simple_sample(simple_rgb_img: Tensor, simple_mask: LongTensor) -> Dict[str, Tensor]: +def simple_sample(simple_rgb_img: Tensor, simple_mask: LongTensor) -> dict[str, Tensor]: return {"image": simple_rgb_img, "mask": simple_mask} @pytest.fixture def flipped_simple_sample( flipped_rgb_img: Tensor, flipped_simple_mask: LongTensor -) -> Dict[str, Tensor]: +) -> dict[str, Tensor]: return {"image": flipped_rgb_img, "mask": flipped_simple_mask} @pytest.fixture def random_rgbi_batch( - rgbi_input_size: Tuple[int, int, int], std_batch_size: int + rgbi_input_size: tuple[int, int, int], std_batch_size: int ) -> Tensor: return torch.rand((std_batch_size, *rgbi_input_size)) @pytest.fixture def random_tensor_mask( - std_n_classes: int, small_patch_size: Tuple[int, int] + std_n_classes: int, small_patch_size: tuple[int, int] ) -> LongTensor: mask = torch.randint(0, std_n_classes - 1, size=small_patch_size, dtype=torch.long) assert isinstance(mask, LongTensor) @@ -334,7 +334,7 @@ def random_tensor_mask( @pytest.fixture def random_mask_batch( - std_batch_size: int, std_n_classes: int, rgbi_input_size: Tuple[int, int, int] + std_batch_size: int, std_n_classes: int, rgbi_input_size: tuple[int, int, int] ) -> LongTensor: mask = torch.randint( 0, @@ -368,7 +368,7 @@ def bounds_for_test_img() -> BoundingBox: @pytest.fixture -def exp_classes() -> Dict[int, str]: +def exp_classes() -> dict[int, str]: return { 0: "No Data", 1: "Water", @@ -382,7 +382,7 @@ def exp_classes() -> Dict[int, str]: @pytest.fixture -def exp_cmap_dict() -> Dict[int, str]: +def exp_cmap_dict() -> dict[int, str]: return { 0: "#000000", # Transparent 1: "#00c5ff", # Light Blue @@ -411,7 +411,7 @@ def flipped_simple_mask() -> LongTensor: @pytest.fixture -def example_matrix() -> Dict[int, int]: +def example_matrix() -> dict[int, int]: return {1: 1, 3: 3, 4: 2, 5: 0} @@ -421,35 +421,31 @@ def simple_bbox() -> BoundingBox: @pytest.fixture -def exp_dataset_params() -> Dict[str, Any]: +def exp_dataset_params() -> dict[str, Any]: return { "image": { "transforms": {"AutoNorm": {"length": 12}}, - "module": "minerva.datasets.__testing", - "name": "TstImgDataset", + "_target_": "minerva.datasets.__testing.TstImgDataset", "paths": "NAIP", - "params": {"res": 1.0, "crs": 26918}, + "res": 1.0, + "crs": 26918, }, "mask": { "transforms": False, - "module": "minerva.datasets.__testing", - "name": "TstMaskDataset", + "_target_": "minerva.datasets.__testing.TstMaskDataset", "paths": "Chesapeake7", - "params": {"res": 1.0}, + "res": 1.0, }, } @pytest.fixture -def exp_sampler_params(small_patch_size: Tuple[int, int]): +def exp_sampler_params(small_patch_size: tuple[int, int]): return { - "module": "torchgeo.samplers", - "name": "RandomGeoSampler", + "_target_": "torchgeo.samplers.RandomGeoSampler", "roi": False, - "params": { - "size": small_patch_size, - "length": 120, - }, + "size": small_patch_size, + "length": 120, } @@ -478,7 +474,7 @@ def default_dataset( @pytest.fixture def default_image_dataset( - default_config: DictConfig, exp_dataset_params: Dict[str, Any] + default_config: DictConfig, exp_dataset_params: dict[str, Any] ) -> RasterDataset: del exp_dataset_params["mask"] dataset, _ = make_dataset( diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200424T094029_20200424T094123_T33TYN/B1.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200424T094029_20200424T094123_T33TYN/B1.tif new file mode 100644 index 000000000..8fdc0fda9 Binary files /dev/null and b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200424T094029_20200424T094123_T33TYN/B1.tif differ diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200424T094029_20200424T094123_T33TYN/B11.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200424T094029_20200424T094123_T33TYN/B11.tif new file mode 100644 index 000000000..8257c4f6f Binary files /dev/null and b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200424T094029_20200424T094123_T33TYN/B11.tif differ diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200424T094029_20200424T094123_T33TYN/B12.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200424T094029_20200424T094123_T33TYN/B12.tif new file mode 100644 index 000000000..d00a4de7e Binary files /dev/null and b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200424T094029_20200424T094123_T33TYN/B12.tif differ diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200424T094029_20200424T094123_T33TYN/B2.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200424T094029_20200424T094123_T33TYN/B2.tif new file mode 100644 index 000000000..950ca2e7b Binary files /dev/null and b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200424T094029_20200424T094123_T33TYN/B2.tif differ diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200424T094029_20200424T094123_T33TYN/B3.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200424T094029_20200424T094123_T33TYN/B3.tif new file mode 100644 index 000000000..cd6acd4d0 Binary files /dev/null and b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200424T094029_20200424T094123_T33TYN/B3.tif differ diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200424T094029_20200424T094123_T33TYN/B4.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200424T094029_20200424T094123_T33TYN/B4.tif new file mode 100644 index 000000000..a0a7530be Binary files /dev/null and b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200424T094029_20200424T094123_T33TYN/B4.tif differ diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200424T094029_20200424T094123_T33TYN/B5.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200424T094029_20200424T094123_T33TYN/B5.tif new file mode 100644 index 000000000..00a1207ff Binary files /dev/null and b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200424T094029_20200424T094123_T33TYN/B5.tif differ diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200424T094029_20200424T094123_T33TYN/B6.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200424T094029_20200424T094123_T33TYN/B6.tif new file mode 100644 index 000000000..42941b2d5 Binary files /dev/null and b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200424T094029_20200424T094123_T33TYN/B6.tif differ diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200424T094029_20200424T094123_T33TYN/B7.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200424T094029_20200424T094123_T33TYN/B7.tif new file mode 100644 index 000000000..fcd2a5317 Binary files /dev/null and b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200424T094029_20200424T094123_T33TYN/B7.tif differ diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200424T094029_20200424T094123_T33TYN/B8.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200424T094029_20200424T094123_T33TYN/B8.tif new file mode 100644 index 000000000..25a6621e3 Binary files /dev/null and b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200424T094029_20200424T094123_T33TYN/B8.tif differ diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200424T094029_20200424T094123_T33TYN/B8A.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200424T094029_20200424T094123_T33TYN/B8A.tif new file mode 100644 index 000000000..bfdc215fb Binary files /dev/null and b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200424T094029_20200424T094123_T33TYN/B8A.tif differ diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200424T094029_20200424T094123_T33TYN/B9.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200424T094029_20200424T094123_T33TYN/B9.tif new file mode 100644 index 000000000..5335622a3 Binary files /dev/null and b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200424T094029_20200424T094123_T33TYN/B9.tif differ diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200424T094029_20200424T094123_T33TYN/metadata.json b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200424T094029_20200424T094123_T33TYN/metadata.json new file mode 100644 index 000000000..081ca2731 --- /dev/null +++ b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200424T094029_20200424T094123_T33TYN/metadata.json @@ -0,0 +1 @@ +{"AOT_RETRIEVAL_ACCURACY": 0, "CLOUDY_PIXEL_PERCENTAGE": 2.812705, "CLOUD_COVERAGE_ASSESSMENT": 2.812705, "CLOUD_SHADOW_PERCENTAGE": 0.015249, "DARK_FEATURES_PERCENTAGE": 0.468881, "DATASTRIP_ID": "S2B_OPER_MSI_L2A_DS_SGS__20200424T123330_S20200424T094123_N02.14", "DATATAKE_IDENTIFIER": "GS2B_20200424T094029_016364_N02.14", "DATATAKE_TYPE": "INS-NOBS", "DEGRADED_MSI_DATA_PERCENTAGE": 0, "FORMAT_CORRECTNESS": "PASSED", "GENERAL_QUALITY": "PASSED", "GENERATION_TIME": 1587731610000, "GEOMETRIC_QUALITY": "PASSED", "GRANULE_ID": "L2A_T33TYN_A016364_20200424T094123", "HIGH_PROBA_CLOUDS_PERCENTAGE": 0.015214, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B1": 106.543897547236, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B10": 105.92977627649, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B11": 106.2294370992, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B12": 106.526973972026, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B2": 105.427079405054, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B3": 105.700015654406, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B4": 105.929384543531, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B5": 106.065332002664, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B6": 106.186455739266, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B7": 106.317289469222, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B8": 105.564528810271, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B8A": 106.398164642186, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B9": 106.660211986834, "MEAN_INCIDENCE_ZENITH_ANGLE_B1": 8.67687308176335, "MEAN_INCIDENCE_ZENITH_ANGLE_B10": 8.54845080104398, "MEAN_INCIDENCE_ZENITH_ANGLE_B11": 8.60029719555954, "MEAN_INCIDENCE_ZENITH_ANGLE_B12": 8.66269723035736, "MEAN_INCIDENCE_ZENITH_ANGLE_B2": 8.50990628623815, "MEAN_INCIDENCE_ZENITH_ANGLE_B3": 8.53383114161826, "MEAN_INCIDENCE_ZENITH_ANGLE_B4": 8.56085011020257, "MEAN_INCIDENCE_ZENITH_ANGLE_B5": 8.5810054217592, "MEAN_INCIDENCE_ZENITH_ANGLE_B6": 8.60351046174368, "MEAN_INCIDENCE_ZENITH_ANGLE_B7": 8.62846298114179, "MEAN_INCIDENCE_ZENITH_ANGLE_B8": 8.52063436830031, "MEAN_INCIDENCE_ZENITH_ANGLE_B8A": 8.64956945753696, "MEAN_INCIDENCE_ZENITH_ANGLE_B9": 8.70970015807593, "MEAN_SOLAR_AZIMUTH_ANGLE": 156.022265244126, "MEAN_SOLAR_ZENITH_ANGLE": 36.2587132733646, "MEDIUM_PROBA_CLOUDS_PERCENTAGE": 0.024197, "MGRS_TILE": "33TYN", "NODATA_PIXEL_PERCENTAGE": 24.501736, "NOT_VEGETATED_PERCENTAGE": 33.184448, "PROCESSING_BASELINE": "02.14", "PRODUCT_ID": "S2B_MSIL2A_20200424T094029_N0214_R036_T33TYN_20200424T123330", "RADIATIVE_TRANSFER_ACCURACY": 0, "RADIOMETRIC_QUALITY": "PASSED", "REFLECTANCE_CONVERSION_CORRECTION": 0.990710191784445, "SATURATED_DEFECTIVE_PIXEL_PERCENTAGE": 0, "SENSING_ORBIT_DIRECTION": "DESCENDING", "SENSING_ORBIT_NUMBER": 36, "SENSOR_QUALITY": "PASSED", "SNOW_ICE_PERCENTAGE": 4e-06, "SOLAR_IRRADIANCE_B1": 1874.3, "SOLAR_IRRADIANCE_B10": 365.41, "SOLAR_IRRADIANCE_B11": 247.08, "SOLAR_IRRADIANCE_B12": 87.75, "SOLAR_IRRADIANCE_B2": 1959.75, "SOLAR_IRRADIANCE_B3": 1824.93, "SOLAR_IRRADIANCE_B4": 1512.79, "SOLAR_IRRADIANCE_B5": 1425.78, "SOLAR_IRRADIANCE_B6": 1291.13, "SOLAR_IRRADIANCE_B7": 1175.57, "SOLAR_IRRADIANCE_B8": 1041.28, "SOLAR_IRRADIANCE_B8A": 953.93, "SOLAR_IRRADIANCE_B9": 817.58, "SPACECRAFT_NAME": "Sentinel-2B", "THIN_CIRRUS_PERCENTAGE": 2.773294, "UNCLASSIFIED_PERCENTAGE": 0.328074, "VEGETATION_PERCENTAGE": 58.864659, "WATER_PERCENTAGE": 4.325977, "WATER_VAPOUR_RETRIEVAL_ACCURACY": 0, "system:asset_size": 1340927025, "system:band_names": ["B1", "B2", "B3", "B4", "B5", "B6", "B7", "B8", "B8A", "B9", "B11", "B12", "AOT", "WVP", "SCL", "TCI_R", "TCI_G", "TCI_B", "MSK_CLDPRB", "MSK_SNWPRB", "QA10", "QA20", "QA60"], "system:bands": {"B11": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [5490, 5490], "crs": "EPSG:32633", "crs_transform": [20, 0, 699960, 0, -20, 5300040]}, "TCI_B": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 255}, "dimensions": [10980, 10980], "crs": "EPSG:32633", "crs_transform": [10, 0, 699960, 0, -10, 5300040]}, "B12": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [5490, 5490], "crs": "EPSG:32633", "crs_transform": [20, 0, 699960, 0, -20, 5300040]}, "AOT": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [10980, 10980], "crs": "EPSG:32633", "crs_transform": [10, 0, 699960, 0, -10, 5300040]}, "QA10": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [10980, 10980], "crs": "EPSG:32633", "crs_transform": [10, 0, 699960, 0, -10, 5300040]}, "B8A": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [5490, 5490], "crs": "EPSG:32633", "crs_transform": [20, 0, 699960, 0, -20, 5300040]}, "QA20": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 4294967295}, "dimensions": [5490, 5490], "crs": "EPSG:32633", "crs_transform": [20, 0, 699960, 0, -20, 5300040]}, "MSK_CLDPRB": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 255}, "dimensions": [5490, 5490], "crs": "EPSG:32633", "crs_transform": [20, 0, 699960, 0, -20, 5300040]}, "B1": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [1830, 1830], "crs": "EPSG:32633", "crs_transform": [60, 0, 699960, 0, -60, 5300040]}, "QA60": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [1830, 1830], "crs": "EPSG:32633", "crs_transform": [60, 0, 699960, 0, -60, 5300040]}, "B2": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [10980, 10980], "crs": "EPSG:32633", "crs_transform": [10, 0, 699960, 0, -10, 5300040]}, "WVP": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [10980, 10980], "crs": "EPSG:32633", "crs_transform": [10, 0, 699960, 0, -10, 5300040]}, "B3": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [10980, 10980], "crs": "EPSG:32633", "crs_transform": [10, 0, 699960, 0, -10, 5300040]}, "B4": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [10980, 10980], "crs": "EPSG:32633", "crs_transform": [10, 0, 699960, 0, -10, 5300040]}, "TCI_R": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 255}, "dimensions": [10980, 10980], "crs": "EPSG:32633", "crs_transform": [10, 0, 699960, 0, -10, 5300040]}, "B5": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [5490, 5490], "crs": "EPSG:32633", "crs_transform": [20, 0, 699960, 0, -20, 5300040]}, "B6": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [5490, 5490], "crs": "EPSG:32633", "crs_transform": [20, 0, 699960, 0, -20, 5300040]}, "B7": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [5490, 5490], "crs": "EPSG:32633", "crs_transform": [20, 0, 699960, 0, -20, 5300040]}, "B8": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [10980, 10980], "crs": "EPSG:32633", "crs_transform": [10, 0, 699960, 0, -10, 5300040]}, "B9": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [1830, 1830], "crs": "EPSG:32633", "crs_transform": [60, 0, 699960, 0, -60, 5300040]}, "MSK_SNWPRB": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 255}, "dimensions": [5490, 5490], "crs": "EPSG:32633", "crs_transform": [20, 0, 699960, 0, -20, 5300040]}, "SCL": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 255}, "dimensions": [5490, 5490], "crs": "EPSG:32633", "crs_transform": [20, 0, 699960, 0, -20, 5300040]}, "TCI_G": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 255}, "dimensions": [10980, 10980], "crs": "EPSG:32633", "crs_transform": [10, 0, 699960, 0, -10, 5300040]}}, "system:footprint": {"type": "LinearRing", "coordinates": [[19.135125920080306, 47.77916186962477], [19.135029727720475, 47.779255442991314], [18.17974674339065, 47.80973453124768], [18.175369534026473, 47.80439727501413], [18.158478123290493, 47.76381793259735], [18.135659921426363, 47.70397065763017], [18.02588276983751, 47.409953370306475], [17.84226898482384, 46.89305421884395], [17.83483726398006, 46.87163662286434], [17.82687335431474, 46.839969201428644], [17.825690373350913, 46.83243260435472], [17.82561662238212, 46.831025740554466], [17.82597229677631, 46.830892495549335], [19.058775213444516, 46.79376963860806], [19.058911741293866, 46.793834509947615], [19.077589698914174, 47.04020634864976], [19.096514577054162, 47.28654137136117], [19.115691869935954, 47.53285992686103], [19.135125920080306, 47.77916186962477]]}, "system:id": "COPERNICUS/S2_SR/20200424T094029_20200424T094123_T33TYN", "system:index": "20200424T094029_20200424T094123_T33TYN", "system:time_end": 1587721647687, "system:time_start": 1587721647687, "system:version": 1587930521476594} diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200812T094039_20200812T094034_T33TYN/B1.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200812T094039_20200812T094034_T33TYN/B1.tif new file mode 100644 index 000000000..9dff310b4 Binary files /dev/null and b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200812T094039_20200812T094034_T33TYN/B1.tif differ diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200812T094039_20200812T094034_T33TYN/B11.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200812T094039_20200812T094034_T33TYN/B11.tif new file mode 100644 index 000000000..9efd6e287 Binary files /dev/null and b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200812T094039_20200812T094034_T33TYN/B11.tif differ diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200812T094039_20200812T094034_T33TYN/B12.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200812T094039_20200812T094034_T33TYN/B12.tif new file mode 100644 index 000000000..47d64de43 Binary files /dev/null and b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200812T094039_20200812T094034_T33TYN/B12.tif differ diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200812T094039_20200812T094034_T33TYN/B2.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200812T094039_20200812T094034_T33TYN/B2.tif new file mode 100644 index 000000000..733b98ac7 Binary files /dev/null and b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200812T094039_20200812T094034_T33TYN/B2.tif differ diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200812T094039_20200812T094034_T33TYN/B3.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200812T094039_20200812T094034_T33TYN/B3.tif new file mode 100644 index 000000000..972fc9af0 Binary files /dev/null and b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200812T094039_20200812T094034_T33TYN/B3.tif differ diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200812T094039_20200812T094034_T33TYN/B4.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200812T094039_20200812T094034_T33TYN/B4.tif new file mode 100644 index 000000000..37dc66eb0 Binary files /dev/null and b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200812T094039_20200812T094034_T33TYN/B4.tif differ diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200812T094039_20200812T094034_T33TYN/B5.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200812T094039_20200812T094034_T33TYN/B5.tif new file mode 100644 index 000000000..f3363bd95 Binary files /dev/null and b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200812T094039_20200812T094034_T33TYN/B5.tif differ diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200812T094039_20200812T094034_T33TYN/B6.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200812T094039_20200812T094034_T33TYN/B6.tif new file mode 100644 index 000000000..b19a03a30 Binary files /dev/null and b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200812T094039_20200812T094034_T33TYN/B6.tif differ diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200812T094039_20200812T094034_T33TYN/B7.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200812T094039_20200812T094034_T33TYN/B7.tif new file mode 100644 index 000000000..e988bcfba Binary files /dev/null and b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200812T094039_20200812T094034_T33TYN/B7.tif differ diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200812T094039_20200812T094034_T33TYN/B8.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200812T094039_20200812T094034_T33TYN/B8.tif new file mode 100644 index 000000000..ccc50c8ba Binary files /dev/null and b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200812T094039_20200812T094034_T33TYN/B8.tif differ diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200812T094039_20200812T094034_T33TYN/B8A.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200812T094039_20200812T094034_T33TYN/B8A.tif new file mode 100644 index 000000000..27ee7369f Binary files /dev/null and b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200812T094039_20200812T094034_T33TYN/B8A.tif differ diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200812T094039_20200812T094034_T33TYN/B9.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200812T094039_20200812T094034_T33TYN/B9.tif new file mode 100644 index 000000000..b07964457 Binary files /dev/null and b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200812T094039_20200812T094034_T33TYN/B9.tif differ diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200812T094039_20200812T094034_T33TYN/metadata.json b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200812T094039_20200812T094034_T33TYN/metadata.json new file mode 100644 index 000000000..d3c3117e5 --- /dev/null +++ b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20200812T094039_20200812T094034_T33TYN/metadata.json @@ -0,0 +1 @@ +{"AOT_RETRIEVAL_ACCURACY": 0, "CLOUDY_PIXEL_PERCENTAGE": 0.208477, "CLOUD_COVERAGE_ASSESSMENT": 0.208477, "CLOUD_SHADOW_PERCENTAGE": 0.078845, "DARK_FEATURES_PERCENTAGE": 0.266203, "DATASTRIP_ID": "S2B_OPER_MSI_L2A_DS_EPAE_20200812T120400_S20200812T094034_N02.14", "DATATAKE_IDENTIFIER": "GS2B_20200812T094039_017937_N02.14", "DATATAKE_TYPE": "INS-NOBS", "DEGRADED_MSI_DATA_PERCENTAGE": 0, "FORMAT_CORRECTNESS": "PASSED", "GENERAL_QUALITY": "PASSED", "GENERATION_TIME": 1597233840000, "GEOMETRIC_QUALITY": "PASSED", "GRANULE_ID": "L2A_T33TYN_A017937_20200812T094034", "HIGH_PROBA_CLOUDS_PERCENTAGE": 0.137471, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B1": 106.429324062005, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B10": 105.852008867325, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B11": 106.077757429308, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B12": 106.335669133722, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B2": 105.380300664506, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B3": 105.619830676885, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B4": 105.833171389503, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B5": 105.955584495271, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B6": 106.063210971005, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B7": 106.180783262801, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B8": 105.500006984303, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B8A": 106.287212521103, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B9": 106.474000451992, "MEAN_INCIDENCE_ZENITH_ANGLE_B1": 8.65065979233341, "MEAN_INCIDENCE_ZENITH_ANGLE_B10": 8.5230203611065, "MEAN_INCIDENCE_ZENITH_ANGLE_B11": 8.5659857030015, "MEAN_INCIDENCE_ZENITH_ANGLE_B12": 8.6286564275332, "MEAN_INCIDENCE_ZENITH_ANGLE_B2": 8.48649538224668, "MEAN_INCIDENCE_ZENITH_ANGLE_B3": 8.50238336028187, "MEAN_INCIDENCE_ZENITH_ANGLE_B4": 8.53311449180501, "MEAN_INCIDENCE_ZENITH_ANGLE_B5": 8.55333421195006, "MEAN_INCIDENCE_ZENITH_ANGLE_B6": 8.57592089704101, "MEAN_INCIDENCE_ZENITH_ANGLE_B7": 8.60096440676022, "MEAN_INCIDENCE_ZENITH_ANGLE_B8": 8.48915559090344, "MEAN_INCIDENCE_ZENITH_ANGLE_B8A": 8.62827228110442, "MEAN_INCIDENCE_ZENITH_ANGLE_B9": 8.68825044784009, "MEAN_SOLAR_AZIMUTH_ANGLE": 152.478041660485, "MEAN_SOLAR_ZENITH_ANGLE": 35.147934745323, "MEDIUM_PROBA_CLOUDS_PERCENTAGE": 0.070437, "MGRS_TILE": "33TYN", "NODATA_PIXEL_PERCENTAGE": 23.649484, "NOT_VEGETATED_PERCENTAGE": 20.454341, "PROCESSING_BASELINE": "02.14", "PRODUCT_ID": "S2B_MSIL2A_20200812T094039_N0214_R036_T33TYN_20200812T120400", "RADIATIVE_TRANSFER_ACCURACY": 0, "RADIOMETRIC_QUALITY": "PASSED", "REFLECTANCE_CONVERSION_CORRECTION": 0.97302905813138, "SATURATED_DEFECTIVE_PIXEL_PERCENTAGE": 0, "SENSING_ORBIT_DIRECTION": "DESCENDING", "SENSING_ORBIT_NUMBER": 36, "SENSOR_QUALITY": "PASSED", "SNOW_ICE_PERCENTAGE": 0.000183, "SOLAR_IRRADIANCE_B1": 1874.3, "SOLAR_IRRADIANCE_B10": 365.41, "SOLAR_IRRADIANCE_B11": 247.08, "SOLAR_IRRADIANCE_B12": 87.75, "SOLAR_IRRADIANCE_B2": 1959.75, "SOLAR_IRRADIANCE_B3": 1824.93, "SOLAR_IRRADIANCE_B4": 1512.79, "SOLAR_IRRADIANCE_B5": 1425.78, "SOLAR_IRRADIANCE_B6": 1291.13, "SOLAR_IRRADIANCE_B7": 1175.57, "SOLAR_IRRADIANCE_B8": 1041.28, "SOLAR_IRRADIANCE_B8A": 953.93, "SOLAR_IRRADIANCE_B9": 817.58, "SPACECRAFT_NAME": "Sentinel-2B", "THIN_CIRRUS_PERCENTAGE": 0.000569, "UNCLASSIFIED_PERCENTAGE": 0.25633, "VEGETATION_PERCENTAGE": 74.069816, "WATER_PERCENTAGE": 4.665802, "WATER_VAPOUR_RETRIEVAL_ACCURACY": 0, "system:asset_size": 1331946979, "system:band_names": ["B1", "B2", "B3", "B4", "B5", "B6", "B7", "B8", "B8A", "B9", "B11", "B12", "AOT", "WVP", "SCL", "TCI_R", "TCI_G", "TCI_B", "MSK_CLDPRB", "MSK_SNWPRB", "QA10", "QA20", "QA60"], "system:bands": {"B11": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [5490, 5490], "crs": "EPSG:32633", "crs_transform": [20, 0, 699960, 0, -20, 5300040]}, "TCI_B": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 255}, "dimensions": [10980, 10980], "crs": "EPSG:32633", "crs_transform": [10, 0, 699960, 0, -10, 5300040]}, "B12": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [5490, 5490], "crs": "EPSG:32633", "crs_transform": [20, 0, 699960, 0, -20, 5300040]}, "AOT": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [10980, 10980], "crs": "EPSG:32633", "crs_transform": [10, 0, 699960, 0, -10, 5300040]}, "QA10": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [10980, 10980], "crs": "EPSG:32633", "crs_transform": [10, 0, 699960, 0, -10, 5300040]}, "B8A": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [5490, 5490], "crs": "EPSG:32633", "crs_transform": [20, 0, 699960, 0, -20, 5300040]}, "QA20": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 4294967295}, "dimensions": [5490, 5490], "crs": "EPSG:32633", "crs_transform": [20, 0, 699960, 0, -20, 5300040]}, "MSK_CLDPRB": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 255}, "dimensions": [5490, 5490], "crs": "EPSG:32633", "crs_transform": [20, 0, 699960, 0, -20, 5300040]}, "B1": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [1830, 1830], "crs": "EPSG:32633", "crs_transform": [60, 0, 699960, 0, -60, 5300040]}, "QA60": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [1830, 1830], "crs": "EPSG:32633", "crs_transform": [60, 0, 699960, 0, -60, 5300040]}, "B2": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [10980, 10980], "crs": "EPSG:32633", "crs_transform": [10, 0, 699960, 0, -10, 5300040]}, "WVP": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [10980, 10980], "crs": "EPSG:32633", "crs_transform": [10, 0, 699960, 0, -10, 5300040]}, "B3": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [10980, 10980], "crs": "EPSG:32633", "crs_transform": [10, 0, 699960, 0, -10, 5300040]}, "B4": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [10980, 10980], "crs": "EPSG:32633", "crs_transform": [10, 0, 699960, 0, -10, 5300040]}, "TCI_R": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 255}, "dimensions": [10980, 10980], "crs": "EPSG:32633", "crs_transform": [10, 0, 699960, 0, -10, 5300040]}, "B5": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [5490, 5490], "crs": "EPSG:32633", "crs_transform": [20, 0, 699960, 0, -20, 5300040]}, "B6": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [5490, 5490], "crs": "EPSG:32633", "crs_transform": [20, 0, 699960, 0, -20, 5300040]}, "B7": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [5490, 5490], "crs": "EPSG:32633", "crs_transform": [20, 0, 699960, 0, -20, 5300040]}, "B8": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [10980, 10980], "crs": "EPSG:32633", "crs_transform": [10, 0, 699960, 0, -10, 5300040]}, "B9": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [1830, 1830], "crs": "EPSG:32633", "crs_transform": [60, 0, 699960, 0, -60, 5300040]}, "MSK_SNWPRB": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 255}, "dimensions": [5490, 5490], "crs": "EPSG:32633", "crs_transform": [20, 0, 699960, 0, -20, 5300040]}, "SCL": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 255}, "dimensions": [5490, 5490], "crs": "EPSG:32633", "crs_transform": [20, 0, 699960, 0, -20, 5300040]}, "TCI_G": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 255}, "dimensions": [10980, 10980], "crs": "EPSG:32633", "crs_transform": [10, 0, 699960, 0, -10, 5300040]}}, "system:footprint": {"type": "LinearRing", "coordinates": [[19.135125920080306, 47.77916186962477], [19.135029727720475, 47.779255442991314], [18.16693852790548, 47.8100889727091], [18.164267652937216, 47.80632443625982], [18.158840869103887, 47.79620697829836], [18.088526955702534, 47.61022107161706], [17.903701479756446, 47.09670164788835], [17.84229285719976, 46.92329249637601], [17.827427767869775, 46.88046144909253], [17.81606866971548, 46.844018609587174], [17.813931590571993, 46.833263753347495], [17.81382991163962, 46.831316282941685], [17.814185583479816, 46.83118308142532], [19.058775213444516, 46.79376963860806], [19.058911741293866, 46.793834509947615], [19.077589698914174, 47.04020634864976], [19.096514577054162, 47.28654137136117], [19.115691869935954, 47.53285992686103], [19.135125920080306, 47.77916186962477]]}, "system:id": "COPERNICUS/S2_SR/20200812T094039_20200812T094034_T33TYN", "system:index": "20200812T094039_20200812T094034_T33TYN", "system:time_end": 1597225656684, "system:time_start": 1597225656684, "system:version": 1597431921198064} diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20201118T095311_20201118T095400_T33TYN/B1.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20201118T095311_20201118T095400_T33TYN/B1.tif new file mode 100644 index 000000000..09c23f408 Binary files /dev/null and b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20201118T095311_20201118T095400_T33TYN/B1.tif differ diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20201118T095311_20201118T095400_T33TYN/B11.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20201118T095311_20201118T095400_T33TYN/B11.tif new file mode 100644 index 000000000..262d05a60 Binary files /dev/null and b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20201118T095311_20201118T095400_T33TYN/B11.tif differ diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20201118T095311_20201118T095400_T33TYN/B12.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20201118T095311_20201118T095400_T33TYN/B12.tif new file mode 100644 index 000000000..42914e101 Binary files /dev/null and b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20201118T095311_20201118T095400_T33TYN/B12.tif differ diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20201118T095311_20201118T095400_T33TYN/B2.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20201118T095311_20201118T095400_T33TYN/B2.tif new file mode 100644 index 000000000..8bbca800c Binary files /dev/null and b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20201118T095311_20201118T095400_T33TYN/B2.tif differ diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20201118T095311_20201118T095400_T33TYN/B3.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20201118T095311_20201118T095400_T33TYN/B3.tif new file mode 100644 index 000000000..54d75d141 Binary files /dev/null and b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20201118T095311_20201118T095400_T33TYN/B3.tif differ diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20201118T095311_20201118T095400_T33TYN/B4.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20201118T095311_20201118T095400_T33TYN/B4.tif new file mode 100644 index 000000000..07d895a7d Binary files /dev/null and b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20201118T095311_20201118T095400_T33TYN/B4.tif differ diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20201118T095311_20201118T095400_T33TYN/B5.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20201118T095311_20201118T095400_T33TYN/B5.tif new file mode 100644 index 000000000..f88d453d9 Binary files /dev/null and b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20201118T095311_20201118T095400_T33TYN/B5.tif differ diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20201118T095311_20201118T095400_T33TYN/B6.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20201118T095311_20201118T095400_T33TYN/B6.tif new file mode 100644 index 000000000..b6d0815e3 Binary files /dev/null and b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20201118T095311_20201118T095400_T33TYN/B6.tif differ diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20201118T095311_20201118T095400_T33TYN/B7.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20201118T095311_20201118T095400_T33TYN/B7.tif new file mode 100644 index 000000000..8ae56682c Binary files /dev/null and b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20201118T095311_20201118T095400_T33TYN/B7.tif differ diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20201118T095311_20201118T095400_T33TYN/B8.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20201118T095311_20201118T095400_T33TYN/B8.tif new file mode 100644 index 000000000..579d2a9a8 Binary files /dev/null and b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20201118T095311_20201118T095400_T33TYN/B8.tif differ diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20201118T095311_20201118T095400_T33TYN/B8A.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20201118T095311_20201118T095400_T33TYN/B8A.tif new file mode 100644 index 000000000..bed885ace Binary files /dev/null and b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20201118T095311_20201118T095400_T33TYN/B8A.tif differ diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20201118T095311_20201118T095400_T33TYN/B9.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20201118T095311_20201118T095400_T33TYN/B9.tif new file mode 100644 index 000000000..8973a2922 Binary files /dev/null and b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20201118T095311_20201118T095400_T33TYN/B9.tif differ diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20201118T095311_20201118T095400_T33TYN/metadata.json b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20201118T095311_20201118T095400_T33TYN/metadata.json new file mode 100644 index 000000000..83ac26bd5 --- /dev/null +++ b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20201118T095311_20201118T095400_T33TYN/metadata.json @@ -0,0 +1 @@ +{"AOT_RETRIEVAL_ACCURACY": 0, "CLOUDY_PIXEL_PERCENTAGE": 0.453934, "CLOUD_COVERAGE_ASSESSMENT": 0.45393399999999995, "CLOUD_SHADOW_PERCENTAGE": 1.138567, "DARK_FEATURES_PERCENTAGE": 15.463047, "DATASTRIP_ID": "S2A_OPER_MSI_L2A_DS_EPAE_20201118T115530_S20201118T095400_N02.14", "DATATAKE_IDENTIFIER": "GS2A_20201118T095311_028247_N02.14", "DATATAKE_TYPE": "INS-NOBS", "DEGRADED_MSI_DATA_PERCENTAGE": 0, "FORMAT_CORRECTNESS": "PASSED", "GENERAL_QUALITY": "PASSED", "GENERATION_TIME": 1605700530000, "GEOMETRIC_QUALITY": "PASSED", "GRANULE_ID": "L2A_T33TYN_A028247_20201118T095400", "HIGH_PROBA_CLOUDS_PERCENTAGE": 0.248884, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B1": 281.681280146803, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B10": 281.572710595024, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B11": 281.081166730384, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B12": 281.107900743002, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B2": 282.19659426411, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B3": 281.893354312561, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B4": 281.719087546553, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B5": 281.69787113297, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B6": 281.699110683518, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B7": 281.697525241078, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B8": 282.017178016916, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B8A": 281.726600099165, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B9": 281.718430321399, "MEAN_INCIDENCE_ZENITH_ANGLE_B1": 5.67907654400748, "MEAN_INCIDENCE_ZENITH_ANGLE_B10": 5.41638149201038, "MEAN_INCIDENCE_ZENITH_ANGLE_B11": 5.51137814408818, "MEAN_INCIDENCE_ZENITH_ANGLE_B12": 5.63006844946331, "MEAN_INCIDENCE_ZENITH_ANGLE_B2": 5.31075209346664, "MEAN_INCIDENCE_ZENITH_ANGLE_B3": 5.36458098961822, "MEAN_INCIDENCE_ZENITH_ANGLE_B4": 5.43516937389945, "MEAN_INCIDENCE_ZENITH_ANGLE_B5": 5.48025769954631, "MEAN_INCIDENCE_ZENITH_ANGLE_B6": 5.52449969165148, "MEAN_INCIDENCE_ZENITH_ANGLE_B7": 5.57253782358173, "MEAN_INCIDENCE_ZENITH_ANGLE_B8": 5.33541219024446, "MEAN_INCIDENCE_ZENITH_ANGLE_B8A": 5.62367424140266, "MEAN_INCIDENCE_ZENITH_ANGLE_B9": 5.73796871755139, "MEAN_SOLAR_AZIMUTH_ANGLE": 171.248871380871, "MEAN_SOLAR_ZENITH_ANGLE": 67.1303819169954, "MEDIUM_PROBA_CLOUDS_PERCENTAGE": 0.169226, "MGRS_TILE": "33TYN", "NODATA_PIXEL_PERCENTAGE": 4e-05, "NOT_VEGETATED_PERCENTAGE": 31.151235, "PROCESSING_BASELINE": "02.14", "PRODUCT_ID": "S2A_MSIL2A_20201118T095311_N0214_R079_T33TYN_20201118T115530", "RADIATIVE_TRANSFER_ACCURACY": 0, "RADIOMETRIC_QUALITY": "PASSED", "REFLECTANCE_CONVERSION_CORRECTION": 1.02242117671629, "SATURATED_DEFECTIVE_PIXEL_PERCENTAGE": 0, "SENSING_ORBIT_DIRECTION": "DESCENDING", "SENSING_ORBIT_NUMBER": 79, "SENSOR_QUALITY": "PASSED", "SNOW_ICE_PERCENTAGE": 6e-05, "SOLAR_IRRADIANCE_B1": 1884.69, "SOLAR_IRRADIANCE_B10": 367.15, "SOLAR_IRRADIANCE_B11": 245.59, "SOLAR_IRRADIANCE_B12": 85.25, "SOLAR_IRRADIANCE_B2": 1959.66, "SOLAR_IRRADIANCE_B3": 1823.24, "SOLAR_IRRADIANCE_B4": 1512.06, "SOLAR_IRRADIANCE_B5": 1424.64, "SOLAR_IRRADIANCE_B6": 1287.61, "SOLAR_IRRADIANCE_B7": 1162.08, "SOLAR_IRRADIANCE_B8": 1041.63, "SOLAR_IRRADIANCE_B8A": 955.32, "SOLAR_IRRADIANCE_B9": 812.92, "SPACECRAFT_NAME": "Sentinel-2A", "THIN_CIRRUS_PERCENTAGE": 0.035823, "UNCLASSIFIED_PERCENTAGE": 12.55997, "VEGETATION_PERCENTAGE": 34.667516, "WATER_PERCENTAGE": 4.56567, "WATER_VAPOUR_RETRIEVAL_ACCURACY": 0, "system:asset_size": 1736727953, "system:band_names": ["B1", "B2", "B3", "B4", "B5", "B6", "B7", "B8", "B8A", "B9", "B11", "B12", "AOT", "WVP", "SCL", "TCI_R", "TCI_G", "TCI_B", "MSK_CLDPRB", "MSK_SNWPRB", "QA10", "QA20", "QA60"], "system:bands": {"B11": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [5490, 5490], "crs": "EPSG:32633", "crs_transform": [20, 0, 699960, 0, -20, 5300040]}, "TCI_B": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 255}, "dimensions": [10980, 10980], "crs": "EPSG:32633", "crs_transform": [10, 0, 699960, 0, -10, 5300040]}, "B12": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [5490, 5490], "crs": "EPSG:32633", "crs_transform": [20, 0, 699960, 0, -20, 5300040]}, "AOT": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [10980, 10980], "crs": "EPSG:32633", "crs_transform": [10, 0, 699960, 0, -10, 5300040]}, "QA10": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [10980, 10980], "crs": "EPSG:32633", "crs_transform": [10, 0, 699960, 0, -10, 5300040]}, "B8A": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [5490, 5490], "crs": "EPSG:32633", "crs_transform": [20, 0, 699960, 0, -20, 5300040]}, "QA20": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 4294967295}, "dimensions": [5490, 5490], "crs": "EPSG:32633", "crs_transform": [20, 0, 699960, 0, -20, 5300040]}, "MSK_CLDPRB": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 255}, "dimensions": [5490, 5490], "crs": "EPSG:32633", "crs_transform": [20, 0, 699960, 0, -20, 5300040]}, "B1": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [1830, 1830], "crs": "EPSG:32633", "crs_transform": [60, 0, 699960, 0, -60, 5300040]}, "QA60": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [1830, 1830], "crs": "EPSG:32633", "crs_transform": [60, 0, 699960, 0, -60, 5300040]}, "B2": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [10980, 10980], "crs": "EPSG:32633", "crs_transform": [10, 0, 699960, 0, -10, 5300040]}, "WVP": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [10980, 10980], "crs": "EPSG:32633", "crs_transform": [10, 0, 699960, 0, -10, 5300040]}, "B3": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [10980, 10980], "crs": "EPSG:32633", "crs_transform": [10, 0, 699960, 0, -10, 5300040]}, "B4": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [10980, 10980], "crs": "EPSG:32633", "crs_transform": [10, 0, 699960, 0, -10, 5300040]}, "TCI_R": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 255}, "dimensions": [10980, 10980], "crs": "EPSG:32633", "crs_transform": [10, 0, 699960, 0, -10, 5300040]}, "B5": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [5490, 5490], "crs": "EPSG:32633", "crs_transform": [20, 0, 699960, 0, -20, 5300040]}, "B6": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [5490, 5490], "crs": "EPSG:32633", "crs_transform": [20, 0, 699960, 0, -20, 5300040]}, "B7": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [5490, 5490], "crs": "EPSG:32633", "crs_transform": [20, 0, 699960, 0, -20, 5300040]}, "B8": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [10980, 10980], "crs": "EPSG:32633", "crs_transform": [10, 0, 699960, 0, -10, 5300040]}, "B9": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [1830, 1830], "crs": "EPSG:32633", "crs_transform": [60, 0, 699960, 0, -60, 5300040]}, "MSK_SNWPRB": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 255}, "dimensions": [5490, 5490], "crs": "EPSG:32633", "crs_transform": [20, 0, 699960, 0, -20, 5300040]}, "SCL": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 255}, "dimensions": [5490, 5490], "crs": "EPSG:32633", "crs_transform": [20, 0, 699960, 0, -20, 5300040]}, "TCI_G": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 255}, "dimensions": [10980, 10980], "crs": "EPSG:32633", "crs_transform": [10, 0, 699960, 0, -10, 5300040]}}, "system:footprint": {"type": "LinearRing", "coordinates": [[17.67141078171894, 47.82263288510892], [17.671405837212518, 47.82261820708107], [17.646411626097294, 47.32925082218025], [17.62207290404578, 46.83583851015872], [17.62212427253828, 46.83579584772863], [17.622167538661174, 46.8357491351654], [17.622188932533497, 46.8357458205179], [19.058775213444516, 46.79376963860806], [19.0588384118566, 46.793803980922334], [19.05890750034179, 46.79383271512413], [19.058912745577004, 46.79384734936178], [19.077589698914174, 47.04020634864976], [19.096514577054162, 47.28654137136117], [19.115691869935954, 47.53285992686103], [19.135125920080306, 47.77916186962477], [19.13507488572603, 47.77920526079178], [19.135032249348438, 47.77925250976796], [19.135010549350152, 47.77925610053018], [17.67154409192475, 47.82269745296916], [17.671480450885944, 47.82266239915652], [17.67141078171894, 47.82263288510892]]}, "system:id": "COPERNICUS/S2_SR/20201118T095311_20201118T095400_T33TYN", "system:index": "20201118T095311_20201118T095400_T33TYN", "system:time_end": 1605693452958, "system:time_start": 1605693452958, "system:version": 1605833526635574} diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20210218T094029_20210218T094028_T33TYN/B1.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20210218T094029_20210218T094028_T33TYN/B1.tif new file mode 100644 index 000000000..253b906f3 Binary files /dev/null and b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20210218T094029_20210218T094028_T33TYN/B1.tif differ diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20210218T094029_20210218T094028_T33TYN/B11.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20210218T094029_20210218T094028_T33TYN/B11.tif new file mode 100644 index 000000000..ef83ac19d Binary files /dev/null and b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20210218T094029_20210218T094028_T33TYN/B11.tif differ diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20210218T094029_20210218T094028_T33TYN/B12.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20210218T094029_20210218T094028_T33TYN/B12.tif new file mode 100644 index 000000000..ccdd8fb01 Binary files /dev/null and b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20210218T094029_20210218T094028_T33TYN/B12.tif differ diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20210218T094029_20210218T094028_T33TYN/B2.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20210218T094029_20210218T094028_T33TYN/B2.tif new file mode 100644 index 000000000..17bf01d6a Binary files /dev/null and b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20210218T094029_20210218T094028_T33TYN/B2.tif differ diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20210218T094029_20210218T094028_T33TYN/B3.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20210218T094029_20210218T094028_T33TYN/B3.tif new file mode 100644 index 000000000..225ce192c Binary files /dev/null and b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20210218T094029_20210218T094028_T33TYN/B3.tif differ diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20210218T094029_20210218T094028_T33TYN/B4.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20210218T094029_20210218T094028_T33TYN/B4.tif new file mode 100644 index 000000000..3228fe9dc Binary files /dev/null and b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20210218T094029_20210218T094028_T33TYN/B4.tif differ diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20210218T094029_20210218T094028_T33TYN/B5.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20210218T094029_20210218T094028_T33TYN/B5.tif new file mode 100644 index 000000000..dd1c11bba Binary files /dev/null and b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20210218T094029_20210218T094028_T33TYN/B5.tif differ diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20210218T094029_20210218T094028_T33TYN/B6.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20210218T094029_20210218T094028_T33TYN/B6.tif new file mode 100644 index 000000000..f2f2a950b Binary files /dev/null and b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20210218T094029_20210218T094028_T33TYN/B6.tif differ diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20210218T094029_20210218T094028_T33TYN/B7.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20210218T094029_20210218T094028_T33TYN/B7.tif new file mode 100644 index 000000000..d9146ba9d Binary files /dev/null and b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20210218T094029_20210218T094028_T33TYN/B7.tif differ diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20210218T094029_20210218T094028_T33TYN/B8.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20210218T094029_20210218T094028_T33TYN/B8.tif new file mode 100644 index 000000000..41cdc9f90 Binary files /dev/null and b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20210218T094029_20210218T094028_T33TYN/B8.tif differ diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20210218T094029_20210218T094028_T33TYN/B8A.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20210218T094029_20210218T094028_T33TYN/B8A.tif new file mode 100644 index 000000000..55bb31b14 Binary files /dev/null and b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20210218T094029_20210218T094028_T33TYN/B8A.tif differ diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20210218T094029_20210218T094028_T33TYN/B9.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20210218T094029_20210218T094028_T33TYN/B9.tif new file mode 100644 index 000000000..e01001e63 Binary files /dev/null and b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20210218T094029_20210218T094028_T33TYN/B9.tif differ diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20210218T094029_20210218T094028_T33TYN/metadata.json b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20210218T094029_20210218T094028_T33TYN/metadata.json new file mode 100644 index 000000000..fa970f6ee --- /dev/null +++ b/tests/fixtures/data/SSL4EO-S12/s2a/0000042/20210218T094029_20210218T094028_T33TYN/metadata.json @@ -0,0 +1 @@ +{"AOT_RETRIEVAL_ACCURACY": 0, "CLOUDY_PIXEL_PERCENTAGE": 0.901432, "CLOUD_COVERAGE_ASSESSMENT": 0.901432, "CLOUD_SHADOW_PERCENTAGE": 0.84547, "DARK_FEATURES_PERCENTAGE": 6.795183, "DATASTRIP_ID": "S2B_OPER_MSI_L2A_DS_VGS2_20210218T123009_S20210218T094028_N02.14", "DATATAKE_IDENTIFIER": "GS2B_20210218T094029_020654_N02.14", "DATATAKE_TYPE": "INS-NOBS", "DEGRADED_MSI_DATA_PERCENTAGE": 0, "FORMAT_CORRECTNESS": "PASSED", "GENERAL_QUALITY": "PASSED", "GENERATION_TIME": 1613651409000, "GEOMETRIC_QUALITY": "PASSED", "GRANULE_ID": "L2A_T33TYN_A020654_20210218T094028", "HIGH_PROBA_CLOUDS_PERCENTAGE": 0.16992, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B1": 106.191528469216, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B10": 105.765178013311, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B11": 105.899702597322, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B12": 106.125463551834, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B2": 105.346122751905, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B3": 105.53713441503, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B4": 105.695362490083, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B5": 105.795932240219, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B6": 105.881764502021, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B7": 105.977799431462, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B8": 105.437391144718, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B8A": 106.062903343031, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B9": 106.246892369833, "MEAN_INCIDENCE_ZENITH_ANGLE_B1": 8.60382616209157, "MEAN_INCIDENCE_ZENITH_ANGLE_B10": 8.4747035379831, "MEAN_INCIDENCE_ZENITH_ANGLE_B11": 8.51482581279656, "MEAN_INCIDENCE_ZENITH_ANGLE_B12": 8.5809124485055, "MEAN_INCIDENCE_ZENITH_ANGLE_B2": 8.42926736088918, "MEAN_INCIDENCE_ZENITH_ANGLE_B3": 8.44752826229191, "MEAN_INCIDENCE_ZENITH_ANGLE_B4": 8.48181484372435, "MEAN_INCIDENCE_ZENITH_ANGLE_B5": 8.50214317942103, "MEAN_INCIDENCE_ZENITH_ANGLE_B6": 8.52486382484645, "MEAN_INCIDENCE_ZENITH_ANGLE_B7": 8.55005538439596, "MEAN_INCIDENCE_ZENITH_ANGLE_B8": 8.4342391916749, "MEAN_INCIDENCE_ZENITH_ANGLE_B8A": 8.57753521260584, "MEAN_INCIDENCE_ZENITH_ANGLE_B9": 8.64032193814476, "MEAN_SOLAR_AZIMUTH_ANGLE": 159.489777929019, "MEAN_SOLAR_ZENITH_ANGLE": 61.0015939922247, "MEDIUM_PROBA_CLOUDS_PERCENTAGE": 0.169342, "MGRS_TILE": "33TYN", "NODATA_PIXEL_PERCENTAGE": 22.564258, "NOT_VEGETATED_PERCENTAGE": 63.721645, "PROCESSING_BASELINE": "02.14", "PRODUCT_ID": "S2B_MSIL2A_20210218T094029_N0214_R036_T33TYN_20210218T123009", "RADIATIVE_TRANSFER_ACCURACY": 0, "RADIOMETRIC_QUALITY": "PASSED", "REFLECTANCE_CONVERSION_CORRECTION": 1.02538502693364, "SATURATED_DEFECTIVE_PIXEL_PERCENTAGE": 0, "SENSING_ORBIT_DIRECTION": "DESCENDING", "SENSING_ORBIT_NUMBER": 36, "SENSOR_QUALITY": "PASSED", "SNOW_ICE_PERCENTAGE": 0.140262, "SOLAR_IRRADIANCE_B1": 1874.3, "SOLAR_IRRADIANCE_B10": 365.41, "SOLAR_IRRADIANCE_B11": 247.08, "SOLAR_IRRADIANCE_B12": 87.75, "SOLAR_IRRADIANCE_B2": 1959.75, "SOLAR_IRRADIANCE_B3": 1824.93, "SOLAR_IRRADIANCE_B4": 1512.79, "SOLAR_IRRADIANCE_B5": 1425.78, "SOLAR_IRRADIANCE_B6": 1291.13, "SOLAR_IRRADIANCE_B7": 1175.57, "SOLAR_IRRADIANCE_B8": 1041.28, "SOLAR_IRRADIANCE_B8A": 953.93, "SOLAR_IRRADIANCE_B9": 817.58, "SPACECRAFT_NAME": "Sentinel-2B", "THIN_CIRRUS_PERCENTAGE": 0.56217, "UNCLASSIFIED_PERCENTAGE": 8.106547, "VEGETATION_PERCENTAGE": 14.548491, "WATER_PERCENTAGE": 4.940968, "WATER_VAPOUR_RETRIEVAL_ACCURACY": 0, "system:asset_size": 1352573025, "system:band_names": ["B1", "B2", "B3", "B4", "B5", "B6", "B7", "B8", "B8A", "B9", "B11", "B12", "AOT", "WVP", "SCL", "TCI_R", "TCI_G", "TCI_B", "MSK_CLDPRB", "MSK_SNWPRB", "QA10", "QA20", "QA60"], "system:bands": {"B11": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [5490, 5490], "crs": "EPSG:32633", "crs_transform": [20, 0, 699960, 0, -20, 5300040]}, "TCI_B": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 255}, "dimensions": [10980, 10980], "crs": "EPSG:32633", "crs_transform": [10, 0, 699960, 0, -10, 5300040]}, "B12": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [5490, 5490], "crs": "EPSG:32633", "crs_transform": [20, 0, 699960, 0, -20, 5300040]}, "AOT": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [10980, 10980], "crs": "EPSG:32633", "crs_transform": [10, 0, 699960, 0, -10, 5300040]}, "QA10": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [10980, 10980], "crs": "EPSG:32633", "crs_transform": [10, 0, 699960, 0, -10, 5300040]}, "B8A": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [5490, 5490], "crs": "EPSG:32633", "crs_transform": [20, 0, 699960, 0, -20, 5300040]}, "QA20": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 4294967295}, "dimensions": [5490, 5490], "crs": "EPSG:32633", "crs_transform": [20, 0, 699960, 0, -20, 5300040]}, "MSK_CLDPRB": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 255}, "dimensions": [5490, 5490], "crs": "EPSG:32633", "crs_transform": [20, 0, 699960, 0, -20, 5300040]}, "B1": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [1830, 1830], "crs": "EPSG:32633", "crs_transform": [60, 0, 699960, 0, -60, 5300040]}, "QA60": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [1830, 1830], "crs": "EPSG:32633", "crs_transform": [60, 0, 699960, 0, -60, 5300040]}, "B2": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [10980, 10980], "crs": "EPSG:32633", "crs_transform": [10, 0, 699960, 0, -10, 5300040]}, "WVP": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [10980, 10980], "crs": "EPSG:32633", "crs_transform": [10, 0, 699960, 0, -10, 5300040]}, "B3": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [10980, 10980], "crs": "EPSG:32633", "crs_transform": [10, 0, 699960, 0, -10, 5300040]}, "B4": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [10980, 10980], "crs": "EPSG:32633", "crs_transform": [10, 0, 699960, 0, -10, 5300040]}, "TCI_R": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 255}, "dimensions": [10980, 10980], "crs": "EPSG:32633", "crs_transform": [10, 0, 699960, 0, -10, 5300040]}, "B5": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [5490, 5490], "crs": "EPSG:32633", "crs_transform": [20, 0, 699960, 0, -20, 5300040]}, "B6": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [5490, 5490], "crs": "EPSG:32633", "crs_transform": [20, 0, 699960, 0, -20, 5300040]}, "B7": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [5490, 5490], "crs": "EPSG:32633", "crs_transform": [20, 0, 699960, 0, -20, 5300040]}, "B8": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [10980, 10980], "crs": "EPSG:32633", "crs_transform": [10, 0, 699960, 0, -10, 5300040]}, "B9": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [1830, 1830], "crs": "EPSG:32633", "crs_transform": [60, 0, 699960, 0, -60, 5300040]}, "MSK_SNWPRB": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 255}, "dimensions": [5490, 5490], "crs": "EPSG:32633", "crs_transform": [20, 0, 699960, 0, -20, 5300040]}, "SCL": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 255}, "dimensions": [5490, 5490], "crs": "EPSG:32633", "crs_transform": [20, 0, 699960, 0, -20, 5300040]}, "TCI_G": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 255}, "dimensions": [10980, 10980], "crs": "EPSG:32633", "crs_transform": [10, 0, 699960, 0, -10, 5300040]}}, "system:footprint": {"type": "LinearRing", "coordinates": [[17.797345040594656, 46.83204619803086], [17.797362873719553, 46.83199454391213], [17.797885329067725, 46.83160945562662], [17.79795167729368, 46.83158136738409], [19.058775213444516, 46.79376963860806], [19.058911741293866, 46.793834509947615], [19.077589698914174, 47.04020634864976], [19.096514577054162, 47.28654137136117], [19.115691869935954, 47.53285992686103], [19.135125920080306, 47.77916186962477], [19.135029727720475, 47.779255442991314], [18.15182335365654, 47.81050538734816], [18.15174592944439, 47.810493350914705], [18.148367147337964, 47.808395609286016], [18.09547502744622, 47.673746790872656], [17.914556442526802, 47.17958035145654], [17.825145976296444, 46.92695559117315], [17.797373483050553, 46.832592829927364], [17.797345040594656, 46.83204619803086]]}, "system:id": "COPERNICUS/S2_SR/20210218T094029_20210218T094028_T33TYN", "system:index": "20210218T094029_20210218T094028_T33TYN", "system:time_end": 1613641650888, "system:time_start": 1613641650888, "system:version": 1614206891239703} diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000099/20200614T181931_20200614T183418_T11SMS/B1.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000099/20200614T181931_20200614T183418_T11SMS/B1.tif deleted file mode 100644 index 8b4b55754..000000000 Binary files a/tests/fixtures/data/SSL4EO-S12/s2a/0000099/20200614T181931_20200614T183418_T11SMS/B1.tif and /dev/null differ diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000099/20200614T181931_20200614T183418_T11SMS/B11.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000099/20200614T181931_20200614T183418_T11SMS/B11.tif deleted file mode 100644 index 656051ca3..000000000 Binary files a/tests/fixtures/data/SSL4EO-S12/s2a/0000099/20200614T181931_20200614T183418_T11SMS/B11.tif and /dev/null differ diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000099/20200614T181931_20200614T183418_T11SMS/B12.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000099/20200614T181931_20200614T183418_T11SMS/B12.tif deleted file mode 100644 index 472d8b414..000000000 Binary files a/tests/fixtures/data/SSL4EO-S12/s2a/0000099/20200614T181931_20200614T183418_T11SMS/B12.tif and /dev/null differ diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000099/20200614T181931_20200614T183418_T11SMS/B2.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000099/20200614T181931_20200614T183418_T11SMS/B2.tif deleted file mode 100644 index b0ef24579..000000000 Binary files a/tests/fixtures/data/SSL4EO-S12/s2a/0000099/20200614T181931_20200614T183418_T11SMS/B2.tif and /dev/null differ diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000099/20200614T181931_20200614T183418_T11SMS/B3.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000099/20200614T181931_20200614T183418_T11SMS/B3.tif deleted file mode 100644 index eb5c25c62..000000000 Binary files a/tests/fixtures/data/SSL4EO-S12/s2a/0000099/20200614T181931_20200614T183418_T11SMS/B3.tif and /dev/null differ diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000099/20200614T181931_20200614T183418_T11SMS/B4.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000099/20200614T181931_20200614T183418_T11SMS/B4.tif deleted file mode 100644 index d86fd7fd6..000000000 Binary files a/tests/fixtures/data/SSL4EO-S12/s2a/0000099/20200614T181931_20200614T183418_T11SMS/B4.tif and /dev/null differ diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000099/20200614T181931_20200614T183418_T11SMS/B5.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000099/20200614T181931_20200614T183418_T11SMS/B5.tif deleted file mode 100644 index 2f2514b78..000000000 Binary files a/tests/fixtures/data/SSL4EO-S12/s2a/0000099/20200614T181931_20200614T183418_T11SMS/B5.tif and /dev/null differ diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000099/20200614T181931_20200614T183418_T11SMS/B6.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000099/20200614T181931_20200614T183418_T11SMS/B6.tif deleted file mode 100644 index 68c6dc87c..000000000 Binary files a/tests/fixtures/data/SSL4EO-S12/s2a/0000099/20200614T181931_20200614T183418_T11SMS/B6.tif and /dev/null differ diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000099/20200614T181931_20200614T183418_T11SMS/B7.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000099/20200614T181931_20200614T183418_T11SMS/B7.tif deleted file mode 100644 index 4403ab65a..000000000 Binary files a/tests/fixtures/data/SSL4EO-S12/s2a/0000099/20200614T181931_20200614T183418_T11SMS/B7.tif and /dev/null differ diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000099/20200614T181931_20200614T183418_T11SMS/B8.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000099/20200614T181931_20200614T183418_T11SMS/B8.tif deleted file mode 100644 index 343b888f6..000000000 Binary files a/tests/fixtures/data/SSL4EO-S12/s2a/0000099/20200614T181931_20200614T183418_T11SMS/B8.tif and /dev/null differ diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000099/20200614T181931_20200614T183418_T11SMS/B8A.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000099/20200614T181931_20200614T183418_T11SMS/B8A.tif deleted file mode 100644 index f257d7612..000000000 Binary files a/tests/fixtures/data/SSL4EO-S12/s2a/0000099/20200614T181931_20200614T183418_T11SMS/B8A.tif and /dev/null differ diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000099/20200614T181931_20200614T183418_T11SMS/B9.tif b/tests/fixtures/data/SSL4EO-S12/s2a/0000099/20200614T181931_20200614T183418_T11SMS/B9.tif deleted file mode 100644 index e643fa0ec..000000000 Binary files a/tests/fixtures/data/SSL4EO-S12/s2a/0000099/20200614T181931_20200614T183418_T11SMS/B9.tif and /dev/null differ diff --git a/tests/fixtures/data/SSL4EO-S12/s2a/0000099/20200614T181931_20200614T183418_T11SMS/metadata.json b/tests/fixtures/data/SSL4EO-S12/s2a/0000099/20200614T181931_20200614T183418_T11SMS/metadata.json deleted file mode 100644 index 56095b8f3..000000000 --- a/tests/fixtures/data/SSL4EO-S12/s2a/0000099/20200614T181931_20200614T183418_T11SMS/metadata.json +++ /dev/null @@ -1 +0,0 @@ -{"AOT_RETRIEVAL_ACCURACY": 0, "CLOUDY_PIXEL_PERCENTAGE": 1.380343, "CLOUD_COVERAGE_ASSESSMENT": 1.380343, "CLOUD_SHADOW_PERCENTAGE": 0.045924, "DARK_FEATURES_PERCENTAGE": 0.207592, "DATASTRIP_ID": "S2A_OPER_MSI_L2A_DS_EPAE_20200614T233609_S20200614T183418_N02.14", "DATATAKE_IDENTIFIER": "GS2A_20200614T181931_026007_N02.14", "DATATAKE_TYPE": "INS-NOBS", "DEGRADED_MSI_DATA_PERCENTAGE": 0, "FORMAT_CORRECTNESS": "PASSED", "GENERAL_QUALITY": "PASSED", "GENERATION_TIME": 1592177769000, "GEOMETRIC_QUALITY": "PASSED", "GRANULE_ID": "L2A_T11SMS_A026007_20200614T183418", "HIGH_PROBA_CLOUDS_PERCENTAGE": 1.025103, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B1": 103.669795569119, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B10": 103.06629594126, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B11": 103.34224573122, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B12": 103.61285978569, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B2": 102.600210808613, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B3": 102.835706432612, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B4": 103.047500410431, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B5": 103.169579176377, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B6": 103.279959829067, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B7": 103.399267251377, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B8": 102.718211170632, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B8A": 103.517410582706, "MEAN_INCIDENCE_AZIMUTH_ANGLE_B9": 103.782954154122, "MEAN_INCIDENCE_ZENITH_ANGLE_B1": 10.241410125066, "MEAN_INCIDENCE_ZENITH_ANGLE_B10": 10.1279749581719, "MEAN_INCIDENCE_ZENITH_ANGLE_B11": 10.1660724760143, "MEAN_INCIDENCE_ZENITH_ANGLE_B12": 10.2124812156495, "MEAN_INCIDENCE_ZENITH_ANGLE_B2": 10.0913569044097, "MEAN_INCIDENCE_ZENITH_ANGLE_B3": 10.1094453559007, "MEAN_INCIDENCE_ZENITH_ANGLE_B4": 10.1331032363887, "MEAN_INCIDENCE_ZENITH_ANGLE_B5": 10.1488031307943, "MEAN_INCIDENCE_ZENITH_ANGLE_B6": 10.1664615257693, "MEAN_INCIDENCE_ZENITH_ANGLE_B7": 10.1861200861194, "MEAN_INCIDENCE_ZENITH_ANGLE_B8": 10.0993961975397, "MEAN_INCIDENCE_ZENITH_ANGLE_B8A": 10.217514876513, "MEAN_INCIDENCE_ZENITH_ANGLE_B9": 10.2674475409991, "MEAN_SOLAR_AZIMUTH_ANGLE": 115.516292013745, "MEAN_SOLAR_ZENITH_ANGLE": 19.0927215025527, "MEDIUM_PROBA_CLOUDS_PERCENTAGE": 0.35524, "MGRS_TILE": "11SMS", "NODATA_PIXEL_PERCENTAGE": 62.915778, "NOT_VEGETATED_PERCENTAGE": 49.45268, "PROCESSING_BASELINE": "02.14", "PRODUCT_ID": "S2A_MSIL2A_20200614T181931_N0214_R127_T11SMS_20200614T233609", "RADIATIVE_TRANSFER_ACCURACY": 0, "RADIOMETRIC_QUALITY": "PASSED", "REFLECTANCE_CONVERSION_CORRECTION": 0.969929694485884, "SATURATED_DEFECTIVE_PIXEL_PERCENTAGE": 0, "SENSING_ORBIT_DIRECTION": "DESCENDING", "SENSING_ORBIT_NUMBER": 127, "SENSOR_QUALITY": "PASSED", "SNOW_ICE_PERCENTAGE": 0.002174, "SOLAR_IRRADIANCE_B1": 1884.69, "SOLAR_IRRADIANCE_B10": 367.15, "SOLAR_IRRADIANCE_B11": 245.59, "SOLAR_IRRADIANCE_B12": 85.25, "SOLAR_IRRADIANCE_B2": 1959.66, "SOLAR_IRRADIANCE_B3": 1823.24, "SOLAR_IRRADIANCE_B4": 1512.06, "SOLAR_IRRADIANCE_B5": 1424.64, "SOLAR_IRRADIANCE_B6": 1287.61, "SOLAR_IRRADIANCE_B7": 1162.08, "SOLAR_IRRADIANCE_B8": 1041.63, "SOLAR_IRRADIANCE_B8A": 955.32, "SOLAR_IRRADIANCE_B9": 812.92, "SPACECRAFT_NAME": "Sentinel-2A", "THIN_CIRRUS_PERCENTAGE": 0, "UNCLASSIFIED_PERCENTAGE": 0.570473, "VEGETATION_PERCENTAGE": 23.253819, "WATER_PERCENTAGE": 25.086993, "WATER_VAPOUR_RETRIEVAL_ACCURACY": 0, "system:asset_size": 670484430, "system:band_names": ["B1", "B2", "B3", "B4", "B5", "B6", "B7", "B8", "B8A", "B9", "B11", "B12", "AOT", "WVP", "SCL", "TCI_R", "TCI_G", "TCI_B", "MSK_CLDPRB", "MSK_SNWPRB", "QA10", "QA20", "QA60"], "system:bands": {"B11": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [5490, 5490], "crs": "EPSG:32611", "crs_transform": [20, 0, 399960, 0, -20, 3700020]}, "TCI_B": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 255}, "dimensions": [10980, 10980], "crs": "EPSG:32611", "crs_transform": [10, 0, 399960, 0, -10, 3700020]}, "B12": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [5490, 5490], "crs": "EPSG:32611", "crs_transform": [20, 0, 399960, 0, -20, 3700020]}, "AOT": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [10980, 10980], "crs": "EPSG:32611", "crs_transform": [10, 0, 399960, 0, -10, 3700020]}, "QA10": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [10980, 10980], "crs": "EPSG:32611", "crs_transform": [10, 0, 399960, 0, -10, 3700020]}, "B8A": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [5490, 5490], "crs": "EPSG:32611", "crs_transform": [20, 0, 399960, 0, -20, 3700020]}, "QA20": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 4294967295}, "dimensions": [5490, 5490], "crs": "EPSG:32611", "crs_transform": [20, 0, 399960, 0, -20, 3700020]}, "MSK_CLDPRB": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 255}, "dimensions": [5490, 5490], "crs": "EPSG:32611", "crs_transform": [20, 0, 399960, 0, -20, 3700020]}, "B1": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [1830, 1830], "crs": "EPSG:32611", "crs_transform": [60, 0, 399960, 0, -60, 3700020]}, "QA60": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [1830, 1830], "crs": "EPSG:32611", "crs_transform": [60, 0, 399960, 0, -60, 3700020]}, "B2": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [10980, 10980], "crs": "EPSG:32611", "crs_transform": [10, 0, 399960, 0, -10, 3700020]}, "WVP": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [10980, 10980], "crs": "EPSG:32611", "crs_transform": [10, 0, 399960, 0, -10, 3700020]}, "B3": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [10980, 10980], "crs": "EPSG:32611", "crs_transform": [10, 0, 399960, 0, -10, 3700020]}, "B4": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [10980, 10980], "crs": "EPSG:32611", "crs_transform": [10, 0, 399960, 0, -10, 3700020]}, "TCI_R": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 255}, "dimensions": [10980, 10980], "crs": "EPSG:32611", "crs_transform": [10, 0, 399960, 0, -10, 3700020]}, "B5": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [5490, 5490], "crs": "EPSG:32611", "crs_transform": [20, 0, 399960, 0, -20, 3700020]}, "B6": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [5490, 5490], "crs": "EPSG:32611", "crs_transform": [20, 0, 399960, 0, -20, 3700020]}, "B7": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [5490, 5490], "crs": "EPSG:32611", "crs_transform": [20, 0, 399960, 0, -20, 3700020]}, "B8": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [10980, 10980], "crs": "EPSG:32611", "crs_transform": [10, 0, 399960, 0, -10, 3700020]}, "B9": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 65535}, "dimensions": [1830, 1830], "crs": "EPSG:32611", "crs_transform": [60, 0, 399960, 0, -60, 3700020]}, "MSK_SNWPRB": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 255}, "dimensions": [5490, 5490], "crs": "EPSG:32611", "crs_transform": [20, 0, 399960, 0, -20, 3700020]}, "SCL": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 255}, "dimensions": [5490, 5490], "crs": "EPSG:32611", "crs_transform": [20, 0, 399960, 0, -20, 3700020]}, "TCI_G": {"data_type": {"type": "PixelType", "precision": "int", "min": 0, "max": 255}, "dimensions": [10980, 10980], "crs": "EPSG:32611", "crs_transform": [10, 0, 399960, 0, -10, 3700020]}}, "system:footprint": {"type": "LinearRing", "coordinates": [[-116.89627078890739, 32.44924419582073], [-116.89626737417575, 32.44925898927161], [-116.8951045650362, 33.43953070750837], [-116.89514832775068, 33.439572205468394], [-116.89518786695115, 33.4396207947493], [-117.19903171329149, 33.43950534084795], [-117.19907137449762, 33.43947601948668], [-117.19912234182809, 33.4394638167782], [-117.19913391715312, 33.43944350739506], [-117.19977701964083, 33.43781570659065], [-117.20298056753498, 33.42698251736538], [-117.45058635492418, 32.49597355187842], [-117.45371990599305, 32.48405448940864], [-117.46060922124882, 32.45750750785478], [-117.4618614719186, 32.452630471418175], [-117.46248822151223, 32.44991226060029], [-117.46248078368934, 32.44846000384899], [-117.46243723209412, 32.4484187108336], [-117.46239789227033, 32.448370362912655], [-116.8963741294445, 32.44917659271196], [-116.89632497529956, 32.44921309854787], [-116.89627078890739, 32.44924419582073]]}, "system:id": "COPERNICUS/S2_SR/20200614T181931_20200614T183418_T11SMS", "system:index": "20200614T181931_20200614T183418_T11SMS", "system:time_end": 1592159723525, "system:time_start": 1592159723525, "system:version": 1592325576909226} diff --git a/tests/test_datasets/test_collators.py b/tests/test_datasets/test_collators.py index 31666f396..b35218bcb 100644 --- a/tests/test_datasets/test_collators.py +++ b/tests/test_datasets/test_collators.py @@ -37,8 +37,9 @@ # IMPORTS # ===================================================================================================================== from collections import defaultdict -from typing import Any, Dict, List, Union +from typing import Any +import pytest import torch from numpy.testing import assert_array_equal from torch import Tensor @@ -50,12 +51,15 @@ # ===================================================================================================================== # TESTS # ===================================================================================================================== -def test_get_collator() -> None: - collator_params_1 = {"module": "torchgeo.datasets.utils", "name": "stack_samples"} - collator_params_2 = {"name": "stack_sample_pairs"} - - assert callable(mdt.get_collator(collator_params_1)) - assert callable(mdt.get_collator(collator_params_2)) +@pytest.mark.parametrize( + "target", + ( + "torchgeo.datasets.utils.stack_samples", + "minerva.datasets.collators.stack_sample_pairs", + ), +) +def test_get_collator(target: str) -> None: + assert callable(mdt.get_collator(target)) def test_stack_sample_pairs() -> None: @@ -67,13 +71,13 @@ def test_stack_sample_pairs() -> None: mask_2 = torch.randint(0, 8, (52, 52)) # type: ignore[attr-defined] bbox_2 = [BoundingBox(0, 1, 0, 1, 0, 1)] - sample_1: Dict[str, Union[Tensor, List[Any]]] = { + sample_1: dict[str, Tensor | list[Any]] = { "image": image_1, "mask": mask_1, "bbox": bbox_1, } - sample_2: Dict[str, Union[Tensor, List[Any]]] = { + sample_2: dict[str, Tensor | list[Any]] = { "image": image_2, "mask": mask_2, "bbox": bbox_2, diff --git a/tests/test_datasets/test_factory.py b/tests/test_datasets/test_factory.py index a3ac47387..37d720346 100644 --- a/tests/test_datasets/test_factory.py +++ b/tests/test_datasets/test_factory.py @@ -39,7 +39,7 @@ # ===================================================================================================================== from copy import deepcopy from pathlib import Path -from typing import Any, Dict +from typing import Any import pandas as pd import pytest @@ -56,12 +56,14 @@ # ===================================================================================================================== # TESTS # ===================================================================================================================== -def test_make_dataset(exp_dataset_params: Dict[str, Any], data_root: Path) -> None: +def test_make_dataset(exp_dataset_params: dict[str, Any], data_root: Path) -> None: dataset_params2 = { "image": { - "image_1": exp_dataset_params["image"], - "image_2": exp_dataset_params["image"], + "subdatasets": { + "image_1": exp_dataset_params["image"], + "image_2": exp_dataset_params["image"], + }, }, "mask": exp_dataset_params["mask"], } @@ -80,8 +82,12 @@ def test_make_dataset(exp_dataset_params: Dict[str, Any], data_root: Path) -> No assert isinstance(subdatasets_4[0], UnionDataset) dataset_params3 = dataset_params2 - dataset_params3["image"]["image_1"]["transforms"] = {"AutoNorm": {"length": 12}} - dataset_params3["image"]["image_2"]["transforms"] = {"AutoNorm": {"length": 12}} + dataset_params3["image"]["subdatasets"]["image_1"]["transforms"] = { + "AutoNorm": {"length": 12} + } + dataset_params3["image"]["subdatasets"]["image_2"]["transforms"] = { + "AutoNorm": {"length": 12} + } dataset_params3["image"]["transforms"] = {"AutoNorm": {"length": 12}} dataset_5, subdatasets_5 = mdt.make_dataset(data_root, dataset_params3, cache=False) @@ -111,7 +117,7 @@ def test_make_dataset(exp_dataset_params: Dict[str, Any], data_root: Path) -> No @pytest.mark.parametrize("sample_pairs", (False, True)) def test_caching_datasets( - exp_dataset_params: Dict[str, Any], + exp_dataset_params: dict[str, Any], data_root: Path, cache_dir: Path, sample_pairs: bool, @@ -160,47 +166,38 @@ def test_caching_datasets( [ ( { - "module": "torchgeo.samplers", - "name": "RandomBatchGeoSampler", + "_target_": "torchgeo.samplers.RandomBatchGeoSampler", "roi": False, - "params": { - "size": 224, - "length": 4096, - }, + "size": 224, + "length": 4096, }, {}, ), ( { - "module": "minerva.samplers", - "name": "RandomPairGeoSampler", + "_target_": "minerva.samplers.RandomPairGeoSampler", "roi": False, - "params": { - "size": 224, - "length": 4096, - }, + "size": 224, + "length": 4096, }, {"sample_pairs": True}, ), ( { - "module": "torchgeo.samplers", - "name": "RandomBatchGeoSampler", + "_target_": "torchgeo.samplers.RandomBatchGeoSampler", "roi": False, - "params": { - "size": 224, - "length": 4096, - }, + "size": 224, + "length": 4096, }, {"world_size": 2}, ), ], ) def test_construct_dataloader( - exp_dataset_params: Dict[str, Any], + exp_dataset_params: dict[str, Any], data_root: Path, - sampler_params: Dict[str, Any], - kwargs: Dict[str, Any], + sampler_params: dict[str, Any], + kwargs: dict[str, Any], ) -> None: batch_size = 256 @@ -236,9 +233,10 @@ def test_make_loaders(default_config: DictConfig) -> None: old_params_2 = OmegaConf.to_object(deepcopy(default_config)) assert isinstance(old_params_2, dict) dataset_params = old_params_2["tasks"]["fit-val"]["dataset_params"].copy() - old_params_2["tasks"]["fit-val"]["dataset_params"] = {} - old_params_2["tasks"]["fit-val"]["dataset_params"]["val-1"] = dataset_params - old_params_2["tasks"]["fit-val"]["dataset_params"]["val-2"] = dataset_params + old_params_2["tasks"]["fit-val"]["dataset_params"] = { + "val-1": dataset_params, + "val-2": dataset_params, + } loaders, n_batches, class_dist, params = mdt.make_loaders( # type: ignore[arg-type] **old_params_2, @@ -255,9 +253,8 @@ def test_make_loaders(default_config: DictConfig) -> None: def test_get_manifest( data_root: Path, - exp_dataset_params: Dict[str, Any], - exp_loader_params: Dict[str, Any], - exp_sampler_params: Dict[str, Any], + exp_dataset_params: dict[str, Any], + exp_sampler_params: dict[str, Any], ) -> None: manifest_path = Path("tests", "tmp", "cache", "Chesapeake7_Manifest.csv") diff --git a/tests/test_datasets/test_paired.py b/tests/test_datasets/test_paired.py index cf916f7cf..e52a38049 100644 --- a/tests/test_datasets/test_paired.py +++ b/tests/test_datasets/test_paired.py @@ -38,8 +38,6 @@ # ===================================================================================================================== # IMPORTS # ===================================================================================================================== -from typing import Tuple - import matplotlib.pyplot as plt import pytest from rasterio.crs import CRS @@ -94,7 +92,6 @@ def test_paired_geodatasets(img_root: Path) -> None: assert isinstance(dataset.crs, CRS) assert isinstance(getattr(dataset, "crs"), CRS) assert isinstance(dataset.dataset, TstImgDataset) - assert isinstance(dataset.__getattr__("dataset"), TstImgDataset) with pytest.raises(AttributeError): _ = dataset.roi @@ -163,7 +160,7 @@ def test_paired_nongeodatasets(data_root: Path) -> None: def test_paired_concat_datasets( - data_root: Path, small_patch_size: Tuple[int, int] + data_root: Path, small_patch_size: tuple[int, int] ) -> None: def dataset_test(_dataset) -> None: for sub_dataset in _dataset.datasets: diff --git a/tests/test_logging.py b/tests/test_logging.py index f265ee440..eda99d2bb 100644 --- a/tests/test_logging.py +++ b/tests/test_logging.py @@ -41,7 +41,7 @@ # ===================================================================================================================== # IMPORTS # ===================================================================================================================== -from typing import Any, Dict, List, Tuple, Union +from typing import Any import numpy as np import torch @@ -62,7 +62,7 @@ from minerva.logger.steplog import SupervisedStepLogger from minerva.logger.tasklog import SSLTaskLogger, SupervisedTaskLogger from minerva.loss import SegBarlowTwinsLoss -from minerva.modelio import ssl_pair_tg, sup_tg +from minerva.modelio import ssl_pair_torchgeo_io, supervised_torchgeo_io from minerva.models import FCN16ResNet18, MinervaSiamese, SimCLR18, SimConv from minerva.utils import utils @@ -80,7 +80,7 @@ def test_SupervisedStepLogger( std_n_batches: int, std_n_classes: int, std_batch_size: int, - small_patch_size: Tuple[int, int], + small_patch_size: tuple[int, int], default_device: torch.device, train: bool, model_type: str, @@ -134,29 +134,29 @@ def test_SupervisedStepLogger( record_float=True, writer=writer, model_type=model_type, - step_logger_params={"params": {"n_classes": std_n_classes}}, + step_logger_params={"n_classes": std_n_classes}, ) - correct_loss: Dict[str, List[float]] = {"x": [], "y": []} - correct_acc: Dict[str, List[float]] = {"x": [], "y": []} - correct_miou: Dict[str, List[float]] = {"x": [], "y": []} + correct_loss: dict[str, list[float]] = {"x": [], "y": []} + correct_acc: dict[str, list[float]] = {"x": [], "y": []} + correct_miou: dict[str, list[float]] = {"x": [], "y": []} for epoch_no in range(n_epochs): - data: List[Dict[str, Union[Tensor, List[Any]]]] = [] + data: list[dict[str, Tensor | list[Any]]] = [] for i in range(std_n_batches): images = torch.rand(size=(std_batch_size, 4, *small_patch_size)) masks = torch.randint( # type: ignore[attr-defined] 0, std_n_classes, (std_batch_size, *small_patch_size) ) bboxes = [simple_bbox] * std_batch_size - batch: Dict[str, Union[Tensor, List[Any]]] = { + batch: dict[str, Tensor | list[Any]] = { "image": images, "mask": masks, "bbox": bboxes, } data.append(batch) - logger.step(i, i, *sup_tg(batch, model, device=default_device, train=train)) # type: ignore[arg-type] + logger.step(i, i, *supervised_torchgeo_io(batch, model, device=default_device, train=train)) # type: ignore[arg-type] # noqa: E501 logs = logger.get_logs assert logs["batch_num"] == std_n_batches - 1 @@ -184,7 +184,7 @@ def test_SupervisedStepLogger( (std_n_batches, std_batch_size, *output_shape), dtype=np.uint8 ) for i in range(std_n_batches): - mask: Union[Tensor, List[Any]] = data[i]["mask"] + mask: Tensor | list[Any] = data[i]["mask"] assert isinstance(mask, Tensor) y[i] = mask.cpu().numpy() @@ -239,7 +239,7 @@ def test_SSLStepLogger( simple_bbox: BoundingBox, std_n_batches: int, std_batch_size: int, - small_patch_size: Tuple[int, int], + small_patch_size: tuple[int, int], default_device: torch.device, model_cls: MinervaSiamese, model_type: str, @@ -284,9 +284,9 @@ def test_SSLStepLogger( sample_pairs=True, ) - correct_loss: Dict[str, List[float]] = {"x": [], "y": []} - correct_collapse_level: Dict[str, List[float]] = {"x": [], "y": []} - correct_euc_dist: Dict[str, List[float]] = {"x": [], "y": []} + correct_loss: dict[str, list[float]] = {"x": [], "y": []} + correct_collapse_level: dict[str, list[float]] = {"x": [], "y": []} + correct_euc_dist: dict[str, list[float]] = {"x": [], "y": []} for epoch_no in range(n_epochs): for i in range(std_n_batches): @@ -297,7 +297,7 @@ def test_SSLStepLogger( logger.step( i, i, - *ssl_pair_tg((batch, batch), model, device=default_device, train=train), # type: ignore[arg-type] + *ssl_pair_torchgeo_io((batch, batch), model, device=default_device, train=train), # type: ignore[arg-type] # noqa: E501 ) logs = logger.get_logs diff --git a/tests/test_modelio.py b/tests/test_modelio.py index e7f77d300..b1164e6b7 100644 --- a/tests/test_modelio.py +++ b/tests/test_modelio.py @@ -37,7 +37,7 @@ # IMPORTS # ===================================================================================================================== import importlib -from typing import Any, Dict, List, Tuple, Union +from typing import Any import torch import torch.nn.modules as nn @@ -53,20 +53,20 @@ from torch import Tensor from torchgeo.datasets.utils import BoundingBox -from minerva.modelio import autoencoder_io, ssl_pair_tg, sup_tg +from minerva.modelio import autoencoder_io, ssl_pair_torchgeo_io, supervised_torchgeo_io from minerva.models import FCN32ResNet18, SimCLR34 # ===================================================================================================================== # TESTS # ===================================================================================================================== -def test_sup_tg( +def test_supervised_torchgeo_io( simple_bbox: BoundingBox, random_rgbi_batch: Tensor, random_mask_batch: Tensor, std_batch_size: int, std_n_classes: int, - rgbi_input_size: Tuple[int, int, int], + rgbi_input_size: tuple[int, int, int], default_device: torch.device, ) -> None: criterion = nn.CrossEntropyLoss() @@ -76,13 +76,13 @@ def test_sup_tg( for train in (True, False): bboxes = [simple_bbox] * std_batch_size - batch: Dict[str, Union[Tensor, List[Any]]] = { + batch: dict[str, Tensor | list[Any]] = { "image": random_rgbi_batch, "mask": random_mask_batch, "bbox": bboxes, } - results = sup_tg(batch, model, default_device, train) + results = supervised_torchgeo_io(batch, model, default_device, train) assert isinstance(results[0], Tensor) assert isinstance(results[1], Tensor) @@ -95,10 +95,10 @@ def test_sup_tg( assert results[3] == batch["bbox"] -def test_ssl_pair_tg( +def test_ssl_pair_torchgeo_io( simple_bbox: BoundingBox, std_batch_size: int, - rgbi_input_size: Tuple[int, int, int], + rgbi_input_size: tuple[int, int, int], default_device: torch.device, ) -> None: criterion = NTXentLoss(0.5) @@ -123,7 +123,7 @@ def test_ssl_pair_tg( "bbox": bboxes_2, } - results = ssl_pair_tg((batch_1, batch_2), model, default_device, train) + results = ssl_pair_torchgeo_io((batch_1, batch_2), model, default_device, train) assert isinstance(results[0], Tensor) assert isinstance(results[1], Tensor) @@ -138,7 +138,7 @@ def test_mask_autoencoder_io( simple_bbox: BoundingBox, std_batch_size: int, std_n_classes: int, - rgbi_input_size: Tuple[int, int, int], + rgbi_input_size: tuple[int, int, int], default_device: torch.device, ) -> None: criterion = nn.CrossEntropyLoss() @@ -152,7 +152,7 @@ def test_mask_autoencoder_io( images = torch.rand(size=(std_batch_size, *rgbi_input_size)) masks = torch.randint(0, 8, (std_batch_size, *rgbi_input_size[1:])) # type: ignore[attr-defined] bboxes = [simple_bbox] * std_batch_size - batch: Dict[str, Union[Tensor, List[Any]]] = { + batch: dict[str, Tensor | list[Any]] = { "image": images, "mask": masks, "bbox": bboxes, @@ -191,7 +191,7 @@ def test_image_autoencoder_io( random_rgbi_batch: Tensor, random_mask_batch: Tensor, std_batch_size: int, - rgbi_input_size: Tuple[int, int, int], + rgbi_input_size: tuple[int, int, int], default_device: torch.device, ) -> None: criterion = nn.CrossEntropyLoss() @@ -203,7 +203,7 @@ def test_image_autoencoder_io( for train in (True, False): bboxes = [simple_bbox] * std_batch_size - batch: Dict[str, Union[Tensor, List[Any]]] = { + batch: dict[str, Tensor | list[Any]] = { "image": random_rgbi_batch, "mask": random_mask_batch, "bbox": bboxes, diff --git a/tests/test_models/test_core.py b/tests/test_models/test_core.py index 6a2dd5841..181e504ea 100644 --- a/tests/test_models/test_core.py +++ b/tests/test_models/test_core.py @@ -39,7 +39,6 @@ import importlib import os from platform import python_version -from typing import Tuple import internet_sabotage import numpy as np @@ -114,7 +113,7 @@ def test_minerva_model(x_entropy_loss, std_n_classes: int, std_n_batches: int) - assert z.size() == (std_n_batches, std_n_classes) -def test_minerva_backbone(rgbi_input_size: Tuple[int, int, int]) -> None: +def test_minerva_backbone(rgbi_input_size: tuple[int, int, int]) -> None: loss_func = NTXentLoss(0.3) model = SimCLR18(loss_func, input_size=rgbi_input_size) @@ -124,7 +123,7 @@ def test_minerva_backbone(rgbi_input_size: Tuple[int, int, int]) -> None: def test_minerva_wrapper( x_entropy_loss, - small_patch_size: Tuple[int, int], + small_patch_size: tuple[int, int], std_n_classes: int, std_n_batches: int, ) -> None: diff --git a/tests/test_models/test_fcn.py b/tests/test_models/test_fcn.py index 15cb845de..2cf91cfee 100644 --- a/tests/test_models/test_fcn.py +++ b/tests/test_models/test_fcn.py @@ -36,8 +36,6 @@ # ===================================================================================================================== # IMPORTS # ===================================================================================================================== -from typing import Tuple - import pytest import torch from torch import Tensor @@ -86,7 +84,7 @@ def test_fcn( random_mask_batch: Tensor, std_batch_size: int, std_n_classes: int, - rgbi_input_size: Tuple[int, int, int], + rgbi_input_size: tuple[int, int, int], ) -> None: model: MinervaModel = model_cls(x_entropy_loss, input_size=rgbi_input_size) optimiser = torch.optim.SGD(model.parameters(), lr=1.0e-3) diff --git a/tests/test_models/test_resnets.py b/tests/test_models/test_resnets.py index bb4f0fdd6..d6a1f34ca 100644 --- a/tests/test_models/test_resnets.py +++ b/tests/test_models/test_resnets.py @@ -36,8 +36,6 @@ # ===================================================================================================================== # IMPORTS # ===================================================================================================================== -from typing import Tuple - import pytest import torch from torch import LongTensor, Tensor @@ -94,7 +92,7 @@ def test_resnets( random_scene_classification_batch: LongTensor, std_batch_size: int, std_n_classes: int, - rgbi_input_size: Tuple[int, int, int], + rgbi_input_size: tuple[int, int, int], ) -> None: model: MinervaModel = model_cls( x_entropy_loss, input_size=rgbi_input_size, zero_init_residual=zero_init @@ -121,7 +119,7 @@ def test_replace_stride( random_scene_classification_batch: LongTensor, std_batch_size: int, std_n_classes: int, - rgbi_input_size: Tuple[int, int, int], + rgbi_input_size: tuple[int, int, int], ) -> None: for model in ( ResNet50(x_entropy_loss, input_size=rgbi_input_size), @@ -144,7 +142,7 @@ def test_replace_stride( def test_resnet_encoder( x_entropy_loss, random_rgbi_batch: Tensor, - rgbi_input_size: Tuple[int, int, int], + rgbi_input_size: tuple[int, int, int], ) -> None: encoder = ResNet18(x_entropy_loss, input_size=rgbi_input_size, encoder=True) optimiser = torch.optim.SGD(encoder.parameters(), lr=1.0e-3) @@ -157,7 +155,7 @@ def test_resnet_encoder( assert len(encoder(random_rgbi_batch)) == 5 -def test_preload_weights(rgbi_input_size: Tuple[int, int, int]) -> None: +def test_preload_weights(rgbi_input_size: tuple[int, int, int]) -> None: resnet = ResNet(BasicBlock, [2, 2, 2, 2]) new_resnet = _preload_weights(resnet, None, rgbi_input_size, encoder_on=False) diff --git a/tests/test_models/test_unet.py b/tests/test_models/test_unet.py index 18ee9df1a..0c4880b92 100644 --- a/tests/test_models/test_unet.py +++ b/tests/test_models/test_unet.py @@ -36,8 +36,6 @@ # ===================================================================================================================== # IMPORTS # ===================================================================================================================== -from typing import Tuple - import pytest import torch from torch import Tensor @@ -62,7 +60,7 @@ def unet_test( y: Tensor, batch_size: int, n_classes: int, - input_size: Tuple[int, int, int], + input_size: tuple[int, int, int], ) -> None: optimiser = torch.optim.SGD(model.parameters(), lr=1.0e-3) model.set_optimiser(optimiser) @@ -97,7 +95,7 @@ def test_unetrs( random_mask_batch: Tensor, std_batch_size: int, std_n_classes: int, - rgbi_input_size: Tuple[int, int, int], + rgbi_input_size: tuple[int, int, int], ) -> None: model: MinervaModel = model_cls(x_entropy_loss, rgbi_input_size) @@ -117,7 +115,7 @@ def test_unet( random_mask_batch: Tensor, std_batch_size: int, std_n_classes: int, - rgbi_input_size: Tuple[int, int, int], + rgbi_input_size: tuple[int, int, int], ) -> None: model = UNet(x_entropy_loss, input_size=rgbi_input_size) unet_test( diff --git a/tests/test_samplers.py b/tests/test_samplers.py index e01d697ed..42cc08cd3 100644 --- a/tests/test_samplers.py +++ b/tests/test_samplers.py @@ -38,7 +38,7 @@ # ===================================================================================================================== from collections import defaultdict from pathlib import Path -from typing import Any, Dict +from typing import Any import pytest from torch.utils.data import DataLoader @@ -60,7 +60,7 @@ def test_randompairgeosampler(img_root: Path) -> None: dataset = PairedGeoDataset(TstImgDataset, str(img_root), res=1.0) sampler = RandomPairGeoSampler(dataset, size=32, length=32, max_r=52) - loader: DataLoader[Dict[str, Any]] = DataLoader( + loader: DataLoader[dict[str, Any]] = DataLoader( dataset, batch_size=8, sampler=sampler, collate_fn=stack_sample_pairs ) @@ -78,7 +78,7 @@ def test_randompairbatchgeosampler(img_root: Path) -> None: sampler = RandomPairBatchGeoSampler( dataset, size=32, length=32, batch_size=8, max_r=52, tiles_per_batch=1 ) - loader: DataLoader[Dict[str, Any]] = DataLoader( + loader: DataLoader[dict[str, Any]] = DataLoader( dataset, batch_sampler=sampler, collate_fn=stack_sample_pairs ) diff --git a/tests/test_tasks/test_tasks_core.py b/tests/test_tasks/test_tasks_core.py index f2924fc7f..353fb3b04 100644 --- a/tests/test_tasks/test_tasks_core.py +++ b/tests/test_tasks/test_tasks_core.py @@ -45,4 +45,4 @@ # ===================================================================================================================== def test_get_task(): with pytest.raises(TypeError): - _ = get_task("MinervaTask") + _ = get_task("minerva.tasks.MinervaTask") diff --git a/tests/test_trainer.py b/tests/test_trainer.py index 7015b0a8f..e025728bc 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -38,7 +38,7 @@ import shutil from copy import deepcopy from pathlib import Path -from typing import Any, Dict, Optional, Union +from typing import Any, Optional import hydra import pytest @@ -56,9 +56,7 @@ # TESTS # ===================================================================================================================== @runner.distributed_run -def run_trainer( - gpu: int, wandb_run: Optional[Union[Run, RunDisabled]], cfg: DictConfig -): +def run_trainer(gpu: int, wandb_run: Optional[Run | RunDisabled], cfg: DictConfig): params = deepcopy(cfg) params["calc_norm"] = True @@ -145,10 +143,17 @@ def test_trainer_3(default_config: DictConfig) -> None: params1 = deepcopy(default_config) trainer1 = Trainer(0, **params1) - trainer1.save_model(fn=trainer1.get_model_cache_path()) + + pre_train_path = trainer1.get_model_cache_path() + trainer1.save_model(fn=pre_train_path, fmt="onnx") params2 = deepcopy(default_config) - OmegaConf.update(params2, "pre_train_name", params1["model_name"], force_add=True) + OmegaConf.update( + params2, + "pre_train_name", + str(pre_train_path.with_suffix(".onnx")), + force_add=True, + ) params2["fine_tune"] = True params2["max_epochs"] = 2 params2["elim"] = False @@ -177,8 +182,8 @@ def test_trainer_3(default_config: DictConfig) -> None: def test_trainer_4( inbuilt_cfg_root: Path, cfg_name: str, - cfg_args: Dict[str, Any], - kwargs: Dict[str, Any], + cfg_args: dict[str, Any], + kwargs: dict[str, Any], ) -> None: with hydra.initialize(version_base="1.3", config_path=str(inbuilt_cfg_root)): diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 62d210c54..96a3c1f54 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -36,7 +36,7 @@ # ===================================================================================================================== # IMPORTS # ===================================================================================================================== -from typing import Any, Dict, List, Optional +from typing import Any, Optional import pytest import torch @@ -77,7 +77,7 @@ ], ) def test_class_transform( - example_matrix: Dict[int, int], input_mask: LongTensor, output: LongTensor + example_matrix: dict[int, int], input_mask: LongTensor, output: LongTensor ) -> None: transform = ClassTransform(example_matrix) @@ -276,7 +276,7 @@ def test_dublicator( ], ) def test_tg_to_torch( - transform, keys: Optional[List[str]], args: Any, in_img, expected + transform, keys: Optional[list[str]], args: Any, in_img, expected ) -> None: transformation = (utils.tg_to_torch(transform, keys=keys))(args) @@ -376,22 +376,21 @@ def test_init_auto_norm(default_image_dataset: RasterDataset, transforms) -> Non and transforms is not None ): with pytest.raises(TypeError): - _ = init_auto_norm(default_image_dataset, params) + _ = init_auto_norm(default_image_dataset, **params) else: - dataset = init_auto_norm(default_image_dataset, params) + dataset = init_auto_norm(default_image_dataset, **params) assert isinstance(dataset, RasterDataset) assert isinstance(dataset.transforms.transforms[-1], AutoNorm) # type: ignore[union-attr] def test_get_transform() -> None: - name = "RandomResizedCrop" - params = {"module": "torchvision.transforms", "size": 128} - transform = get_transform(name, params) + params = {"_target_": "torchvision.transforms.RandomResizedCrop", "size": 128} + transform = get_transform(params) assert callable(transform) with pytest.raises(TypeError): - _ = get_transform("DataFrame", {"module": "pandas"}) + _ = get_transform({"_target_": "pandas.DataFrame"}) @pytest.mark.parametrize( @@ -399,35 +398,50 @@ def test_get_transform() -> None: [ ( { - "CenterCrop": {"module": "torchvision.transforms", "size": 128}, - "RandomHorizontalFlip": {"module": "torchvision.transforms", "p": 0.7}, + "crop": {"_target_": "torchvision.transforms.CenterCrop", "size": 128}, + "flip": { + "_target_": "torchvision.transforms.RandomHorizontalFlip", + "p": 0.7, + }, }, "mask", ), ( { "RandomApply": { - "CenterCrop": {"module": "torchvision.transforms", "size": 128}, + "crop": { + "_target_": "torchvision.transforms.CenterCrop", + "size": 128, + }, "p": 0.3, }, - "RandomHorizontalFlip": {"module": "torchvision.transforms", "p": 0.7}, + "flip": { + "_target_": "torchvision.transforms.RandomHorizontalFlip", + "p": 0.7, + }, }, "image", ), ( { - "CenterCrop": {"module": "torchvision.transforms", "size": 128}, + "crop": {"_target_": "torchvision.transforms.CenterCrop", "size": 128}, "RandomApply": { - "CenterCrop": {"module": "torchvision.transforms", "size": 128}, + "crop": { + "_target_": "torchvision.transforms.CenterCrop", + "size": 128, + }, "p": 0.3, }, - "RandomHorizontalFlip": {"module": "torchvision.transforms", "p": 0.7}, + "flip": { + "_target_": "torchvision.transforms.RandomHorizontalFlip", + "p": 0.7, + }, }, "image", ), ], ) -def test_make_transformations(params: Dict[str, Any], key: str) -> None: +def test_make_transformations(params: dict[str, Any], key: str) -> None: if params: transforms = make_transformations({key: params}) assert callable(transforms) diff --git a/tests/test_utils/test_runner.py b/tests/test_utils/test_runner.py index 0d07128e1..980f48eab 100644 --- a/tests/test_utils/test_runner.py +++ b/tests/test_utils/test_runner.py @@ -39,7 +39,7 @@ import os import subprocess import time -from typing import Optional, Union +from typing import Optional import pytest import requests @@ -111,7 +111,7 @@ def test_config_env_vars(default_config: DictConfig) -> None: @runner.distributed_run def _run_func( - gpu: int, wandb_run: Optional[Union[Run, RunDisabled]], cfg: DictConfig + gpu: int, wandb_run: Optional[Run | RunDisabled], cfg: DictConfig ) -> None: time.sleep(0.5) return diff --git a/tests/test_utils/test_utils.py b/tests/test_utils/test_utils.py index 34c6e6f1f..f88d9298e 100644 --- a/tests/test_utils/test_utils.py +++ b/tests/test_utils/test_utils.py @@ -45,7 +45,7 @@ from collections import defaultdict from datetime import datetime from pathlib import Path -from typing import Any, Dict, List, Tuple, Union +from typing import Any import numpy as np import pandas as pd @@ -85,7 +85,7 @@ def test_is_notebook() -> None: def test_return_updated_kwargs() -> None: @utils.return_updated_kwargs - def example_func(*args, **kwargs) -> Tuple[Any, Dict[str, Any]]: + def example_func(*args, **kwargs) -> tuple[Any, dict[str, Any]]: _ = ( kwargs["update_1"] * kwargs["update_3"] - kwargs["static_2"] / args[1] * args[0] @@ -196,13 +196,13 @@ def test_ohe_labels() -> None: assert_array_equal(correct_targets, targets) -def test_empty_classes(exp_classes: Dict[int, str]) -> None: +def test_empty_classes(exp_classes: dict[int, str]) -> None: class_distribution = [(3, 321), (4, 112), (1, 671), (5, 456)] assert utils.find_empty_classes(class_distribution, exp_classes) == [0, 2, 6, 7] def test_eliminate_classes( - exp_classes: Dict[int, str], exp_cmap_dict: Dict[int, str] + exp_classes: dict[int, str], exp_cmap_dict: dict[int, str] ) -> None: empty = [0, 2, 7] new_classes = { @@ -246,12 +246,12 @@ def test_eliminate_classes( ], ) def test_check_test_empty( - exp_classes: Dict[int, str], - in_labels: List[int], - in_pred: List[int], - out_labels: List[int], - out_pred: List[int], - out_classes: Dict[int, str], + exp_classes: dict[int, str], + in_labels: list[int], + in_pred: list[int], + out_labels: list[int], + out_pred: list[int], + out_classes: dict[int, str], ) -> None: results = utils.check_test_empty(in_pred, in_labels, exp_classes) @@ -260,7 +260,7 @@ def test_check_test_empty( assert results[2] == out_classes -def test_find_modes(exp_classes: Dict[int, str]) -> None: +def test_find_modes(exp_classes: dict[int, str]) -> None: labels = [1, 1, 3, 5, 1, 4, 1, 5, 3, 3] class_dist = utils.find_modes(labels, plot=True) @@ -370,12 +370,12 @@ def test_batch_flatten(x, exp_len: int) -> None: ], ) def test_transform_coordinates( - x: Union[List[float], float], - y: Union[List[float], float], + x: list[float] | float, + y: list[float] | float, src_crs: CRS, dest_crs: CRS, - exp_x: Union[List[float], float], - exp_y: Union[List[float], float], + exp_x: list[float] | float, + exp_y: list[float] | float, ) -> None: out_x, out_y = utils.transform_coordinates(x, y, src_crs, dest_crs) @@ -451,9 +451,7 @@ def test_get_centre_loc() -> None: (-77.844504, 166.707506, "McMurdo Station"), # McMurdo Station, Antartica. ], ) -def test_lat_lon_to_loc( - lat: Union[float, str], lon: Union[float, str], loc: str -) -> None: +def test_lat_lon_to_loc(lat: float | str, lon: float | str, loc: str) -> None: try: requests.head("http://www.google.com/", timeout=1.0) except (requests.ConnectionError, requests.ReadTimeout): @@ -546,7 +544,7 @@ def test_find_best_of() -> None: assert scene == ["2018_06_21"] -def test_modes_from_manifest(exp_classes: Dict[int, str]) -> None: +def test_modes_from_manifest(exp_classes: dict[int, str]) -> None: df = pd.DataFrame() class_dist = [ @@ -608,7 +606,7 @@ def test_compute_roc_curves() -> None: labels = [0, 3, 2, 1, 3, 2, 1, 0] class_labels = [0, 1, 2, 3] - fpr: Dict[Any, NDArray[Any, Any]] = { + fpr: dict[Any, NDArray[Any, Any]] = { 0: np.array([0.0, 0.0, 0.0, 0.5, 5.0 / 6.0, 1.0]), 1: np.array([0.0, 0.0, 1.0 / 6.0, 1.0 / 6.0, 1.0]), 2: np.array([0.0, 0.0, 1.0 / 6.0, 0.5, 0.5, 1.0]), @@ -637,7 +635,7 @@ def test_compute_roc_curves() -> None: "macro": np.array([0.0, 1.0 / 6.0, 0.5, 5.0 / 6.0, 1.0]), } - tpr: Dict[Any, NDArray[Any, Any]] = { + tpr: dict[Any, NDArray[Any, Any]] = { 0: np.array([0.0, 0.5, 1.0, 1.0, 1.0, 1.0]), 1: np.array([0.0, 0.5, 0.5, 1.0, 1.0]), 2: np.array([0.0, 0.5, 0.5, 0.5, 1.0, 1.0]), diff --git a/tests/test_utils/test_visutils.py b/tests/test_utils/test_visutils.py index e2820a3f0..8542a8616 100644 --- a/tests/test_utils/test_visutils.py +++ b/tests/test_utils/test_visutils.py @@ -40,7 +40,7 @@ import shutil import tempfile from pathlib import Path -from typing import Any, Dict, Tuple +from typing import Any import matplotlib as mlp import numpy as np @@ -116,7 +116,7 @@ def test_get_mlp_cmap() -> None: def test_discrete_heatmap( - random_mask, exp_classes: Dict[int, str], exp_cmap_dict: Dict[int, str] + random_mask, exp_classes: dict[int, str], exp_cmap_dict: dict[int, str] ) -> None: cmap = ListedColormap(exp_cmap_dict.values()) # type: ignore visutils.discrete_heatmap(random_mask, list(exp_classes.values()), cmap_style=cmap) @@ -151,8 +151,8 @@ def test_labelled_rgb_image( random_mask, random_image, bounds_for_test_img, - exp_classes: Dict[int, str], - exp_cmap_dict: Dict[int, str], + exp_classes: dict[int, str], + exp_cmap_dict: dict[int, str], ) -> None: path = tempfile.gettempdir() name = "pretty_pic" @@ -174,7 +174,7 @@ def test_labelled_rgb_image( def test_make_gif( - bounds_for_test_img, exp_classes: Dict[int, str], exp_cmap_dict: Dict[int, str] + bounds_for_test_img, exp_classes: dict[int, str], exp_cmap_dict: dict[int, str] ) -> None: dates = ["2018-01-15", "2018-07-03", "2018-11-30"] images = np.random.rand(3, 32, 32, 3) @@ -209,7 +209,7 @@ def test_prediction_plot( random_image, random_mask, bounds_for_test_img, - exp_classes: Dict[int, str], + exp_classes: dict[int, str], results_dir: Path, ) -> None: pred = np.random.randint(0, 8, size=(32, 32)) @@ -229,9 +229,9 @@ def test_seg_plot( results_root: Path, data_root: Path, default_dataset: GeoDataset, - exp_dataset_params: Dict[str, Any], - exp_classes: Dict[int, str], - exp_cmap_dict: Dict[int, str], + exp_dataset_params: dict[str, Any], + exp_classes: dict[int, str], + exp_cmap_dict: dict[int, str], cache_dir: Path, monkeypatch, ) -> None: @@ -271,7 +271,7 @@ def test_seg_plot( def test_plot_subpopulations( - exp_classes: Dict[int, str], exp_cmap_dict: Dict[int, str] + exp_classes: dict[int, str], exp_cmap_dict: dict[int, str] ) -> None: class_dist = [(1, 25000), (0, 1300), (2, 100), (3, 2)] @@ -305,7 +305,7 @@ def test_plot_history() -> None: filename.unlink(missing_ok=True) -def test_make_confusion_matrix(exp_classes: Dict[int, str]) -> None: +def test_make_confusion_matrix(exp_classes: dict[int, str]) -> None: batch_size = 2 patch_size = (32, 32) @@ -346,11 +346,11 @@ def test_format_names() -> None: def test_plot_results( default_dataset: GeoDataset, - exp_classes: Dict[int, str], - exp_cmap_dict: Dict[int, str], + exp_classes: dict[int, str], + exp_cmap_dict: dict[int, str], results_dir: Path, - default_config: Dict[str, Any], - small_patch_size: Tuple[int, int], + default_config: dict[str, Any], + small_patch_size: tuple[int, int], std_batch_size: int, std_n_classes: int, ) -> None: @@ -420,7 +420,7 @@ def test_plot_results( def test_plot_embeddings( results_root: Path, default_dataset: GeoDataset, - exp_dataset_params: Dict[str, Any], + exp_dataset_params: dict[str, Any], data_root: Path, cache_dir: Path, ) -> None: diff --git a/tox.ini b/tox.ini index bfea17c61..6e9b2d9ec 100644 --- a/tox.ini +++ b/tox.ini @@ -1,6 +1,6 @@ [tox] minversion = 3.8.0 -envlist = minerva-{312, 311, 310, 39} +envlist = minerva-{312, 311, 310} isolated_build = true skip_missing_interpreters = true @@ -9,7 +9,6 @@ python = 3.12: minerva-312 3.11: minerva-311 3.10: minerva-310 - 3.9: minerva-39 [testenv] skip_install = true @@ -33,7 +32,3 @@ deps = -r{toxinidir}/requirements/requirements_dev.txt [testenv:minerva-310] basepython = python3.10 deps = -r{toxinidir}/requirements/requirements_dev.txt - -[testenv:minerva-39] -basepython = python3.9 -deps = -r{toxinidir}/requirements/requirements_dev.txt