Skip to content

Commit

Permalink
Added some arguments from command line.
Browse files Browse the repository at this point in the history
  • Loading branch information
Flavio Piccoli committed Jul 24, 2017
1 parent 677da90 commit 6f93f1d
Showing 1 changed file with 31 additions and 5 deletions.
36 changes: 31 additions & 5 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,13 +163,39 @@ def forward(self,x):
parser = argparse.ArgumentParser()
parser.add_argument("-r", "--regen", help="Regenerate images using the model specified.",
default="")
parser.add_argument("-ps", "--patchsize", help="Dimension of the patch.",
default=8)
parser.add_argument("-nrow", "--nrow", help="Batchsize will be nrow*nrow.",
default=5)
parser.add_argument("-di", "--degin", help="Degree of net input.",
default=3)
parser.add_argument("-do", "--degout", help="Degree of polynomial regressor.",
default=3)
parser.add_argument("-dir", "--dir", help="Folder containing images.",
default='/media/flavio/Volume/datasets/places-instagram/')
args = parser.parse_args()

# set args to int
args.patchsize = int(args.patchsize)
args.nrow = int(args.nrow)
args.degin = int(args.degin)
args.degout = int(args.degout)
# print args
conf_txt = ''
for arg in vars(args):
conf_txt = conf_txt + '{:>10} = '.format(arg) + str(getattr(args, arg)) + '\n'
print(conf_txt)
# write args on file
out_file = open("config.txt","w")
out_file.write(conf_txt)
out_file.close()
# for a in args:
# import ipdb; ipdb.set_trace()
# ------------------ TRAIN ------------------
# set parameters
img_dim = [256,256]
patchSize = 256
nRow = 6
patchSize = args.patchsize
nRow = args.nrow
batchSize = nRow*nRow
batchSizeVal = 50
nepochs = 4
Expand All @@ -178,8 +204,8 @@ def forward(self,x):
nc = 200
nf = 2000
lr = 0.0001 #0.0002
deg_poly_in = 1
deg_poly_out = 3
deg_poly_in = args.degin
deg_poly_out = args.degout
# init net
net = Net(img_dim, patchSize, nc, nf, deg_poly_in, deg_poly_out).cuda()

Expand All @@ -192,7 +218,7 @@ def forward(self,x):
optimizer = optim.Adam(net.parameters(), lr=lr)

# create dataloaders
base_dir = '/media/flavio/Volume/datasets/places-instagram/'
base_dir = args.dir
img_dirs = [os.path.join(base_dir,'images_orig/'), os.path.join(base_dir,'images/')]
gt_train = os.path.join(base_dir,'train-list.txt')
gt_valid = os.path.join(base_dir,'smallvalidation-list.txt')
Expand Down

0 comments on commit 6f93f1d

Please sign in to comment.