Skip to content

Commit

Permalink
update dch
Browse files Browse the repository at this point in the history
  • Loading branch information
bl0 committed Apr 22, 2018
1 parent 5d17857 commit 4499aee
Show file tree
Hide file tree
Showing 8 changed files with 133 additions and 169 deletions.
16 changes: 16 additions & 0 deletions DeepHash/model/dch/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from .util import Dataset
from .dch import DCH

def train(train_img, database_img, query_img, config):
model = DCH(config)
img_database = Dataset(database_img, config.output_dim)
img_query = Dataset(query_img, config.output_dim)
img_train = Dataset(train_img, config.output_dim)
model.train(img_train)
return model.save_dir

def validation(database_img, query_img, config):
model = DCH(config)
img_database = Dataset(database_img, config.output_dim)
img_query = Dataset(query_img, config.output_dim)
return model.validation(img_query, img_database, config.R)
121 changes: 46 additions & 75 deletions DeepHash/model/prunehash/prunehash.py → DeepHash/model/dch/dch.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,44 +16,28 @@
import model.plot as plot
from architecture.single_model import img_alexnet_layers
from evaluation import MAPs
from .util import Dataset


class PruneHash(object):
def __init__(self, config, stage):
class DCH(object):
def __init__(self, config):
### Initialize setting
print ("initializing")
np.set_printoptions(precision=4)
self.stage = stage
self.device = config['device']
self.output_dim = config['output_dim']
self.n_class = config['label_dim']
self.cq_lambda = config['cq_lambda']
self.alpha = config['alpha']
self.bias = config['bias']
self.gamma = config['gamma']

self.batch_size = config['batch_size'] if self.stage == "train" else config['val_batch_size']
self.max_iter = config['max_iter']
self.img_model = config['img_model']
self.loss_type = config['loss_type']
self.learning_rate = config['learning_rate']
self.learning_rate_decay_factor = config['learning_rate_decay_factor']
self.decay_step = config['decay_step']

self.finetune_all = config['finetune_all']

with tf.name_scope('stage'):
# 0 for training, 1 for validation
self.stage = tf.placeholder_with_default(tf.constant(0), [])
for k, v in vars(config).items():
setattr(self, k, v)
self.file_name = 'loss_{}_lr_{}_cqlambda_{}_alpha_{}_bias_{}_gamma_{}_dataset_{}'.format(
self.loss_type,
self.learning_rate,
self.cq_lambda,
self.lr,
self.q_lambda,
self.alpha,
self.bias,
self.gamma,
config['dataset'])
self.save_dir = config['save_dir']
self.save_file = os.path.join(config['save_dir'], self.file_name + '.npy')
self.log_dir = config['log_dir']
self.dataset)
self.save_file = os.path.join(self.save_dir, self.file_name + '.npy')

### Setup session
print ("launching session")
Expand All @@ -63,27 +47,25 @@ def __init__(self, config, stage):
self.sess = tf.Session(config=configProto)

### Create variables and placeholders
self.img = tf.placeholder(tf.float32, [None, 256, 256, 3])
self.img_label = tf.placeholder(tf.float32, [None, self.label_dim])
self.img_last_layer, self.deep_param_img, self.train_layers, self.train_last_layer = self.load_model()

with tf.device(self.device):
self.img = tf.placeholder(tf.float32, [self.batch_size, 256, 256, 3])
self.img_label = tf.placeholder(tf.float32, [self.batch_size, self.n_class])

if self.stage == 'train':
self.model_weights = config['model_weights']
else:
self.model_weights = self.save_file
self.img_last_layer, self.deep_param_img, self.train_layers, self.train_last_layer = self.load_model()

self.global_step = tf.Variable(0, trainable=False)
self.train_op = self.apply_loss_function(self.global_step)
self.sess.run(tf.global_variables_initializer())
self.global_step = tf.Variable(0, trainable=False)
self.train_op = self.apply_loss_function(self.global_step)
self.sess.run(tf.global_variables_initializer())
return

def load_model(self):
if self.img_model == 'alexnet':
img_output = img_alexnet_layers(
self.img, self.batch_size, self.output_dim,
self.stage, self.model_weights)
self.img,
self.batch_size,
self.output_dim,
self.stage,
self.model_weights,
self.with_tanh,
self.val_batch_size)
else:
raise Exception('cannot use such CNN model as ' + self.img_model)
return img_output
Expand Down Expand Up @@ -139,7 +121,7 @@ def reduce_shaper(t):
r = tf.reshape(r, [-1, 1])
ip = r - 2*tf.matmul(u, tf.transpose(u)) + tf.transpose(r)

