Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Guarin lig 5512 fix local type check #1654

Merged
merged 10 commits into from
Sep 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 17 additions & 16 deletions lightly/models/batchnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,29 +56,30 @@ def forward(self, input: Tensor) -> Tensor:

# during training, use different stats for each split and otherwise
# use the stats from the first split
momentum = 0.0 if self.momentum is None else self.momentum
if self.training or not self.track_running_stats:
result = nn.functional.batch_norm(
input.view(-1, C * self.num_splits, H, W),
self.running_mean,
self.running_var,
self.weight.repeat(self.num_splits),
self.bias.repeat(self.num_splits),
True,
self.momentum,
self.eps,
input=input.view(-1, C * self.num_splits, H, W),
running_mean=self.running_mean,
running_var=self.running_var,
weight=self.weight.repeat(self.num_splits),
bias=self.bias.repeat(self.num_splits),
training=True,
momentum=momentum,
eps=self.eps,
).view(N, C, H, W)
else:
# We have to ignore the type errors here, because we know that running_mean
# and running_var are not None, but the type checker does not.
result = nn.functional.batch_norm(
input,
self.running_mean[: self.num_features], # type: ignore[index]
self.running_var[: self.num_features], # type: ignore[index]
self.weight,
self.bias,
False,
self.momentum,
self.eps,
input=input,
running_mean=self.running_mean[: self.num_features], # type: ignore[index]
running_var=self.running_var[: self.num_features], # type: ignore[index]
weight=self.weight,
bias=self.bias,
training=False,
momentum=momentum,
eps=self.eps,
)

return result
Expand Down
3 changes: 2 additions & 1 deletion lightly/transforms/dino_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,8 @@ def __init__(
T.RandomResizedCrop(
size=crop_size,
scale=crop_scale,
interpolation=PIL.Image.BICUBIC,
# Type ignore needed because BICUBIC is not recognized as an attribute.
interpolation=PIL.Image.BICUBIC, # type: ignore[attr-defined]
),
T.RandomHorizontalFlip(p=hf_prob),
T.RandomVerticalFlip(p=vf_prob),
Expand Down
13 changes: 8 additions & 5 deletions lightly/transforms/jigsaw.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) 2021. Lightly AG and its affiliates.
# All Rights Reserved

from typing import List
from typing import TYPE_CHECKING, Callable, List

import numpy as np
import torch
Expand All @@ -10,6 +10,9 @@
from torch import Tensor
from torchvision import transforms as T

if TYPE_CHECKING:
from numpy.typing import NDArray


