The checkpoints contain model weights, optimizer state, etc.
For details, see the code for checkpoint saving and checkpoint loading.
Usage
- Extract checkpoint(s) in
YOUR_PATH/survae_flows/experiments/image/log/
. - Sample from the models, evaluate the test log-likelihood or continue training using these scripts.
Sampling:
To sample from the models, use the eval_sample.py
script.
CIFAR-10:
python eval_sample.py --model YOUR_PATH/survae_flows/experiments/image/log/cifar10_8bit/pool_flow/more/maxpool
python eval_sample.py --model YOUR_PATH/survae_flows/experiments/image/log/cifar10_8bit/pool_flow/more/nonpool
ImageNet 32x32:
python eval_sample.py --model YOUR_PATH/survae_flows/experiments/image/log/imagenet32_8bit/pool_flow/more/maxpool
python eval_sample.py --model YOUR_PATH/survae_flows/experiments/image/log/imagenet32_8bit/pool_flow/more/nonpool
ImageNet 64x64:
python eval_sample.py --model YOUR_PATH/survae_flows/experiments/image/log/imagenet64_8bit/pool_flow/more/maxpool
python eval_sample.py --model YOUR_PATH/survae_flows/experiments/image/log/imagenet64_8bit/pool_flow/more/nonpool
Log-likelihood:
To compute the test log-likelihood for the models, use the eval_loglik.py
script.
CIFAR-10:
python eval_loglik.py --model YOUR_PATH/survae_flows/experiments/image/log/cifar10_8bit/pool_flow/more/maxpool --k 1000 --kbs 10
python eval_loglik.py --model YOUR_PATH/survae_flows/experiments/image/log/cifar10_8bit/pool_flow/more/nonpool --k 1000 --kbs 10
ImageNet 32x32:
python eval_loglik.py --model YOUR_PATH/survae_flows/experiments/image/log/imagenet32_8bit/pool_flow/more/maxpool
python eval_loglik.py --model YOUR_PATH/survae_flows/experiments/image/log/imagenet32_8bit/pool_flow/more/nonpool
ImageNet 64x64:
python eval_loglik.py --model YOUR_PATH/survae_flows/experiments/image/log/imagenet64_8bit/pool_flow/more/maxpool
python eval_loglik.py --model YOUR_PATH/survae_flows/experiments/image/log/imagenet64_8bit/pool_flow/more/nonpool
Continue Training:
To continue training for some additional epochs using a new, fixed learning rate, use the train_more.py
script.
CIFAR-10:
python train_more.py --model YOUR_PATH/survae_flows/experiments/image/log/cifar10_8bit/pool_flow/more/maxpool --new_epochs 600 --new_lr 1e-5
python train_more.py --model YOUR_PATH/survae_flows/experiments/image/log/cifar10_8bit/pool_flow/more/nonpool --new_epochs 600 --new_lr 1e-5
ImageNet 32x32:
python train_more.py --model YOUR_PATH/survae_flows/experiments/image/log/imagenet32_8bit/pool_flow/more/maxpool --new_epochs 30 --new_lr 1e-5
python train_more.py --model YOUR_PATH/survae_flows/experiments/image/log/imagenet32_8bit/pool_flow/more/nonpool --new_epochs 30 --new_lr 1e-5
ImageNet 64x64:
python train_more.py --model YOUR_PATH/survae_flows/experiments/image/log/imagenet64_8bit/pool_flow/more/maxpool --new_epochs 25 --new_lr 1e-5
python train_more.py --model YOUR_PATH/survae_flows/experiments/image/log/imagenet64_8bit/pool_flow/more/nonpool --new_epochs 25 --new_lr 1e-5