ip = tf.constant(self.gamma) / (ip + tf.constant(self.gamma)*tf.constant(self.gamma))
ip = self.gamma / (ip + self.gamma ** 2)
else:
ip = tf.clip_by_value(tf.matmul(u, tf.transpose(u)), -1.5e1, 1.5e1)
ones = tf.ones([tf.shape(u)[0], tf.shape(u)[0]])
Expand All @@ -158,13 +140,12 @@ def apply_loss_function(self, global_step):
self.cos_loss = self.cross_entropy(self.img_last_layer, self.img_label, self.alpha, True, True, self.bias)

self.q_loss_img = tf.reduce_mean(tf.square(tf.subtract(tf.abs(self.img_last_layer), tf.constant(1.0))))
self.q_lambda = tf.Variable(self.cq_lambda, name='cq_lambda')
self.q_loss = tf.multiply(self.q_lambda, self.q_loss_img)
self.q_loss = self.q_lambda * self.q_loss_img
self.loss = self.cos_loss + self.q_loss

### Last layer has a 10 times learning rate
self.lr = tf.train.exponential_decay(self.learning_rate, global_step, self.decay_step, self.learning_rate_decay_factor, staircase=True)
opt = tf.train.MomentumOptimizer(learning_rate=self.lr, momentum=0.9)
lr = tf.train.exponential_decay(self.lr, global_step, self.decay_step, self.lr, staircase=True)
opt = tf.train.MomentumOptimizer(learning_rate=lr, momentum=0.9)
grads_and_vars = opt.compute_gradients(self.loss, self.train_layers+self.train_last_layer)
fcgrad, _ = grads_and_vars[-2]
fbgrad, _ = grads_and_vars[-1]
Expand All @@ -174,11 +155,11 @@ def apply_loss_function(self, global_step):
tf.summary.scalar('loss', self.loss)
tf.summary.scalar('cos_loss', self.cos_loss)
tf.summary.scalar('q_loss', self.q_loss)
tf.summary.scalar('lr', self.lr)
tf.summary.scalar('lr', lr)
self.merged = tf.summary.merge_all()


if self.stage == "train" and self.finetune_all:
if self.finetune_all:
return opt.apply_gradients([(grads_and_vars[0][0], self.train_layers[0]),
(grads_and_vars[1][0]*2, self.train_layers[1]),
(grads_and_vars[2][0], self.train_layers[2]),
Expand Down Expand Up @@ -208,13 +189,10 @@ def train(self, img_dataset):
shutil.rmtree(tflog_path)
train_writer = tf.summary.FileWriter(tflog_path, self.sess.graph)

for train_iter in range(self.max_iter):
for train_iter in range(self.iter_num):
images, labels = img_dataset.next_batch(self.batch_size)
start_time = time.time()

assign_lambda = self.q_lambda.assign(self.cq_lambda)
self.sess.run([assign_lambda])

_, loss, cos_loss, output, summary = self.sess.run([self.train_op, self.loss, self.cos_loss, self.img_last_layer, self.merged],
feed_dict={self.img: images,
self.img_label: labels})
Expand All @@ -224,7 +202,7 @@ def train(self, img_dataset):
img_dataset.feed_batch_output(self.batch_size, output)
duration = time.time() - start_time

if train_iter % 1 == 0:
if train_iter % 100 == 0:
print("%s #train# step %4d, loss = %.4f, cross_entropy loss = %.4f, %.1f sec/batch"
%(datetime.now(), train_iter+1, loss, cos_loss, duration))

Expand All @@ -236,24 +214,29 @@ def train(self, img_dataset):

def validation(self, img_query, img_database, R=100):
print("%s #validation# start validation" % (datetime.now()))
query_batch = int(ceil(img_query.n_samples / self.batch_size))
query_batch = int(ceil(img_query.n_samples / float(self.val_batch_size)))
img_query.finish_epoch()
print("%s #validation# totally %d query in %d batches" % (datetime.now(), img_query.n_samples, query_batch))
for i in range(query_batch):
images, labels = img_query.next_batch(self.batch_size)
images, labels = img_query.next_batch(self.val_batch_size)
output, loss = self.sess.run([self.img_last_layer, self.cos_loss],
feed_dict={self.img: images, self.img_label: labels})
img_query.feed_batch_output(self.batch_size, output)
feed_dict={self.img: images,
self.img_label: labels,
self.stage: 1})
img_query.feed_batch_output(self.val_batch_size, output)
print('Cosine Loss: %s'%loss)

database_batch = int(ceil(img_database.n_samples / self.batch_size))
database_batch = int(ceil(img_database.n_samples / float(self.val_batch_size)))
img_database.finish_epoch()
print("%s #validation# totally %d database in %d batches" % (datetime.now(), img_database.n_samples, database_batch))
for i in range(database_batch):
images, labels = img_database.next_batch(self.batch_size)
images, labels = img_database.next_batch(self.val_batch_size)

output, loss = self.sess.run([self.img_last_layer, self.cos_loss],
feed_dict={self.img: images, self.img_label: labels})
img_database.feed_batch_output(self.batch_size, output)
#print output[:10, :10]
feed_dict={self.img: images,
self.img_label: labels,
self.stage: 1})
img_database.feed_batch_output(self.val_batch_size, output)
if i % 100 == 0:
print('Cosine Loss[%d/%d]: %s'%(i, database_batch, loss))

