- Download and extract NYU v2 dataset to folder
./data
using
python download_nyud2.py
- (Optional) We have provided required meta files
nyu2_train_FDS_subset.csv
andtest_balanced_mask.npy
for efficient FDS feature statistics computation and balanced test set mask in folder./data
. To reproduce the results in the paper, please directly use these two files. If you want to try different FDS computation subsets and balanced test set masks, you can run
python preprocess_nyud2.py
- PyTorch (>= 1.2, tested on 1.6)
- numpy, pandas, scipy, tqdm, matplotlib, PIL, gdown, tensorboardX
# preprocess gmm
python preprocess_gmm.py
python train.py \
--bmse --imp gai --gmm gmm.pkl --init_noise_sigma 1.0 --fix_noise_sigma
python train.py \
--bmse --imp bni --init_noise_sigma 1.0 --fix_noise_sigma
python test.py --eval_model <path_to_evaluation_ckpt>
We provide below reproduced results on NYUD2-DIR (metric RMSE
).
Model | Overall | Many-Shot | Medium-Shot | Few-Shot | Download |
---|---|---|---|---|---|
GAI | 1.279 | 0.819 | 0.917 | 1.705 | model |
BNI | 1.281 | 0.833 | 0.856 | 1.714 | model |