A repository that implements perturbation learning code, capable of learning perturbation sets from data for MNIST, CIFAR10, and Multi-Illumination datasets. Created by Eric Wong with Zico Kolter, with the code structure loosely based off of the robustness repostory here. See our paper on arXiv here and our corresponding blog post.
- 7/16/2020 - Paper and blog post released
One of the core tenents of making machine learning models that are robust to adversarial attacks is to define the threat model that contains all possible perturbations, which is critical for performing a proper robustness evaluation. However, well-defined threat models are have been largely limited to mathematically nice sets that can be described a priori, such as the Lp ball or Wasserstein metric, whereas many real-world transformations may be impossible to define mathematically. This work aims to bridge this gap, by learning perturbation sets as being generated from an Lp ball in an underlying latent space. This simple characterization of a perturbation set allows us to leverage state of the art approaches in adversarial training directly in the latent space, while at the same time capturing complex real-world perturbations.
configs/
contains configuration files to train perturbation setsconfigs_eval/
contains configuration files to evaluate perturbation setsconfigs_robust/
contains configuration files to train robust models (e.g. with perturbation sets and data augmentation baselines)configs_attack/
contains configuration files to evaluate robust models (e.g. with standard and robust metrics)
- To train a perturbation set with configuration
<perturbation>
, usepython train.py --config configs/<perturbation>.json
- To evaluate a perturbation set with configuration
<perturbation>
on a metric<metric>, use
python eval.py --config configs/.json --eval-config configs_eval/.json` - To gather perturbation statistics on the validation set (e.g. to get the maximum radius for the latent space) with configuration
<perturbation>
, usepython latent_distances.py <perturbation>
- To train a robust model with
<method>
, usepython robust_train.py --config configs_robust/<method>.json
- To evaluate a robust model trained with
<method>
on an attack<attack>
, usepython robust_eval.py --config configs_robust/<method>.json --config-attack configs_attack/<attack>.json
A full list of all commands and configurations to run all experiments in the paper is in all.sh
.
We include some convenience scripts for generating the CIFAR10 common corruptions and Multi-Illumination in datasets/
, which are based on the corresponding official repositories
- CIFAR10 common corruptions: https://github.com/hendrycks/robustness
- Multi illumination: https://github.com/lmurmann/multi_illumination
We have released the following files here: https://drive.google.com/drive/folders/1azgGxCHNuLO2ef_FHkxl-o8WNn-mWc1v?usp=sharing
- Pretrained model weights for all learned perturbation sets
- Pretrained model weights for downstream robustly trained classifiers
- A copy of the CIFAR10C dataset (both train and test)
Within this folder, perturbation_sets/
contains the model weights for the following learned perturbation sets:
- CIFAR10 common corruptions learned perturbation sets using the three strategies from Table 5+6
- MI learned perturbation sets at varying resolutions from Table 8+9
Additional, robust_models/
contains the model weights for the robust models trained using our CVAE perturbation sets:
- CIFAR10 classifier trained with CVAE adversarial training from Table 1
- CIFAR10 classifier trained with CVAE data augmentation from Table 1
- Certifiably robust CIFAR10 classifier using randomized smoothing at the three noise levels in Table 7.
- MI segmentation model trained with CVAE adversarial training from Table 10
- MI segmentation model trained with CVAE data augmentation from Table 10
- Certifiably robust MI segmentation model using randomized smoothing as reported in Figure 16.
perturbation_learning/
contains the main components of the CVAE perturbation sets.cvae.py
contains the general CVAE implementation for a generic encoder, decoder, and prior network. The CVAE is further implemented by defining a module with these corresponding components as follows:mnist_conv.py
,mnist_fc.py
, andmnist_stn.py
implement the networks for MNIST perturbation sets, with the last one usingSTNModule.py
to implement the spatial transformer.cifar10_rectangle.py
implements residual networks for the CIFAR10 common corruptions perturbation setsmi_unet.py
implements UNets for multi-illumination perturbation sets, using the UNet components defined inunet_parts.py
.scaled_tanh.py
implements the scaled tanh activation function used to stabilize the log variance prediction
datasets.py
contains the dataloaders for the MNIST/CIFAR10/MI datasets and their corresponding perturbed datasetsperturbations.py
contains the MNIST perturbations defined by PyTorch transform
robustness/
contains the main components of the robustness experiments using these CVAE perturbation setsclassifiers.py
aggregates the models and interfaces to a config filewideresnet.py
implements a standard WideResnet for CIFAR10 classificationunet_model.py
implements a standard UNet for MI material segmentation usingunet_parts.py
smoothing_core.py
implements a generic randomized smoothing prediction and certification wrapper for CVAE models
datasets/
contains come convenience scripts for fetching and generating the CIFAR10 common corruptions dataset and the MI dataset