From ca4d467b8a97e37b6a9f93aafc3c49bc4d84a9dc Mon Sep 17 00:00:00 2001 From: guarin <43336610+guarin@users.noreply.github.com> Date: Wed, 25 Sep 2024 08:21:53 +0200 Subject: [PATCH] Guarin lig 5512 fix local type check (#1654) * Fix typing inconsistencies between Python versions * Remove python_version mypy flag --- lightly/models/batchnorm.py | 33 ++++++++++--------- lightly/transforms/dino_transform.py | 3 +- lightly/transforms/jigsaw.py | 13 +++++--- lightly/transforms/rotation.py | 6 ++-- .../utils/benchmarking/linear_classifier.py | 4 ++- lightly/utils/lars.py | 18 +++++++--- pyproject.toml | 3 +- tests/helpers.py | 19 +++++++++++ tests/transforms/__init__.py | 0 tests/transforms/test_byol_transform.py | 4 ++- tests/transforms/test_densecl_transform.py | 4 ++- tests/transforms/test_dino_transform.py | 4 ++- tests/transforms/test_fastsiam_transform.py | 4 ++- tests/transforms/test_moco_transform.py | 6 ++-- tests/transforms/test_msn_transform.py | 4 ++- tests/transforms/test_pirl_transform.py | 4 ++- tests/transforms/test_simclr_transform.py | 4 ++- tests/transforms/test_simsiam_transform.py | 4 ++- tests/transforms/test_smog_transform.py | 4 ++- tests/transforms/test_swav_transform.py | 4 ++- tests/transforms/test_vicreg_transform.py | 4 ++- tests/transforms/test_vicregl_transform.py | 7 +++- 22 files changed, 110 insertions(+), 46 deletions(-) create mode 100644 tests/helpers.py create mode 100644 tests/transforms/__init__.py diff --git a/lightly/models/batchnorm.py b/lightly/models/batchnorm.py index 7f05fe48e..4653882bc 100644 --- a/lightly/models/batchnorm.py +++ b/lightly/models/batchnorm.py @@ -56,29 +56,30 @@ def forward(self, input: Tensor) -> Tensor: # during training, use different stats for each split and otherwise # use the stats from the first split + momentum = 0.0 if self.momentum is None else self.momentum if self.training or not self.track_running_stats: result = nn.functional.batch_norm( - input.view(-1, C * self.num_splits, H, W), - self.running_mean, - self.running_var, - self.weight.repeat(self.num_splits), - self.bias.repeat(self.num_splits), - True, - self.momentum, - self.eps, + input=input.view(-1, C * self.num_splits, H, W), + running_mean=self.running_mean, + running_var=self.running_var, + weight=self.weight.repeat(self.num_splits), + bias=self.bias.repeat(self.num_splits), + training=True, + momentum=momentum, + eps=self.eps, ).view(N, C, H, W) else: # We have to ignore the type errors here, because we know that running_mean # and running_var are not None, but the type checker does not. result = nn.functional.batch_norm( - input, - self.running_mean[: self.num_features], # type: ignore[index] - self.running_var[: self.num_features], # type: ignore[index] - self.weight, - self.bias, - False, - self.momentum, - self.eps, + input=input, + running_mean=self.running_mean[: self.num_features], # type: ignore[index] + running_var=self.running_var[: self.num_features], # type: ignore[index] + weight=self.weight, + bias=self.bias, + training=False, + momentum=momentum, + eps=self.eps, ) return result diff --git a/lightly/transforms/dino_transform.py b/lightly/transforms/dino_transform.py index b88624890..4c59b39e7 100644 --- a/lightly/transforms/dino_transform.py +++ b/lightly/transforms/dino_transform.py @@ -219,7 +219,8 @@ def __init__( T.RandomResizedCrop( size=crop_size, scale=crop_scale, - interpolation=PIL.Image.BICUBIC, + # Type ignore needed because BICUBIC is not recognized as an attribute. + interpolation=PIL.Image.BICUBIC, # type: ignore[attr-defined] ), T.RandomHorizontalFlip(p=hf_prob), T.RandomVerticalFlip(p=vf_prob), diff --git a/lightly/transforms/jigsaw.py b/lightly/transforms/jigsaw.py index adebb808b..4ed24bb1c 100644 --- a/lightly/transforms/jigsaw.py +++ b/lightly/transforms/jigsaw.py @@ -1,7 +1,7 @@ # Copyright (c) 2021. Lightly AG and its affiliates. # All Rights Reserved -from typing import List +from typing import TYPE_CHECKING, Callable, List import numpy as np import torch @@ -10,6 +10,9 @@ from torch import Tensor from torchvision import transforms as T +if TYPE_CHECKING: + from numpy.typing import NDArray + class Jigsaw(object): """Implementation of Jigsaw image augmentation, inspired from PyContrast library. @@ -49,7 +52,7 @@ def __init__( self.crop_size = crop_size self.grid_size = int(img_size / self.n_grid) self.side = self.grid_size - self.crop_size - self.transform = transform + self.transform: Callable[[PILImage], Tensor] = transform yy, xx = np.meshgrid(np.arange(n_grid), np.arange(n_grid)) self.yy = np.reshape(yy * self.grid_size, (n_grid * n_grid,)) @@ -66,11 +69,11 @@ def __call__(self, img: PILImage) -> Tensor: """ r_x = np.random.randint(0, self.side + 1, self.n_grid * self.n_grid) r_y = np.random.randint(0, self.side + 1, self.n_grid * self.n_grid) - img = np.asarray(img, np.uint8) - crops: List[PILImage] = [] + img_arr = np.asarray(img, np.uint8) + crops: List[NDArray[np.uint8]] = [] for i in range(self.n_grid * self.n_grid): crops.append( - img[ + img_arr[ self.xx[i] + r_x[i] : self.xx[i] + r_x[i] + self.crop_size, self.yy[i] + r_y[i] : self.yy[i] + r_y[i] + self.crop_size, :, diff --git a/lightly/transforms/rotation.py b/lightly/transforms/rotation.py index a165e735d..b6b37a63b 100644 --- a/lightly/transforms/rotation.py +++ b/lightly/transforms/rotation.py @@ -1,7 +1,7 @@ # Copyright (c) 2020. Lightly AG and its affiliates. # All Rights Reserved -from typing import Tuple, Union +from typing import Callable, Tuple, Union import numpy as np import torchvision.transforms as T @@ -65,7 +65,9 @@ class RandomRotateDegrees: """ def __init__(self, prob: float, degrees: Union[float, Tuple[float, float]]): - self.transform = T.RandomApply([T.RandomRotation(degrees=degrees)], p=prob) + self.transform: Callable[ + [Union[Image, Tensor]], Union[Image, Tensor] + ] = T.RandomApply([T.RandomRotation(degrees=degrees)], p=prob) def __call__(self, image: Union[Image, Tensor]) -> Union[Image, Tensor]: """Rotates the images with a given probability. diff --git a/lightly/utils/benchmarking/linear_classifier.py b/lightly/utils/benchmarking/linear_classifier.py index 647b8dd45..acb030d47 100644 --- a/lightly/utils/benchmarking/linear_classifier.py +++ b/lightly/utils/benchmarking/linear_classifier.py @@ -129,7 +129,9 @@ def validation_step(self, batch: Tuple[Tensor, ...], batch_idx: int) -> Tensor: self.log_dict(log_dict, prog_bar=True, sync_dist=True, batch_size=batch_size) return loss - def configure_optimizers( + # Type ignore is needed because return type of LightningModule.configure_optimizers + # is complicated and typing changes between versions. + def configure_optimizers( # type: ignore[override] self, ) -> Tuple[List[Optimizer], List[Dict[str, Union[Any, str]]]]: parameters = list(self.classification_head.parameters()) diff --git a/lightly/utils/lars.py b/lightly/utils/lars.py index 063149d36..315f14559 100644 --- a/lightly/utils/lars.py +++ b/lightly/utils/lars.py @@ -1,8 +1,7 @@ -from typing import Any, Callable, Dict, Optional, Union +from typing import Any, Callable, Dict, Optional, overload import torch -from torch import Tensor -from torch.optim.optimizer import Optimizer, required # type: ignore[attr-defined] +from torch.optim.optimizer import Optimizer class LARS(Optimizer): @@ -69,7 +68,7 @@ class LARS(Optimizer): def __init__( self, params: Any, - lr: float = required, + lr: float, momentum: float = 0, dampening: float = 0, weight_decay: float = 0, @@ -77,7 +76,7 @@ def __init__( trust_coefficient: float = 0.001, eps: float = 1e-8, ): - if lr is not required and lr < 0.0: + if lr < 0.0: raise ValueError(f"Invalid learning rate: {lr}") if momentum < 0.0: raise ValueError(f"Invalid momentum value: {momentum}") @@ -104,6 +103,15 @@ def __setstate__(self, state: Dict[str, Any]) -> None: for group in self.param_groups: group.setdefault("nesterov", False) + # Type ignore for overloads is required for Python 3.7 + @overload # type: ignore[override] + def step(self, closure: None = None) -> None: + ... + + @overload # type: ignore[override] + def step(self, closure: Callable[[], float]) -> float: + ... + @torch.no_grad() def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]: """Performs a single optimization step. diff --git a/pyproject.toml b/pyproject.toml index 40bbb50ef..12363495f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -145,7 +145,6 @@ omit = ["lightly/openapi_generated/*"] [tool.mypy] ignore_missing_imports = true -python_version = "3.10" warn_unused_configs = true strict_equality = true @@ -167,7 +166,7 @@ no_implicit_optional = true strict_optional = true # Configuring warnings -warn_unused_ignores = true +warn_unused_ignores = false # Different ignores are required for different Python versions warn_no_return = true warn_return_any = true warn_redundant_casts = true diff --git a/tests/helpers.py b/tests/helpers.py new file mode 100644 index 000000000..907589037 --- /dev/null +++ b/tests/helpers.py @@ -0,0 +1,19 @@ +from typing import Any, List + +from torch import Tensor + + +def assert_list_tensor(items: Any) -> List[Tensor]: + """Makes sure that the input is a list of tensors. + + Should be used in tests where functions return Union[List[Tensor], List[Image]] and + we want to make sure that the output is a list of tensors. + + Example: + >>> output: Union[List[Tensor], List[Image]] = transform(images) + >>> tensors: List[Tensor] = assert_list_tensor(output) + + """ + assert isinstance(items, list) + assert all(isinstance(item, Tensor) for item in items) + return items diff --git a/tests/transforms/__init__.py b/tests/transforms/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/transforms/test_byol_transform.py b/tests/transforms/test_byol_transform.py index ecf72df48..b81748cb1 100644 --- a/tests/transforms/test_byol_transform.py +++ b/tests/transforms/test_byol_transform.py @@ -6,6 +6,8 @@ BYOLView2Transform, ) +from .. import helpers + def test_view_on_pil_image() -> None: single_view_transform = BYOLView1Transform(input_size=32) @@ -20,7 +22,7 @@ def test_multi_view_on_pil_image() -> None: view_2_transform=BYOLView2Transform(input_size=32), ) sample = Image.new("RGB", (100, 100)) - output = multi_view_transform(sample) + output = helpers.assert_list_tensor(multi_view_transform(sample)) assert len(output) == 2 assert output[0].shape == (3, 32, 32) assert output[1].shape == (3, 32, 32) diff --git a/tests/transforms/test_densecl_transform.py b/tests/transforms/test_densecl_transform.py index 299379dd9..3a5be00d8 100644 --- a/tests/transforms/test_densecl_transform.py +++ b/tests/transforms/test_densecl_transform.py @@ -2,11 +2,13 @@ from lightly.transforms import DenseCLTransform +from .. import helpers + def test_multi_view_on_pil_image() -> None: multi_view_transform = DenseCLTransform(input_size=32) sample = Image.new("RGB", (100, 100)) - output = multi_view_transform(sample) + output = helpers.assert_list_tensor(multi_view_transform(sample)) assert len(output) == 2 assert output[0].shape == (3, 32, 32) assert output[1].shape == (3, 32, 32) diff --git a/tests/transforms/test_dino_transform.py b/tests/transforms/test_dino_transform.py index 74bfea478..17137303b 100644 --- a/tests/transforms/test_dino_transform.py +++ b/tests/transforms/test_dino_transform.py @@ -2,6 +2,8 @@ from lightly.transforms.dino_transform import DINOTransform, DINOViewTransform +from .. import helpers + def test_view_on_pil_image() -> None: single_view_transform = DINOViewTransform(crop_size=32) @@ -13,7 +15,7 @@ def test_view_on_pil_image() -> None: def test_multi_view_on_pil_image() -> None: multi_view_transform = DINOTransform(global_crop_size=32, local_crop_size=8) sample = Image.new("RGB", (100, 100)) - output = multi_view_transform(sample) + output = helpers.assert_list_tensor(multi_view_transform(sample)) assert len(output) == 8 # global views assert all(out.shape == (3, 32, 32) for out in output[:2]) diff --git a/tests/transforms/test_fastsiam_transform.py b/tests/transforms/test_fastsiam_transform.py index 672cf0a41..224014373 100644 --- a/tests/transforms/test_fastsiam_transform.py +++ b/tests/transforms/test_fastsiam_transform.py @@ -2,11 +2,13 @@ from lightly.transforms.fast_siam_transform import FastSiamTransform +from .. import helpers + def test_multi_view_on_pil_image() -> None: multi_view_transform = FastSiamTransform(num_views=3, input_size=32) sample = Image.new("RGB", (100, 100)) - output = multi_view_transform(sample) + output = helpers.assert_list_tensor(multi_view_transform(sample)) assert len(output) == 3 assert output[0].shape == (3, 32, 32) assert output[1].shape == (3, 32, 32) diff --git a/tests/transforms/test_moco_transform.py b/tests/transforms/test_moco_transform.py index aa43a216f..ea1a332fd 100644 --- a/tests/transforms/test_moco_transform.py +++ b/tests/transforms/test_moco_transform.py @@ -2,11 +2,13 @@ from lightly.transforms.moco_transform import MoCoV1Transform, MoCoV2Transform +from .. import helpers + def test_moco_v1_multi_view_on_pil_image() -> None: multi_view_transform = MoCoV1Transform(input_size=32) sample = Image.new("RGB", (100, 100)) - output = multi_view_transform(sample) + output = helpers.assert_list_tensor(multi_view_transform(sample)) assert len(output) == 2 assert output[0].shape == (3, 32, 32) assert output[1].shape == (3, 32, 32) @@ -15,7 +17,7 @@ def test_moco_v1_multi_view_on_pil_image() -> None: def test_moco_v2_multi_view_on_pil_image() -> None: multi_view_transform = MoCoV2Transform(input_size=32) sample = Image.new("RGB", (100, 100)) - output = multi_view_transform(sample) + output = helpers.assert_list_tensor(multi_view_transform(sample)) assert len(output) == 2 assert output[0].shape == (3, 32, 32) assert output[1].shape == (3, 32, 32) diff --git a/tests/transforms/test_msn_transform.py b/tests/transforms/test_msn_transform.py index 4f5be7b53..84194e694 100644 --- a/tests/transforms/test_msn_transform.py +++ b/tests/transforms/test_msn_transform.py @@ -2,6 +2,8 @@ from lightly.transforms.msn_transform import MSNTransform, MSNViewTransform +from .. import helpers + def test_view_on_pil_image() -> None: single_view_transform = MSNViewTransform(crop_size=32) @@ -13,7 +15,7 @@ def test_view_on_pil_image() -> None: def test_multi_view_on_pil_image() -> None: multi_view_transform = MSNTransform(random_size=32, focal_size=8) sample = Image.new("RGB", (100, 100)) - output = multi_view_transform(sample) + output = helpers.assert_list_tensor(multi_view_transform(sample)) assert len(output) == 12 # global views assert all(out.shape == (3, 32, 32) for out in output[:2]) diff --git a/tests/transforms/test_pirl_transform.py b/tests/transforms/test_pirl_transform.py index 20c7c8705..f0a62d14b 100644 --- a/tests/transforms/test_pirl_transform.py +++ b/tests/transforms/test_pirl_transform.py @@ -2,11 +2,13 @@ from lightly.transforms.pirl_transform import PIRLTransform +from .. import helpers + def test_multi_view_on_pil_image() -> None: multi_view_transform = PIRLTransform(input_size=32) sample = Image.new("RGB", (100, 100)) - output = multi_view_transform(sample) + output = helpers.assert_list_tensor(multi_view_transform(sample)) assert len(output) == 2 assert output[0].shape == (3, 32, 32) assert output[1].shape == (9, 3, 10, 10) diff --git a/tests/transforms/test_simclr_transform.py b/tests/transforms/test_simclr_transform.py index 78a9a5cca..d76b39039 100644 --- a/tests/transforms/test_simclr_transform.py +++ b/tests/transforms/test_simclr_transform.py @@ -2,6 +2,8 @@ from lightly.transforms.simclr_transform import SimCLRTransform, SimCLRViewTransform +from .. import helpers + def test_view_on_pil_image() -> None: single_view_transform = SimCLRViewTransform(input_size=32) @@ -13,7 +15,7 @@ def test_view_on_pil_image() -> None: def test_multi_view_on_pil_image() -> None: multi_view_transform = SimCLRTransform(input_size=32) sample = Image.new("RGB", (100, 100)) - output = multi_view_transform(sample) + output = helpers.assert_list_tensor(multi_view_transform(sample)) assert len(output) == 2 assert output[0].shape == (3, 32, 32) assert output[1].shape == (3, 32, 32) diff --git a/tests/transforms/test_simsiam_transform.py b/tests/transforms/test_simsiam_transform.py index 39a88721a..f692ee03e 100644 --- a/tests/transforms/test_simsiam_transform.py +++ b/tests/transforms/test_simsiam_transform.py @@ -2,6 +2,8 @@ from lightly.transforms.simsiam_transform import SimSiamTransform, SimSiamViewTransform +from .. import helpers + def test_view_on_pil_image() -> None: single_view_transform = SimSiamViewTransform(input_size=32) @@ -13,7 +15,7 @@ def test_view_on_pil_image() -> None: def test_multi_view_on_pil_image() -> None: multi_view_transform = SimSiamTransform(input_size=32) sample = Image.new("RGB", (100, 100)) - output = multi_view_transform(sample) + output = helpers.assert_list_tensor(multi_view_transform(sample)) assert len(output) == 2 assert output[0].shape == (3, 32, 32) assert output[1].shape == (3, 32, 32) diff --git a/tests/transforms/test_smog_transform.py b/tests/transforms/test_smog_transform.py index 042d46f9f..783b95885 100644 --- a/tests/transforms/test_smog_transform.py +++ b/tests/transforms/test_smog_transform.py @@ -2,6 +2,8 @@ from lightly.transforms.smog_transform import SMoGTransform, SmoGViewTransform +from .. import helpers + def test_view_on_pil_image() -> None: single_view_transform = SmoGViewTransform(crop_size=32) @@ -13,7 +15,7 @@ def test_view_on_pil_image() -> None: def test_multi_view_on_pil_image() -> None: multi_view_transform = SMoGTransform(crop_sizes=(32, 8)) sample = Image.new("RGB", (100, 100)) - output = multi_view_transform(sample) + output = helpers.assert_list_tensor(multi_view_transform(sample)) assert len(output) == 8 assert all(out.shape == (3, 32, 32) for out in output[:4]) assert all(out.shape == (3, 8, 8) for out in output[4:]) diff --git a/tests/transforms/test_swav_transform.py b/tests/transforms/test_swav_transform.py index 7c2cdd2c0..3c0707087 100644 --- a/tests/transforms/test_swav_transform.py +++ b/tests/transforms/test_swav_transform.py @@ -2,6 +2,8 @@ from lightly.transforms.swav_transform import SwaVTransform, SwaVViewTransform +from .. import helpers + def test_view_on_pil_image() -> None: single_view_transform = SwaVViewTransform() @@ -13,7 +15,7 @@ def test_view_on_pil_image() -> None: def test_multi_view_on_pil_image() -> None: multi_view_transform = SwaVTransform(crop_sizes=(32, 8)) sample = Image.new("RGB", (100, 100)) - output = multi_view_transform(sample) + output = helpers.assert_list_tensor(multi_view_transform(sample)) assert len(output) == 8 assert all(out.shape == (3, 32, 32) for out in output[:2]) assert all(out.shape == (3, 8, 8) for out in output[2:]) diff --git a/tests/transforms/test_vicreg_transform.py b/tests/transforms/test_vicreg_transform.py index 06e710f25..0aa5ea08b 100644 --- a/tests/transforms/test_vicreg_transform.py +++ b/tests/transforms/test_vicreg_transform.py @@ -2,6 +2,8 @@ from lightly.transforms.vicreg_transform import VICRegTransform, VICRegViewTransform +from .. import helpers + def test_view_on_pil_image() -> None: single_view_transform = VICRegViewTransform(input_size=32) @@ -13,7 +15,7 @@ def test_view_on_pil_image() -> None: def test_multi_view_on_pil_image() -> None: multi_view_transform = VICRegTransform(input_size=32) sample = Image.new("RGB", (100, 100)) - output = multi_view_transform(sample) + output = helpers.assert_list_tensor(multi_view_transform(sample)) assert len(output) == 2 assert output[0].shape == (3, 32, 32) assert output[1].shape == (3, 32, 32) diff --git a/tests/transforms/test_vicregl_transform.py b/tests/transforms/test_vicregl_transform.py index e697807c4..bf05a6c09 100644 --- a/tests/transforms/test_vicregl_transform.py +++ b/tests/transforms/test_vicregl_transform.py @@ -1,7 +1,12 @@ +from typing import List + from PIL import Image +from torch import Tensor from lightly.transforms.vicregl_transform import VICRegLTransform, VICRegLViewTransform +from .. import helpers + def test_view_on_pil_image() -> None: single_view_transform = VICRegLViewTransform() @@ -19,7 +24,7 @@ def test_multi_view_on_pil_image() -> None: local_grid_size=2, ) sample = Image.new("RGB", (100, 100)) - output = multi_view_transform(sample) + output = helpers.assert_list_tensor(multi_view_transform(sample)) assert len(output) == 16 # (2 global crops * 2) + (6 local crops * 2) global_views = output[:2] local_views = output[2:8]