Skip to content

Commit

Permalink
Merge branch 'update_xpu_code' of https://github.com/openvinotoolkit/…
Browse files Browse the repository at this point in the history
…training_extensions into update_xpu_code
  • Loading branch information
kprokofi committed Nov 13, 2024
2 parents 702056b + 7be2712 commit e7fb947
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 65 deletions.
2 changes: 1 addition & 1 deletion src/otx/algo/instance_segmentation/heads/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,4 @@
"MaskDINODecoderHeadModule",
"MaskDINOEncoderHeadModule",
"MaskDINOHead",
]
]
123 changes: 61 additions & 62 deletions src/otx/algo/instance_segmentation/heads/maskdino_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,11 @@

from __future__ import annotations

from typing import Callable
from typing import Any, Callable, ClassVar

import numpy as np
import torch
from torch import Tensor, nn
from torch.amp import autocast
from torch.nn import functional as f
from torch.nn.init import normal_

Expand Down Expand Up @@ -310,66 +309,66 @@ def __init__(

def forward_features(self, features: dict[str, Tensor]) -> tuple[Tensor, Tensor, list[Tensor]]:
"""Forward pass of the encoder."""
with autocast(device_type=features[self.transformer_in_features[0]].device.type, enabled=False):
# backbone features
srcs = []
pos = []
# additional downsampled features
srcsl: list[Tensor] = []
posl = []
if self.total_num_feature_levels > self.transformer_num_feature_levels:
smallest_feat = features[self.transformer_in_features[self.low_resolution_index]].float()
_len_srcs = self.transformer_num_feature_levels
for lvl in range(_len_srcs, self.total_num_feature_levels):
src = self.input_proj[lvl](smallest_feat) if lvl == _len_srcs else self.input_proj[lvl](srcsl[-1])
srcsl.append(src)
posl.append(self.pe_layer(src))
srcsl = srcsl[::-1]
# Reverse feature maps
for idx, feat in enumerate(self.transformer_in_features[::-1]):
x = features[feat].float() # deformable detr does not support half precision
srcs.append(self.input_proj[idx](x))
pos.append(self.pe_layer(x))
srcs.extend(srcsl)
pos.extend(posl)
y, spatial_shapes, level_start_index = self.transformer(srcs, pos)
bs = y.shape[0]

split_size_or_sections = [None] * self.total_num_feature_levels
for i in range(self.total_num_feature_levels):
if i < self.total_num_feature_levels - 1:
split_size_or_sections[i] = level_start_index[i + 1] - level_start_index[i]
else:
split_size_or_sections[i] = y.shape[1] - level_start_index[i]
y = torch.split(y, split_size_or_sections, dim=1)

out = []
multi_scale_features = []
num_cur_levels = 0
for i, z in enumerate(y):
out.append(z.transpose(1, 2).view(bs, -1, spatial_shapes[i][0], spatial_shapes[i][1]))

# append `out` with extra FPN levels
# Reverse feature maps into top-down order (from low to high resolution)
for idx, feat in enumerate(self.in_features[: self.num_fpn_levels][::-1]):
x = features[feat].float()
lateral_conv = self.lateral_convs[idx]
output_conv = self.output_convs[idx]
cur_fpn = lateral_conv(x)
# Following FPN implementation, we use nearest upsampling here
y = cur_fpn + f.interpolate(
out[self.high_resolution_index],
size=cur_fpn.shape[-2:],
mode="bilinear",
align_corners=False,
)
y = output_conv(y)
out.append(y)
for o in out:
if num_cur_levels < self.total_num_feature_levels:
multi_scale_features.append(o)
num_cur_levels += 1
return self.mask_features(out[-1]), out[0], multi_scale_features
# backbone features
srcs = []
pos = []
# additional downsampled features
srcsl: list[Tensor] = []
posl = []
if self.total_num_feature_levels > self.transformer_num_feature_levels:
smallest_feat = features[self.transformer_in_features[self.low_resolution_index]].float()
_len_srcs = self.transformer_num_feature_levels
for lvl in range(_len_srcs, self.total_num_feature_levels):
src = self.input_proj[lvl](smallest_feat) if lvl == _len_srcs else self.input_proj[lvl](srcsl[-1])
srcsl.append(src)
posl.append(self.pe_layer(src))
srcsl = srcsl[::-1]
# Reverse feature maps
for idx, feat in enumerate(self.transformer_in_features[::-1]):
x = features[feat].float() # deformable detr does not support half precision
srcs.append(self.input_proj[idx](x))
pos.append(self.pe_layer(x))
srcs.extend(srcsl)
pos.extend(posl)
y, spatial_shapes, level_start_index = self.transformer(srcs, pos)
bs = y.shape[0]

split_size_or_sections = [None] * self.total_num_feature_levels
for i in range(self.total_num_feature_levels):
if i < self.total_num_feature_levels - 1:
split_size_or_sections[i] = level_start_index[i + 1] - level_start_index[i]
else:
split_size_or_sections[i] = y.shape[1] - level_start_index[i]
y = torch.split(y, split_size_or_sections, dim=1)

out = []
multi_scale_features = []
num_cur_levels = 0
for i, z in enumerate(y):
out.append(z.transpose(1, 2).view(bs, -1, spatial_shapes[i][0], spatial_shapes[i][1]))

# append `out` with extra FPN levels
# Reverse feature maps into top-down order (from low to high resolution)
for idx, feat in enumerate(self.in_features[: self.num_fpn_levels][::-1]):
x = features[feat].float()
lateral_conv = self.lateral_convs[idx]
output_conv = self.output_convs[idx]
cur_fpn = lateral_conv(x)
# Following FPN implementation, we use nearest upsampling here
y = cur_fpn + f.interpolate(
out[self.high_resolution_index],
size=cur_fpn.shape[-2:],
mode="bilinear",
align_corners=False,
)
y = output_conv(y)
out.append(y)
for o in out:
if num_cur_levels < self.total_num_feature_levels:
multi_scale_features.append(o)
num_cur_levels += 1
return self.mask_features(out[-1]), out[0], multi_scale_features


class MaskDINOEncoderHead:
"""MaskDINO Encoder Head Factory Selector."""
Expand Down
2 changes: 0 additions & 2 deletions src/otx/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1147,8 +1147,6 @@ def _build_trainer(self, **kwargs) -> None:
self._cache.update(strategy="xpu_single")
# add plugin for Automatic Mixed Precision on XPU
if self._cache.args.get("precision", 32) == 16:
msg = "XPU doesn't support fp16 now, so bfp16 will be used instead."
warn(msg, stacklevel=1)
self._cache.update(
plugins=[
MixedPrecision(
Expand Down

0 comments on commit e7fb947

Please sign in to comment.