Skip to content

Latest commit

 

History

History
78 lines (59 loc) · 3.99 KB

README.md

File metadata and controls

78 lines (59 loc) · 3.99 KB

Stochastic Weight Averaging for Low-Precision Training (SWALP)

This repository contains a PyTorch implementation of the paper:

SWALP : Stochastic Weight Averaging for Low-Precision Training (SWALP).

Guandao Yang, Tianyi Zhang, Polina Kirichenko, Junwen Bai, Andrew Gordon Wilson, Christopher De Sa

swalp-image

Introduction

Low precision operations can provide scalability, memory savings, portability, and energy efficiency. This paper proposes SWALP, an approach to low precision training that averages low-precision SGD iterates with a modified learning rate schedule. SWALP is easy to implement and can match the performance of full-precision SGD even with all numbers quantized down to 8 bits, including the gradient accumulators. Additionally, we show that SWALP converges arbitrarily close to the optimal solution for quadratic objectives, and to a noise ball asymptotically smaller than low precision SGD in strongly convex settings.

This repo contains the codes to replicate our experiment for CIFAR datasets with VGG16 and PreResNet164.

Citing this Work

Please cite our work if you find this approach useful in your research:

@misc{gu2019swalp,
    title={SWALP : Stochastic Weight Averaging in Low-Precision Training},
    author={Guandao Yang and Tianyi Zhang and Polina Kirichenko and Junwen Bai and Andrew Gordon Wilson and Christopher De Sa},
    year={2019},
    eprint={1904.11943},
    archivePrefix={arXiv},
    primaryClass={cs.LG}
}

Dependencies

To install other requirements through $ pip install -r requirements.txt.

Usage

We provide scripts to run Small-block Block Floating Point experiments on CIFAR10 and CIFAR100 with VGG16 or PreResNet164. Following are scripts to reproduce experimental results.

seed=100                                      # Specify experiment seed.
bash exp/block_vgg_swa.sh CIFAR10 ${seed}     # SWALP training on VGG16 with Small-block BFP in CIFAR10
bash exp/block_vgg_swa.sh CIFAR100 ${seed}    # SWALP training on VGG16 with Small-block BFP in CIFAR100
bash exp/block_resnet_swa.sh CIFAR10 ${seed}  # SWALP training on PreResNet164 with Small-block BFP in CIFAR10
bash exp/block_resnet_swa.sh CIFAR100 ${seed} # SWALP training on PreResNet164 with Small-block BFP in CIFAR100

Results

The low-precision results (SGD-LP and SWALP) are produced by running the scripts in /exp folder. The full-precision results (SGD-FP and SWA-FP) are produced by running the SWA repo.

Datset Model SGD-FP SWA-FP SGD-LP SWALP
CIFAR10 VGG16 6.81±0.09 6.51±0.14 7.61±0.15 6.70±0.12
PreResNet164 4.63±0.18 4.03±0.10
CIFAR100 VGG16 27.23±0.17 25.93±0.21 29.59±0.32 26.65±0.29
PreResNet164 22.20±0.57 19.95±0.19

Other implementations

Tianyi Zhang provides an implementation using a low-precision training framework QPyTorch in this link.

References

We use the SWA repo as starter template. Network architecture implementations are adapted from: