PyTorch Out-of-Distribution Detection
A Python library for Out-of-Distribution (OOD) Detection with Deep Neural Networks based on PyTorch.
The library provides:
Out-of-Distribution Detection Methods
Loss Functions
Datasets
Neural Network Architectures, as well as pre-trained weights
Data Augmentations
Useful Utilities
and is designed to be compatible with frameworks
like pytorch-lightning and
pytorch-segmentation-models .
The library also covers some methods from closely related fields, such as Open-Set Recognition, Novelty Detection,
Confidence Estimation and Anomaly Detection.
The documentation is available here .
NOTE : An important convention adopted in pytorch-ood
is that OOD detectors predict outlier scores
that should be larger for outliers than for inliers.
If you notice that the scores predicted by a detector do not match the formulas in the corresponding publication, we may have adjusted the score calculation to comply with this convention.
Load a WideResNet-40 model (used in major publications), pre-trained on CIFAR-10 with the Energy-Bounded Learning Loss [8] (weights from to original paper), and predict on some dataset data_loader
using
Energy-based OOD Detection (EBO) [8] , calculating the common metrics.
OOD data must be marked with labels < 0.
from pytorch_ood .detector import EnergyBased
from pytorch_ood .utils import OODMetrics
from pytorch_ood .model import WideResNet
data_loader = ... # your data, OOD with label < 0
# Create Neural Network
model = WideResNet (num_classes = 10 , pretrained = "er-cifar10-tune" ).eval ().cuda ()
preprocess = WideResNet .transform_for ("er-cifar10-tune" )
# Create detector
detector = EnergyBased (model )
# Evaluate
metrics = OODMetrics ()
for x , y in data_loader :
x = preprocess (x ).cuda ()
metrics .update (detector (x , y )
print (metrics .compute ())
You can find more examples in the documentation .
Evaluate detectors against common benchmarks, for example the OpenOOD ImageNet benchmark
(including ImageNet-O, OpenImages-O, Textures, SVHN, MNIST). All datasets (except for ImageNet itself) will be downloaded automatically.
import pandas as pd
from pytorch_ood .benchmark import ImageNet_OpenOOD
from pytorch_ood .detector import MaxSoftmax
from torchvision .models import resnet50
from torchvision .models .resnet import ResNet50_Weights
model = resnet50 (ResNet50_Weights .IMAGENET1K_V1 ).eval ().to ("cuda:0" )
trans = ResNet50_Weights .IMAGENET1K_V1 .transforms ()
benchmark = ImageNet_OpenOOD (root = "data" , image_net_root = "data/imagenet-2012/" , transform = trans )
detector = MaxSoftmax (model )
results = benchmark .evaluate (detector , loader_kwargs = {"batch_size" : 64 }, device = "cuda:0" )
df = pd .DataFrame (results )
print (df )
This produces the following table:
Dataset
AUROC
AUPR-IN
AUPR-OUT
FPR95TPR
ImageNetO
28.64
2.52
94.85
91.20
OpenImagesO
84.98
62.61
94.67
49.95
Textures
80.46
37.50
96.80
67.75
SVHN
97.62
95.56
98.77
11.58
MNIST
90.04
90.45
89.88
39.03
The package can be installed via PyPI:
Dependencies
torch
torchvision
scipy
torchmetrics
Optional Dependencies
scikit-learn
for ViM
gdown
to download some datasets and model weights
pandas
for the examples .
segmentation-models-pytorch
to run the examples for anomaly segmentation
Detectors :
Detector
Description
Year
Ref
OpenMax
Implementation of the OpenMax Layer as proposed in the paper Towards Open Set Deep Networks .
2016
[1]
Monte Carlo Dropout
Implements Monte Carlo Dropout.
2016
[4]
Maximum Softmax Probability
Implements the Softmax Baseline for OOD and Error detection.
2017
[5]
Temperature Scaling
Implements the Temperatur Scaling for Softmax.
2017
[6]
ODIN
ODIN is a preprocessing method for inputs that aims to increase the discriminability of
the softmax outputs for In- and Out-of-Distribution data.
2018
[2]
Mahalanobis
Implements the Mahalanobis Method.
2018
[3]
Energy-Based OOD Detection
Implements the Energy Score of Energy-based Out-of-distribution Detection .
2020
[8]
Entropy
Uses entropy to detect OOD inputs.
2021
[40]
ReAct
ReAct: Out-of-distribution Detection With Rectified Activations.
2021
[44]
Maximum Logit
Implements the MaxLogit method.
2022
[24]
KL-Matching
Implements the KL-Matching method for Multi-Class classification.
2022
[24]
ViM
Implements Virtual Logit Matching.
2022
[36]
Weighted Energy-Based
Implements Weighted Energy-Based for OOD Detection
2022
[37]
Nearest Neighbor
Implements Depp Nearest Neighbors for OOD Detection
2022
[38]
DICE
Implements Sparsification for OOD Detection
2022
[41]
ASH
Implements Extremely Simple Activation Shaping
2023
[42]
SHE
Implements Simplified Hopfield Networks
2023
[43]
Objective Functions :
Objective Function
Description
Year
Ref
Objectosphere
Implementation of the paper Reducing Network Agnostophobia .
2016
[9]
Center Loss
Generalized version of the Center Loss from the Paper A Discriminative Feature Learning
Approach for Deep Face Recognition .
2016
[14]
Outlier Exposure
Implementation of the paper Deep Anomaly Detection With Outlier Exposure .
2018
[10]
Confidence Loss
Model learn confidence additional to class membership prediction.
2018
[7]
Deep SVDD
Implementation of the Deep Support Vector Data Description from the paper Deep One-Class
Classification .
2018
[11]
Energy-Bounded Loss
Adds a regularization term to the cross-entropy that aims to increase the energy gap between IN
and OOD samples.
2020
[8]
CAC Loss
Class Anchor Clustering Loss from Class Anchor Clustering: a Distance-based Loss for Training
Open Set Classifiers
2021
[13]
Entropic Open-Set Loss
Entropy maximization and meta classification for OOD in semantic segmentation
2021
[40]
II Loss
Implementation of II Loss function from Learning a neural network-based representation for
open set recognition .
2022
[12]
MCHAD Loss
Implementation of the MCHAD Loss from the paper Multi Class Hypersphere Anomaly Detection .
2022
[35]
VOS Energy-Based Loss
Implementation of the paper VOS: Learning what you don’t know by virtual outlier synthesis .
2022
[37]
Image Datasets :
Dataset
Description
Year
Ref
Chars74k
The Chars74K dataset contains 74,000 images across 64 classes, comprising English letters and Arabic numerals.
2012
[31]
TinyImages
The TinyImages dataset is often used as auxiliary OOD training data. However, use is discouraged.
2012
[30]
Textures
Textures dataset, also known as DTD, often used as OOD Examples.
2013
[29]
FoolingImages
OOD Images Generated to fool certain Deep Neural Networks.
2015
[16]
Tiny ImageNet
A derived version of ImageNet with 64x64-sized images.
2015
[17]
TinyImages300k
A cleaned version of the TinyImages Dataset with 300.000 images, often used as auxiliary OOD training data.
2018
[10]
LSUN
A version of the Large-scale Scene UNderstanding Dataset with 10.000 images, often used as auxiliary
OOD training data.
2018
[2]
MNIST-C
Corrupted version of the MNIST.
2019
[21]
CIFAR10-C
Corrupted version of the CIFAR 10.
2019
[15]
CIFAR100-C
Corrupted version of the CIFAR 100.
2019
[15]
ImageNet-C
Corrupted version of the ImageNet.
2019
[15]
ImageNet - A, O, R
Different Outlier Variants for the ImageNet.
2019
[18]
ImageNet - V2
A new test set for the ImageNet.
2019
[19]
ImageNet - ES
Event stream (ES) version of the ImageNet.
2021
[20]
iNaturalist
A Subset of iNaturalist, with 10.000 images.
2021
[34]
Fractals
A dataset with Fractals from PIXMIX: Dreamlike Pictures Comprehensively Improve Safety Measures
2022
[39]
Feature
Visualizations
A dataset with Feature visualizations from PIXMIX: Dreamlike Pictures Comprehensively Improve Safety Measures
2022
[39]
FS Static
The FishyScapes (FS) Static dataset contains real world OOD images from the CityScapes dataset.
2021
[22]
FS LostAndFound
The FishyScapes dataset contains images from the CityScapes dataset blended with unknown objects scraped from
the web.
2021
[22]
MVTech-AD
The MVTec AD is a dataset for benchmarking anomaly detection methods with a focus on industrial inspection.
2021
[23]
StreetHazards
Anomaly Segmentation Dataset
2022
[24]
CIFAR100-GAN
Images sampled from low likelihood regions of a BigGAN trained on CIFAR 100 from the paper On Outlier Exposure
with Generative Models.
2022
[25]
SSB - hard
The hard split of the Semantic Shift Benchmark, which contains 49.00 images.
2022
[26]
NINCO
The NINCO (No ImageNet Class Objects) dataset which contains 5.879 images of 64 OOD classes.
2023
[27]
SuMNIST
The SuMNIST dataset is based on MNIST but each image display four numbers instead of one.
2023
[28]
Gaussian Noise
Dataset with samples drawn from a normal distribution.
Uniform Noise
Dataset with samples drawn from a uniform distribution.
Text Datasets :
Dataset
Description
Year
Ref
Multi30k
Multi-30k dataset, as used by Hendrycks et al. in the OOD baseline paper.
2016
[32]
WikiText2
Texts from the wikipedia often used as auxiliary OOD training data.
2016
[33]
WikiText103
Texts from the wikipedia often used as auxiliary OOD training data.
2016
[33]
NewsGroup20
Textx from different newsgroups, as used by Hendrycks et al. in the OOD baseline paper.
Augmentation Methods :
Augmentation
Description
Year
Ref
PixMix
PixMix image augmentation method
2022
[39]
COCO Outlier Pasting
From "Entropy maximization and meta classification for OOD in semantic segmentation"
2021
[40]
We encourage everyone to contribute to this project by adding implementations of OOD Detection methods, datasets etc,
or check the existing implementations for bugs.
pytorch-ood
was presented at a CVPR Workshop in 2022.
If you use it in a scientific publication, please consider citing:
@InProceedings{kirchheim2022pytorch,
author = {Kirchheim, Konstantin and Filax, Marco and Ortmeier, Frank},
title = {PyTorch-OOD: A Library for Out-of-Distribution Detection Based on PyTorch},
booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR) Workshops},
month = {June},
year = {2022},
pages = {4351-4360}
}
The code is licensed under Apache 2.0. We have taken care to make sure any third party code included or adapted has compatible (permissive) licenses such as MIT, BSD, etc.
The legal implications of using pre-trained models in commercial services are, to our knowledge, not fully understood.
[1] Bendale, A., & Boult, T. E. (2016). Towards open set deep networks. CVPR.
[2] (1 , 2 ) Liang, S., et al. (2017). Enhancing the reliability of out-of-distribution image detection in neural networks. ICLR.
[3] Lee, K., et al. (2018). A simple unified framework for detecting out-of-distribution samples and adversarial attacks. NeurIPS.
[4] Gal, Y., & Ghahramani, Z. (2016). Dropout as a bayesian approximation: Representing model uncertainty in deep learning. ICML.
[5] Hendrycks, D., & Gimpel, K. (2016). A baseline for detecting misclassified and out-of-distribution examples in neural networks. ICLR.
[6] Guo, C., et al. (2017). On calibration of modern neural networks. ICML.
[7] DeVries, T., & Taylor, G. W. (2018). Learning confidence for out-of-distribution detection in neural networks. ArXiv .
[8] (1 , 2 , 3 , 4 ) Liu, W., et al. (2020). Energy-based out-of-distribution detection. NeurIPS.
[9] Dhamija, A. R., et al. (2018). Reducing network agnostophobia. NeurIPS.
[10] (1 , 2 ) Hendrycks, D., Mazeika, M., & Dietterich, T. (2018). Deep anomaly detection with outlier exposure. ICLR.
[11] Ruff, L., et al. (2018). Deep one-class classification. ICML.
[12] Hassen, M., & Chan, P. K. (2020). Learning a neural-network-based representation for open set recognition. SDM.
[13] Miller, D., et al. (2021). Class anchor clustering: A loss for distance-based open set recognition. WACV.
[14] Wen, Y., et al. (2016). A discriminative feature learning approach for deep face recognition. ECCV.
[15] (1 , 2 , 3 ) Hendrycks, D., & Dietterich, T. (2019). Benchmarking neural network robustness to common corruptions and perturbations. ICLR.
[16] Nguyen, A., et al. (2015). Deep neural networks are easily fooled: High confidence predictions for unrecognizable images. CVPR.
[17] Le, Y., et al. (2015). Tiny ImageNet Visual Recognition Challenge. Stanford .
[18] Hendrycks, D., et al. (2021). Natural adversarial examples. CVPR.
[19] Recht, B., et al. (2019). Do imagenet classifiers generalize to imagenet?. PMLR.
[20] Lin, Y., et al. (2021). ES-ImageNet: A Million Event-Stream Classification Dataset for Spiking Neural Networks. Front Neurosci .
[21] Mu, N., & Gilmer, J. (2019). MNIST-C: A robustness benchmark for computer vision. ICLR Workshop.
[22] (1 , 2 ) Blum, H. et al (2021) The Fishyscapes Benchmark: Measuring Blind Spots in Semantic Segmentation. IJCV.
[23] Bergmann, P. et al (2021) The MVTec Anomaly Detection Dataset: A Comprehensive Real-World Dataset for Unsupervised Anomaly Detection. IJCV
[24] (1 , 2 , 3 ) Hendrycks, D., et al. (2022). Scaling out-of-distribution detection for real-world settings. ICML.
[25] Kirchheim, K., & Ortmeier, F. (2022) On Outlier Exposure with Generative Models. NeurIPS.
[26] Vaze, S., et al. (2022) Open-set recognition: A good closed-set classifier is all you need. ICLR.
[27] Bitterwolf, J., et al. (2023) In or Out? Fixing ImageNet Out-of-Distribution Detection Evaluation. ICML.
[28] Kirchheim, K. (2023) Towards Deep Anomaly Detection with Structured Knowledge Representations. SAFECOMP.
[29] Cimpoi, M., et al. (2014). Describing textures in the wild. CVPR.
[30] Torralba, A., et al. (2007). 80 million tiny images: a large dataset for non-parametric object and scene recognition. IEEE Transactions on Pattern Analysis and Machine Learning.
[31] de Campos, T. E., et al. (2009). Character recognition in natural images. In Proceedings of the International Conference on Computer Vision Theory and Applications (VISAPP).
[32] Elliott, D., et al. (2016). Multi30k: Multilingual english-german image descriptions. Proceedings of the 5th Workshop on Vision and Language.
[33] (1 , 2 ) Merity, S., et al. (2016). Pointer sentinel mixture models. ArXiv
[34] Huang, R., & Li, Y. (2021) MOS: Towards Scaling Out-of-distribution Detection for Large Semantic Space. CVPR.
[35] Kirchheim, K., et al. (2022) Multi Class Hypersphere Anomaly Detection. ICPR.
[36] Wang, H., et al. (2022) ViM: Out-Of-Distribution with Virtual-logit Matching. CVPR.
[37] (1 , 2 ) Du, X., et al. (2022) VOS: Learning What You Don't Know by Virtual Outlier Synthesis. ICLR.
[38] Sun, Y., et al. (2022) Out-of-Distribution Detection with Deep Nearest Neighbors. ICML.
[39] (1 , 2 , 3 ) Hendrycks, D, et al. (2022) PixMix: Dreamlike Pictures Comprehensively Improve Safety Measures. CVPR.
[40] (1 , 2 , 3 ) Chan R, et al. (2021) Entropy maximization and meta classification for out-of-distribution detection in semantic segmentation. CVPR.
[41] Sun, et al. (2022) DICE: Leveraging Sparsification for Out-of-Distribution Detection. ECCV.
[42] Djurisic, et al. (2023) Extremely Simple Activation Shaping for Out-of-Distribution Detection, ICLR.
[43] Zhang, et al. (2023) Out-of-Distribution Detection Based on In-Distribution Data Patterns Memorization with Modern Hopfield Energy, ICLR.
[44] Sun, et al. (2023) ReAct: Out-of-distribution Detection With Rectified Activations, NeurIPS