Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Dinov2 with Registers] Some fixes #35411

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from ...configuration_utils import PretrainedConfig
from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices

Expand Down Expand Up @@ -69,10 +70,6 @@ class Dinov2WithRegistersConfig(BackboneConfigMixin, PretrainedConfig):
Whether to use the SwiGLU feedforward neural network.
num_register_tokens (`int`, *optional*, defaults to 4):
Number of register tokens to use.
interpolate_antialias (`bool`, *optional*, defaults to `True`):
Whether to use antialiasing when interpolating the image patches.
interpolate_offset (`float`, *optional*, defaults to 0.0):
Offset to use when interpolating the image patches.
out_features (`List[str]`, *optional*):
If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
(depending on how many stages the model has). If unset and `out_indices` is set, will default to the
Expand Down Expand Up @@ -105,7 +102,7 @@ class Dinov2WithRegistersConfig(BackboneConfigMixin, PretrainedConfig):
>>> configuration = model.config
```"""

model_type = "dinov2-with-registers-base"
model_type = "dinov2_with_registers"

def __init__(
self,
Expand All @@ -126,8 +123,6 @@ def __init__(
drop_path_rate=0.0,
use_swiglu_ffn=False,
num_register_tokens=4,
interpolate_antialias=True,
interpolate_offset=0.0,
out_features=None,
out_indices=None,
apply_layernorm=True,
Expand All @@ -153,8 +148,6 @@ def __init__(
self.drop_path_rate = drop_path_rate
self.use_swiglu_ffn = use_swiglu_ffn
self.num_register_tokens = num_register_tokens
self.interpolate_antialias = interpolate_antialias
self.interpolate_offset = interpolate_offset
self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, num_hidden_layers + 1)]
self._out_features, self._out_indices = get_aligned_output_features_output_indices(
out_features=out_features, out_indices=out_indices, stage_names=self.stage_names
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import collections.abc
import math
from typing import Dict, List, Optional, Set, Tuple, Union
Expand All @@ -37,6 +38,7 @@
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
torch_int,
)
from ...utils.backbone_utils import BackboneMixin
from .configuration_dinov2_with_registers import Dinov2WithRegistersConfig
Expand Down Expand Up @@ -99,43 +101,62 @@ def __init__(self, config: Dinov2WithRegistersConfig) -> None:
num_patches = self.patch_embeddings.num_patches
self.position_embeddings = nn.Parameter(torch.randn(1, num_patches + 1, config.hidden_size))
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.patch_size = config.patch_size
self.config = config

def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
"""
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
resolution images.
resolution images. This implementation supports torch.jit tracing while maintaining backwards compatibility
with the original implementation.

Source:
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
Adapted from:
- https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
- https://github.com/facebookresearch/dinov2/blob/main/dinov2/models/vision_transformer.py
"""

num_patches = embeddings.shape[1] - 1
num_positions = self.position_embeddings.shape[1] - 1
if num_patches == num_positions and height == width:

# Skip interpolation for matching dimensions (unless tracing)
if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
return self.position_embeddings

# Handle class token and patch embeddings separately
class_pos_embed = self.position_embeddings[:, 0]
patch_pos_embed = self.position_embeddings[:, 1:]
dim = embeddings.shape[-1]

# Calculate new dimensions
height = height // self.config.patch_size
width = width // self.config.patch_size
# we add a small number to avoid floating point error in the interpolation
# see discussion at https://github.com/facebookresearch/dino/issues/8
height, width = height + self.config.interpolate_offset, width + self.config.interpolate_offset
patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)

# Reshape for interpolation
sqrt_num_positions = torch_int(num_positions**0.5)
patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)

# Store original dtype for restoration after interpolation
target_dtype = patch_pos_embed.dtype

# Interpolate at float32 precision
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed.to(dtype=torch.float32),
scale_factor=(float(height / math.sqrt(num_positions)), float(width / math.sqrt(num_positions))),
size=(torch_int(height), torch_int(width)), # Explicit size instead of scale_factor
mode="bicubic",
align_corners=False,
antialias=self.config.interpolate_antialias,
)
patch_pos_embed = patch_pos_embed.to(dtype=target_dtype)
if int(height) != patch_pos_embed.shape[-2] or int(width) != patch_pos_embed.shape[-1]:
raise ValueError("Width or height does not match with the interpolated position embeddings")
# Maintain backward compatibility with antialias if configured
antialias=getattr(self.config, "interpolate_antialias", False),
).to(dtype=target_dtype)

# Validate output dimensions if not tracing
if not torch.jit.is_tracing():
if int(height) != patch_pos_embed.shape[-2] or int(width) != patch_pos_embed.shape[-1]:
raise ValueError("Width or height does not match with the interpolated position embeddings")

# Reshape back to original format
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)

# Combine class and patch embeddings
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)

def forward(self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.Tensor] = None) -> torch.Tensor:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math

from typing import Optional

import torch
Expand All @@ -30,7 +30,7 @@
)
from ...configuration_utils import PretrainedConfig
from ...modeling_outputs import BackboneOutput
from ...utils import logging
from ...utils import logging, torch_int
from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices


Expand Down Expand Up @@ -83,10 +83,6 @@ class Dinov2WithRegistersConfig(BackboneConfigMixin, PretrainedConfig):
Whether to use the SwiGLU feedforward neural network.
num_register_tokens (`int`, *optional*, defaults to 4):
Number of register tokens to use.
interpolate_antialias (`bool`, *optional*, defaults to `True`):
Whether to use antialiasing when interpolating the image patches.
interpolate_offset (`float`, *optional*, defaults to 0.0):
Offset to use when interpolating the image patches.
out_features (`List[str]`, *optional*):
If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
(depending on how many stages the model has). If unset and `out_indices` is set, will default to the
Expand Down Expand Up @@ -119,7 +115,7 @@ class Dinov2WithRegistersConfig(BackboneConfigMixin, PretrainedConfig):
>>> configuration = model.config
```"""

model_type = "dinov2-with-registers-base"
model_type = "dinov2_with_registers"

def __init__(
self,
Expand All @@ -140,8 +136,6 @@ def __init__(
drop_path_rate=0.0,
use_swiglu_ffn=False,
num_register_tokens=4,
interpolate_antialias=True,
interpolate_offset=0.0,
out_features=None,
out_indices=None,
apply_layernorm=True,
Expand All @@ -167,8 +161,6 @@ def __init__(
self.drop_path_rate = drop_path_rate
self.use_swiglu_ffn = use_swiglu_ffn
self.num_register_tokens = num_register_tokens
self.interpolate_antialias = interpolate_antialias
self.interpolate_offset = interpolate_offset
self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, num_hidden_layers + 1)]
self._out_features, self._out_indices = get_aligned_output_features_output_indices(
out_features=out_features, out_indices=out_indices, stage_names=self.stage_names
Expand Down Expand Up @@ -196,43 +188,62 @@ def __init__(self, config: Dinov2WithRegistersConfig) -> None:
num_patches = self.patch_embeddings.num_patches
self.position_embeddings = nn.Parameter(torch.randn(1, num_patches + 1, config.hidden_size))
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.patch_size = config.patch_size
self.config = config

def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
"""
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
resolution images.
resolution images. This implementation supports torch.jit tracing while maintaining backwards compatibility
with the original implementation.

