From bb942a2b4f26b16331ba1cc070ade096de1b0dc3 Mon Sep 17 00:00:00 2001 From: LeoXing Date: Sun, 23 Jan 2022 15:21:41 +0800 Subject: [PATCH 1/6] support download inception stat from url --- mmgen/core/evaluation/metrics.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/mmgen/core/evaluation/metrics.py b/mmgen/core/evaluation/metrics.py index ba7802bda..02bd9a38f 100644 --- a/mmgen/core/evaluation/metrics.py +++ b/mmgen/core/evaluation/metrics.py @@ -505,15 +505,23 @@ def __init__(self, def prepare(self): """Prepare for evaluating models with this metric.""" # if `inception_pkl` is provided, read mean and cov stat - if self.inception_pkl is not None and mmcv.is_filepath( - self.inception_pkl): - with open(self.inception_pkl, 'rb') as f: + if self.inception_pkl is not None: + if self.inception_pkl[:4] == 'http': + inception_path = download_from_url(self.inception_pkl) + elif mmcv.is_filepath(self.inception_pkl): + inception_path = self.inception_pkl + else: + raise FileNotFoundError('Cannot load inception pkl from ' + f'{self.inception_pkl}') + + # load from path + with open(inception_path, 'rb') as f: reference = pickle.load(f) self.real_mean = reference['mean'] self.real_cov = reference['cov'] - mmcv.print_log( - f'Load reference inception pkl from {self.inception_pkl}', - 'mmgen') + mmcv.print_log( + f'Load reference inception pkl from {self.inception_pkl}', + 'mmgen') self.num_real_feeded = self.num_images @torch.no_grad() From 516a2a988ce989608c99d8aeb3443c11583fa942 Mon Sep 17 00:00:00 2001 From: LeoXing Date: Sun, 23 Jan 2022 15:48:37 +0800 Subject: [PATCH 2/6] update configs --- .../biggan/biggan_ajbrock-sn_imagenet1k_128x128_b32x8_1500k.py | 3 ++- .../biggan/biggan_torch-sn_imagenet1k_128x128_b32x8_1500k.py | 3 ++- configs/improved_ddpm/README.md | 2 +- ...sine_hybird_timestep-4k_drop0.3_cifar10_32x32_b8x16_500k.py | 3 ++- ..._hybird_timestep-4k_drop0.3_imagenet1k_64x64_b8x16_1500k.py | 3 ++- ...m_cosine_hybird_timestep-4k_imagenet1k_64x64_b8x16_1500k.py | 3 ++- configs/sagan/README.md | 2 +- ..._woReLUinplace_Glr-1e-4_Dlr-4e-4_ndisc1_imagenet1k_b64x4.py | 3 ++- ...e_noaug_bigGAN_Glr-1e-4_Dlr-4e-4_ndisc1_imagenet1k_b32x8.py | 3 ++- .../sagan_32_wReLUinplace_lr-2e-4_ndisc5_cifar10_b64x1.py | 3 ++- .../sagan_32_woReLUinplace_lr-2e-4_ndisc5_cifar10_b64x1.py | 3 ++- configs/sngan_proj/README.md | 2 +- ..._wReLUinplace_Glr-2e-4_Dlr-5e-5_ndisc5_imagenet1k_b128x2.py | 3 ++- ...woReLUinplace_Glr-2e-4_Dlr-5e-5_ndisc5_imagenet1k_b128x2.py | 3 ++- .../sngan_proj_32_wReLUinplace_lr-2e-4_ndisc5_cifar10_b64x1.py | 3 ++- ...sngan_proj_32_woReLUinplace_lr-2e-4_ndisc5_cifar10_b64x1.py | 3 ++- 16 files changed, 29 insertions(+), 16 deletions(-) diff --git a/configs/biggan/biggan_ajbrock-sn_imagenet1k_128x128_b32x8_1500k.py b/configs/biggan/biggan_ajbrock-sn_imagenet1k_128x128_b32x8_1500k.py index 9d09137f2..44eda7199 100644 --- a/configs/biggan/biggan_ajbrock-sn_imagenet1k_128x128_b32x8_1500k.py +++ b/configs/biggan/biggan_ajbrock-sn_imagenet1k_128x128_b32x8_1500k.py @@ -37,7 +37,8 @@ pass_training_status=True) # Note set your inception_pkl's path -inception_pkl = 'work_dirs/inception_pkl/imagenet.pkl' +inception_pkl = ('https://download.openmmlab.com/mmgen/evaluation/' + 'fid_inception_pkl/imagenet.pkl') evaluation = dict( type='GenerativeEvalHook', interval=10000, diff --git a/configs/biggan/biggan_torch-sn_imagenet1k_128x128_b32x8_1500k.py b/configs/biggan/biggan_torch-sn_imagenet1k_128x128_b32x8_1500k.py index bafbeeea9..78404617a 100644 --- a/configs/biggan/biggan_torch-sn_imagenet1k_128x128_b32x8_1500k.py +++ b/configs/biggan/biggan_torch-sn_imagenet1k_128x128_b32x8_1500k.py @@ -40,7 +40,8 @@ pass_training_status=True) # Note set your inception_pkl's path -inception_pkl = 'work_dirs/inception_pkl/imagenet.pkl' +inception_pkl = ('https://download.openmmlab.com/mmgen/evaluation/' + 'fid_inception_pkl/imagenet.pkl') evaluation = dict( type='GenerativeEvalHook', interval=10000, diff --git a/configs/improved_ddpm/README.md b/configs/improved_ddpm/README.md index f5cd098ec..3187cd99b 100644 --- a/configs/improved_ddpm/README.md +++ b/configs/improved_ddpm/README.md @@ -49,7 +49,7 @@ Denoising diffusion probabilistic models (DDPM) are a class of generative models For FID evaluation, we follow the pipeline of [BigGAN](https://github.com/ajbrock/BigGAN-PyTorch/blob/98459431a5d618d644d54cd1e9fceb1e5045648d/calculate_inception_moments.py#L52), where the whole training set is adopted to extract inception statistics, and Pytorch Studio GAN uses 50000 randomly selected samples. Besides, we also use [Tero's Inception](https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt) for feature extraction. -You can download the preprocessed inception state by the following url: [CIFAR10](https://download.openmmlab.com/mmgen/evaluation/fid_inception_pkl/cifar10.pkl) and [ImageNet1k-64x64](https://download.openmmlab.com/mmgen/evaluation/fid_inception_pkl/imagenet_64x64.pkl). +MMGen will automatically download the preprocessed inception state by the following url: [CIFAR10](https://download.openmmlab.com/mmgen/evaluation/fid_inception_pkl/cifar10.pkl) and [ImageNet1k-64x64](https://download.openmmlab.com/mmgen/evaluation/fid_inception_pkl/imagenet_64x64.pkl). You can use following commands to extract those inception states by yourself. diff --git a/configs/improved_ddpm/ddpm_cosine_hybird_timestep-4k_drop0.3_cifar10_32x32_b8x16_500k.py b/configs/improved_ddpm/ddpm_cosine_hybird_timestep-4k_drop0.3_cifar10_32x32_b8x16_500k.py index 9f1b125cb..295294762 100644 --- a/configs/improved_ddpm/ddpm_cosine_hybird_timestep-4k_drop0.3_cifar10_32x32_b8x16_500k.py +++ b/configs/improved_ddpm/ddpm_cosine_hybird_timestep-4k_drop0.3_cifar10_32x32_b8x16_500k.py @@ -36,7 +36,8 @@ is_dynamic_ddp=False, # Note that this flag should be False. pass_training_status=True) -inception_pkl = './work_dirs/inception_pkl/cifar10.pkl' +inception_pkl = ('https://download.openmmlab.com/mmgen/evaluation/' + 'fid_inception_pkl/cifar10.pkl') metrics = dict( fid50k=dict( type='FID', diff --git a/configs/improved_ddpm/ddpm_cosine_hybird_timestep-4k_drop0.3_imagenet1k_64x64_b8x16_1500k.py b/configs/improved_ddpm/ddpm_cosine_hybird_timestep-4k_drop0.3_imagenet1k_64x64_b8x16_1500k.py index 3eeb7df0a..5868aa7f3 100644 --- a/configs/improved_ddpm/ddpm_cosine_hybird_timestep-4k_drop0.3_imagenet1k_64x64_b8x16_1500k.py +++ b/configs/improved_ddpm/ddpm_cosine_hybird_timestep-4k_drop0.3_imagenet1k_64x64_b8x16_1500k.py @@ -39,7 +39,8 @@ is_dynamic_ddp=False, # Note that this flag should be False. pass_training_status=True) -inception_pkl = './work_dirs/inception_pkl/imagenet_64x64.pkl' +inception_pkl = ('https://download.openmmlab.com/mmgen/evaluation/' + 'fid_inception_pkl/imagenet_64x64.pkl') metrics = dict( fid50k=dict( type='FID', diff --git a/configs/improved_ddpm/ddpm_cosine_hybird_timestep-4k_imagenet1k_64x64_b8x16_1500k.py b/configs/improved_ddpm/ddpm_cosine_hybird_timestep-4k_imagenet1k_64x64_b8x16_1500k.py index 3bdc54d27..eabb00a2f 100644 --- a/configs/improved_ddpm/ddpm_cosine_hybird_timestep-4k_imagenet1k_64x64_b8x16_1500k.py +++ b/configs/improved_ddpm/ddpm_cosine_hybird_timestep-4k_imagenet1k_64x64_b8x16_1500k.py @@ -36,7 +36,8 @@ is_dynamic_ddp=False, # Note that this flag should be False. pass_training_status=True) -inception_pkl = './work_dirs/inception_pkl/imagenet_64x64.pkl' +inception_pkl = ('https://download.openmmlab.com/mmgen/evaluation/' + 'fid_inception_pkl/imagenet_64x64.pkl') metrics = dict( fid50k=dict( type='FID', diff --git a/configs/sagan/README.md b/configs/sagan/README.md index 2238d8033..ce25f89bd 100644 --- a/configs/sagan/README.md +++ b/configs/sagan/README.md @@ -71,7 +71,7 @@ For IS metric, our implementation is different from PyTorch-Studio GAN in the fo For FID evaluation, we follow the pipeline of [BigGAN](https://github.com/ajbrock/BigGAN-PyTorch/blob/98459431a5d618d644d54cd1e9fceb1e5045648d/calculate_inception_moments.py#L52), where the whole training set is adopted to extract inception statistics, and Pytorch Studio GAN uses 50000 randomly selected samples. Besides, we also use [Tero's Inception](https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt) for feature extraction. -You can download the preprocessed inception state by the following url: [CIFAR10](https://download.openmmlab.com/mmgen/evaluation/fid_inception_pkl/cifar10.pkl) and [ImageNet1k](https://download.openmmlab.com/mmgen/evaluation/fid_inception_pkl/imagenet.pkl). +MMGen will automatically download the preprocessed inception state by the following url: [CIFAR10](https://download.openmmlab.com/mmgen/evaluation/fid_inception_pkl/cifar10.pkl) and [ImageNet1k](https://download.openmmlab.com/mmgen/evaluation/fid_inception_pkl/imagenet.pkl). You can use following commands to extract those inception states by yourself. ``` diff --git a/configs/sagan/sagan_128_woReLUinplace_Glr-1e-4_Dlr-4e-4_ndisc1_imagenet1k_b64x4.py b/configs/sagan/sagan_128_woReLUinplace_Glr-1e-4_Dlr-4e-4_ndisc1_imagenet1k_b64x4.py index ea431e71a..430320548 100644 --- a/configs/sagan/sagan_128_woReLUinplace_Glr-1e-4_Dlr-4e-4_ndisc1_imagenet1k_b64x4.py +++ b/configs/sagan/sagan_128_woReLUinplace_Glr-1e-4_Dlr-4e-4_ndisc1_imagenet1k_b64x4.py @@ -18,7 +18,8 @@ interval=1000) ] -inception_pkl = './work_dirs/inception_pkl/imagenet.pkl' +inception_pkl = ('https://download.openmmlab.com/mmgen/evaluation/' + 'fid_inception_pkl/imagenet.pkl') evaluation = dict( type='GenerativeEvalHook', diff --git a/configs/sagan/sagan_128_woReLUinplace_noaug_bigGAN_Glr-1e-4_Dlr-4e-4_ndisc1_imagenet1k_b32x8.py b/configs/sagan/sagan_128_woReLUinplace_noaug_bigGAN_Glr-1e-4_Dlr-4e-4_ndisc1_imagenet1k_b32x8.py index 5e1088819..33c7234db 100644 --- a/configs/sagan/sagan_128_woReLUinplace_noaug_bigGAN_Glr-1e-4_Dlr-4e-4_ndisc1_imagenet1k_b32x8.py +++ b/configs/sagan/sagan_128_woReLUinplace_noaug_bigGAN_Glr-1e-4_Dlr-4e-4_ndisc1_imagenet1k_b32x8.py @@ -45,7 +45,8 @@ priority='VERY_HIGH') ] -inception_pkl = './work_dirs/inception_pkl/imagenet.pkl' +inception_pkl = ('https://download.openmmlab.com/mmgen/evaluation/' + 'fid_inception_pkl/imagenet.pkl') evaluation = dict( type='GenerativeEvalHook', diff --git a/configs/sagan/sagan_32_wReLUinplace_lr-2e-4_ndisc5_cifar10_b64x1.py b/configs/sagan/sagan_32_wReLUinplace_lr-2e-4_ndisc5_cifar10_b64x1.py index 98abcf45e..5c04d9525 100644 --- a/configs/sagan/sagan_32_wReLUinplace_lr-2e-4_ndisc5_cifar10_b64x1.py +++ b/configs/sagan/sagan_32_wReLUinplace_lr-2e-4_ndisc5_cifar10_b64x1.py @@ -25,7 +25,8 @@ interval=1000) ] -inception_pkl = './work_dirs/inception_pkl/cifar10.pkl' +inception_pkl = ('https://download.openmmlab.com/mmgen/evaluation/' + 'fid_inception_pkl/cifar10.pkl') evaluation = dict( type='GenerativeEvalHook', diff --git a/configs/sagan/sagan_32_woReLUinplace_lr-2e-4_ndisc5_cifar10_b64x1.py b/configs/sagan/sagan_32_woReLUinplace_lr-2e-4_ndisc5_cifar10_b64x1.py index 395ca8ce9..7cb487208 100644 --- a/configs/sagan/sagan_32_woReLUinplace_lr-2e-4_ndisc5_cifar10_b64x1.py +++ b/configs/sagan/sagan_32_woReLUinplace_lr-2e-4_ndisc5_cifar10_b64x1.py @@ -18,7 +18,8 @@ interval=1000) ] -inception_pkl = './work_dirs/inception_pkl/cifar10.pkl' +inception_pkl = ('https://download.openmmlab.com/mmgen/evaluation/' + 'fid_inception_pkl/cifar10.pkl') evaluation = dict( type='GenerativeEvalHook', diff --git a/configs/sngan_proj/README.md b/configs/sngan_proj/README.md index 09a4f93a6..3660b3e65 100644 --- a/configs/sngan_proj/README.md +++ b/configs/sngan_proj/README.md @@ -70,7 +70,7 @@ For IS metric, our implementation is different from PyTorch-Studio GAN in the fo For FID evaluation, we follow the pipeline of [BigGAN](https://github.com/ajbrock/BigGAN-PyTorch/blob/98459431a5d618d644d54cd1e9fceb1e5045648d/calculate_inception_moments.py#L52), where the whole training set is adopted to extract inception statistics, and Pytorch Studio GAN uses 50000 randomly selected samples. Besides, we also use [Tero's Inception](https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt) for feature extraction. -You can download the preprocessed inception state by the following url: [CIFAR10](https://download.openmmlab.com/mmgen/evaluation/fid_inception_pkl/cifar10.pkl) and [ImageNet1k](https://download.openmmlab.com/mmgen/evaluation/fid_inception_pkl/imagenet.pkl). +MMGen will automatically download the preprocessed inception state by the following url: [CIFAR10](https://download.openmmlab.com/mmgen/evaluation/fid_inception_pkl/cifar10.pkl) and [ImageNet1k](https://download.openmmlab.com/mmgen/evaluation/fid_inception_pkl/imagenet.pkl). You can use following commands to extract those inception states by yourself. ``` diff --git a/configs/sngan_proj/sngan_proj_128_wReLUinplace_Glr-2e-4_Dlr-5e-5_ndisc5_imagenet1k_b128x2.py b/configs/sngan_proj/sngan_proj_128_wReLUinplace_Glr-2e-4_Dlr-5e-5_ndisc5_imagenet1k_b128x2.py index 90078d2b8..e51141095 100644 --- a/configs/sngan_proj/sngan_proj_128_wReLUinplace_Glr-2e-4_Dlr-5e-5_ndisc5_imagenet1k_b128x2.py +++ b/configs/sngan_proj/sngan_proj_128_wReLUinplace_Glr-2e-4_Dlr-5e-5_ndisc5_imagenet1k_b128x2.py @@ -30,7 +30,8 @@ log_config = dict(interval=100, hooks=[dict(type='TextLoggerHook')]) -inception_pkl = './work_dirs/inception_pkl/imagenet.pkl' +inception_pkl = ('https://download.openmmlab.com/mmgen/evaluation/' + 'fid_inception_pkl/imagenet.pkl') evaluation = dict( type='GenerativeEvalHook', diff --git a/configs/sngan_proj/sngan_proj_128_woReLUinplace_Glr-2e-4_Dlr-5e-5_ndisc5_imagenet1k_b128x2.py b/configs/sngan_proj/sngan_proj_128_woReLUinplace_Glr-2e-4_Dlr-5e-5_ndisc5_imagenet1k_b128x2.py index 96f86a3cc..5bfed2de3 100644 --- a/configs/sngan_proj/sngan_proj_128_woReLUinplace_Glr-2e-4_Dlr-5e-5_ndisc5_imagenet1k_b128x2.py +++ b/configs/sngan_proj/sngan_proj_128_woReLUinplace_Glr-2e-4_Dlr-5e-5_ndisc5_imagenet1k_b128x2.py @@ -24,7 +24,8 @@ log_config = dict(interval=100, hooks=[dict(type='TextLoggerHook')]) -inception_pkl = './work_dirs/inception_pkl/imagenet.pkl' +inception_pkl = ('https://download.openmmlab.com/mmgen/evaluation/' + 'fid_inception_pkl/imagenet.pkl') evaluation = dict( type='GenerativeEvalHook', diff --git a/configs/sngan_proj/sngan_proj_32_wReLUinplace_lr-2e-4_ndisc5_cifar10_b64x1.py b/configs/sngan_proj/sngan_proj_32_wReLUinplace_lr-2e-4_ndisc5_cifar10_b64x1.py index 4d6bd2d91..699d64f0f 100644 --- a/configs/sngan_proj/sngan_proj_32_wReLUinplace_lr-2e-4_ndisc5_cifar10_b64x1.py +++ b/configs/sngan_proj/sngan_proj_32_wReLUinplace_lr-2e-4_ndisc5_cifar10_b64x1.py @@ -28,7 +28,8 @@ interval=5000) ] -inception_pkl = './work_dirs/inception_pkl/cifar10.pkl' +inception_pkl = ('https://download.openmmlab.com/mmgen/evaluation/' + 'fid_inception_pkl/cifar10.pkl') evaluation = dict( type='GenerativeEvalHook', diff --git a/configs/sngan_proj/sngan_proj_32_woReLUinplace_lr-2e-4_ndisc5_cifar10_b64x1.py b/configs/sngan_proj/sngan_proj_32_woReLUinplace_lr-2e-4_ndisc5_cifar10_b64x1.py index 8db95fea7..86da61a49 100644 --- a/configs/sngan_proj/sngan_proj_32_woReLUinplace_lr-2e-4_ndisc5_cifar10_b64x1.py +++ b/configs/sngan_proj/sngan_proj_32_woReLUinplace_lr-2e-4_ndisc5_cifar10_b64x1.py @@ -22,7 +22,8 @@ interval=5000) ] -inception_pkl = './work_dirs/inception_pkl/cifar10.pkl' +inception_pkl = ('https://download.openmmlab.com/mmgen/evaluation/' + 'fid_inception_pkl/cifar10.pkl') evaluation = dict( type='GenerativeEvalHook', From 16a80886b57fd32b24ea23ac2710377e3f647246 Mon Sep 17 00:00:00 2001 From: LeoXing Date: Sun, 23 Jan 2022 16:17:01 +0800 Subject: [PATCH 3/6] update unit test --- mmgen/core/evaluation/metrics.py | 2 +- tests/test_cores/test_metrics.py | 81 ++++++++++++++++++-------------- 2 files changed, 46 insertions(+), 37 deletions(-) diff --git a/mmgen/core/evaluation/metrics.py b/mmgen/core/evaluation/metrics.py index 02bd9a38f..8e5074879 100644 --- a/mmgen/core/evaluation/metrics.py +++ b/mmgen/core/evaluation/metrics.py @@ -118,7 +118,7 @@ def _load_inception_torch(inception_args, metric): assert metric in ['FID', 'IS'] if metric == 'FID': inception_model = InceptionV3([3], **inception_args) - elif metric == 'IS': + else: # metric == 'IS' inception_model = inception_v3(pretrained=True, transform_input=False) mmcv.print_log( 'Load Inception V3 Network from Pytorch Model Zoo ' diff --git a/tests/test_cores/test_metrics.py b/tests/test_cores/test_metrics.py index a5aa46a62..0133cf95a 100644 --- a/tests/test_cores/test_metrics.py +++ b/tests/test_cores/test_metrics.py @@ -11,32 +11,38 @@ from mmgen.models import build_model from mmgen.models.architectures import InceptionV3 -# def test_inception_download(): -# from mmgen.core.evaluation.metrics import load_inception -# from mmgen.utils import MMGEN_CACHE_DIR -# args_FID_pytorch = dict(type='pytorch', normalize_input=False) -# args_FID_tero = dict(type='StyleGAN', inception_path='') -# args_IS_pytorch = dict(type='pytorch') -# args_IS_tero = dict( -# type='StyleGAN', -# inception_path=osp.join(MMGEN_CACHE_DIR, 'inception-2015-12-05.pt')) +def test_inception_download(): + from mmgen.core.evaluation.metrics import load_inception + from mmgen.utils import MMGEN_CACHE_DIR -# tar_style_list = ['pytorch', 'StyleGAN', 'pytorch', 'StyleGAN'] + args_FID_pytorch = dict(type='pytorch', normalize_input=False) + args_FID_tero = dict(type='StyleGAN') + args_IS_pytorch = dict(type='pytorch') + args_IS_tero = dict( + type='StyleGAN', + inception_path=osp.join(MMGEN_CACHE_DIR, 'inception-2015-12-05.pt')) -# for inception_args, metric, tar_style in zip( -# [args_FID_pytorch, args_FID_tero, args_IS_pytorch, args_IS_tero], -# ['FID', 'FID', 'IS', 'IS'], tar_style_list): -# model, style = load_inception(inception_args, metric) -# assert style == tar_style + arg_list = [args_FID_pytorch, args_FID_tero, args_IS_pytorch, args_IS_tero] + metric_list = ['FID', 'FID', 'IS', 'IS'] + tar_style_list = ['pytorch', 'StyleGAN', 'pytorch', 'StyleGAN'] -# args_empty = '' -# with pytest.raises(TypeError) as exc_info: -# load_inception(args_empty, 'FID') + for inception_args, metric, tar_style in zip(arg_list, metric_list, + tar_style_list): + model, style = load_inception(inception_args, metric) + assert style == tar_style + assert style == tar_style -# args_error_path = dict(type='StyleGAN', inception_path='error-path') -# with pytest.raises(RuntimeError) as exc_info: -# load_inception(args_error_path, 'FID') + args_empty = '' + with pytest.raises(TypeError): + load_inception(args_empty, 'FID') + + args_error_path = dict(type='StyleGAN', inception_path='error-path') + with pytest.raises(RuntimeError): + load_inception(args_error_path, 'FID') + + with pytest.raises(AssertionError): + load_inception(dict(type='pytorch', normalize_input=False), 'PPL') def test_swd_metric(): @@ -144,21 +150,24 @@ def test_fid(self): assert fid_score > 0 and mean > 0 and cov > 0 # To reduce the size of git repo, we remove the following test - # fid = FID( - # 3, - # inception_args=dict( - # normalize_input=False, load_fid_inception=False), - # inception_pkl=osp.join( - # osp.dirname(__file__), '..', 'data', 'test_dirty.pkl')) - # assert fid.num_real_feeded == 3 - # for b in self.reals: - # fid.feed(b, 'reals') - - # for b in self.fakes: - # fid.feed(b, 'fakes') - - # fid_score, mean, cov = fid.summary() - # assert fid_score > 0 and mean > 0 and cov > 0 + + inception_pkl = ('https://download.openmmlab.com/mmgen/evaluation/' + 'fid_inception_pkl/cifar10.pkl') + fid = FID( + 3, + inception_args=dict( + normalize_input=False, load_fid_inception=False), + inception_pkl=inception_pkl) + fid.prepare() + assert fid.num_real_feeded == 3 + for b in self.reals: + fid.feed(b, 'reals') + + for b in self.fakes: + fid.feed(b, 'fakes') + + fid_score, mean, cov = fid.summary() + assert fid_score > 0 and mean > 0 and cov > 0 class TestPR: From 370fdea632196b18f6cae96ccd35a2797646cba7 Mon Sep 17 00:00:00 2001 From: LeoXing Date: Sun, 23 Jan 2022 18:48:09 +0800 Subject: [PATCH 4/6] fix bug in pt151 --- tests/test_cores/test_metrics.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/test_cores/test_metrics.py b/tests/test_cores/test_metrics.py index 0133cf95a..e03cac282 100644 --- a/tests/test_cores/test_metrics.py +++ b/tests/test_cores/test_metrics.py @@ -30,8 +30,11 @@ def test_inception_download(): for inception_args, metric, tar_style in zip(arg_list, metric_list, tar_style_list): model, style = load_inception(inception_args, metric) - assert style == tar_style - assert style == tar_style + + if torch.__version__ < '1.6.0': + assert style == 'pytorch' + else: + assert style == tar_style args_empty = '' with pytest.raises(TypeError): From 9898ceb2fc38720bb447e1503f026f47d294c65f Mon Sep 17 00:00:00 2001 From: LeoXing Date: Sun, 23 Jan 2022 20:29:03 +0800 Subject: [PATCH 5/6] revise inception load behavior for pt <= 1.6 --- mmgen/core/evaluation/metrics.py | 5 +++++ tests/test_cores/test_metrics.py | 9 ++++++--- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/mmgen/core/evaluation/metrics.py b/mmgen/core/evaluation/metrics.py index 8e5074879..647ba76da 100644 --- a/mmgen/core/evaluation/metrics.py +++ b/mmgen/core/evaluation/metrics.py @@ -57,6 +57,11 @@ def load_inception(inception_args, metric): inceptoin_type = _inception_args.pop('type', None) if torch.__version__ < '1.6.0': + # reset inception_args for FID (Inception for IS do not use + # inception_args) + if metric == 'FID': + _inception_args = dict(normalize_input=False) + mmcv.print_log( 'Current Pytorch Version not support script module, load ' 'Inception Model from torch model zoo. If you want to use ' diff --git a/tests/test_cores/test_metrics.py b/tests/test_cores/test_metrics.py index e03cac282..d7be2ccfd 100644 --- a/tests/test_cores/test_metrics.py +++ b/tests/test_cores/test_metrics.py @@ -40,9 +40,12 @@ def test_inception_download(): with pytest.raises(TypeError): load_inception(args_empty, 'FID') - args_error_path = dict(type='StyleGAN', inception_path='error-path') - with pytest.raises(RuntimeError): - load_inception(args_error_path, 'FID') + # pt lower than this version cannot load Tero's inception and direct use + # torch ones, only test this for pt >= 1.6 + if torch.__version__ >= '1.6.0': + args_error_path = dict(type='StyleGAN', inception_path='error-path') + with pytest.raises(RuntimeError): + load_inception(args_error_path, 'FID') with pytest.raises(AssertionError): load_inception(dict(type='pytorch', normalize_input=False), 'PPL') From 18b5c818e4383ec2a1e4c2fd81225cd7627f3075 Mon Sep 17 00:00:00 2001 From: LeoXing Date: Sun, 23 Jan 2022 21:15:41 +0800 Subject: [PATCH 6/6] try to cover more unit test --- tests/test_cores/test_metrics.py | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/tests/test_cores/test_metrics.py b/tests/test_cores/test_metrics.py index d7be2ccfd..2e5d27cc3 100644 --- a/tests/test_cores/test_metrics.py +++ b/tests/test_cores/test_metrics.py @@ -32,6 +32,7 @@ def test_inception_download(): model, style = load_inception(inception_args, metric) if torch.__version__ < '1.6.0': + print(inception_args, metric, tar_style) assert style == 'pytorch' else: assert style == tar_style @@ -175,6 +176,34 @@ def test_fid(self): fid_score, mean, cov = fid.summary() assert fid_score > 0 and mean > 0 and cov > 0 + # test load + inception_pkl = osp.expanduser('~/.cache/openmmlab/mmgen/cifar10.pkl') + fid = FID( + 3, + inception_args=dict( + normalize_input=False, load_fid_inception=False), + inception_pkl=inception_pkl) + fid.prepare() + assert fid.num_real_feeded == 3 + for b in self.reals: + fid.feed(b, 'reals') + + for b in self.fakes: + fid.feed(b, 'fakes') + + fid_score, mean, cov = fid.summary() + assert fid_score > 0 and mean > 0 and cov > 0 + + # test raise load error + inception_pkl = 'wrong_path' + fid = FID( + 3, + inception_args=dict( + normalize_input=False, load_fid_inception=False), + inception_pkl=inception_pkl) + with pytest.raises(FileNotFoundError): + fid.prepare() + class TestPR: