From 47e033c466dd7a34a9aa8772193305af033f32c1 Mon Sep 17 00:00:00 2001 From: Ezra-Yu <18586273+Ezra-Yu@users.noreply.github.com> Date: Fri, 7 Apr 2023 15:38:04 +0800 Subject: [PATCH] [Fix] fix pbn bug (#1466) --- mmpretrain/engine/hooks/precise_bn_hook.py | 4 ++-- .../test_hooks/test_precise_bn_hook.py | 23 ++++++++++--------- 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/mmpretrain/engine/hooks/precise_bn_hook.py b/mmpretrain/engine/hooks/precise_bn_hook.py index 9f5f584191e..4fb0e4c419e 100644 --- a/mmpretrain/engine/hooks/precise_bn_hook.py +++ b/mmpretrain/engine/hooks/precise_bn_hook.py @@ -123,8 +123,8 @@ def update_bn_stats( prog_bar = ProgressBar(num_iter) for data in itertools.islice(loader, num_iter): - batch_inputs, data_samples = model.data_preprocessor(data, False) - model(batch_inputs, data_samples) + data = model.data_preprocessor(data, False) + model(**data) for i, bn in enumerate(bn_layers): running_means[i] += bn.running_mean / num_iter diff --git a/tests/test_engine/test_hooks/test_precise_bn_hook.py b/tests/test_engine/test_hooks/test_precise_bn_hook.py index 5a3917e545e..f549b0dbbe4 100644 --- a/tests/test_engine/test_hooks/test_precise_bn_hook.py +++ b/tests/test_engine/test_hooks/test_precise_bn_hook.py @@ -9,10 +9,11 @@ import torch import torch.nn as nn from mmengine.logging import MMLogger -from mmengine.model import BaseDataPreprocessor, BaseModel +from mmengine.model import BaseModel from mmengine.runner import Runner from torch.utils.data import DataLoader, Dataset +from mmpretrain.models.utils import ClsDataPreprocessor from mmpretrain.registry import HOOKS from mmpretrain.structures import DataSample @@ -31,12 +32,12 @@ def __len__(self): return 10 -class MockDataPreprocessor(BaseDataPreprocessor): +class MockDataPreprocessor(ClsDataPreprocessor): """mock preprocessor that do nothing.""" - def forward(self, data, training): + def forward(self, data, training=False): - return data['imgs'], DataSample() + return dict(inputs=data['imgs'], data_samples=DataSample()) class ExampleModel(BaseModel): @@ -48,9 +49,9 @@ def __init__(self): self.bn = nn.BatchNorm1d(1) self.test_cfg = None - def forward(self, batch_inputs, data_samples, mode='tensor'): - batch_inputs = batch_inputs.to(next(self.parameters()).device) - return self.bn(self.conv(batch_inputs)) + def forward(self, inputs, data_samples, mode='tensor'): + inputs = inputs.to(next(self.parameters()).device) + return self.bn(self.conv(inputs)) def train_step(self, data, optim_wrapper): outputs = {'loss': 0.5, 'num_samples': 1} @@ -64,8 +65,8 @@ def __init__(self): self.bn = nn.BatchNorm1d(1) self.test_cfg = None - def forward(self, batch_inputs, data_samples, mode='tensor'): - return self.bn(batch_inputs) + def forward(self, inputs, data_samples, mode='tensor'): + return self.bn(inputs) class GNExampleModel(ExampleModel): @@ -84,8 +85,8 @@ def __init__(self): delattr(self, 'bn') self.test_cfg = None - def forward(self, batch_inputs, data_samples, mode='tensor'): - return self.conv(batch_inputs) + def forward(self, inputs, data_samples, mode='tensor'): + return self.conv(inputs) class TestPreciseBNHookHook(TestCase):