Source:
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
Adapted from:
- https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
- https://github.com/facebookresearch/dinov2/blob/main/dinov2/models/vision_transformer.py
"""

num_patches = embeddings.shape[1] - 1
num_positions = self.position_embeddings.shape[1] - 1
if num_patches == num_positions and height == width:

# Skip interpolation for matching dimensions (unless tracing)
if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
return self.position_embeddings

# Handle class token and patch embeddings separately
class_pos_embed = self.position_embeddings[:, 0]
patch_pos_embed = self.position_embeddings[:, 1:]
dim = embeddings.shape[-1]

# Calculate new dimensions
height = height // self.config.patch_size
width = width // self.config.patch_size
# we add a small number to avoid floating point error in the interpolation
# see discussion at https://github.com/facebookresearch/dino/issues/8
height, width = height + self.config.interpolate_offset, width + self.config.interpolate_offset
patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)

# Reshape for interpolation
sqrt_num_positions = torch_int(num_positions**0.5)
patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)

# Store original dtype for restoration after interpolation
target_dtype = patch_pos_embed.dtype

# Interpolate at float32 precision
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed.to(dtype=torch.float32),
scale_factor=(float(height / math.sqrt(num_positions)), float(width / math.sqrt(num_positions))),
size=(torch_int(height), torch_int(width)), # Explicit size instead of scale_factor
mode="bicubic",
align_corners=False,
antialias=self.config.interpolate_antialias,
)
patch_pos_embed = patch_pos_embed.to(dtype=target_dtype)
if int(height) != patch_pos_embed.shape[-2] or int(width) != patch_pos_embed.shape[-1]:
raise ValueError("Width or height does not match with the interpolated position embeddings")
# Maintain backward compatibility with antialias if configured
antialias=getattr(self.config, "interpolate_antialias", False),
).to(dtype=target_dtype)

# Validate output dimensions if not tracing
if not torch.jit.is_tracing():
if int(height) != patch_pos_embed.shape[-2] or int(width) != patch_pos_embed.shape[-1]:
raise ValueError("Width or height does not match with the interpolated position embeddings")

# Reshape back to original format
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)

# Combine class and patch embeddings
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)

def forward(self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.Tensor] = None) -> torch.Tensor:
Expand Down