Skip to content

Latest commit

 

History

History
87 lines (76 loc) · 3.47 KB

README.md

File metadata and controls

87 lines (76 loc) · 3.47 KB

TL;DL: A simple approach to improve single-model deep uncertainty by replacing the last layer with a Gaussian process layer. Spectral normalization is applied to any model. Online demo is available at Google Colab

Install

It's a preview version, so you can install it from the source code.

git clone https://github.com/iamownt/sngp_wrapper.git
cd sngp_wrapper
pip install -e .

Example

# Step 1: import the module

from sngp_wrapper.covert_utils import convert_to_sn_my, replace_layer_with_gaussian

# define args, like args.spec_norm_bound, args.spec_norm_replace_list, args.gaussian_process


GP_KWARGS = {
    'num_inducing': 1024,
    'gp_scale': 1.0,
    'gp_bias': 0.,
    'gp_kernel_type': 'gaussian',
    'gp_input_normalization': True,
    'gp_cov_discount_factor': -1,
    'gp_cov_ridge_penalty': 1.,
    'gp_output_bias_trainable': False,
    'gp_scale_random_features': False,
    'gp_use_custom_random_features': True,
    'gp_random_feature_type': 'orf',
    'gp_output_imagenet_initializer': True,
}
model = ...
+ if args.spec_norm_bound is not None:
+    model = convert_to_sn_my(model, args.spec_norm_replace_list, args.spec_norm_bound)
+ if args.gaussian_process:
+    replace_layer_with_gaussian(container=model, signature="classifier", **GP_KWARGS)
    
# Step 2: train the model
train_loader, val_loader = ..., ...
optimizer = ....
criterion = ...
for epoch in range(args.epochs):
    model.train()
+    if args.gaussian_process:
+        # GP_KWARGS["gp_cov_discount_factor"] == -1, in fact, it is not necessary when momentum != -1
+        model.classifier.reset_covariance_matrix()
+        kwargs = {'return_random_features': False, 'return_covariance': False,
+                  'update_precision_matrix': True, 'update_covariance_matrix': False}
+    else:
+        kwargs = {}
    for idx, assets in tqdm(enumerate(train_loader)):
        optimizer.zero_grad()
        preds = model(batch["input"], **kwargs)
        loss = criterion(preds, batch["target"])
        loss.backward()
        optimizer.step()
 
# Step 3: evaluate or test the model
    model.eval()
    logit_list, uncertainty_list, label_list = [], [], []
+    if args.gaussian_process:
+        model.classifier.update_covariance_matrix()
+        eval_kwargs = {'return_random_features': False, 'return_covariance': True,
+                       'update_precision_matrix': False, 'update_covariance_matrix': False}
+   else:
+        eval_kwargs = {}
    for idx, batch in tqdm(enumerate(val_loader)):
        output = model(batch["input"], **eval_kwargs)
        if isinstance(batch["input"], tuple):
            logits, covariance = output
            logits = logits.cpu().detach().numpy()
            uncertainty = torch.diagonal(covariance).cpu().detach().numpy()
            uncertainty_list.extend(uncertainty)
        else:
            logits = output.cpu().detach().numpy()
        logit_list.extend(logits)
        label_list.extend(assets['labels'].long())
    # save or evalute the model with logit_list, uncertainty_list, label_list

Acknowledgements

We would like to thank the authors of the original paper for their work and the authors of the uncertainty-baselines, and ml-sigma-reparam.