Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support download inception state from url #233

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@
pass_training_status=True)

# Note set your inception_pkl's path
inception_pkl = 'work_dirs/inception_pkl/imagenet.pkl'
inception_pkl = ('https://download.openmmlab.com/mmgen/evaluation/'
'fid_inception_pkl/imagenet.pkl')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

may rename to 'imagenet_128x128.pkl'

evaluation = dict(
type='GenerativeEvalHook',
interval=10000,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@
pass_training_status=True)

# Note set your inception_pkl's path
inception_pkl = 'work_dirs/inception_pkl/imagenet.pkl'
inception_pkl = ('https://download.openmmlab.com/mmgen/evaluation/'
'fid_inception_pkl/imagenet.pkl')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the same as above

evaluation = dict(
type='GenerativeEvalHook',
interval=10000,
Expand Down
2 changes: 1 addition & 1 deletion configs/improved_ddpm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ Denoising diffusion probabilistic models (DDPM) are a class of generative models

For FID evaluation, we follow the pipeline of [BigGAN](https://github.com/ajbrock/BigGAN-PyTorch/blob/98459431a5d618d644d54cd1e9fceb1e5045648d/calculate_inception_moments.py#L52), where the whole training set is adopted to extract inception statistics, and Pytorch Studio GAN uses 50000 randomly selected samples. Besides, we also use [Tero's Inception](https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt) for feature extraction.

You can download the preprocessed inception state by the following url: [CIFAR10](https://download.openmmlab.com/mmgen/evaluation/fid_inception_pkl/cifar10.pkl) and [ImageNet1k-64x64](https://download.openmmlab.com/mmgen/evaluation/fid_inception_pkl/imagenet_64x64.pkl).
MMGen will automatically download the preprocessed inception state by the following url: [CIFAR10](https://download.openmmlab.com/mmgen/evaluation/fid_inception_pkl/cifar10.pkl) and [ImageNet1k-64x64](https://download.openmmlab.com/mmgen/evaluation/fid_inception_pkl/imagenet_64x64.pkl).

You can use following commands to extract those inception states by yourself.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@
is_dynamic_ddp=False, # Note that this flag should be False.
pass_training_status=True)

inception_pkl = './work_dirs/inception_pkl/cifar10.pkl'
inception_pkl = ('https://download.openmmlab.com/mmgen/evaluation/'
'fid_inception_pkl/cifar10.pkl')
metrics = dict(
fid50k=dict(
type='FID',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@
is_dynamic_ddp=False, # Note that this flag should be False.
pass_training_status=True)

inception_pkl = './work_dirs/inception_pkl/imagenet_64x64.pkl'
inception_pkl = ('https://download.openmmlab.com/mmgen/evaluation/'
'fid_inception_pkl/imagenet_64x64.pkl')
metrics = dict(
fid50k=dict(
type='FID',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@
is_dynamic_ddp=False, # Note that this flag should be False.
pass_training_status=True)

inception_pkl = './work_dirs/inception_pkl/imagenet_64x64.pkl'
inception_pkl = ('https://download.openmmlab.com/mmgen/evaluation/'
'fid_inception_pkl/imagenet_64x64.pkl')
metrics = dict(
fid50k=dict(
type='FID',
Expand Down
2 changes: 1 addition & 1 deletion configs/sagan/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ For IS metric, our implementation is different from PyTorch-Studio GAN in the fo

For FID evaluation, we follow the pipeline of [BigGAN](https://github.com/ajbrock/BigGAN-PyTorch/blob/98459431a5d618d644d54cd1e9fceb1e5045648d/calculate_inception_moments.py#L52), where the whole training set is adopted to extract inception statistics, and Pytorch Studio GAN uses 50000 randomly selected samples. Besides, we also use [Tero's Inception](https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt) for feature extraction.

You can download the preprocessed inception state by the following url: [CIFAR10](https://download.openmmlab.com/mmgen/evaluation/fid_inception_pkl/cifar10.pkl) and [ImageNet1k](https://download.openmmlab.com/mmgen/evaluation/fid_inception_pkl/imagenet.pkl).
MMGen will automatically download the preprocessed inception state by the following url: [CIFAR10](https://download.openmmlab.com/mmgen/evaluation/fid_inception_pkl/cifar10.pkl) and [ImageNet1k](https://download.openmmlab.com/mmgen/evaluation/fid_inception_pkl/imagenet.pkl).

You can use following commands to extract those inception states by yourself.
```
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
interval=1000)
]

inception_pkl = './work_dirs/inception_pkl/imagenet.pkl'
inception_pkl = ('https://download.openmmlab.com/mmgen/evaluation/'
'fid_inception_pkl/imagenet.pkl')

evaluation = dict(
type='GenerativeEvalHook',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@
priority='VERY_HIGH')
]

inception_pkl = './work_dirs/inception_pkl/imagenet.pkl'
inception_pkl = ('https://download.openmmlab.com/mmgen/evaluation/'
'fid_inception_pkl/imagenet.pkl')

evaluation = dict(
type='GenerativeEvalHook',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
interval=1000)
]

inception_pkl = './work_dirs/inception_pkl/cifar10.pkl'
inception_pkl = ('https://download.openmmlab.com/mmgen/evaluation/'
'fid_inception_pkl/cifar10.pkl')

evaluation = dict(
type='GenerativeEvalHook',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
interval=1000)
]

inception_pkl = './work_dirs/inception_pkl/cifar10.pkl'
inception_pkl = ('https://download.openmmlab.com/mmgen/evaluation/'
'fid_inception_pkl/cifar10.pkl')

evaluation = dict(
type='GenerativeEvalHook',
Expand Down
2 changes: 1 addition & 1 deletion configs/sngan_proj/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ For IS metric, our implementation is different from PyTorch-Studio GAN in the fo

For FID evaluation, we follow the pipeline of [BigGAN](https://github.com/ajbrock/BigGAN-PyTorch/blob/98459431a5d618d644d54cd1e9fceb1e5045648d/calculate_inception_moments.py#L52), where the whole training set is adopted to extract inception statistics, and Pytorch Studio GAN uses 50000 randomly selected samples. Besides, we also use [Tero's Inception](https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt) for feature extraction.

You can download the preprocessed inception state by the following url: [CIFAR10](https://download.openmmlab.com/mmgen/evaluation/fid_inception_pkl/cifar10.pkl) and [ImageNet1k](https://download.openmmlab.com/mmgen/evaluation/fid_inception_pkl/imagenet.pkl).
MMGen will automatically download the preprocessed inception state by the following url: [CIFAR10](https://download.openmmlab.com/mmgen/evaluation/fid_inception_pkl/cifar10.pkl) and [ImageNet1k](https://download.openmmlab.com/mmgen/evaluation/fid_inception_pkl/imagenet.pkl).

You can use following commands to extract those inception states by yourself.
```
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@

log_config = dict(interval=100, hooks=[dict(type='TextLoggerHook')])

inception_pkl = './work_dirs/inception_pkl/imagenet.pkl'
inception_pkl = ('https://download.openmmlab.com/mmgen/evaluation/'
'fid_inception_pkl/imagenet.pkl')

evaluation = dict(
type='GenerativeEvalHook',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@

log_config = dict(interval=100, hooks=[dict(type='TextLoggerHook')])

inception_pkl = './work_dirs/inception_pkl/imagenet.pkl'
inception_pkl = ('https://download.openmmlab.com/mmgen/evaluation/'
'fid_inception_pkl/imagenet.pkl')

evaluation = dict(
type='GenerativeEvalHook',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@
interval=5000)
]

inception_pkl = './work_dirs/inception_pkl/cifar10.pkl'
inception_pkl = ('https://download.openmmlab.com/mmgen/evaluation/'
'fid_inception_pkl/cifar10.pkl')

evaluation = dict(
type='GenerativeEvalHook',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
interval=5000)
]

inception_pkl = './work_dirs/inception_pkl/cifar10.pkl'
inception_pkl = ('https://download.openmmlab.com/mmgen/evaluation/'
'fid_inception_pkl/cifar10.pkl')

evaluation = dict(
type='GenerativeEvalHook',
Expand Down
27 changes: 20 additions & 7 deletions mmgen/core/evaluation/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@ def load_inception(inception_args, metric):
inceptoin_type = _inception_args.pop('type', None)

if torch.__version__ < '1.6.0':
# reset inception_args for FID (Inception for IS do not use
# inception_args)
if metric == 'FID':
_inception_args = dict(normalize_input=False)

mmcv.print_log(
'Current Pytorch Version not support script module, load '
'Inception Model from torch model zoo. If you want to use '
Expand Down Expand Up @@ -118,7 +123,7 @@ def _load_inception_torch(inception_args, metric):
assert metric in ['FID', 'IS']
if metric == 'FID':
inception_model = InceptionV3([3], **inception_args)
elif metric == 'IS':
else: # metric == 'IS'
inception_model = inception_v3(pretrained=True, transform_input=False)
mmcv.print_log(
'Load Inception V3 Network from Pytorch Model Zoo '
Expand Down Expand Up @@ -505,15 +510,23 @@ def __init__(self,
def prepare(self):
"""Prepare for evaluating models with this metric."""
# if `inception_pkl` is provided, read mean and cov stat
if self.inception_pkl is not None and mmcv.is_filepath(
self.inception_pkl):
with open(self.inception_pkl, 'rb') as f:
if self.inception_pkl is not None:
if self.inception_pkl[:4] == 'http':
inception_path = download_from_url(self.inception_pkl)
elif mmcv.is_filepath(self.inception_pkl):
inception_path = self.inception_pkl
else:
raise FileNotFoundError('Cannot load inception pkl from '
f'{self.inception_pkl}')

# load from path
with open(inception_path, 'rb') as f:
reference = pickle.load(f)
self.real_mean = reference['mean']
self.real_cov = reference['cov']
mmcv.print_log(
f'Load reference inception pkl from {self.inception_pkl}',
'mmgen')
mmcv.print_log(
f'Load reference inception pkl from {self.inception_pkl}',
'mmgen')
self.num_real_feeded = self.num_images

@torch.no_grad()
Expand Down
122 changes: 83 additions & 39 deletions tests/test_cores/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,32 +11,45 @@
from mmgen.models import build_model
from mmgen.models.architectures import InceptionV3

# def test_inception_download():
# from mmgen.core.evaluation.metrics import load_inception
# from mmgen.utils import MMGEN_CACHE_DIR

# args_FID_pytorch = dict(type='pytorch', normalize_input=False)
# args_FID_tero = dict(type='StyleGAN', inception_path='')
# args_IS_pytorch = dict(type='pytorch')
# args_IS_tero = dict(
# type='StyleGAN',
# inception_path=osp.join(MMGEN_CACHE_DIR, 'inception-2015-12-05.pt'))
def test_inception_download():
from mmgen.core.evaluation.metrics import load_inception
from mmgen.utils import MMGEN_CACHE_DIR

args_FID_pytorch = dict(type='pytorch', normalize_input=False)
args_FID_tero = dict(type='StyleGAN')
args_IS_pytorch = dict(type='pytorch')
args_IS_tero = dict(
type='StyleGAN',
inception_path=osp.join(MMGEN_CACHE_DIR, 'inception-2015-12-05.pt'))

arg_list = [args_FID_pytorch, args_FID_tero, args_IS_pytorch, args_IS_tero]
metric_list = ['FID', 'FID', 'IS', 'IS']
tar_style_list = ['pytorch', 'StyleGAN', 'pytorch', 'StyleGAN']

for inception_args, metric, tar_style in zip(arg_list, metric_list,
tar_style_list):
model, style = load_inception(inception_args, metric)

if torch.__version__ < '1.6.0':
print(inception_args, metric, tar_style)
assert style == 'pytorch'
else:
assert style == tar_style

args_empty = ''
with pytest.raises(TypeError):
load_inception(args_empty, 'FID')

# pt lower than this version cannot load Tero's inception and direct use
# torch ones, only test this for pt >= 1.6
if torch.__version__ >= '1.6.0':
args_error_path = dict(type='StyleGAN', inception_path='error-path')
with pytest.raises(RuntimeError):
load_inception(args_error_path, 'FID')

# tar_style_list = ['pytorch', 'StyleGAN', 'pytorch', 'StyleGAN']

# for inception_args, metric, tar_style in zip(
# [args_FID_pytorch, args_FID_tero, args_IS_pytorch, args_IS_tero],
# ['FID', 'FID', 'IS', 'IS'], tar_style_list):
# model, style = load_inception(inception_args, metric)
# assert style == tar_style

# args_empty = ''
# with pytest.raises(TypeError) as exc_info:
# load_inception(args_empty, 'FID')

# args_error_path = dict(type='StyleGAN', inception_path='error-path')
# with pytest.raises(RuntimeError) as exc_info:
# load_inception(args_error_path, 'FID')
with pytest.raises(AssertionError):
load_inception(dict(type='pytorch', normalize_input=False), 'PPL')


def test_swd_metric():
Expand Down Expand Up @@ -144,21 +157,52 @@ def test_fid(self):
assert fid_score > 0 and mean > 0 and cov > 0

# To reduce the size of git repo, we remove the following test
# fid = FID(
# 3,
# inception_args=dict(
# normalize_input=False, load_fid_inception=False),
# inception_pkl=osp.join(
# osp.dirname(__file__), '..', 'data', 'test_dirty.pkl'))
# assert fid.num_real_feeded == 3
# for b in self.reals:
# fid.feed(b, 'reals')

# for b in self.fakes:
# fid.feed(b, 'fakes')

# fid_score, mean, cov = fid.summary()
# assert fid_score > 0 and mean > 0 and cov > 0

inception_pkl = ('https://download.openmmlab.com/mmgen/evaluation/'
'fid_inception_pkl/cifar10.pkl')
fid = FID(
3,
inception_args=dict(
normalize_input=False, load_fid_inception=False),
inception_pkl=inception_pkl)
fid.prepare()
assert fid.num_real_feeded == 3
for b in self.reals:
fid.feed(b, 'reals')

for b in self.fakes:
fid.feed(b, 'fakes')

fid_score, mean, cov = fid.summary()
assert fid_score > 0 and mean > 0 and cov > 0

# test load
inception_pkl = osp.expanduser('~/.cache/openmmlab/mmgen/cifar10.pkl')
fid = FID(
3,
inception_args=dict(
normalize_input=False, load_fid_inception=False),
inception_pkl=inception_pkl)
fid.prepare()
assert fid.num_real_feeded == 3
for b in self.reals:
fid.feed(b, 'reals')

for b in self.fakes:
fid.feed(b, 'fakes')

fid_score, mean, cov = fid.summary()
assert fid_score > 0 and mean > 0 and cov > 0

# test raise load error
inception_pkl = 'wrong_path'
fid = FID(
3,
inception_args=dict(
normalize_input=False, load_fid_inception=False),
inception_pkl=inception_pkl)
with pytest.raises(FileNotFoundError):
fid.prepare()


class TestPR:
Expand Down