Skip to content

Commit

Permalink
[Fix] fix pbn bug (#1466)
Browse files Browse the repository at this point in the history
  • Loading branch information
Ezra-Yu authored Apr 7, 2023
1 parent 5ea46fb commit 47e033c
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 13 deletions.
4 changes: 2 additions & 2 deletions mmpretrain/engine/hooks/precise_bn_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,8 @@ def update_bn_stats(
prog_bar = ProgressBar(num_iter)

for data in itertools.islice(loader, num_iter):
batch_inputs, data_samples = model.data_preprocessor(data, False)
model(batch_inputs, data_samples)
data = model.data_preprocessor(data, False)
model(**data)

for i, bn in enumerate(bn_layers):
running_means[i] += bn.running_mean / num_iter
Expand Down
23 changes: 12 additions & 11 deletions tests/test_engine/test_hooks/test_precise_bn_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@
import torch
import torch.nn as nn
from mmengine.logging import MMLogger
from mmengine.model import BaseDataPreprocessor, BaseModel
from mmengine.model import BaseModel
from mmengine.runner import Runner
from torch.utils.data import DataLoader, Dataset

from mmpretrain.models.utils import ClsDataPreprocessor
from mmpretrain.registry import HOOKS
from mmpretrain.structures import DataSample

Expand All @@ -31,12 +32,12 @@ def __len__(self):
return 10


class MockDataPreprocessor(BaseDataPreprocessor):
class MockDataPreprocessor(ClsDataPreprocessor):
"""mock preprocessor that do nothing."""

def forward(self, data, training):
def forward(self, data, training=False):

return data['imgs'], DataSample()
return dict(inputs=data['imgs'], data_samples=DataSample())


class ExampleModel(BaseModel):
Expand All @@ -48,9 +49,9 @@ def __init__(self):
self.bn = nn.BatchNorm1d(1)
self.test_cfg = None

def forward(self, batch_inputs, data_samples, mode='tensor'):
batch_inputs = batch_inputs.to(next(self.parameters()).device)
return self.bn(self.conv(batch_inputs))
def forward(self, inputs, data_samples, mode='tensor'):
inputs = inputs.to(next(self.parameters()).device)
return self.bn(self.conv(inputs))

def train_step(self, data, optim_wrapper):
outputs = {'loss': 0.5, 'num_samples': 1}
Expand All @@ -64,8 +65,8 @@ def __init__(self):
self.bn = nn.BatchNorm1d(1)
self.test_cfg = None

def forward(self, batch_inputs, data_samples, mode='tensor'):
return self.bn(batch_inputs)
def forward(self, inputs, data_samples, mode='tensor'):
return self.bn(inputs)


class GNExampleModel(ExampleModel):
Expand All @@ -84,8 +85,8 @@ def __init__(self):
delattr(self, 'bn')
self.test_cfg = None

def forward(self, batch_inputs, data_samples, mode='tensor'):
return self.conv(batch_inputs)
def forward(self, inputs, data_samples, mode='tensor'):
return self.conv(inputs)


class TestPreciseBNHookHook(TestCase):
Expand Down

0 comments on commit 47e033c

Please sign in to comment.