class Jigsaw(object):
"""Implementation of Jigsaw image augmentation, inspired from PyContrast library.
Expand Down Expand Up @@ -49,7 +52,7 @@ def __init__(
self.crop_size = crop_size
self.grid_size = int(img_size / self.n_grid)
self.side = self.grid_size - self.crop_size
self.transform = transform
self.transform: Callable[[PILImage], Tensor] = transform

yy, xx = np.meshgrid(np.arange(n_grid), np.arange(n_grid))
self.yy = np.reshape(yy * self.grid_size, (n_grid * n_grid,))
Expand All @@ -66,11 +69,11 @@ def __call__(self, img: PILImage) -> Tensor:
"""
r_x = np.random.randint(0, self.side + 1, self.n_grid * self.n_grid)
r_y = np.random.randint(0, self.side + 1, self.n_grid * self.n_grid)
img = np.asarray(img, np.uint8)
crops: List[PILImage] = []
img_arr = np.asarray(img, np.uint8)
crops: List[NDArray[np.uint8]] = []
for i in range(self.n_grid * self.n_grid):
crops.append(
img[
img_arr[
self.xx[i] + r_x[i] : self.xx[i] + r_x[i] + self.crop_size,
self.yy[i] + r_y[i] : self.yy[i] + r_y[i] + self.crop_size,
:,
Expand Down
6 changes: 4 additions & 2 deletions lightly/transforms/rotation.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) 2020. Lightly AG and its affiliates.
# All Rights Reserved

from typing import Tuple, Union
from typing import Callable, Tuple, Union

import numpy as np
import torchvision.transforms as T
Expand Down Expand Up @@ -65,7 +65,9 @@ class RandomRotateDegrees:
"""

def __init__(self, prob: float, degrees: Union[float, Tuple[float, float]]):
self.transform = T.RandomApply([T.RandomRotation(degrees=degrees)], p=prob)
self.transform: Callable[
[Union[Image, Tensor]], Union[Image, Tensor]
] = T.RandomApply([T.RandomRotation(degrees=degrees)], p=prob)

def __call__(self, image: Union[Image, Tensor]) -> Union[Image, Tensor]:
"""Rotates the images with a given probability.
Expand Down
4 changes: 3 additions & 1 deletion lightly/utils/benchmarking/linear_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,9 @@ def validation_step(self, batch: Tuple[Tensor, ...], batch_idx: int) -> Tensor:
self.log_dict(log_dict, prog_bar=True, sync_dist=True, batch_size=batch_size)
return loss

def configure_optimizers(
# Type ignore is needed because return type of LightningModule.configure_optimizers
# is complicated and typing changes between versions.
def configure_optimizers( # type: ignore[override]
self,
) -> Tuple[List[Optimizer], List[Dict[str, Union[Any, str]]]]:
parameters = list(self.classification_head.parameters())
Expand Down
18 changes: 13 additions & 5 deletions lightly/utils/lars.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from typing import Any, Callable, Dict, Optional, Union
from typing import Any, Callable, Dict, Optional, overload

import torch
from torch import Tensor
from torch.optim.optimizer import Optimizer, required # type: ignore[attr-defined]
from torch.optim.optimizer import Optimizer


class LARS(Optimizer):
Expand Down Expand Up @@ -69,15 +68,15 @@ class LARS(Optimizer):
def __init__(
self,
params: Any,
lr: float = required,
lr: float,
momentum: float = 0,
dampening: float = 0,
weight_decay: float = 0,
nesterov: bool = False,
trust_coefficient: float = 0.001,
eps: float = 1e-8,
):
if lr is not required and lr < 0.0:
if lr < 0.0:
raise ValueError(f"Invalid learning rate: {lr}")
if momentum < 0.0:
raise ValueError(f"Invalid momentum value: {momentum}")
Expand All @@ -104,6 +103,15 @@ def __setstate__(self, state: Dict[str, Any]) -> None:
for group in self.param_groups:
group.setdefault("nesterov", False)

# Type ignore for overloads is required for Python 3.7
@overload # type: ignore[override]
def step(self, closure: None = None) -> None:
...

@overload # type: ignore[override]
def step(self, closure: Callable[[], float]) -> float:
...

@torch.no_grad()
def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
"""Performs a single optimization step.
Expand Down
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,6 @@ omit = ["lightly/openapi_generated/*"]

[tool.mypy]
ignore_missing_imports = true
python_version = "3.10"
warn_unused_configs = true
strict_equality = true

Expand All @@ -167,7 +166,7 @@ no_implicit_optional = true
strict_optional = true

# Configuring warnings
warn_unused_ignores = true
warn_unused_ignores = false # Different ignores are required for different Python versions
warn_no_return = true
warn_return_any = true
warn_redundant_casts = true
Expand Down
19 changes: 19 additions & 0 deletions tests/helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from typing import Any, List

from torch import Tensor


def assert_list_tensor(items: Any) -> List[Tensor]:
"""Makes sure that the input is a list of tensors.

Should be used in tests where functions return Union[List[Tensor], List[Image]] and
we want to make sure that the output is a list of tensors.

Example:
>>> output: Union[List[Tensor], List[Image]] = transform(images)
>>> tensors: List[Tensor] = assert_list_tensor(output)

"""
assert isinstance(items, list)
assert all(isinstance(item, Tensor) for item in items)
return items
Empty file added tests/transforms/__init__.py
Empty file.
4 changes: 3 additions & 1 deletion tests/transforms/test_byol_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
BYOLView2Transform,
)

from .. import helpers


def test_view_on_pil_image() -> None:
single_view_transform = BYOLView1Transform(input_size=32)
Expand All @@ -20,7 +22,7 @@ def test_multi_view_on_pil_image() -> None:
view_2_transform=BYOLView2Transform(input_size=32),
)
sample = Image.new("RGB", (100, 100))
output = multi_view_transform(sample)
output = helpers.assert_list_tensor(multi_view_transform(sample))
assert len(output) == 2
assert output[0].shape == (3, 32, 32)
assert output[1].shape == (3, 32, 32)
4 changes: 3 additions & 1 deletion tests/transforms/test_densecl_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@

from lightly.transforms import DenseCLTransform

from .. import helpers


def test_multi_view_on_pil_image() -> None:
multi_view_transform = DenseCLTransform(input_size=32)
sample = Image.new("RGB", (100, 100))
output = multi_view_transform(sample)
output = helpers.assert_list_tensor(multi_view_transform(sample))
assert len(output) == 2
assert output[0].shape == (3, 32, 32)
assert output[1].shape == (3, 32, 32)
4 changes: 3 additions & 1 deletion tests/transforms/test_dino_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from lightly.transforms.dino_transform import DINOTransform, DINOViewTransform

from .. import helpers


def test_view_on_pil_image() -> None:
single_view_transform = DINOViewTransform(crop_size=32)
Expand All @@ -13,7 +15,7 @@ def test_view_on_pil_image() -> None:
def test_multi_view_on_pil_image() -> None:
multi_view_transform = DINOTransform(global_crop_size=32, local_crop_size=8)
sample = Image.new("RGB", (100, 100))
output = multi_view_transform(sample)
output = helpers.assert_list_tensor(multi_view_transform(sample))
assert len(output) == 8
# global views
assert all(out.shape == (3, 32, 32) for out in output[:2])
Expand Down
4 changes: 3 additions & 1 deletion tests/transforms/test_fastsiam_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@

from lightly.transforms.fast_siam_transform import FastSiamTransform

from .. import helpers


def test_multi_view_on_pil_image() -> None:
multi_view_transform = FastSiamTransform(num_views=3, input_size=32)
sample = Image.new("RGB", (100, 100))
output = multi_view_transform(sample)
output = helpers.assert_list_tensor(multi_view_transform(sample))
assert len(output) == 3
assert output[0].shape == (3, 32, 32)
assert output[1].shape == (3, 32, 32)
Expand Down
6 changes: 4 additions & 2 deletions tests/transforms/test_moco_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@

from lightly.transforms.moco_transform import MoCoV1Transform, MoCoV2Transform

from .. import helpers


def test_moco_v1_multi_view_on_pil_image() -> None:
multi_view_transform = MoCoV1Transform(input_size=32)
sample = Image.new("RGB", (100, 100))
output = multi_view_transform(sample)
output = helpers.assert_list_tensor(multi_view_transform(sample))
assert len(output) == 2
assert output[0].shape == (3, 32, 32)
assert output[1].shape == (3, 32, 32)
Expand All @@ -15,7 +17,7 @@ def test_moco_v1_multi_view_on_pil_image() -> None:
def test_moco_v2_multi_view_on_pil_image() -> None:
multi_view_transform = MoCoV2Transform(input_size=32)
sample = Image.new("RGB", (100, 100))
output = multi_view_transform(sample)
output = helpers.assert_list_tensor(multi_view_transform(sample))
assert len(output) == 2
assert output[0].shape == (3, 32, 32)
assert output[1].shape == (3, 32, 32)
4 changes: 3 additions & 1 deletion tests/transforms/test_msn_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from lightly.transforms.msn_transform import MSNTransform, MSNViewTransform

from .. import helpers


def test_view_on_pil_image() -> None:
single_view_transform = MSNViewTransform(crop_size=32)
Expand All @@ -13,7 +15,7 @@ def test_view_on_pil_image() -> None:
def test_multi_view_on_pil_image() -> None:
multi_view_transform = MSNTransform(random_size=32, focal_size=8)
sample = Image.new("RGB", (100, 100))
output = multi_view_transform(sample)
output = helpers.assert_list_tensor(multi_view_transform(sample))
assert len(output) == 12
# global views
assert all(out.shape == (3, 32, 32) for out in output[:2])
Expand Down
4 changes: 3 additions & 1 deletion tests/transforms/test_pirl_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@

from lightly.transforms.pirl_transform import PIRLTransform

from .. import helpers


def test_multi_view_on_pil_image() -> None:
multi_view_transform = PIRLTransform(input_size=32)
sample = Image.new("RGB", (100, 100))
output = multi_view_transform(sample)
output = helpers.assert_list_tensor(multi_view_transform(sample))
assert len(output) == 2
assert output[0].shape == (3, 32, 32)
assert output[1].shape == (9, 3, 10, 10)
4 changes: 3 additions & 1 deletion tests/transforms/test_simclr_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from lightly.transforms.simclr_transform import SimCLRTransform, SimCLRViewTransform

from .. import helpers


def test_view_on_pil_image() -> None:
single_view_transform = SimCLRViewTransform(input_size=32)
Expand All @@ -13,7 +15,7 @@ def test_view_on_pil_image() -> None:
def test_multi_view_on_pil_image() -> None:
multi_view_transform = SimCLRTransform(input_size=32)
sample = Image.new("RGB", (100, 100))
output = multi_view_transform(sample)
output = helpers.assert_list_tensor(multi_view_transform(sample))
assert len(output) == 2
assert output[0].shape == (3, 32, 32)
assert output[1].shape == (3, 32, 32)
4 changes: 3 additions & 1 deletion tests/transforms/test_simsiam_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from lightly.transforms.simsiam_transform import SimSiamTransform, SimSiamViewTransform

from .. import helpers


def test_view_on_pil_image() -> None:
single_view_transform = SimSiamViewTransform(input_size=32)
Expand All @@ -13,7 +15,7 @@ def test_view_on_pil_image() -> None:
def test_multi_view_on_pil_image() -> None:
multi_view_transform = SimSiamTransform(input_size=32)
sample = Image.new("RGB", (100, 100))
output = multi_view_transform(sample)
output = helpers.assert_list_tensor(multi_view_transform(sample))
assert len(output) == 2
assert output[0].shape == (3, 32, 32)
assert output[1].shape == (3, 32, 32)
4 changes: 3 additions & 1 deletion tests/transforms/test_smog_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from lightly.transforms.smog_transform import SMoGTransform, SmoGViewTransform

from .. import helpers


def test_view_on_pil_image() -> None:
single_view_transform = SmoGViewTransform(crop_size=32)
Expand All @@ -13,7 +15,7 @@ def test_view_on_pil_image() -> None:
def test_multi_view_on_pil_image() -> None:
multi_view_transform = SMoGTransform(crop_sizes=(32, 8))
sample = Image.new("RGB", (100, 100))
output = multi_view_transform(sample)
output = helpers.assert_list_tensor(multi_view_transform(sample))
assert len(output) == 8
assert all(out.shape == (3, 32, 32) for out in output[:4])
assert all(out.shape == (3, 8, 8) for out in output[4:])
Loading