-
Notifications
You must be signed in to change notification settings - Fork 100
/
args.py
163 lines (148 loc) · 9.36 KB
/
args.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
import argparse
NONLINEARITIES = ["tanh", "relu", "softplus", "elu", "swish", "square", "identity"]
SOLVERS = ["dopri5", "bdf", "rk4", "midpoint", 'adams', 'explicit_adams', 'fixed_adams']
LAYERS = ["ignore", "concat", "concat_v2", "squash", "concatsquash", "scale", "concatscale"]
def add_args(parser):
# model architecture options
parser.add_argument('--input_dim', type=int, default=3,
help='Number of input dimensions (3 for 3D point clouds)')
parser.add_argument('--dims', type=str, default='256')
parser.add_argument('--latent_dims', type=str, default='256')
parser.add_argument("--num_blocks", type=int, default=1,
help='Number of stacked CNFs.')
parser.add_argument("--latent_num_blocks", type=int, default=1,
help='Number of stacked CNFs.')
parser.add_argument("--layer_type", type=str, default="concatsquash", choices=LAYERS)
parser.add_argument('--time_length', type=float, default=0.5)
parser.add_argument('--train_T', type=eval, default=True, choices=[True, False])
parser.add_argument("--nonlinearity", type=str, default="tanh", choices=NONLINEARITIES)
parser.add_argument('--use_adjoint', type=eval, default=True, choices=[True, False])
parser.add_argument('--solver', type=str, default='dopri5', choices=SOLVERS)
parser.add_argument('--atol', type=float, default=1e-5)
parser.add_argument('--rtol', type=float, default=1e-5)
parser.add_argument('--batch_norm', type=eval, default=True, choices=[True, False])
parser.add_argument('--sync_bn', type=eval, default=False, choices=[True, False])
parser.add_argument('--bn_lag', type=float, default=0)
# training options
parser.add_argument('--use_latent_flow', action='store_true',
help='Whether to use the latent flow to model the prior.')
parser.add_argument('--use_deterministic_encoder', action='store_true',
help='Whether to use a deterministic encoder.')
parser.add_argument('--zdim', type=int, default=128,
help='Dimension of the shape code')
parser.add_argument('--optimizer', type=str, default='adam',
help='Optimizer to use', choices=['adam', 'adamax', 'sgd'])
parser.add_argument('--batch_size', type=int, default=50,
help='Batch size (of datasets) for training')
parser.add_argument('--lr', type=float, default=1e-3,
help='Learning rate for the Adam optimizer.')
parser.add_argument('--beta1', type=float, default=0.9,
help='Beta1 for Adam.')
parser.add_argument('--beta2', type=float, default=0.999,
help='Beta2 for Adam.')
parser.add_argument('--momentum', type=float, default=0.9,
help='Momentum for SGD')
parser.add_argument('--weight_decay', type=float, default=0.,
help='Weight decay for the optimizer.')
parser.add_argument('--epochs', type=int, default=100,
help='Number of epochs for training (default: 100)')
parser.add_argument('--seed', type=int, default=None,
help='Seed for initializing training. ')
parser.add_argument('--recon_weight', type=float, default=1.,
help='Weight for the reconstruction loss.')
parser.add_argument('--prior_weight', type=float, default=1.,
help='Weight for the prior loss.')
parser.add_argument('--entropy_weight', type=float, default=1.,
help='Weight for the entropy loss.')
parser.add_argument('--scheduler', type=str, default='linear',
help='Type of learning rate schedule')
parser.add_argument('--exp_decay', type=float, default=1.,
help='Learning rate schedule exponential decay rate')
parser.add_argument('--exp_decay_freq', type=int, default=1,
help='Learning rate exponential decay frequency')
# data options
parser.add_argument('--dataset_type', type=str, default="shapenet15k",
help="Dataset types.", choices=['shapenet15k', 'modelnet40_15k', 'modelnet10_15k'])
parser.add_argument('--cates', type=str, nargs='+', default=["airplane"],
help="Categories to be trained (useful only if 'shapenet' is selected)")
parser.add_argument('--data_dir', type=str, default="data/ShapeNetCore.v2.PC15k",
help="Path to the training data")
parser.add_argument('--mn40_data_dir', type=str, default="data/ModelNet40.PC15k",
help="Path to ModelNet40")
parser.add_argument('--mn10_data_dir', type=str, default="data/ModelNet10.PC15k",
help="Path to ModelNet10")
parser.add_argument('--dataset_scale', type=float, default=1.,
help='Scale of the dataset (x,y,z * scale = real output, default=1).')
parser.add_argument('--random_rotate', action='store_true',
help='Whether to randomly rotate each shape.')
parser.add_argument('--normalize_per_shape', action='store_true',
help='Whether to perform normalization per shape.')
parser.add_argument('--normalize_std_per_axis', action='store_true',
help='Whether to perform normalization per axis.')
parser.add_argument("--tr_max_sample_points", type=int, default=2048,
help='Max number of sampled points (train)')
parser.add_argument("--te_max_sample_points", type=int, default=2048,
help='Max number of sampled points (test)')
parser.add_argument('--num_workers', type=int, default=4,
help='Number of data loading threads')
# logging and saving frequency
parser.add_argument('--log_name', type=str, default=None, help="Name for the log dir")
parser.add_argument('--viz_freq', type=int, default=10)
parser.add_argument('--val_freq', type=int, default=10)
parser.add_argument('--log_freq', type=int, default=10)
parser.add_argument('--save_freq', type=int, default=10)
# validation options
parser.add_argument('--no_validation', action='store_true',
help='Whether to disable validation altogether.')
parser.add_argument('--save_val_results', action='store_true',
help='Whether to save the validation results.')
parser.add_argument('--eval_classification', action='store_true',
help='Whether to evaluate classification accuracy on MN40 and MN10.')
parser.add_argument('--no_eval_sampling', action='store_true',
help='Whether to evaluate sampling.')
parser.add_argument('--max_validate_shapes', type=int, default=None,
help='Max number of shapes used for validation pass.')
# resuming
parser.add_argument('--resume_checkpoint', type=str, default=None,
help='Path to the checkpoint to be loaded.')
parser.add_argument('--resume_optimizer', action='store_true',
help='Whether to resume the optimizer when resumed training.')
parser.add_argument('--resume_non_strict', action='store_true',
help='Whether to resume in none-strict mode.')
parser.add_argument('--resume_dataset_mean', type=str, default=None,
help='Path to the file storing the dataset mean.')
parser.add_argument('--resume_dataset_std', type=str, default=None,
help='Path to the file storing the dataset std.')
# distributed training
parser.add_argument('--world_size', default=1, type=int,
help='Number of distributed nodes.')
parser.add_argument('--dist_url', default='tcp://127.0.0.1:9991', type=str,
help='url used to set up distributed training')
parser.add_argument('--dist_backend', default='nccl', type=str,
help='distributed backend')
parser.add_argument('--distributed', action='store_true',
help='Use multi-processing distributed training to launch '
'N processes per node, which has N GPUs. This is the '
'fastest way to use PyTorch for either single node or '
'multi node data parallel training')
parser.add_argument('--rank', default=0, type=int,
help='node rank for distributed training')
parser.add_argument('--gpu', default=None, type=int,
help='GPU id to use. None means using all available GPUs.')
# Evaluation options
parser.add_argument('--evaluate_recon', default=False, action='store_true',
help='Whether set to the evaluation for reconstruction.')
parser.add_argument('--num_sample_shapes', default=10, type=int,
help='Number of shapes to be sampled (for demo.py).')
parser.add_argument('--num_sample_points', default=2048, type=int,
help='Number of points (per-shape) to be sampled (for demo.py).')
return parser
def get_parser():
# command line args
parser = argparse.ArgumentParser(description='Flow-based Point Cloud Generation Experiment')
parser = add_args(parser)
return parser
def get_args():
parser = get_parser()
args = parser.parse_args()
return args