From e3c343d6db103dec49246e4def14ad5253f2d083 Mon Sep 17 00:00:00 2001 From: princewen Date: Sat, 12 Jan 2019 20:40:53 +0800 Subject: [PATCH] cmn --- recommendation/Basic-CMN-Demo/README.md | 18 ++ recommendation/Basic-CMN-Demo/pretrain.py | 73 +++++ recommendation/Basic-CMN-Demo/train.py | 148 +++++++++ .../Basic-CMN-Demo/util/attention.py | 189 +++++++++++ recommendation/Basic-CMN-Demo/util/cmn.py | 160 ++++++++++ recommendation/Basic-CMN-Demo/util/data.py | 134 ++++++++ .../Basic-CMN-Demo/util/evaluation.py | 87 +++++ recommendation/Basic-CMN-Demo/util/gmf.py | 85 +++++ recommendation/Basic-CMN-Demo/util/helper.py | 274 ++++++++++++++++ recommendation/Basic-CMN-Demo/util/layers.py | 297 ++++++++++++++++++ 10 files changed, 1465 insertions(+) create mode 100755 recommendation/Basic-CMN-Demo/README.md create mode 100644 recommendation/Basic-CMN-Demo/pretrain.py create mode 100644 recommendation/Basic-CMN-Demo/train.py create mode 100644 recommendation/Basic-CMN-Demo/util/attention.py create mode 100644 recommendation/Basic-CMN-Demo/util/cmn.py create mode 100644 recommendation/Basic-CMN-Demo/util/data.py create mode 100644 recommendation/Basic-CMN-Demo/util/evaluation.py create mode 100644 recommendation/Basic-CMN-Demo/util/gmf.py create mode 100644 recommendation/Basic-CMN-Demo/util/helper.py create mode 100644 recommendation/Basic-CMN-Demo/util/layers.py diff --git a/recommendation/Basic-CMN-Demo/README.md b/recommendation/Basic-CMN-Demo/README.md new file mode 100755 index 00000000..00b5c24d --- /dev/null +++ b/recommendation/Basic-CMN-Demo/README.md @@ -0,0 +1,18 @@ +# Collaborative Memory Network for Recommendation Systems + +https://arxiv.org/pdf/1804.10862.pdf + + +* Python 3.6 +* TensorFlow 1.8+ +* dm-sonnet + + +## Data Format +The structure of the data in the npz file is as follows: + +``` +train_data = [[user id, item id], ...] +test_data = {userid: (pos_id, [neg_id1, neg_id2, ...]), ...} +``` + diff --git a/recommendation/Basic-CMN-Demo/pretrain.py b/recommendation/Basic-CMN-Demo/pretrain.py new file mode 100644 index 00000000..8e22ab09 --- /dev/null +++ b/recommendation/Basic-CMN-Demo/pretrain.py @@ -0,0 +1,73 @@ +import argparse +import os +import numpy as np +import tensorflow as tf + +from tqdm import tqdm + +from util.gmf import PairwiseGMF +from util.helper import BaseConfig +from util.data import Dataset + + +parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) +parser.add_argument('-g', '--gpu', help='set gpu device number 0-3', type=str, default=0) +parser.add_argument('--iters', help='Max iters', type=int, default=15) +parser.add_argument('-b', '--batch_size', help='Batch Size', type=int, default=128) +parser.add_argument('-e', '--embedding', help='Embedding Size', type=int, default=50) +parser.add_argument('--dataset', help='path to npz file', type=str, default='pretrain_data/citeulike-a.npz') +parser.add_argument('-n', '--neg', help='Negative Samples Count', type=int, default=4) +parser.add_argument('--l2', help='l2 Regularization', type=float, default=0.001) +parser.add_argument('-o', '--output', help='save filename for trained embeddings', type=str, + default='pretrain/citeulike-a_e50.npz') + + +FLAGS = parser.parse_args() +os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpu + +class Config(BaseConfig): + filename = FLAGS.dataset + embed_size = FLAGS.embedding + batch_size = FLAGS.batch_size + l2 = FLAGS.l2 + user_count = -1 + item_count = -1 + optimizer = 'adam' + neg_count = FLAGS.neg + learning_rate = 0.001 + +config = Config() +dataset = Dataset(config.filename) +config.item_count = dataset.item_count +config.user_count = dataset.user_count +tf.logging.info("\n\n%s\n\n" % config) + +model = PairwiseGMF(config) +sv = tf.train.Supervisor(logdir=None, save_model_secs=0, save_summaries_secs=0) +sess = sv.prepare_or_wait_for_session( + config=tf.ConfigProto(gpu_options=tf.GPUOptions( + per_process_gpu_memory_fraction=0.1, + allow_growth=True))) + +for i in range(FLAGS.iters): + if sv.should_stop(): + break + progress = tqdm(enumerate(dataset.get_data(FLAGS.batch_size, False, FLAGS.neg)), + dynamic_ncols=True, total=(dataset.train_size * FLAGS.neg) // FLAGS.batch_size) + loss = [] + for k, example in progress: + feed = { + model.input_users: example[:, 0], + model.input_items: example[:, 1], + model.input_items_negative: example[:, 2], + } + batch_loss, _ = sess.run([model.loss, model.train], feed) + loss.append(batch_loss) + progress.set_description(u"[{}] Loss: {:,.4f} » » » » ".format(i, batch_loss)) + + print("Epoch {}: Avg Loss/Batch {:<20,.6f}".format(i, np.mean(loss))) + +user_embed, item_embed, v = sess.run([model.user_memory.embeddings, model.item_memory.embeddings, model.v.w]) +np.savez(FLAGS.output, user=user_embed, item=item_embed, v=v) +print('Saving to: %s' % FLAGS.output) +sv.request_stop() diff --git a/recommendation/Basic-CMN-Demo/train.py b/recommendation/Basic-CMN-Demo/train.py new file mode 100644 index 00000000..417216da --- /dev/null +++ b/recommendation/Basic-CMN-Demo/train.py @@ -0,0 +1,148 @@ +import os +import argparse +from util.helper import get_optimizer_argparse, preprocess_args, create_exp_directory, BaseConfig, get_logging_config +from util.data import Dataset +from util.evaluation import evaluate_model, get_eval, get_model_scores +from util.cmn import CollaborativeMemoryNetwork +import numpy as np +import tensorflow as tf +from logging.config import dictConfig +from tqdm import tqdm + +parser = argparse.ArgumentParser(parents=[get_optimizer_argparse()], + formatter_class=argparse.ArgumentDefaultsHelpFormatter) +parser.add_argument('-g', '--gpu', help='set gpu device number 0-3', type=str, default=0) +parser.add_argument('--iters', help='Max iters', type=int, default=30) +parser.add_argument('-b', '--batch_size', help='Batch Size', type=int, default=128) +parser.add_argument('-e', '--embedding', help='Embedding Size', type=int, default=50) +parser.add_argument('--dataset', help='path to file', type=str, default='pretrain_data/citeulike-a.npz') +parser.add_argument('--hops', help='Number of hops/layers', type=int, default=2) +parser.add_argument('-n', '--neg', help='Negative Samples Count', type=int, default=4) +parser.add_argument('--l2', help='l2 Regularization', type=float, default=0.1) +parser.add_argument('-l', '--logdir', help='Set custom name for logdirectory', + type=str, default=None) +parser.add_argument('--resume', help='Resume existing from logdir', action="store_true") +parser.add_argument('--pretrain', help='Load pretrained user/item embeddings', type=str, + default='pretrain/citeulike-a_e50.npz') +parser.set_defaults(optimizer='rmsprop', learning_rate=0.001, decay=0.9, momentum=0.9) +FLAGS = parser.parse_args() +preprocess_args(FLAGS) +os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpu + +# Create results in here unless we specify a logdir +BASE_DIR = 'result/' +if FLAGS.logdir is not None and not os.path.exists(FLAGS.logdir): + os.mkdir(FLAGS.logdir) + +class Config(BaseConfig): + logdir = create_exp_directory(BASE_DIR) if FLAGS.logdir is None else FLAGS.logdir + filename = FLAGS.dataset + embed_size = FLAGS.embedding + batch_size = FLAGS.batch_size + hops = FLAGS.hops + l2 = FLAGS.l2 + user_count = -1 + item_count = -1 + optimizer = FLAGS.optimizer + tol = 1e-5 + neg_count = FLAGS.neg + optimizer_params = FLAGS.optimizer_params + grad_clip = 5.0 + decay_rate = 0.9 + learning_rate = FLAGS.learning_rate + pretrain = FLAGS.pretrain + max_neighbors = -1 + +config = Config() + +if FLAGS.resume: + config.save_directory = config.logdir + config.load() + +dictConfig(get_logging_config(config.logdir)) +dataset = Dataset(config.filename) + +config.item_count = dataset.item_count +config.user_count = dataset.user_count +config.save_directory = config.logdir +config.max_neighbors = dataset._max_user_neighbors + +tf.logging.info('\n\n%s\n\n' % config) + +if not FLAGS.resume: + config.save() + +model = CollaborativeMemoryNetwork(config) + +sv = tf.train.Supervisor(logdir=config.logdir, save_model_secs=60 * 10, + save_summaries_secs=0) + +sess = sv.prepare_or_wait_for_session(config=tf.ConfigProto( + gpu_options=tf.GPUOptions(allow_growth=True))) + +if not FLAGS.resume: + pretrain = np.load(FLAGS.pretrain) + sess.graph._unsafe_unfinalize() + tf.logging.info('Loading Pretrained Embeddings.... from %s' % FLAGS.pretrain) + sess.run([ + model.user_memory.embeddings.assign(pretrain['user']*0.5), + model.item_memory.embeddings.assign(pretrain['item']*0.5)]) + +# Train Loop +for i in range(FLAGS.iters): + if sv.should_stop(): + break + + progress = tqdm(enumerate(dataset.get_data(FLAGS.batch_size, True, FLAGS.neg)), + dynamic_ncols=True, total=(dataset.train_size * FLAGS.neg) // FLAGS.batch_size) + loss = [] + for k, example in progress: + ratings, pos_neighborhoods, pos_neighborhood_length, \ + neg_neighborhoods, neg_neighborhood_length = example + feed = { + model.input_users: ratings[:, 0], + model.input_items: ratings[:, 1], + model.input_items_negative: ratings[:, 2], + model.input_neighborhoods: pos_neighborhoods, + model.input_neighborhood_lengths: pos_neighborhood_length, + model.input_neighborhoods_negative: neg_neighborhoods, + model.input_neighborhood_lengths_negative: neg_neighborhood_length + } + batch_loss, _ = sess.run([model.loss, model.train], feed) + loss.append(batch_loss) + progress.set_description(u"[{}] Loss: {:,.4f} » » » » ".format(i, batch_loss)) + + tf.logging.info("Epoch {}: Avg Loss/Batch {:<20,.6f}".format(i, np.mean(loss))) + evaluate_model(sess, dataset.test_data, dataset.item_users_list, model.input_users, model.input_items, + model.input_neighborhoods, model.input_neighborhood_lengths, + model.dropout, model.score, config.max_neighbors) + +EVAL_AT = range(1, 11) +hrs, ndcgs = [], [] +s = "" +scores, out = get_model_scores(sess, dataset.test_data, dataset.item_users_list, model.input_users, model.input_items, + model.input_neighborhoods, model.input_neighborhood_lengths, + model.dropout, model.score, config.max_neighbors, True) + +for k in EVAL_AT: + hr, ndcg = get_eval(scores, len(scores[0])-1, k) + hrs.append(hr) + ndcgs.append(ndcg) + s += "{:<14} {:<14.6f}{:<14} {:.6f}\n".format('HR@%s' % k, hr, + 'NDCG@%s' % k, ndcg) +tf.logging.info(s) + +with open("{}/final_results".format(config.logdir), 'w') as fout: + header = ','.join([str(k) for k in EVAL_AT]) + fout.write("{},{}\n".format('metric', header)) + ndcg = ','.join([str(x) for x in ndcgs]) + hr = ','.join([str(x) for x in hrs]) + fout.write("ndcg,{}\n".format(ndcg)) + fout.write("hr,{}".format(hr)) + +tf.logging.info("Saving model...") +# Save before exiting +sv.saver.save(sess, sv.save_path, + global_step=tf.contrib.framework.get_global_step()) +sv.request_stop() + diff --git a/recommendation/Basic-CMN-Demo/util/attention.py b/recommendation/Basic-CMN-Demo/util/attention.py new file mode 100644 index 00000000..80f74db6 --- /dev/null +++ b/recommendation/Basic-CMN-Demo/util/attention.py @@ -0,0 +1,189 @@ + +import tensorflow as tf +import sonnet as snt +from .helper import GraphKeys +from collections import namedtuple +import numpy as np + +AttentionOutput = namedtuple("AttentionOutput", ['weight', 'output']) + + +class MemoryMask(snt.AbstractModule): + """ + Helper Module to apply a simple memory mask for attention based reads. The + values beyond the sequence length are set to the smallest possible value we + can represent with a float32. + """ + + def __init__(self, name='MemoryMask'): + super(MemoryMask, self).__init__(name=name) + + def _build(self, inputs, mask_length, maxlen=None,): + """ + Apply a memory mask such that the values we mask result in being the + minimum possible value we can represent with a float32. + + Taken from Sonnet Attention Module + + :param inputs: [batch size, length], dtype=tf.float32 + :param memory_mask: [batch_size] shape Tensor of ints indicating the + length of inputs + :param maxlen: Sets the maximum length of the sequence; if None infered + from inputs + :returns: [batch size, length] dim Tensor with the mask applied + """ + if len(mask_length.shape) != 1: + raise ValueError('Mask Length must be a 1-d Tensor, got %s' % mask_length.shape) + + # [batch_size, length] + memory_mask = tf.sequence_mask(mask_length, maxlen=maxlen, name='SequenceMask') + inputs.shape.assert_is_compatible_with(memory_mask.shape) + + + num_remaining_memory_slots = tf.reduce_sum( + tf.cast(memory_mask, dtype=tf.int32), axis=[1]) + + with tf.control_dependencies([tf.assert_positive( + num_remaining_memory_slots)]): + # Get the numerical limits of a float + finfo = np.finfo(np.float32) + + # If True = 1 = Keep that memory slot + kept_indices = tf.cast(memory_mask, dtype=tf.float32) + + # Inverse + ignored_indices = tf.cast(tf.logical_not(memory_mask), dtype=tf.float32) + + # If we keep the indices its the max float value else its the + # minimum float value. Then we can take the minimum + lower_bound = finfo.max * kept_indices + finfo.min * ignored_indices + slice_length = tf.reduce_max(mask_length) + + # Return the elementwise + return tf.minimum(inputs[:, :slice_length], + lower_bound[:, :slice_length]) + + +class ApplyAttentionMemory(snt.AbstractModule): + + def __init__(self, name="AttentionMemory"): + super(ApplyAttentionMemory, self).__init__(name=name) + + def _build(self, memory, output_memory, query, memory_mask=None, maxlen=None): + """ + + :param memory: [batch size, max length, embedding size], + typically Matrix M + :param output_memory: [batch size, max length, embedding size], + typically Matrix C + :param query: [batch size, embed size], typically u + :param memory_mask: [batch size] dim Tensor, the length of each + sequence if variable length + :param maxlen: int/Tensor, the maximum sequence padding length; if None it + infers based on the max of memory_mask + :returns: AttentionOutput + output: [batch size, embedding size] + weight: [batch size, max length], the attention weights applied to + the output representation. + """ + memory.shape.assert_has_rank(3) + output_memory.shape.assert_has_rank(3) + + # query = [batch size, embeddings] => Expand => [batch size, embeddings, 1] + # Transpose => [batch size, 1, embeddings] + query_expanded = tf.transpose(tf.expand_dims(query, -1), [0, 2, 1]) + + # Apply batched dot product + # memory = [batch size, , Embeddings] + # Broadcast the same memory across each dimension of max length + # We obtain an attention value for each memory, + # ie a_0 p_0, a_1 p_1, .. a_n p_n, which equates to the max length + # because our query is only 1 dim, we only get attention over memory + # for that query. If our query was 2-d then we would obtain a matrix. + # Return: [batch size, max length] + scores = tf.reduce_sum(query_expanded * memory, axis=2) + + if memory_mask is not None: + mask_mod = MemoryMask() + scores = mask_mod(scores, memory_mask, maxlen) + + # Attention over memories: [Batch Size, ] + attention = tf.nn.softmax(scores, name='Attention') + tf.add_to_collection(GraphKeys.ATTTENTION, attention) + + # [Batch Size, ] => [Batch Size, 1, ] + probs_temp = tf.expand_dims(attention, 1, name='TransformAttention') + + # Output_Memories = [batch size, , Embeddings] + # Transpose = [Batch Size, Embedding Size, ] + c_temp = tf.transpose(output_memory, [0, 2, 1], + name='TransformOutputMemory') + + # Apply a weighted scalar or attention to the external memory + # [batch size, 1, ] * [batch size, embedding size, ] + neighborhood = tf.multiply(c_temp, probs_temp, name='WeightedNeighborhood') + + # Sum the weighted memories together + # Input: [batch Size, embedding size, ] + # Output: [Batch Size, Embedding Size] + # Weighted output vector + weighted_output = tf.reduce_sum(neighborhood, axis=2, + name='OutputNeighborhood') + + return AttentionOutput(weight=attention, output=weighted_output) + + +class VariableLengthMemoryLayer(snt.AbstractModule): + + def __init__(self, hops, embed_size, activation_fn, initializers=None, + regularizers=None, name='MemoryLayer', ): + super(VariableLengthMemoryLayer, self).__init__(name=name) + self._hops = hops + self._initializers = initializers + self._regularizers = regularizers + self._activation_fn = activation_fn + self._embed_size = embed_size + + def _build(self, query, memory, output_memory, seq_length, maxlen=32): + """ + + :param query: initial query + :param memory: internal memory to query + :param output_memory: external memory to query + :param seq_length: length of the sequences + :param maxlen: int, the maximum length over the entire dataset + :return: + """ + memory.shape.assert_has_rank(3) + output_memory.shape.assert_has_rank(3) + max_length = tf.reduce_max(seq_length) + # Slice to maximum length + memory = memory[:, :max_length] + output_memory = output_memory[:, :max_length] + + user_query, item_query = query + hop_outputs = [] + query = tf.add(user_query, item_query, name='InitialQuery') + + for hop_k in range(self._hops): # For each hop + if hop_k > 0: + # Apply Mapping + hop_mapping = snt.Linear(self._embed_size, True, + regularizers=self._regularizers, + initializers=self._initializers, + name='HopMap%s' % hop_k) + with tf.name_scope('Map'): + # z = m_u + e_i + # f(Wz + o + b) + query = self._activation_fn(hop_mapping(query) + memory_hop.output) + tf.add_to_collection(GraphKeys.ACTIVATIONS, query) + tf.logging.info('Creating Hop Mapping {} with {}'.format(hop_k+1, + self._activation_fn)) + # Apply attention + hop = ApplyAttentionMemory('AttentionHop%s' % hop_k) + + # [batch size, embedding size] + memory_hop = hop(memory, output_memory, query, seq_length, maxlen=maxlen) + hop_outputs.append(memory_hop) + + return hop_outputs diff --git a/recommendation/Basic-CMN-Demo/util/cmn.py b/recommendation/Basic-CMN-Demo/util/cmn.py new file mode 100644 index 00000000..9b82f03d --- /dev/null +++ b/recommendation/Basic-CMN-Demo/util/cmn.py @@ -0,0 +1,160 @@ +import sonnet as snt +import tensorflow as tf +from util.helper import GraphKeys, add_to_collection +from util.layers import LossLayer, OptimizerLayer, ModelBase, DenseLayer +from util.attention import VariableLengthMemoryLayer + + +class CollaborativeMemoryNetwork(ModelBase): + + def __init__(self, config): + """ + + :param config: + """ + super(CollaborativeMemoryNetwork, self).__init__(config) + + self._embedding_initializers = {'embeddings': tf.truncated_normal_initializer(stddev=0.01)} + self._initializers = { + 'w': tf.contrib.layers.variance_scaling_initializer(factor=2.0, + mode='FAN_IN', + uniform=False), + } + + self._hops_init = { + 'w': tf.contrib.layers.variance_scaling_initializer(factor=2.0, + mode='FAN_IN', + uniform=False), + # Ensure ReLU fires + 'b': tf.constant_initializer(1.0) + } + + self._output_initializers = { + 'w': tf.contrib.layers.xavier_initializer() + } + + self._regularizers = { + 'w': tf.contrib.layers.l2_regularizer(config.l2) + } + + self._construct_placeholders() + self._construct_weights() + self._construct() + # Add summaries + tf.summary.scalar('Model/Loss', tf.get_collection(GraphKeys.LOSSES)[0]) + tf.summary.scalar('Model/LearningRate', self.learning_rate) + + self.summary = tf.summary.merge_all() + + def _construct(self): + """ + Construct the model; main part of it goes here + """ + # our query = m_u + e_i + query = (self._cur_user, self._cur_item) + neg_query = (self._cur_user, self._cur_item_negative) + + # Positive + neighbor = self._mem_layer(query, + self.user_memory(self.input_neighborhoods), + self.user_output(self.input_neighborhoods), + self.input_neighborhood_lengths, + self.config.max_neighbors)[-1].output + self.score = self._output_module(tf.concat([self._cur_user * self._cur_item, + neighbor], axis=1)) + + # Negative + neighbor_negative = self._mem_layer(neg_query, + self.user_memory(self.input_neighborhoods_negative), + self.user_output(self.input_neighborhoods_negative), + self.input_neighborhood_lengths_negative, + self.config.max_neighbors)[-1].output + negative_output = self._output_module(tf.concat( + [self._cur_user * self._cur_item_negative, neighbor_negative], axis=1)) + + # Loss and Optimizer + self.loss = LossLayer()(self.score, negative_output) + self._optimizer = OptimizerLayer(self.config.optimizer, clip=self.config.grad_clip, + params=self.config.optimizer_params) + self.train = self._optimizer(self.loss) + + tf.add_to_collection(GraphKeys.PREDICTION, self.score) + + def _construct_placeholders(self): + """Create placeholders for our model""" + self.input_users = tf.placeholder(tf.int32, [None], 'UserID') + self.input_items = tf.placeholder(tf.int32, [None], 'ItemID') + self.input_items_negative = tf.placeholder(tf.int32, [None], + 'NegativeItemID') + self.input_neighborhoods = tf.placeholder(tf.int32, [None, None], + 'Neighborhood') + + self.input_neighborhood_lengths = tf.placeholder(tf.int32, [None], + 'NeighborhoodLengthID') + + self.input_neighborhoods_negative = tf.placeholder(tf.int32, + [None, None], + 'NeighborhoodNeg') + + self.input_neighborhood_lengths_negative = tf.placeholder(tf.int32, + [None], + 'NeighborhoodLengthIDNeg') + # Add our placeholders + add_to_collection(GraphKeys.PLACEHOLDER, [self.input_users, + self.input_items, + self.input_items_negative, + self.input_neighborhoods, + self.input_neighborhood_lengths, + self.input_neighborhoods_negative, + self.input_neighborhood_lengths_negative, + self.dropout]) + + def _construct_weights(self): + """ + Constructs the user/item memories and user/item external memory/outputs + + Also add the embedding lookups + """ + self.user_memory = snt.Embed(self.config.user_count, self.config.embed_size, + initializers=self._embedding_initializers, + name='MemoryEmbed') + + self.user_output = snt.Embed(self.config.user_count, self.config.embed_size, + initializers=self._embedding_initializers, + name='MemoryOutput') + + self.item_memory = snt.Embed(self.config.item_count, + self.config.embed_size, + initializers=self._embedding_initializers, + name="ItemMemory") + self._mem_layer = VariableLengthMemoryLayer(self.config.hops, + self.config.embed_size, + tf.nn.relu, + initializers=self._hops_init, + regularizers=self._regularizers, + name='UserMemoryLayer') + + self._output_module = snt.Sequential([ + DenseLayer(self.config.embed_size, True, tf.nn.relu, + initializers=self._initializers, + regularizers=self._regularizers, + name='Layer'), + snt.Linear(1, False, + initializers=self._output_initializers, + regularizers=self._regularizers, + name='OutputVector'), + tf.squeeze]) + + # [batch, embedding size] + self._cur_user = self.user_memory(self.input_users) + self._cur_user_output = self.user_output(self.input_users) + + # Item memories a query + self._cur_item = self.item_memory(self.input_items) + self._cur_item_negative = self.item_memory(self.input_items_negative) + + # Share Embeddings + self._cur_item_output = self._cur_item + self._cur_item_output_negative = self._cur_item_negative + + diff --git a/recommendation/Basic-CMN-Demo/util/data.py b/recommendation/Basic-CMN-Demo/util/data.py new file mode 100644 index 00000000..cc808789 --- /dev/null +++ b/recommendation/Basic-CMN-Demo/util/data.py @@ -0,0 +1,134 @@ +import numpy as np +from collections import defaultdict + + +class Dataset(object): + + def __init__(self, filename): + self._data = np.load(filename) + self.train_data = self._data['train_data'] + self.test_data = self._data['test_data'].tolist() + self._train_index = np.arange(len(self.train_data), dtype=np.uint) + self._n_users, self._n_items = self.train_data.max(axis=0) + 1 + + # Neighborhoods + self.user_items = defaultdict(set) + self.item_users = defaultdict(set) + self.item_users_list = defaultdict(list) + for u, i in self.train_data: + self.user_items[u].add(i) + self.item_users[i].add(u) + # Get a list version so we do not need to perform type casting + self.item_users_list[i].append(u) + + self._max_user_neighbors = max([len(x) for x in self.item_users.values()]) + + @property + def train_size(self): + """ + :return: number of examples in training set + :rtype: int + """ + return len(self.train_data) + + @property + def user_count(self): + return self._n_users + + @property + def item_count(self): + return self._n_items + + def _sample_item(self): + """ + Draw an item uniformly + """ + return np.random.randint(0, self.item_count) + + def _sample_negative_item(self, user_id): + """ + Uniformly sample a negative item + """ + if user_id > self.user_count: + raise ValueError("Trying to sample user id: {} > user count: {}".format( + user_id, self.user_count)) + + n = self._sample_item() + positive_items = self.user_items[user_id] + + if len(positive_items) >= self.item_count: + raise ValueError("The User has rated more items than possible %s / %s" % ( + len(positive_items), self.item_count)) + while n in positive_items or n not in self.item_users: + n = self._sample_item() + return n + + def _generate_data(self, neg_count): + idx = 0 + self._examples = np.zeros((self.train_size*neg_count, 3), + dtype=np.uint32) + self._examples[:, :] = 0 + for user_idx, item_idx in self.train_data: + for _ in range(neg_count): + neg_item_idx = self._sample_negative_item(user_idx) + self._examples[idx, :] = [user_idx, item_idx, neg_item_idx] + idx += 1 + + def get_data(self, batch_size, neighborhood, neg_count): + # Allocate inputs + batch = np.zeros((batch_size, 3), dtype=np.uint32) + pos_neighbor = np.zeros((batch_size, self._max_user_neighbors), dtype=np.int32) + pos_length = np.zeros(batch_size, dtype=np.int32) + neg_neighbor = np.zeros((batch_size, self._max_user_neighbors), dtype=np.int32) + neg_length = np.zeros(batch_size, dtype=np.int32) + + # Shuffle index + np.random.shuffle(self._train_index) + + idx = 0 + for user_idx, item_idx in self.train_data[self._train_index]: + # TODO: set positive values outside of for loop + for _ in range(neg_count): + neg_item_idx = self._sample_negative_item(user_idx) + batch[idx, :] = [user_idx, item_idx, neg_item_idx] + + # Get neighborhood information + if neighborhood: + if len(self.item_users[item_idx]) > 0: + pos_length[idx] = len(self.item_users[item_idx]) + pos_neighbor[idx, :pos_length[idx]] = self.item_users_list[item_idx] + else: + # Length defaults to 1 + pos_length[idx] = 1 + pos_neighbor[idx, 0] = item_idx + + if len(self.item_users[neg_item_idx]) > 0: + neg_length[idx] = len(self.item_users[neg_item_idx]) + neg_neighbor[idx, :neg_length[idx]] = self.item_users_list[neg_item_idx] + else: + # Length defaults to 1 + neg_length[idx] = 1 + neg_neighbor[idx, 0] = neg_item_idx + + idx += 1 + # Yield batch if we filled queue + if idx == batch_size: + if neighborhood: + max_length = max(neg_length.max(), pos_length.max()) + yield batch, pos_neighbor[:, :max_length], pos_length, \ + neg_neighbor[:, :max_length], neg_length + pos_length[:] = 1 + neg_length[:] = 1 + else: + yield batch + # Reset + idx = 0 + + # Provide remainder + if idx > 0: + if neighborhood: + max_length = max(neg_length[:idx].max(), pos_length[:idx].max()) + yield batch[:idx], pos_neighbor[:idx, :max_length], pos_length[:idx], \ + neg_neighbor[:idx, :max_length], neg_length[:idx] + else: + yield batch[:idx] diff --git a/recommendation/Basic-CMN-Demo/util/evaluation.py b/recommendation/Basic-CMN-Demo/util/evaluation.py new file mode 100644 index 00000000..4a11f72e --- /dev/null +++ b/recommendation/Basic-CMN-Demo/util/evaluation.py @@ -0,0 +1,87 @@ +import numpy as np +import tensorflow as tf +from tqdm import tqdm + + +def get_model_scores(sess, test_data, neighborhood, input_user_handle, input_item_handle, + input_neighborhood_handle, input_neighborhood_length_handle, + dropout_handle, score_op, max_neighbors, return_scores=False): + """ + test_data = dict([positive, np.array[negatives]]) + """ + out = '' + scores = [] + progress = tqdm(test_data.items(), total=len(test_data), + leave=False, desc=u'Evaluate || ') + for user, (pos, neg) in progress: + item_indices = list(neg) + [pos] + + feed = { + input_user_handle: [user] * (len(neg) + 1), + input_item_handle: item_indices, + } + + if neighborhood is not None: + neighborhoods, neighborhood_length = np.zeros((len(neg) + 1, max_neighbors), + dtype=np.int32), np.ones(len(neg) + 1, dtype=np.int32) + + for _idx, item in enumerate(item_indices): + _len = min(len(neighborhood[item]), max_neighbors) + if _len > 0: + neighborhoods[_idx, :_len] = neighborhood[item][:_len] + neighborhood_length[_idx] = _len + else: + neighborhoods[_idx, :1] = user + feed.update({ + input_neighborhood_handle: neighborhoods, + input_neighborhood_length_handle: neighborhood_length + }) + + score = sess.run(score_op, feed) + scores.append(score.ravel()) + if return_scores: + s = ' '.join(["{}:{}".format(n, s) for s, n in zip(score.ravel().tolist(), item_indices)]) + out += "{}\t{}\n".format(user, s) + if return_scores: + return scores, out + return scores + + +def evaluate_model(sess, test_data, neighborhood, input_user_handle, input_item_handle, + input_neighborhood_handle, input_neighborhood_length_handle, + dropout_handle, score_op, max_neighbors, EVAL_AT=[1, 5, 10]): + scores = get_model_scores(sess, test_data, neighborhood, input_user_handle, input_item_handle, + input_neighborhood_handle, input_neighborhood_length_handle, + dropout_handle, score_op, max_neighbors) + hrs = [] + ndcgs = [] + s = '\n' + for k in EVAL_AT: + hr, ndcg = get_eval(scores, len(scores[0]) - 1, k) + s += "{:<14} {:<14.6f}{:<14} {:.6f}\n".format('HR@%s' % k, hr, 'NDCG@%s' % k, ndcg) + hrs.append(hr) + ndcgs.append(ndcg) + tf.logging.info(s + '\n') + + return hrs, ndcgs + + +def get_eval(scores, index, top_n=10): + """ + if the last element is the correct one, then + index = len(scores[0])-1 + """ + ndcg = 0.0 + hr = 0.0 + assert len(scores[0]) > index and index >= 0 + + for score in scores: + # Get the top n indices + arg_index = np.argsort(-score)[:top_n] + if index in arg_index: + # Get the position + ndcg += np.log(2.0) / np.log(arg_index.tolist().index(index) + 2.0) + # Increment + hr += 1.0 + + return hr / len(scores), ndcg / len(scores) diff --git a/recommendation/Basic-CMN-Demo/util/gmf.py b/recommendation/Basic-CMN-Demo/util/gmf.py new file mode 100644 index 00000000..f0c32784 --- /dev/null +++ b/recommendation/Basic-CMN-Demo/util/gmf.py @@ -0,0 +1,85 @@ +import sonnet as snt +import tensorflow as tf + +from util.helper import GraphKeys, add_to_collection +from util.layers import DenseLayer, LossLayer, OptimizerLayer, ModelBase + + +class PairwiseGMF(ModelBase): + + def __init__(self, config): + """ + + :param config: + """ + # super(PairwiseGMF, self).__init__(config) + self.config = config + self._activation_fn = tf.nn.relu + self._embedding_initializers = { + 'embeddings': tf.truncated_normal_initializer(stddev=0.01), + } + + self._embedding_regularizers = {} + + self._initializers = { + "w": tf.contrib.layers.xavier_initializer(), + } + + self._regularizers = { + 'w': tf.contrib.layers.l2_regularizer(config.l2) + } + + self._construct_placeholders() + self._construct_weights() + self._construct() + tf.summary.scalar('Model/Loss', tf.get_collection(GraphKeys.LOSSES)[0]) + self.summary = tf.summary.merge_all() + + def _construct(self): + """ + Construct the model; main part of it goes here + """ + + self.v = DenseLayer(1, False, tf.nn.relu, initializers=self._initializers, + regularizers=self._regularizers, name='OutputVector') + self.score = tf.squeeze(self.v(self._cur_user * self._cur_item)) + negative_output = tf.squeeze(self.v(self._cur_user * self._cur_item_negative)) + tf.add_to_collection(GraphKeys.PREDICTION, self.score) + self.loss = LossLayer()(self.score, negative_output) + self._optimizer = OptimizerLayer(self.config.optimizer, clip=5.0, + params={}) + self.train = self._optimizer(self.loss) + + def _construct_weights(self): + """ + Constructs the user/item memories and user/item external memory/outputs + + Also add the embedding lookups + """ + self.user_memory = snt.Embed(self.config.user_count, self.config.embed_size, + initializers=self._embedding_initializers, + regularizers=self._embedding_regularizers, + name='MemoryEmbed') + + self.item_memory = snt.Embed(self.config.item_count, + self.config.embed_size, + initializers=self._embedding_initializers, + regularizers=self._embedding_regularizers, + name="ItemMemory") + + # [batch, embedding size] + self._cur_user = self.user_memory(self.input_users) + + # Item memories a query + self._cur_item = self.item_memory(self.input_items) + self._cur_item_negative = self.item_memory(self.input_items_negative) + + def _construct_placeholders(self): + self.input_users = tf.placeholder(tf.int32, [None], 'UserID') + self.input_items = tf.placeholder(tf.int32, [None], 'ItemID') + self.input_items_negative = tf.placeholder(tf.int32, [None], 'NegativeItemID') + + # Add our placeholders + add_to_collection(GraphKeys.PLACEHOLDER, [self.input_users, + self.input_items, + self.input_items_negative]) \ No newline at end of file diff --git a/recommendation/Basic-CMN-Demo/util/helper.py b/recommendation/Basic-CMN-Demo/util/helper.py new file mode 100644 index 00000000..a12f1c48 --- /dev/null +++ b/recommendation/Basic-CMN-Demo/util/helper.py @@ -0,0 +1,274 @@ +import tensorflow as tf +import argparse +from itertools import chain +import json +import pickle +import os +import logging + + +def add_to_collection(names, values): + """ + Adds multiple elements to a given collection(s) + + :param names: str or list of collections + :param values: tensor or list of tensors to add to collection + """ + if isinstance(names, str): + names = [names] + if isinstance(values, str): + values = [values] + for name in names: + for value in values: + tf.add_to_collection(name, value) + + +class GraphKeys(object): + """ + Custom GraphKeys; primarily to be backwards compatable incase tensorflow + changes it. Also to add my own names + + https://github.com/tensorflow/tensorflow/blob/r1.1/tensorflow/python/framework/ops.py#L3921 + """ + TRAINABLE_VARIABLES = "trainable_variables" + PLACEHOLDER = 'placeholder' + PREDICTION = 'prediction' + ATTTENTION = 'attention' + TRAIN_OP = 'train_op' + EVAL_STEP = 'eval_step' + LOSSES = 'losses' + WEIGHTS = 'weights' + BIASES = 'biases' + REG_WEIGHTS = 'reg_weights' + USER_WEIGHTS = 'user_weights' + ITEM_WEIGHTS = 'item_weights' + GRADIENTS = 'gradients' + + # Regularization l1/l2 Penalty that would be added + LOSS_REG = 'regularization_losses' + + # Loss Value without Penalty + LOSS_NO_REG = 'loss' + + # Keys for the activation of a layer + ACTIVATIONS = 'activations' + + # Keys for prior to applying the activation function of a layer + PRE_ACTIVATIONS = 'pre_activations' + + SUMMARIES = 'summaries' + METRIC_UPDATE = 'metric_update' + METRIC = 'metric' + TRAIN = 'train_op' + + +# List of optimizer classes mappings +OPTIMIZER = { + # learning_rate=0.001, beta1=0.9, beta2=0.999 + 'adam': tf.train.AdamOptimizer, + + # Lazy Adam only updates momentum estimators on values used; it may cause + # different results than adam + 'lazyadam': tf.contrib.opt.LazyAdamOptimizer, + + # learning_rate, initial_accumulator_value=0.1 + 'adagrad': tf.train.AdagradOptimizer, + + # learning_rate, decay=0.9, momentum=0.0 + 'rmsprop': tf.train.RMSPropOptimizer, + + # learning_rate, momentum, use_nesterov=False + 'momentum': tf.train.MomentumOptimizer, + + # learning_rate=0.001, rho=0.95, epsilon=1e-08 + 'adadelta': tf.train.AdadeltaOptimizer, + + 'sgd': tf.train.GradientDescentOptimizer, +} + +# Hyperparameters for various optimizers +# learning_rate is for all +_optimizer_args = { + 'adam': ['beta1', 'beta2', 'epsilon'], + 'lazyadam': ['beta1', 'beta2', 'epsilon'], + 'momentum': ['momentum', 'use_nesterov'], + 'rmsprop': ['momentum', 'decay'], + 'adadelta': ['rho'] +} + + +def get_optimizer_argparse(): + """ + Get arguments for our blocks optimizer + """ + parser = argparse.ArgumentParser(add_help=False) + + optimizer_group = parser.add_argument_group('OPTIMIZATION', + description='Hyperparameters') + + optimizer_group.add_argument('--optimizer', default='adam', help='SGD optimizer', + choices=OPTIMIZER.keys()) + + optimizer_group.add_argument('--learning_rate', default=0.001, type=float, + help='learning rate [All]') + + optimizer_group.add_argument('--momentum', default=0.9, type=float, + help='Momentum value [Momentum/RMSProp]') + + optimizer_group.add_argument('--use_nesterov', default=False, action='store_true', + help='Use nesterov momentum [Momentum]') + + optimizer_group.add_argument('--beta1', default=0.9, type=float, + help='beta 1 hyperparameter [Adam]') + + optimizer_group.add_argument('--beta2', default=0.999, type=float, + help='beta 1 hyperparameter [Adam]') + + optimizer_group.add_argument('--epsilon', default=1e-08, type=float, + help='Epsilon for numerical stability [Adam]') + + optimizer_group.add_argument('--decay', default=0.9, type=float, + help='decay rate hyperparameter [RMSProp]') + + optimizer_group.add_argument('--rho', default=0.95, type=float, + help='rho hyperparameter [Adadelta]') + return parser + + +def _preprocess_args(parsed_obj, remove_attrs, keep_attrs, keyname): + """ + Note modifies inplace. Removes the attributes from a given class object and + consolidates list of keep_attrs to a single dictionary and sets the + attribute in the object with keyname. + + :param parsed_obj: object to access via attributes + :param remove_attrs: iterable of keys of attributes to remove + :param keep_attrs: iterable of keys to add to a dict and add keyname in + namespace + :param keyname: str, name of key to add keep_attrs to as a dict + """ + args = {attr: getattr(parsed_obj, attr) for attr in keep_attrs} + setattr(parsed_obj, keyname, args) + + for attr in remove_attrs: + delattr(parsed_obj, attr) + +def preprocess_args(FLAGS): + _preprocess_args(FLAGS, set(list(chain.from_iterable(_optimizer_args.values()))), + _optimizer_args[FLAGS.optimizer], 'optimizer_params') + + +class BaseConfig(object): + + save_directory = None + _IGNORE = ['fields', 'save', 'load'] + + # Set Custom Parameters by name with init + def __init__(self, **kwargs): + self.__dict__.update(kwargs) + + @property + def fields(self): + """ + Get all fields/properties stored in this config class + """ + return [m for m in dir(self) + if not m.startswith('_') and m not in self._IGNORE] + + def save(self): + """ + Config is dumped as a json file + """ + json.dump(self._get_dict(), + open('%s/config.json' % self.save_directory, 'w'), + sort_keys=True, indent=2) + pickle.dump({key: self.__getattribute__(key) for key in self.fields}, + open('%s/config.pkl' % self.save_directory, 'wb'), + pickle.HIGHEST_PROTOCOL) + + def load(self): + """ + Load config, equivalent to loading json and updating this classes' dict + """ + try: + d = pickle.load(open('%s/config.pkl' % self.save_directory)) + self.__dict__.update(d) + except Exception: + d = json.load(open('%s/config.json' % self.save_directory)) + self.__dict__.update(d) + + def _get_dict(self): + return {key: self.__getattribute__(key) if isinstance(self.__getattribute__(key), (int, float)) + else str(self.__getattribute__(key)) for key in self.fields} + + def __repr__(self): + return json.dumps(self._get_dict(), sort_keys=True, indent=2) + + def __str__(self): + return json.dumps(self._get_dict(), sort_keys=True, indent=2) + + +def create_exp_directory(cwd=''): + ''' + Creates a new directory to store experiment to save data + + Folders: XXX, creates directory sequentially + + Returns + ------- + exp_dir : str + The newly created experiment directory + + ''' + created = False + for i in range(1, 10000): + exp_dir = str(i).zfill(3) + path = os.path.join(cwd, exp_dir) + if not os.path.exists(path): + # Create directory + os.mkdir(path) + created = True + break + if not created: + print('Could not create directory for experiments') + exit(-1) + return path + '/' + +def get_logging_config(save_directory): + # Setup Logging + return dict( + version=1, + formatters={ + # For files + 'detailed': { + 'format': "[%(asctime)s - %(levelname)s:%(name)s]<%(funcName)s>:%(lineno)d: %(message)s", + }, + # For the console + 'console': { + 'format':"[%(levelname)s:%(name)s]<%(funcName)s>:%(lineno)d: %(message)s", + } + }, + handlers={ + 'console': { + 'class': 'logging.StreamHandler', + 'level': logging.INFO, + 'formatter': 'console', + }, + 'file': { + 'class': 'logging.handlers.RotatingFileHandler', + 'level': logging.DEBUG, + 'formatter': 'detailed', + 'filename': "{}/log".format(save_directory), + 'mode': 'a', + 'maxBytes': 10485760, # 10 MB + 'backupCount': 5 + } + }, + loggers={ + 'tensorflow': { + 'level': logging.INFO, + 'handlers': ['console', 'file'], + } + }, + disable_existing_loggers=False, + ) \ No newline at end of file diff --git a/recommendation/Basic-CMN-Demo/util/layers.py b/recommendation/Basic-CMN-Demo/util/layers.py new file mode 100644 index 00000000..6a185061 --- /dev/null +++ b/recommendation/Basic-CMN-Demo/util/layers.py @@ -0,0 +1,297 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +''' +@author: Travis A. Ebesu +@created: 2017-05-08 +@summary: +''' +import tensorflow as tf +import sonnet as snt +from .helper import GraphKeys, OPTIMIZER + + +def _bpr_loss(positive, negative, name=None): + r""" + Pairwise Loss from Bayesian Personalized Ranking. + + \log \sigma(pos - neg) + + where \sigma is the sigmoid function, we try to set the ranking + + if pos > neg = + number + if neg < pos = - number + + Then applying the sigmoid to obtain a monotonically increasing function. Any + monotonically increasing function could be used, eg piecewise or probit. + + :param positive: Score of prefered example + :param negative: Score of negative example + :param name: str, name scope + :returns: mean loss + """ + + with tf.name_scope(name, 'BPRLoss', [positive, negative]) as scope: + difference = positive - negative + # Numerical stability + eps = 1e-12 + loss = -tf.log(tf.nn.sigmoid(difference) + eps) + return tf.reduce_mean(loss, name=scope) + + + +class LossLayer(snt.AbstractModule): + """ + Loss Function Wrapper. Applies regularization from GraphKeys.REGULARIZATION_LOSSES + """ + def __init__(self, name='Loss'): + """ + Wrapper Function for loss with l1/l2 regularization + + :param loss_type: str, see rbase.utils.tfutils.Loss for Keys + :param name: name of this module + """ + super(LossLayer, self).__init__(name=name) + + def _build(self, X, y): + """ + + :param X: predicted value + :param y: ground truth + :returns: Loss with l1/l2 regularization added if in keys + """ + graph_regularizers = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) + + self._loss = tf.squeeze(_bpr_loss(X, y)) + self._regularization = None + self._loss_no_regularization = self._loss + + # Add regularization + if graph_regularizers: + self._regularization = tf.reduce_sum(graph_regularizers) + tf.add_to_collection(GraphKeys.LOSS_REG, self._regularization) + tf.add_to_collection(GraphKeys.LOSS_NO_REG, self._loss) + self._loss = self._loss + self._regularization + + tf.add_to_collection(GraphKeys.LOSSES, self._loss) + return self._loss + + @property + def loss(self): + """ + Total loss including regularization terms + """ + return self._loss + + @property + def regularization(self): + """ + Value of the regularization/weight decay + """ + return self._regularization + + @property + def loss_no_regularization(self): + """ + Obtain the loss without regularization added. This corresponds to + no regularization + """ + return self._loss_no_regularization + + +class DenseLayer(snt.AbstractModule): + """ + Simple dense layer with an activation function + """ + def __init__(self, output_size, add_bias=True, activation_fn=None, + initializers=None, partitioners=None, regularizers=None, + name="DenseLayer"): + super(DenseLayer, self).__init__(name=name) + self._output_size = output_size + self._add_bias = add_bias + self._initializers = initializers + self._partitioners = partitioners + self._regularizers = regularizers + self._activation_fn = activation_fn + self._layer = None + + def _build(self, inputs): + """ + Perform dense/fully connected layer with a activation function + """ + self._layer = snt.Linear(self._output_size, self._add_bias, self._initializers, + self._partitioners, self._regularizers, name='LinearWx') + output = self._layer(inputs) + # Add GraphKeys + if self._add_bias: + tf.add_to_collection(GraphKeys.BIASES, self._layer.b) + + tf.add_to_collection(GraphKeys.WEIGHTS, self._layer.w) + tf.add_to_collection(GraphKeys.PRE_ACTIVATIONS, output) + + if self._activation_fn is None or self._activation_fn == tf.identity: + return output + + output = self._activation_fn(output) + + # Add to GraphKeys for activation output + tf.add_to_collection(GraphKeys.ACTIVATIONS, output) + return output + + # Below are just convenience to access properties from the underlying layer + + @property + def output_size(self): + """ + Return the output size of this layer + """ + return self._layer.output_size + + @property + def input_shape(self): + """Returns shape of input `Tensor` passed at last call to `build`.""" + return self._layer.input_shape + + @property + def w(self): + """ + Get the weights matrix for this layer + :returns: Variable of the weights + """ + return self._layer.w + + @property + def b(self): + """Biases for this layer or raises an error if add_bias = False + + :returns: Variable of the biases + """ + + return self._layer.b + + @property + def layer(self): + return self._layer + + +class OptimizerLayer(snt.AbstractModule): + + def __init__(self, optimizer_name, clip=None, global_step=None, + params=None, name='Optimizer'): + """ + Optimizer Wrapper + + :param optimizer_name: str, name of the optimizer to use + :param clip: float, gradient clipping value to use else None + :param global_step: tensor, global step to use, default gets default from graph + :param params: dict for optimizer parameters to override defaults + :param name: str, name of module name space + """ + super(OptimizerLayer, self).__init__(name=name) + self._params = params + self._optimizer = OPTIMIZER[optimizer_name] + self._clip = clip + if global_step is None: + self._global_step = tf.contrib.framework.get_or_create_global_step() + else: + self._global_step = global_step + self._name = name + + @property + def train_op(self): + """ + Return the operation to minimize the loss function + """ + return self._train + + def _build(self, loss, trainable_variables=None): + """ + Pass a tensor for the loss to be optimized + + :param loss: tensor + :returns: Operation to minimize the loss + """ + + # Init optimizer + self._optimizer = self._optimizer(**self._params) + tvars = trainable_variables + + # Obtain vars to train + if trainable_variables is None: + tvars = tf.trainable_variables() + + # Get gradients + self._grads_vars = self._optimizer.compute_gradients(loss, tvars, + colocate_gradients_with_ops=True) + + for g, v in self._grads_vars: + if g is None: + print(v) + print(g) + print("Trainable Variables error, the graph is not connected") + raise Exception('Variable may not be connected or set to be trained..') + tf.add_to_collection(GraphKeys.GRADIENTS, g) + + # Clip gradients + if self._clip is not None and self._clip > 0: + self._grads_vars = [(tf.clip_by_norm(g, self._clip), v) + for g, v in self._grads_vars] + + self._train = self._optimizer.apply_gradients(self._grads_vars, + global_step=self._global_step, + name='ApplyGradients') + tf.add_to_collection(GraphKeys.TRAIN, g) + return self._train + + +class ModelBase(object): + + def __init__(self, config): + self.config = config + self._global_step = tf.contrib.framework.get_or_create_global_step() + with tf.name_scope("LearningRateDecay"): + self.learning_rate = tf.Variable(float(config.learning_rate), + trainable=False, dtype=tf.float32) + # Placeholder to decay learning rate by some amount + self._learning_rate_decay_factor = tf.placeholder(tf.float32, + name='LearningRateDecayFactor') + # Operation to decay learning rate + self._learning_rate_decay_op = self.learning_rate.assign( + self.learning_rate * self._learning_rate_decay_factor) + + with tf.name_scope("Dropout"): + self.dropout = tf.Variable(1.0, + trainable=False, dtype=tf.float32, + name='DropoutProbability') + + self._dropout_update = tf.placeholder(tf.float32, + name='SetDropoutRate') + + self._set_dropout_op = self.dropout.assign(self._dropout_update) + + # Set the learning rate for the optimizer parameters as our variable + self.config.optimizer_params['learning_rate'] = self.learning_rate + + def decay_learning_rate(self, session, learning_rate_decay): + """ + Decay the current learning rate by decay amount + New Learning Rate = Current Learning Rate * Rate Decay + """ + session.run(self._learning_rate_decay_op, + {self._learning_rate_decay_factor: learning_rate_decay}) + + def turn_off_dropout(self, sess): + """ + Sets keep probability to 1.0 + :param sess: Tf Session + """ + sess.run(self._set_dropout_op, + {self._dropout_update: 1.0}) + + def set_dropout(self, sess, rate): + """Set the dropout rate + + :param sess: Tf Session + :param rate: float, dropout keep probability + """ + sess.run(self._set_dropout_op, + {self._dropout_update: float(rate)}) \ No newline at end of file