-
Notifications
You must be signed in to change notification settings - Fork 9
/
hparams.py
40 lines (33 loc) · 1.72 KB
/
hparams.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import tensorflow as tf
def create_mag_hparams(hparam_string=None):
hparams = tf.contrib.training.HParams(learning_rate=0.001,
lr_decay_step=200000,
lr_decay_rate=0.77,
momentum=0.99,
lamb=15.0,
batch_size=8,
image_size=288,
label_size=36,
scope='mag',
model_name='resnet_v1_50')
if hparam_string:
tf.logging.info('Parsing command line hparams: %s', hparam_string)
hparams.parse(hparam_string)
tf.logging.info('Final parsed hparams: %s', hparams.values())
return hparams
def create_metagrasp_hparams(hparam_string=None):
hparams = tf.contrib.training.HParams(learning_rate=0.001,
lr_decay_step=200000,
lr_decay_rate=0.77,
momentum=0.99,
lamb=120.0,
batch_size=16,
image_size=288,
label_size=288,
scope='metagrasp',
model_name='resnet_v1_50')
if hparam_string:
tf.logging.info('Parsing command line hparams: %s', hparam_string)
hparams.parse(hparam_string)
tf.logging.info('Final parsed hparams: %s', hparams.values())
return hparams