Skip to content

Commit

Permalink
Guarin lig 5512 fix local type check (#1654)
Browse files Browse the repository at this point in the history
* Fix typing inconsistencies between Python versions
* Remove python_version mypy flag
  • Loading branch information
guarin authored Sep 25, 2024
1 parent c529521 commit ca4d467
Show file tree
Hide file tree
Showing 22 changed files with 110 additions and 46 deletions.
33 changes: 17 additions & 16 deletions lightly/models/batchnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,29 +56,30 @@ def forward(self, input: Tensor) -> Tensor:

# during training, use different stats for each split and otherwise
# use the stats from the first split
momentum = 0.0 if self.momentum is None else self.momentum
if self.training or not self.track_running_stats:
result = nn.functional.batch_norm(
input.view(-1, C * self.num_splits, H, W),
self.running_mean,
self.running_var,
self.weight.repeat(self.num_splits),
self.bias.repeat(self.num_splits),
True,
self.momentum,
self.eps,
input=input.view(-1, C * self.num_splits, H, W),
running_mean=self.running_mean,
running_var=self.running_var,
weight=self.weight.repeat(self.num_splits),
bias=self.bias.repeat(self.num_splits),
training=True,
momentum=momentum,
eps=self.eps,
).view(N, C, H, W)
else:
# We have to ignore the type errors here, because we know that running_mean
# and running_var are not None, but the type checker does not.
result = nn.functional.batch_norm(
input,
self.running_mean[: self.num_features], # type: ignore[index]
self.running_var[: self.num_features], # type: ignore[index]
self.weight,
self.bias,
False,
self.momentum,
self.eps,
input=input,
running_mean=self.running_mean[: self.num_features], # type: ignore[index]
running_var=self.running_var[: self.num_features], # type: ignore[index]
weight=self.weight,
bias=self.bias,
training=False,
momentum=momentum,
eps=self.eps,
)

return result
Expand Down
3 changes: 2 additions & 1 deletion lightly/transforms/dino_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,8 @@ def __init__(
T.RandomResizedCrop(
size=crop_size,
scale=crop_scale,
interpolation=PIL.Image.BICUBIC,
# Type ignore needed because BICUBIC is not recognized as an attribute.
interpolation=PIL.Image.BICUBIC, # type: ignore[attr-defined]
),
T.RandomHorizontalFlip(p=hf_prob),
T.RandomVerticalFlip(p=vf_prob),
Expand Down
13 changes: 8 additions & 5 deletions lightly/transforms/jigsaw.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) 2021. Lightly AG and its affiliates.
# All Rights Reserved

from typing import List
from typing import TYPE_CHECKING, Callable, List

import numpy as np
import torch
Expand All @@ -10,6 +10,9 @@
from torch import Tensor
from torchvision import transforms as T

if TYPE_CHECKING:
from numpy.typing import NDArray


class Jigsaw(object):
"""Implementation of Jigsaw image augmentation, inspired from PyContrast library.
Expand Down Expand Up @@ -49,7 +52,7 @@ def __init__(
self.crop_size = crop_size
self.grid_size = int(img_size / self.n_grid)
self.side = self.grid_size - self.crop_size
self.transform = transform
self.transform: Callable[[PILImage], Tensor] = transform

yy, xx = np.meshgrid(np.arange(n_grid), np.arange(n_grid))
self.yy = np.reshape(yy * self.grid_size, (n_grid * n_grid,))
Expand All @@ -66,11 +69,11 @@ def __call__(self, img: PILImage) -> Tensor:
"""
r_x = np.random.randint(0, self.side + 1, self.n_grid * self.n_grid)
r_y = np.random.randint(0, self.side + 1, self.n_grid * self.n_grid)
img = np.asarray(img, np.uint8)
crops: List[PILImage] = []
img_arr = np.asarray(img, np.uint8)
crops: List[NDArray[np.uint8]] = []
for i in range(self.n_grid * self.n_grid):
crops.append(
img[
img_arr[
self.xx[i] + r_x[i] : self.xx[i] + r_x[i] + self.crop_size,
self.yy[i] + r_y[i] : self.yy[i] + r_y[i] + self.crop_size,
:,
Expand Down
6 changes: 4 additions & 2 deletions lightly/transforms/rotation.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) 2020. Lightly AG and its affiliates.
# All Rights Reserved

from typing import Tuple, Union
from typing import Callable, Tuple, Union

import numpy as np
import torchvision.transforms as T
Expand Down Expand Up @@ -65,7 +65,9 @@ class RandomRotateDegrees:
"""

def __init__(self, prob: float, degrees: Union[float, Tuple[float, float]]):
self.transform = T.RandomApply([T.RandomRotation(degrees=degrees)], p=prob)
self.transform: Callable[
[Union[Image, Tensor]], Union[Image, Tensor]
] = T.RandomApply([T.RandomRotation(degrees=degrees)], p=prob)

def __call__(self, image: Union[Image, Tensor]) -> Union[Image, Tensor]:
"""Rotates the images with a given probability.
Expand Down
4 changes: 3 additions & 1 deletion lightly/utils/benchmarking/linear_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,9 @@ def validation_step(self, batch: Tuple[Tensor, ...], batch_idx: int) -> Tensor:
self.log_dict(log_dict, prog_bar=True, sync_dist=True, batch_size=batch_size)
return loss

def configure_optimizers(
# Type ignore is needed because return type of LightningModule.configure_optimizers
# is complicated and typing changes between versions.
def configure_optimizers( # type: ignore[override]
self,
) -> Tuple[List[Optimizer], List[Dict[str, Union[Any, str]]]]:
parameters = list(self.classification_head.parameters())
Expand Down
18 changes: 13 additions & 5 deletions lightly/utils/lars.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from typing import Any, Callable, Dict, Optional, Union
from typing import Any, Callable, Dict, Optional, overload

import torch
from torch import Tensor
from torch.optim.optimizer import Optimizer, required # type: ignore[attr-defined]
from torch.optim.optimizer import Optimizer


class LARS(Optimizer):
Expand Down Expand Up @@ -69,15 +68,15 @@ class LARS(Optimizer):
def __init__(
self,
params: Any,
lr: float = required,
lr: float,
momentum: float = 0,
dampening: float = 0,
weight_decay: float = 0,
nesterov: bool = False,
trust_coefficient: float = 0.001,
eps: float = 1e-8,
):
if lr is not required and lr < 0.0:
if lr < 0.0:
raise ValueError(f"Invalid learning rate: {lr}")
if momentum < 0.0:
raise ValueError(f"Invalid momentum value: {momentum}")
Expand All @@ -104,6 +103,15 @@ def __setstate__(self, state: Dict[str, Any]) -> None:
for group in self.param_groups:
group.setdefault("nesterov", False)

# Type ignore for overloads is required for Python 3.7
@overload # type: ignore[override]
def step(self, closure: None = None) -> None:
...

@overload # type: ignore[override]
def step(self, closure: Callable[[], float]) -> float:
...

@torch.no_grad()
def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
"""Performs a single optimization step.
Expand Down
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,6 @@ omit = ["lightly/openapi_generated/*"]

[tool.mypy]
ignore_missing_imports = true
python_version = "3.10"
warn_unused_configs = true
strict_equality = true

Expand All @@ -167,7 +166,7 @@ no_implicit_optional = true
strict_optional = true

# Configuring warnings
warn_unused_ignores = true
warn_unused_ignores = false # Different ignores are required for different Python versions
warn_no_return = true
warn_return_any = true
warn_redundant_casts = true
Expand Down
19 changes: 19 additions & 0 deletions tests/helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from typing import Any, List

from torch import Tensor


def assert_list_tensor(items: Any) -> List[Tensor]:
"""Makes sure that the input is a list of tensors.
Should be used in tests where functions return Union[List[Tensor], List[Image]] and
we want to make sure that the output is a list of tensors.
Example:
>>> output: Union[List[Tensor], List[Image]] = transform(images)
>>> tensors: List[Tensor] = assert_list_tensor(output)
"""
assert isinstance(items, list)
assert all(isinstance(item, Tensor) for item in items)
return items
Empty file added tests/transforms/__init__.py
Empty file.
4 changes: 3 additions & 1 deletion tests/transforms/test_byol_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
BYOLView2Transform,
)

from .. import helpers


def test_view_on_pil_image() -> None:
single_view_transform = BYOLView1Transform(input_size=32)
Expand All @@ -20,7 +22,7 @@ def test_multi_view_on_pil_image() -> None:
view_2_transform=BYOLView2Transform(input_size=32),
)
sample = Image.new("RGB", (100, 100))
output = multi_view_transform(sample)
output = helpers.assert_list_tensor(multi_view_transform(sample))
assert len(output) == 2
assert output[0].shape == (3, 32, 32)
assert output[1].shape == (3, 32, 32)
4 changes: 3 additions & 1 deletion tests/transforms/test_densecl_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@

from lightly.transforms import DenseCLTransform

from .. import helpers


def test_multi_view_on_pil_image() -> None:
multi_view_transform = DenseCLTransform(input_size=32)
sample = Image.new("RGB", (100, 100))
output = multi_view_transform(sample)
output = helpers.assert_list_tensor(multi_view_transform(sample))
assert len(output) == 2
assert output[0].shape == (3, 32, 32)
assert output[1].shape == (3, 32, 32)
4 changes: 3 additions & 1 deletion tests/transforms/test_dino_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from lightly.transforms.dino_transform import DINOTransform, DINOViewTransform

from .. import helpers


def test_view_on_pil_image() -> None:
single_view_transform = DINOViewTransform(crop_size=32)
Expand All @@ -13,7 +15,7 @@ def test_view_on_pil_image() -> None:
def test_multi_view_on_pil_image() -> None:
multi_view_transform = DINOTransform(global_crop_size=32, local_crop_size=8)
sample = Image.new("RGB", (100, 100))
output = multi_view_transform(sample)
output = helpers.assert_list_tensor(multi_view_transform(sample))
assert len(output) == 8
# global views
assert all(out.shape == (3, 32, 32) for out in output[:2])
Expand Down
4 changes: 3 additions & 1 deletion tests/transforms/test_fastsiam_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@

from lightly.transforms.fast_siam_transform import FastSiamTransform

from .. import helpers


def test_multi_view_on_pil_image() -> None:
multi_view_transform = FastSiamTransform(num_views=3, input_size=32)
sample = Image.new("RGB", (100, 100))
output = multi_view_transform(sample)
output = helpers.assert_list_tensor(multi_view_transform(sample))
assert len(output) == 3
assert output[0].shape == (3, 32, 32)
assert output[1].shape == (3, 32, 32)
Expand Down
6 changes: 4 additions & 2 deletions tests/transforms/test_moco_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@

from lightly.transforms.moco_transform import MoCoV1Transform, MoCoV2Transform

from .. import helpers


def test_moco_v1_multi_view_on_pil_image() -> None:
multi_view_transform = MoCoV1Transform(input_size=32)
sample = Image.new("RGB", (100, 100))
output = multi_view_transform(sample)
output = helpers.assert_list_tensor(multi_view_transform(sample))
assert len(output) == 2
assert output[0].shape == (3, 32, 32)
assert output[1].shape == (3, 32, 32)
Expand All @@ -15,7 +17,7 @@ def test_moco_v1_multi_view_on_pil_image() -> None:
def test_moco_v2_multi_view_on_pil_image() -> None:
multi_view_transform = MoCoV2Transform(input_size=32)
sample = Image.new("RGB", (100, 100))
output = multi_view_transform(sample)
output = helpers.assert_list_tensor(multi_view_transform(sample))
assert len(output) == 2
assert output[0].shape == (3, 32, 32)
assert output[1].shape == (3, 32, 32)
4 changes: 3 additions & 1 deletion tests/transforms/test_msn_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from lightly.transforms.msn_transform import MSNTransform, MSNViewTransform

from .. import helpers


def test_view_on_pil_image() -> None:
single_view_transform = MSNViewTransform(crop_size=32)
Expand All @@ -13,7 +15,7 @@ def test_view_on_pil_image() -> None:
def test_multi_view_on_pil_image() -> None:
multi_view_transform = MSNTransform(random_size=32, focal_size=8)
sample = Image.new("RGB", (100, 100))
output = multi_view_transform(sample)
output = helpers.assert_list_tensor(multi_view_transform(sample))
assert len(output) == 12
# global views
assert all(out.shape == (3, 32, 32) for out in output[:2])
Expand Down
4 changes: 3 additions & 1 deletion tests/transforms/test_pirl_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@

from lightly.transforms.pirl_transform import PIRLTransform

from .. import helpers


def test_multi_view_on_pil_image() -> None:
multi_view_transform = PIRLTransform(input_size=32)
sample = Image.new("RGB", (100, 100))
output = multi_view_transform(sample)
output = helpers.assert_list_tensor(multi_view_transform(sample))
assert len(output) == 2
assert output[0].shape == (3, 32, 32)
assert output[1].shape == (9, 3, 10, 10)
4 changes: 3 additions & 1 deletion tests/transforms/test_simclr_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from lightly.transforms.simclr_transform import SimCLRTransform, SimCLRViewTransform

from .. import helpers


def test_view_on_pil_image() -> None:
single_view_transform = SimCLRViewTransform(input_size=32)
Expand All @@ -13,7 +15,7 @@ def test_view_on_pil_image() -> None:
def test_multi_view_on_pil_image() -> None:
multi_view_transform = SimCLRTransform(input_size=32)
sample = Image.new("RGB", (100, 100))
output = multi_view_transform(sample)
output = helpers.assert_list_tensor(multi_view_transform(sample))
assert len(output) == 2
assert output[0].shape == (3, 32, 32)
assert output[1].shape == (3, 32, 32)
4 changes: 3 additions & 1 deletion tests/transforms/test_simsiam_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from lightly.transforms.simsiam_transform import SimSiamTransform, SimSiamViewTransform

from .. import helpers


def test_view_on_pil_image() -> None:
single_view_transform = SimSiamViewTransform(input_size=32)
Expand All @@ -13,7 +15,7 @@ def test_view_on_pil_image() -> None:
def test_multi_view_on_pil_image() -> None:
multi_view_transform = SimSiamTransform(input_size=32)
sample = Image.new("RGB", (100, 100))
output = multi_view_transform(sample)
output = helpers.assert_list_tensor(multi_view_transform(sample))
assert len(output) == 2
assert output[0].shape == (3, 32, 32)
assert output[1].shape == (3, 32, 32)
4 changes: 3 additions & 1 deletion tests/transforms/test_smog_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from lightly.transforms.smog_transform import SMoGTransform, SmoGViewTransform

from .. import helpers


def test_view_on_pil_image() -> None:
single_view_transform = SmoGViewTransform(crop_size=32)
Expand All @@ -13,7 +15,7 @@ def test_view_on_pil_image() -> None:
def test_multi_view_on_pil_image() -> None:
multi_view_transform = SMoGTransform(crop_sizes=(32, 8))
sample = Image.new("RGB", (100, 100))
output = multi_view_transform(sample)
output = helpers.assert_list_tensor(multi_view_transform(sample))
assert len(output) == 8
assert all(out.shape == (3, 32, 32) for out in output[:4])
assert all(out.shape == (3, 8, 8) for out in output[4:])
Loading

0 comments on commit ca4d467

Please sign in to comment.