Expand Down Expand Up @@ -283,15 +266,3 @@ def validation(self, img_query, img_database, R=100):
'i2i_map_radius_2': mmap,
}

def train(train_img, config):
model = PruneHash(config, 'train')
img_dataset = Dataset(train_img, config['output_dim'])
model.train(img_dataset)
return model.save_file

def validation(database_img, query_img, config):
model = PruneHash(config, 'val')
img_database = Dataset(database_img, config['output_dim'])
img_query = Dataset(query_img, config['output_dim'])
return model.validation(img_query, img_database, config['R'])

File renamed without changes.
Empty file.
70 changes: 70 additions & 0 deletions examples/dch/train_val_script.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import os
import argparse
import warnings
import numpy as np
import scipy.io as sio
import model.dch as model
import data_provider.image as dataset

from pprint import pprint

warnings.filterwarnings("ignore", category = DeprecationWarning)
warnings.filterwarnings("ignore", category = FutureWarning)

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

parser = argparse.ArgumentParser(description='Triplet Hashing')
parser.add_argument('--lr', '--learning-rate', default=0.005, type=float)
parser.add_argument('--output-dim', default=64, type=int) # 256, 128
parser.add_argument('--alpha', default=0.5, type=float)
parser.add_argument('--bias', default=0.0, type=float)
parser.add_argument('--gamma', default=20, type=float)
parser.add_argument('--iter-num', default=2000, type=int)
parser.add_argument('--q-lambda', default=0, type=float)
parser.add_argument('--dataset', default='cifar10', type=str)
parser.add_argument('--gpus', default='0', type=str)
parser.add_argument('--log-dir', default='tflog', type=str)
parser.add_argument('-b', '--batch-size', default=128, type=int)
parser.add_argument('-vb', '--val-batch-size', default=16, type=int)
parser.add_argument('--decay-step', default=10000, type=int)
parser.add_argument('--decay-factor', default=0.1, type=int)
parser.add_argument('--loss-type', default='pruned_cross_entropy', type=str)

tanh_parser = parser.add_mutually_exclusive_group(required=False)
tanh_parser.add_argument('--with-tanh', dest='with_tanh', action='store_true')
tanh_parser.add_argument('--without-tanh', dest='with_tanh', action='store_false')
parser.set_defaults(with_tanh=True)

parser.add_argument('--img-model', default='alexnet', type=str)
parser.add_argument('--model-weights', type=str,
default='../../DeepHash/architecture/single_model/pretrained_model/reference_pretrain.npy')
parser.add_argument('--finetune-all', default=True, type=bool)
parser.add_argument('--save-dir', default="./models/", type=str)
parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true')

args = parser.parse_args()

os.environ['CUDA_VISIBLE_DEVICES'] = args.gpus

label_dims = {'cifar10': 10, 'cub': 200, 'nuswide_81': 81, 'coco': 80}
Rs = {'cifar10': 54000, 'nuswide_81': 5000, 'coco': 5000}
args.R = Rs[args.dataset]
args.label_dim = label_dims[args.dataset]
args.img_tr = "/home/caoyue/data/{}/train.txt".format(args.dataset)
args.img_te = "/home/caoyue/data/{}/test.txt".format(args.dataset)
args.img_db = "/home/caoyue/data/{}/database.txt".format(args.dataset)

pprint(vars(args))

query_img, database_img = dataset.import_validation(args.img_te, args.img_db)

if not args.evaluate:
train_img = dataset.import_train(args.img_tr)
model_weights = model.train(train_img, database_img, query_img, args)
args.model_weights = model_weights

maps = model.validation(database_img, query_img, args)
for key in maps:
print(("{}\t{}".format(key, maps[key])))

pprint(vars(args))
2 changes: 1 addition & 1 deletion examples/dtq/train_val_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@

parser.add_argument('--img-model', default='alexnet', type=str)
parser.add_argument('--model-weights', type=str,
default='../../core/architecture/single_model/pretrained_model/reference_pretrain.npy')
default='../../DeepHash/architecture/single_model/pretrained_model/reference_pretrain.npy')
parser.add_argument('--finetune-all', default=True, type=bool)
parser.add_argument('--max-iter-update-b', default=3, type=int)
parser.add_argument('--max-iter-update-Cb', default=1, type=int)
Expand Down
21 changes: 0 additions & 21 deletions examples/prunehash/train_val.sh

This file was deleted.

Loading

0 comments on commit 4499aee

Please sign in to comment.