diff --git a/learned_optimization/research/data_driven/README.md b/learned_optimization/research/data_driven/README.md new file mode 100644 index 0000000..ee6eda1 --- /dev/null +++ b/learned_optimization/research/data_driven/README.md @@ -0,0 +1,3 @@ +# General-Purpose In-Context Learning by Meta-Learning Transformers + +Research code for the paper https://arxiv.org/abs/2212.04458 \ No newline at end of file diff --git a/learned_optimization/research/data_driven/data.py b/learned_optimization/research/data_driven/data.py new file mode 100644 index 0000000..158e7e4 --- /dev/null +++ b/learned_optimization/research/data_driven/data.py @@ -0,0 +1,322 @@ +# coding=utf-8 +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Supervised data loader.""" + +import functools +from typing import NamedTuple, Optional, Tuple, Sequence + +import gin +import haiku as hk +import jax +import jax.numpy as jnp +from learned_optimization.research.data_driven import resnet +import tensorflow.compat.v2 as tf +import tensorflow_datasets as tfds + + +def standardize( + batch: Tuple[jnp.ndarray, jnp.ndarray], + has_dataset_dim: bool = True, + subsample: int = 0, +) -> Tuple[jnp.ndarray, jnp.ndarray]: + """Z-normalizes the given batch. + + Args: + batch: Tuple if images and labels. + has_dataset_dim: Whether there is a dataset dimension. + subsample: Size of the subsample in batch and sequence dimension. + + Returns: + Z-normalized batch. + """ + + imgs, labels = batch + if has_dataset_dim and subsample > 0: + # of shape [dataset, batch, sequence, ...] + mean = jnp.mean(imgs[0, :subsample, :subsample]) + std = jnp.std(imgs[0, :subsample, :subsample]) + elif subsample > 0: + # of shape [batch, sequence, ...] + mean = jnp.mean(imgs[:subsample, :subsample]) + std = jnp.std(imgs[:subsample, :subsample]) + else: + mean = jnp.mean(imgs) + std = jnp.std(imgs) + imgs = (imgs - mean) / (std + 1e-8) + return imgs, labels + + +@gin.configurable('preprocess') +class PreprocessSpec(NamedTuple): + """A specification for preprocessing input data. + + Attributes: + resize: Target width and height of image. + channel_expand: Whether to expand channels to 3 dimensions. + use_patches: Whether to create patches for vision-transformer processing. + """ + + resize: Optional[int] = 14 + channel_expand: bool = False + use_patches: bool = False + standardize_sub_sample = 0 + + +@gin.configurable() +class RandomDataset: + """A dataset that associcates random observations with random class labels.""" + + def __init__( + self, + key, + batch_size: int, + dataset_size: Optional[int], + sequence_length: int, + preprocess_spec: PreprocessSpec, + normalize: bool, + bias_prob: float = 0.0, + image_shape: Sequence[int] = (14, 14), + num_datapoints: int = 10, + num_classes: int = 10, + ): + self.rng = hk.PRNGSequence(key) + self._batch_size = batch_size + self._sequence_length = sequence_length + self._dataset_size = dataset_size + self._preprocess_spec = preprocess_spec + self._normalize = normalize + self._bias_prob = bias_prob + self._bias_key = next(self.rng) + self._image_shape = image_shape + self._num_datapoints = num_datapoints + self._num_classes = num_classes + std_p = functools.partial( + standardize, + has_dataset_dim=dataset_size is not None, + subsample=preprocess_spec.standardize_sub_sample, + ) + self._standardize = jax.jit(std_p) + + if dataset_size is not None: + self._next = jax.jit(self._generate_tasks) + else: + self._next = jax.jit(self._generate_task) + + def _generate_task(self, key): + """Generate a new unique task. + + Args: + key: A jax PRNGKey + + Returns: + A Tuple of images and labels + """ + key_img, key_choice = jax.random.split(key, num=2) + del key + spec = self._preprocess_spec + images = jax.random.uniform( + key_img, + [self._num_datapoints] + list(self._image_shape), + minval=0.0, + maxval=1.0, + ) + labels = jax.nn.one_hot( + jnp.arange(self._num_datapoints) % self._num_classes, self._num_classes + ) + + if spec.channel_expand: + images = jnp.concatenate([images] * 3, axis=-1) + if not spec.use_patches: + images = jnp.reshape(images, [self._num_datapoints, -1]) + + choice_shape = (self._batch_size, self._sequence_length) + indices = jax.random.choice(key_choice, 10, choice_shape) + batched_images = images[indices] + batched_labels = labels[indices] + + return batched_images, batched_labels + + def _generate_tasks(self, key): + key_tasks, key_mask = jax.random.split(key) + del key + key_tasks = jax.random.split(key_tasks, self._dataset_size) + mask = jax.random.bernoulli( + key_mask, p=self._bias_prob, shape=(self._dataset_size,)) + key_tasks = jnp.where(mask[:, None], self._bias_key[None], key_tasks) + return jax.vmap(self._generate_task)(key_tasks) + + def __next__(self): + item = self._next(next(self.rng)) + if self._normalize: + item = self._standardize(item) + return item + + def __iter__(self): + return self + + +@gin.configurable() +class DataLoader: + """Loads a specific tensorflow dataset and processes data for experiment.""" + + DATASET_STATS = { + 'cifar10': {'mean': 0.4733630120754242, 'std': 0.2515689432621002}, + 'fashion_mnist': {'mean': 0.13066047430038452, 'std': 0.3081078827381134}, + 'mnist': {'mean': 0.13066047430038452, 'std': 0.3081078827381134}, + 'svhn_cropped': {'mean': 0.4514186382293701, 'std': 0.19929124414920807}, + 'random': {'mean': 0.0, 'std': 1.0}, + 'sum': {'mean': 0.0, 'std': 1.0}, + 'emnist': {'mean': 0.1739204376935959, 'std': 0.3319065570831299}, + 'kmnist': {'mean': 0.19176216423511505, 'std': 0.34834328293800354}, + 'omniglot': {'mean': 0.9220603108406067, 'std': 0.26807650923728943}, + 'omniglot_fewshot': { + 'mean': 0.9220603108406067, + 'std': 0.26807650923728943, + }, + } + + def __init__( + self, + dataset_name: str, + num_classes=10, + shuffle_size=10000, + prefetch_size=10, + sequence_length=100, + preprocess_spec=None, + normalize=True, + use_fixed_ds_stats: bool = False, + pretrained_embed: bool = False, + ): + self._num_classes = num_classes + self._dataset_name = dataset_name + self._shuffle_size = shuffle_size + self._prefetch_size = prefetch_size + self._sequence_length = sequence_length + self._preprocess_spec = preprocess_spec or PreprocessSpec() + self._normalize = normalize + self._use_fixed_ds_stats = use_fixed_ds_stats + self._pretrained_embed = pretrained_embed + + # Load pre-trained embedding params + if pretrained_embed: + self._params_embed = resnet.load_params() + self._resnet_embed = jax.jit(resnet.embed) + + def get_dataset(self, + set_name: str, + batch_size: int, + dataset_name: Optional[str] = None, + dataset_size: Optional[int] = None, + key: Optional[jax.random.KeyArray] = None): + """Create numpy iterator of dataset specified by dataset_name. + + Args: + set_name: Dataset subset to load. + batch_size: Batch size of returned data. + dataset_name: Name of dataset to load. + dataset_size: Number of datasets to return each iteration. + key: An optional key to use for jax-based datasets. + + Returns: + Numpy iterator of data. + """ + if dataset_name is None: + dataset_name = self._dataset_name + + if dataset_name == 'random': + ds = RandomDataset(key, batch_size, dataset_size, self._sequence_length, + self._preprocess_spec, self._normalize) + return iter(ds) + else: + return self._get_tf_dataset(set_name, batch_size, dataset_name, + dataset_size) + + def _get_tf_dataset(self, + set_name: str, + batch_size: int, + dataset_name: Optional[str] = None, + dataset_size: Optional[int] = None): + """Create numpy iterator of tensorflow dataset. + + Args: + set_name: Dataset subset to load. + batch_size: Batch size of returned data. + dataset_name: Name of dataset to load. + dataset_size: Number of datasets to return each iteration. + + Returns: + Numpy iterator of tensorflow dataset. + """ + ds = tfds.load(dataset_name, split=set_name) + if self._pretrained_embed: + ds = ds.map(self._process) + ds = ds.shuffle(self._shuffle_size) + + def embed(*args): + return tf.py_function(self._embed, list(args), (tf.float32, tf.float32)) + + ds = ds.batch(128).map(embed).unbatch().cache() + else: + ds = ds.map(self._process).cache() + ds = ds.repeat().shuffle(self._shuffle_size) + ds = ds.batch(self._sequence_length).batch(batch_size) + if dataset_size is not None: + ds = ds.batch(dataset_size) + ds = ds.prefetch(self._prefetch_size) + + itr = ds.as_numpy_iterator() + if not self._use_fixed_ds_stats and self._normalize: + std_p = functools.partial( + standardize, + has_dataset_dim=dataset_size is not None, + subsample=self._preprocess_spec.standardize_sub_sample, + ) + itr = map(jax.jit(std_p), itr) + return itr + + def _embed(self, x, y): + x = self._resnet_embed(self._params_embed, x.numpy()) + x = (x - jnp.mean(x)) / jnp.std(x) + x = jax.nn.tanh(x) + x = x.reshape((x.shape[0], -1)) + return x, y + + def _process(self, x): + """Preprocesses vision data. + + Args: + x: Each element of the dataset iterator. + + Returns: + A tuple of image and label after processing. + """ + img = x['image'] + spec = self._preprocess_spec + if spec.resize: + img = tf.image.resize(img, [spec.resize, spec.resize]) + if spec.channel_expand and img.shape[-1] == 1: + img = tf.concat([img] * 3, axis=-1) + if not spec.use_patches: + img = tf.reshape(img, [-1]) + # TODO(lkirsch) keep uint + img = tf.cast(img, tf.float32) / 255 + + if self._use_fixed_ds_stats: + stats = self.DATASET_STATS[self._dataset_name] + img = (img - stats['mean']) / stats['std'] + + label = tf.one_hot(x['label'], self._num_classes) + return img, label diff --git a/learned_optimization/research/data_driven/mnist_projections.py b/learned_optimization/research/data_driven/mnist_projections.py new file mode 100644 index 0000000..869c64b --- /dev/null +++ b/learned_optimization/research/data_driven/mnist_projections.py @@ -0,0 +1,492 @@ +# coding=utf-8 +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Experiment for training on multiple projected datasets jointly.""" + +import functools +from typing import Callable + +import gin +import haiku as hk +import jax +from jax import tree_util +from jax.experimental import pjit +import jax.numpy as jnp +from learned_optimization import checkpoints +from learned_optimization.optimizers import gradient_accumulator +from learned_optimization.optimizers import optax_opts +from learned_optimization.research.data_driven import data +from learned_optimization.research.data_driven import models +from learned_optimization.research.data_driven import summary +from learned_optimization.summary import TensorboardWriter +import numpy as np +import optax + + +@gin.configurable +class ProjectionExperiment: + """Experiment for training on multiple projected datasets jointly. + + Attributes: + model_creator: Function that returns a models.Model. + num_tasks: Number of tasks to generated and train on. + n_batches: Number of training steps. + dataset_name: Dataset to load using tf datasets. + train_batch_size: Batch size for training. + test_batch_size: Batch size for testing. + task_sample_size: Number of tasks to sample eat each training step. + test_task_sample_size: Number of tasks to sample for meta testing. + test_dataset_names: Datasets to meta test on. + seed: Seed to use for the entire experiment. + learning_rate: Learning rate for meta learner. + eval_steps: Number of steps between each evaluation. + """ + + def __init__( + self, + log_dir, + model_creator: Callable[[jnp.ndarray, jnp.ndarray], models.Model], + data_loader_creator=None, + num_tasks=1, + n_batches=50000, + dataset_name='mnist', + train_batch_size=32, + test_batch_size=128, + task_sample_size=16, + test_task_sample_size=32, + test_dataset_names=('mnist', 'fashion_mnist', 'random'), + seed=0, + learning_rate=1e-3, + eval_steps=100, + grad_accum_steps=1, + permute_labels_prob=0.0, + permute_labels_decay=0, + project_prob=1.0, + use_min_xe=False, + grad_max_norm=0, + use_softmax=False, + ): + self._log_dir = log_dir + self.num_tasks = num_tasks + self.n_batches = n_batches + self.train_batch_size = train_batch_size + self.test_batch_size = test_batch_size + self.task_sample_size = task_sample_size + self.test_task_sample_size = test_task_sample_size + self.test_dataset_names = test_dataset_names + self.seed = seed + self.learning_rate = learning_rate + self.eval_steps = eval_steps + self.model_creator = model_creator + self._grad_accum_steps = grad_accum_steps + self._permute_labels_prob = permute_labels_prob + self._permute_labels_decay = permute_labels_decay + self._project_prob = project_prob + self._use_min_xe = use_min_xe + self._grad_max_norm = grad_max_norm + self._use_softmax = use_softmax + + if data_loader_creator is None: + data_loader_creator = data.DataLoader + self._data_loader = data_loader_creator(dataset_name) + self._log_writer = summary.DictSummaryWriter( + TensorboardWriter(log_dir), log_dir) + + def _training_loss(self, params, key, batch, model): + preds = self._prediction(params, key, batch, model, is_training=True) + if self._use_min_xe: + _, labels = batch + loss = self._optimal_scale_softmax_cross_entropy(preds, labels) + return loss + else: + return self._loss(preds, batch) + + def _optimal_scale_softmax_cross_entropy(self, + logits, + labels, + min_log_scale=-4, + max_log_scale=2, + num_scales=50): + """A modificied cross entropy loss that minimizes a scaling factor. + + Args: + logits: Logits to apply loss to. + labels: One-hot target labels. + min_log_scale: Minimum scale exponent to basis 10. + max_log_scale: Maximum scale exponent to basis 10. + num_scales: Number of scales to minimize over. + + Returns: + Cross entropy loss minimizes over scales. + + """ + + scales = np.logspace(min_log_scale, max_log_scale, num_scales) + logits_scaled = jnp.stack([scl * logits for scl in scales]) + labels_scaled = jnp.stack([labels for _ in scales]) + + xents = optax.softmax_cross_entropy( + logits=logits_scaled, labels=labels_scaled) + + per_scl_performance = jnp.mean(xents, axis=list(range(1, len(xents.shape)))) + return jnp.min(per_scl_performance) + + def _loss(self, predictions, batch, axis=None): + _, labels = batch + loss = -jnp.mean(jnp.sum(predictions * labels, axis=-1), axis=axis) + return loss + + def _prediction(self, params, key, batch, model, is_training): + inputs, labels = batch + preds = model(params, key, inputs, labels, is_training=is_training) + return preds + + def _accuracy(self, predictions, batch): + _, labels = batch + target_class = jnp.argmax(labels, axis=-1) + predicted_class = jnp.argmax(predictions, axis=-1) + acc = jnp.mean(predicted_class == target_class, axis=0) + return acc + + def _evaluate(self, params, key, batch, model): + preds = self._prediction(params, key, batch, model, is_training=False) + return dict( + loss=self._loss(preds, batch, axis=0), + accuracy=self._accuracy(preds, batch), + ) + + @gin.configurable('project') + def _project( + self, + inputs: jnp.ndarray, + w_std_scale: float = 1.0, + b_std: float = 0.0, + shrink_factor: int = 1, + ) -> jnp.ndarray: + """Linearily project a given dataset. + + Args: + inputs: Array to project + w_std_scale: Projection weight standard deviation adjusted for input size. + b_std: Projection bias standard deviaton. + shrink_factor: Factor by which to reduce the dimensionality + + Returns: + Projected array. + """ + + batch_dims = inputs.shape[:-1] + out = inputs.reshape((-1, *inputs.shape[-1:])) + + stddev = w_std_scale / np.sqrt(out.shape[-1]) + w_init = hk.initializers.RandomNormal(stddev=stddev) + b_init = hk.initializers.RandomNormal(stddev=b_std) + out = hk.Linear( + out.shape[-1] // shrink_factor, + with_bias=True, + w_init=w_init, + b_init=b_init, + )(out) + if self._use_softmax: + out = jax.nn.softmax(out, axis=-1) * out.shape[-1] + + out = out.reshape(batch_dims + (-1,)) + return out + + def run(self): + """Runs an mnist experiment. + + Returns: + Log dict with metrics of final model. + """ + rng = hk.PRNGSequence(self.seed) + + # Datasets + key_ds = next(rng) + train_set = self._data_loader.get_dataset( + 'train', self.train_batch_size, key=key_ds) + task_set = self._data_loader.get_dataset( + 'train', + self.train_batch_size, + dataset_size=self.task_sample_size, + key=key_ds) + # TODO(lkirsch) Use a different key_ds for testing? + num_test_tasks = min(self.test_task_sample_size, + self.num_tasks) + self.test_task_sample_size + 1 + get_test_dataset = functools.partial( + self._data_loader.get_dataset, + 'test', + self.test_batch_size, + dataset_size=num_test_tasks, + key=key_ds) + test_batches = { + name: next(get_test_dataset(dataset_name=name)) + for name in self.test_dataset_names + } + + # Model and optimiser + dummy_batch = next(train_set) + dummy_inputs, dummy_labels = dummy_batch + + # Random projections + proj = hk.without_apply_rng(hk.transform(self._project)) + + def create_proj_params(keys): + proj_params = jax.vmap(proj.init, (0, None))(keys, dummy_inputs) + return proj_params + + dummy_proj_inputs = proj.apply( + proj.init(next(rng), dummy_inputs), dummy_inputs + ) + + # Model and optimiser + model = self.model_creator(dummy_proj_inputs, dummy_labels) + if self._grad_max_norm > 0: + opt = optax_opts.OptaxOptimizer( + optax.chain( + optax.clip_by_global_norm(max_norm=self._grad_max_norm), + optax.adam(learning_rate=self.learning_rate), + )) + else: + opt = optax_opts.Adam(learning_rate=self.learning_rate) + if self._grad_accum_steps > 1: + opt = gradient_accumulator.GradientAccumulator(opt, + self._grad_accum_steps) + + params = model.create_model(next(rng)) + opt_state = opt.init(params) + + rng_tasks_init = next(rng) + key_model_meta_test = next(rng) + + def create_checkpoint(step): + return dict( + params=params, opt_state=opt_state, rng=rng.internal_state, step=step) + + restored = checkpoints.restore_checkpoint( + self._log_dir, create_checkpoint(step=0), prefix='checkpoint_') + params = restored['params'] + opt_state = restored['opt_state'] + rng.replace_internal_state(restored['rng']) + step_init = restored['step'] + + @functools.partial( + pjit.pjit, + static_argnums=[4], + in_axis_resources=(None, None, None, jax.sharding.PartitionSpec('b')), + out_axis_resources=None, + ) + def update(params, key, opt_state, batch, model): + f_grad = jax.value_and_grad(self._training_loss) + loss, grads = f_grad(params, key, batch, model) + opt_state = opt.update(opt_state, grads, loss=loss) + new_params = opt.get_params(opt_state) + return new_params, opt_state, loss + + @jax.jit + def project_batches(key, batches, step): + inp, labels = batches + key_subset, key_mask = jax.random.split(key) + del key + + # Sample a random subset of tasks + subset = jax.random.randint( + key_subset, (self.task_sample_size,), minval=0, maxval=self.num_tasks) + + if self._project_prob < 1: + mask = jax.random.bernoulli( + key_mask, shape=subset.shape, + p=self._project_prob).astype(subset.dtype) + subset *= mask + + key = jax.vmap(jax.random.fold_in, (None, 0))(rng_tasks_init, subset) # pytype: disable=wrong-arg-types # jax-types + key_inp, key_out, key_mask = jax.vmap( + functools.partial(jax.random.split, num=3), out_axes=1)( + key) + + # Project inputs + proj_params = create_proj_params(key_inp) + projected = jax.vmap(proj.apply)(proj_params, inp) + + # Permute outputs + if self._permute_labels_prob > 0: + decay_steps = self._permute_labels_decay + permute_prob = self._permute_labels_prob + if decay_steps > 0: + permute_prob *= jnp.maximum( + step.astype(jnp.float32) / decay_steps, 1.) + num_classes = labels.shape[-1] + + def draw_permutation(key_perm, key_mask): + permutation = jax.random.permutation(key_perm, num_classes) + permutation = jax.nn.one_hot(permutation, num_classes=num_classes) + if self._permute_labels_prob < 1.0 or decay_steps > 0: + identity = jnp.identity(num_classes) + mask = jax.random.bernoulli(key_mask, p=permute_prob) + return jnp.where(mask, permutation, identity) + return permutation + + permutation = jax.vmap(draw_permutation)(key_out, key_mask) + labels_perm = labels @ permutation[:, None] + else: + labels_perm = labels + + # Reshape + projected = jnp.reshape(projected, (-1,) + projected.shape[2:]) + labels_perm = jnp.reshape(labels_perm, (-1,) + labels_perm.shape[2:]) + return projected, labels_perm + + @functools.partial( + pjit.pjit, + static_argnums=[2], + in_axis_resources=(None, jax.sharding.PartitionSpec(None, 'b')), + out_axis_resources=None, + ) + def meta_test(params, test_batch, permute_labels: bool): + if self.num_tasks > 0: + # Test on within distr and out of distr task + num_within_tasks = min(self.test_task_sample_size, self.num_tasks) + test_tasks = jnp.concatenate([ + jnp.arange(num_within_tasks), # wid + -jnp.arange(1, self.test_task_sample_size + 1) # ood + ]) + v_fold = jax.vmap(jax.random.fold_in, (None, 0)) + key_task = v_fold(rng_tasks_init, test_tasks) + # Generate three keys to make consistent with meta training + key_inp, key_out, _ = jax.vmap( + functools.partial(jax.random.split, num=3), out_axes=1)( + key_task) + inp, labels = test_batch + + # Project inputs + v_proj = jax.vmap(proj.apply) + projected_test_inputs = v_proj(create_proj_params(key_inp), inp[:-1]) + + # Permute outputs + if permute_labels: + num_classes = labels.shape[-1] + v_permute = jax.vmap(jax.random.permutation, (0, None)) + permutation = v_permute(key_out, num_classes) + permutation = jax.nn.one_hot(permutation, num_classes=num_classes) + labels_perm = labels[:-1] @ permutation[:, None] + else: + labels_perm = labels[:-1] + + # Add identities + last_dim = projected_test_inputs.shape[-1] + orig_dim = inp.shape[-1] + if last_dim > orig_dim: + raise ValueError( + 'Projection dimensionality is expected' + 'to be smaller than the original input.' + ) + elif last_dim < orig_dim: + print( + f'Truncating inputs {orig_dim} ' + f'to projection dimensionality {last_dim}.' + ) + id_inp = inp[-1:, ..., :last_dim] + else: + id_inp = inp[-1:] + projected_test_inputs = jnp.concatenate( + [projected_test_inputs, id_inp], axis=0 + ) + labels_perm = jnp.concatenate([labels_perm, labels[-1:]], axis=0) + projected_test_batch = (projected_test_inputs, labels_perm) + + evaluate = functools.partial(self._evaluate, model=model) + v_evaluate = jax.vmap(evaluate, (None, 0, 0)) + # Use a separate model rng key for each test task. + key_model = jax.random.split(key_model_meta_test, + test_tasks.shape[0] + 1) + eval_dict = v_evaluate(params, key_model, projected_test_batch) + else: + eval_dict = self._evaluate(params, key_model_meta_test, test_batch, + model) + + log_dict = dict() + for metric_name, metric_value in eval_dict.items(): + if self.num_tasks > 0: + test_value, meta_test_value, id_test_value = tree_util.tree_map( + functools.partial(jnp.mean, axis=0), + jnp.split(metric_value, [num_within_tasks, test_tasks.shape[0]])) + else: + test_value = meta_test_value = id_test_value = metric_value + log_dict.update({ + # Report on seen task (test), unseen task (meta_test), orig task + f'test_{metric_name}': test_value[-1], + f'meta_test_{metric_name}': meta_test_value[-1], + f'id_test_{metric_name}': id_test_value[-1], + # Report complete meta-test trajectory as well + f'test_{metric_name}_hist': test_value, + f'meta_test_{metric_name}_hist': meta_test_value, + f'id_test_{metric_name}_hist': id_test_value, + }) + return log_dict + + @jax.jit + def get_params_metrics(params): + log = dict() + for mod, name, value in hk.data_structures.traverse(params): + if isinstance(value, jnp.ndarray): + log[f'{mod}/{name}/norm'] = jnp.linalg.norm(value) + return log + + def evaluate_datasets(params): + results = {} + for permuted, plabel in zip((False, True), ('', '_permuted')): + for name in self.test_dataset_names: + test_batch = test_batches[name] + dataset_results = meta_test(params, test_batch, permuted) + results.update( + {f'{k}_{name}{plabel}': v for k, v in dataset_results.items()}) + results.update(get_params_metrics(params)) + return results + + mesh = jax.sharding.Mesh(np.asarray(jax.devices()), ('b')) + rank = jax.process_index() + loss = None + for step in range(step_init, self.n_batches): + if rank == 0: + checkpoints.periodically_save_checkpoint( + self._log_dir, step, dict(checkpoint_=create_checkpoint(step))) + if step % self.eval_steps == 0: + with jax.sharding.Mesh(mesh.devices, mesh.axis_names): + self._log_writer.dict( + { + **evaluate_datasets(params), 'training_loss': loss + }, step) + batches = next(task_set) + if self.num_tasks > 0: + projected_batches = project_batches(next(rng), batches, step) + else: + projected_batches = tree_util.tree_map( + lambda x: x.reshape((-1,) + x.shape[2:]), batches) + + with jax.sharding.Mesh(mesh.devices, mesh.axis_names): + params, opt_state, loss = update(params, next(rng), opt_state, + projected_batches, model) + + if rank == 0: + checkpoints.save_checkpoint(self._log_dir, 'checkpoint_', + create_checkpoint(self.n_batches), + self.n_batches) + + with jax.sharding.Mesh(mesh.devices, mesh.axis_names): + results = {**evaluate_datasets(params), 'training_loss': loss} + self._log_writer.dict(results, self.n_batches) + self._log_writer.flush() + + return results diff --git a/learned_optimization/research/data_driven/mnist_projections_test.py b/learned_optimization/research/data_driven/mnist_projections_test.py new file mode 100644 index 0000000..f99bbe0 --- /dev/null +++ b/learned_optimization/research/data_driven/mnist_projections_test.py @@ -0,0 +1,308 @@ +# coding=utf-8 +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for mnist_projections.""" + +import functools +import tempfile + +from absl.testing import absltest +from learned_optimization.research.data_driven import data +from learned_optimization.research.data_driven import mnist_projections +from learned_optimization.research.data_driven import model_components +from learned_optimization.research.data_driven import models + + +class MnistProjectionsTest(absltest.TestCase): + + def _get_data_loader(self, *args, **kwargs): + return functools.partial( + data.DataLoader, *args, **kwargs, sequence_length=2) + + def test_lstm(self): + """Smoke test for mnist projection experiment with LSTMs.""" + with tempfile.TemporaryDirectory() as tmpdir: + model_creator = functools.partial(models.LSTM, hidden_size=8) + experiment = mnist_projections.ProjectionExperiment( + tmpdir, + model_creator, + self._get_data_loader(), + n_batches=1, + train_batch_size=2, + task_sample_size=1, + test_batch_size=2) + experiment.run() + + def test_outer_lstm(self): + """Smoke test for experiment with outer product LSTMs.""" + with tempfile.TemporaryDirectory() as tmpdir: + model_creator = functools.partial( + models.LSTM, hidden_size=8, lstm_creator=model_components.OuterLSTM) + experiment = mnist_projections.ProjectionExperiment( + tmpdir, + model_creator, + self._get_data_loader(), + n_batches=1, + train_batch_size=2, + task_sample_size=1, + test_batch_size=2) + experiment.run() + + def test_mlp(self): + """Smoke test for mnist projection experiment with MLPs.""" + with tempfile.TemporaryDirectory() as tmpdir: + model_creator = functools.partial(models.MLP, hidden_size=8) + experiment = mnist_projections.ProjectionExperiment( + tmpdir, + model_creator, + self._get_data_loader(), + n_batches=1, + train_batch_size=2, + task_sample_size=1, + test_batch_size=2) + experiment.run() + + def test_no_proj(self): + """Smoke test for a no projection experiment with MLPs.""" + with tempfile.TemporaryDirectory() as tmpdir: + model_creator = functools.partial(models.MLP, hidden_size=8) + experiment = mnist_projections.ProjectionExperiment( + tmpdir, + model_creator, + self._get_data_loader(), + n_batches=1, + train_batch_size=2, + task_sample_size=1, + num_tasks=0, + test_batch_size=2) + experiment.run() + + def test_cifar10(self): + """Smoke test for cifar10 projection experiment.""" + with tempfile.TemporaryDirectory() as tmpdir: + model_creator = functools.partial(models.MLP, hidden_size=8) + pspec = data.PreprocessSpec(resize=32, channel_expand=True) + experiment = mnist_projections.ProjectionExperiment( + tmpdir, + model_creator, + self._get_data_loader(preprocess_spec=pspec), + dataset_name='cifar10', + test_dataset_names=('cifar10',), + n_batches=1, + train_batch_size=2, + task_sample_size=1, + test_batch_size=2) + experiment.run() + + def test_vision_transformer(self): + """Smoke test for vision transformers.""" + with tempfile.TemporaryDirectory() as tmpdir: + pspec = data.PreprocessSpec( + resize=32, channel_expand=True, use_patches=True) + data_creator = functools.partial( + data.DataLoader, sequence_length=1, preprocess_spec=pspec) + + experiment = mnist_projections.ProjectionExperiment( + tmpdir, + models.VisionTransformer, + data_creator, + dataset_name='cifar10', + test_dataset_names=('cifar10', 'mnist', 'fashion_mnist'), + n_batches=1, + train_batch_size=2, + task_sample_size=1, + test_batch_size=2) + experiment.run() + + def test_transformer(self): + """Smoke test for mnist projection experiment with Transformers.""" + with tempfile.TemporaryDirectory() as tmpdir: + experiment = mnist_projections.ProjectionExperiment( + tmpdir, + models.Transformer, + self._get_data_loader(), + n_batches=1, + train_batch_size=2, + task_sample_size=1, + test_batch_size=2) + experiment.run() + + def test_datasets(self): + """Smoke test for dataset loaders.""" + with tempfile.TemporaryDirectory() as tmpdir: + pspec = data.PreprocessSpec(channel_expand=True) + experiment = mnist_projections.ProjectionExperiment( + tmpdir, + models.Transformer, + self._get_data_loader(preprocess_spec=pspec), + n_batches=1, + test_dataset_names=('mnist', 'fashion_mnist', 'random', 'kmnist', + 'cifar10', 'svhn_cropped'), + train_batch_size=2, + task_sample_size=1, + test_batch_size=2) + experiment.run() + + def test_random_dataset(self): + """Smoke test for mnist projection experiment with Transformers.""" + with tempfile.TemporaryDirectory() as tmpdir: + experiment = mnist_projections.ProjectionExperiment( + tmpdir, + models.Transformer, + self._get_data_loader(), + dataset_name='mnist', + test_dataset_names=('random', 'mnist'), + n_batches=1, + train_batch_size=2, + task_sample_size=1, + test_batch_size=2) + experiment.run() + + def test_transformer_xl(self): + """Smoke test for mnist projection experiment with XLTransformers.""" + with tempfile.TemporaryDirectory() as tmpdir: + model_creator = functools.partial( + models.Transformer, transformer_type='dm_xl') + experiment = mnist_projections.ProjectionExperiment( + tmpdir, + model_creator, + self._get_data_loader(), + n_batches=1, + train_batch_size=2, + task_sample_size=1, + test_batch_size=2) + experiment.run() + + def test_permuted_labels(self): + """Smoke test for mnist projection experiment with permuted labels.""" + with tempfile.TemporaryDirectory() as tmpdir: + experiment = mnist_projections.ProjectionExperiment( + tmpdir, + models.Transformer, + self._get_data_loader(), + n_batches=1, + num_tasks=2, + train_batch_size=2, + task_sample_size=1, + test_batch_size=2, + permute_labels_prob=1.0) + experiment.run() + + def test_permuted_labels_decay(self): + """Smoke test for experiment with decayed permuted labels.""" + with tempfile.TemporaryDirectory() as tmpdir: + experiment = mnist_projections.ProjectionExperiment( + tmpdir, + models.Transformer, + self._get_data_loader(), + n_batches=1, + num_tasks=2, + train_batch_size=2, + task_sample_size=1, + test_batch_size=2, + permute_labels_prob=1.0, + permute_labels_decay=1000) + experiment.run() + + def test_vsml(self): + """Smoke test for mnist projection experiment with VSML.""" + with tempfile.TemporaryDirectory() as tmpdir: + experiment = mnist_projections.ProjectionExperiment( + tmpdir, + models.VSML, + self._get_data_loader(), + n_batches=1, + train_batch_size=2, + task_sample_size=1, + test_batch_size=2) + experiment.run() + + def test_no_sym_vsml(self): + """Smoke test for VSML without symmetries.""" + with tempfile.TemporaryDirectory() as tmpdir: + experiment = mnist_projections.ProjectionExperiment( + tmpdir, + models.NoSymVSML, + self._get_data_loader(), + n_batches=1, + train_batch_size=2, + task_sample_size=1, + test_batch_size=2) + experiment.run() + + def test_fw_memory(self): + """Smoke test for a fast weight memory model.""" + with tempfile.TemporaryDirectory() as tmpdir: + experiment = mnist_projections.ProjectionExperiment( + tmpdir, + models.FWMemory, + self._get_data_loader(), + n_batches=1, + train_batch_size=2, + task_sample_size=1, + test_batch_size=2) + experiment.run() + + def test_sgd(self): + """Smoke test for the SGD baseline.""" + with tempfile.TemporaryDirectory() as tmpdir: + experiment = mnist_projections.ProjectionExperiment( + tmpdir, + models.SGD, + self._get_data_loader(), + n_batches=1, + train_batch_size=2, + task_sample_size=1, + test_batch_size=2) + experiment.run() + + def test_maml(self): + """Smoke test for the MAML baseline.""" + with tempfile.TemporaryDirectory() as tmpdir: + experiment = mnist_projections.ProjectionExperiment( + tmpdir, + functools.partial(models.SGD, use_maml=True), + self._get_data_loader(), + n_batches=1, + train_batch_size=2, + task_sample_size=1, + test_batch_size=2) + experiment.run() + + def test_pretrained(self): + """Smoke test for mnist projection experiment with pretrained embeddings.""" + with tempfile.TemporaryDirectory() as tmpdir: + pspec = data.PreprocessSpec( + resize=224, channel_expand=True, use_patches=True + ) + experiment = mnist_projections.ProjectionExperiment( + tmpdir, + models.Transformer, + self._get_data_loader( + pretrained_embed=True, + use_fixed_ds_stats=True, + preprocess_spec=pspec, + ), + n_batches=1, + train_batch_size=2, + task_sample_size=1, + test_batch_size=2, + test_dataset_names=('mnist', 'fashion_mnist'), + ) + experiment.run() + + +if __name__ == '__main__': + absltest.main() diff --git a/learned_optimization/research/data_driven/model_components.py b/learned_optimization/research/data_driven/model_components.py new file mode 100644 index 0000000..fa5cdb4 --- /dev/null +++ b/learned_optimization/research/data_driven/model_components.py @@ -0,0 +1,90 @@ +# coding=utf-8 +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Model components to be used by models.py.""" + +from typing import Optional, Tuple + +import gin +import haiku as hk +import jax +from jax.ad_checkpoint import checkpoint_name +import jax.numpy as jnp +import numpy as np + + +@gin.configurable() +class OuterLSTM(hk.RNNCore): + """An outer-product based LSTM.""" + + def __init__(self, + hidden_size: int, + num_heads: int = 1, + name: Optional[str] = None): + super().__init__(name=name) + self.hidden_size = hidden_size + self.num_heads = num_heads + + def __call__( + self, + inputs: jnp.ndarray, + prev_state: hk.LSTMState, + ) -> Tuple[jnp.ndarray, hk.LSTMState]: + if len(inputs.shape) != 2 or not inputs.shape: + raise ValueError('OuterLSTM input must be rank-2.') + batch_size = inputs.shape[0] + size = self.hidden_size + x_and_h = jnp.concatenate([inputs, prev_state.hidden], axis=-1) + + gated = hk.Linear(8 * size * self.num_heads)(x_and_h) + gated = gated.reshape((batch_size, self.num_heads, 8 * size)) + gated = checkpoint_name(gated, 'gated') + + # i = input, g = cell_gate, f = forget_gate, q = query, o = output_gate + sizes = (3 * size, 3 * size, size, size) + indices = np.cumsum(sizes[:-1]) + k1, k2, q, o = jnp.split(gated, indices, axis=-1) + scale = jax.nn.softplus( + hk.get_parameter('key_scale', shape=(), dtype=k1.dtype, init=jnp.zeros)) + i, g, f = jnp.einsum('bhki,bhkj->kbhij', + jax.nn.tanh(split_axis(k1, (3, size))) * scale, + jax.nn.tanh(split_axis(k2, (3, size)))) + f = jax.nn.sigmoid(f + 1) # Forget bias, as in sonnet. + c = f * prev_state.cell + jax.nn.sigmoid(i) * g + read = jnp.einsum('bhij,bhi->bhj', c, q) + h = hk.Flatten()(jax.nn.sigmoid(o) * jnp.tanh(read)) + h = checkpoint_name(h, 'hidden') + c = checkpoint_name(c, 'context') + + return h, hk.LSTMState(h, c) + + def initial_state(self, batch_size: Optional[int]) -> hk.LSTMState: + state = hk.LSTMState( + hidden=jnp.zeros([self.num_heads * self.hidden_size]), + cell=jnp.zeros([self.num_heads, self.hidden_size, self.hidden_size])) + if batch_size is not None: + state = add_batch(state, batch_size) + return state + + +def add_batch(nest, batch_size: Optional[int]): + """Adds a batch dimension at axis 0 to the leaves of a nested structure.""" + broadcast = lambda x: jnp.broadcast_to(x, (batch_size,) + x.shape) + return jax.tree_util.tree_map(broadcast, nest) + + +def split_axis(x: jnp.ndarray, shape=Tuple[int], axis=-1): + new_shape = x.shape[:axis] + shape + x.shape[axis:][1:] + return x.reshape(new_shape) diff --git a/learned_optimization/research/data_driven/models.py b/learned_optimization/research/data_driven/models.py new file mode 100644 index 0000000..9aa2031 --- /dev/null +++ b/learned_optimization/research/data_driven/models.py @@ -0,0 +1,748 @@ +# coding=utf-8 +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Defines various neural models for meta learning with sequential models.""" + +import abc +from typing import Optional, List, NamedTuple + +import chex +import gin +import haiku as hk +import jax +import jax.numpy as jnp +from learned_optimization.research.data_driven import model_components # pylint: disable=unused-import +from learned_optimization.research.data_driven import transformer +import optax +from vision_transformer.vit_jax import models + + +class Model(abc.ABC): + """Base class for all data_driven (sequence) models.""" + + @abc.abstractmethod + def create_model(self, key: chex.PRNGKey) -> chex.ArrayTree: + pass + + @abc.abstractmethod + def __call__(self, params: chex.ArrayTree, key: chex.PRNGKey, + inputs: jnp.ndarray, labels: jnp.ndarray, + is_training: bool) -> jnp.ndarray: + pass + + +@gin.configurable() +class LSTM(Model): + """LSTM model.""" + + def __init__(self, + dummy_inputs, + dummy_labels, + hidden_size: int, + lstm_creator=None): + self._hidden_size = hidden_size + self._dummy_inputs = dummy_inputs + self._dummy_labels = dummy_labels + self._lstm_creator = lstm_creator or hk.LSTM + + def create_model(self, key: chex.PRNGKey) -> chex.ArrayTree: + transformed = hk.transform(self.hk_forward) + params = transformed.init(key, self._dummy_inputs, self._dummy_labels) + return params + + def __call__(self, params: chex.ArrayTree, key: chex.PRNGKey, + inputs: jnp.ndarray, labels: jnp.ndarray, + is_training: bool) -> jnp.ndarray: + transformed = hk.transform(self.hk_forward) + return transformed.apply(params, key, inputs, labels) + + def hk_forward(self, inputs: jnp.ndarray, labels: jnp.ndarray) -> jnp.ndarray: + """Runs (sequential) model on inputs. + + Args: + inputs: Inputs of size [batch x sequence_len x feature_size] + labels: One-hot labels of size [batch x sequence_len x num_classes] + + Returns: + Prediction of size [batch x sequence_len x num_classes] + """ + + # Shift labels by one and zero out last one. + # Useful for maximizing log-likelihood at every step. + # Alternatively could just optimize for last prediction. + labels: jnp.ndarray = jnp.roll(labels, shift=1, axis=1) + labels = labels.at[:, 0].multiply(0) + + x = jnp.concatenate([inputs, labels], axis=-1) + + mlp = hk.Sequential([ + hk.Linear(self._hidden_size), + jax.nn.relu, + hk.Linear(self._hidden_size), + jax.nn.relu, + ]) + out_transform = hk.Sequential([hk.Linear(10), jax.nn.log_softmax]) + lstm = self._lstm_creator(self._hidden_size) + + initial_state = lstm.initial_state(x.shape[0]) + hidden = mlp(x) + output, _ = hk.dynamic_unroll(lstm, hidden, initial_state, time_major=False) + return out_transform(output) + + +@gin.configurable() +class MLP(Model): + """MLP model that processes a single example. Not capable of meta learning.""" + + def __init__(self, dummy_inputs, dummy_labels, hidden_size: int): + self._hidden_size = hidden_size + self._dummy_inputs = dummy_inputs + del dummy_labels + + def create_model(self, key: chex.PRNGKey) -> chex.ArrayTree: + # TODO(lkirsch) extract to super class? + transformed = hk.transform(self.hk_forward) + params = transformed.init(key, self._dummy_inputs) + return params + + def __call__(self, params: chex.ArrayTree, key: chex.PRNGKey, + inputs: jnp.ndarray, labels: jnp.ndarray, + is_training: bool) -> jnp.ndarray: + # TODO(lkirsch) extract to super class? + transformed = hk.transform(self.hk_forward) + return transformed.apply(params, key, inputs) + + def hk_forward(self, inputs: jnp.ndarray) -> jnp.ndarray: + """Runs MLP model on inputs. + + Args: + inputs: Inputs of size [batch x sequence_len x feature_size] + + Returns: + Prediction of size [batch x sequence_len x num_classes] + """ + + mlp = hk.Sequential([ + hk.Linear(self._hidden_size), + jax.nn.relu, + hk.Linear(self._hidden_size), + jax.nn.relu, + hk.Linear(10), + jax.nn.log_softmax, + ]) + return mlp(inputs) + + +@gin.configurable() +class Transformer(Model): + """Transformer model.""" + + # TODO(lkirsch) Could also create a non-causally masked version if + # only prediction on test element is evaluated. + def __init__(self, + dummy_inputs, + dummy_labels, + num_heads: int = 8, + num_layers: int = 4, + dropout_rate: float = 0.0, + model_size: int = 512, + key_size: int = 32, + ff_widening_factor: float = 4., + pos_embed_std: float = 0.1, + transformer_type: str = 'haiku'): + self._dummy_inputs = dummy_inputs + self._dummy_labels = dummy_labels + self._num_heads = num_heads + self._num_layers = num_layers + self._dropout_rate = dropout_rate + self._model_size = model_size + self._key_size = key_size + self._ff_widening_factor = ff_widening_factor + self._pos_embed_std = pos_embed_std + self._transformer_type = transformer_type + + def create_model(self, key: chex.PRNGKey) -> chex.ArrayTree: + transformed = hk.transform(self.hk_forward) + params = transformed.init(key, self._dummy_inputs, self._dummy_labels, True) + return params + + def __call__(self, params: chex.ArrayTree, key: chex.PRNGKey, + inputs: jnp.ndarray, labels: jnp.ndarray, + is_training: bool) -> jnp.ndarray: + transformed = hk.transform(self.hk_forward) + return transformed.apply(params, key, inputs, labels, is_training) + + def hk_forward(self, inputs: jnp.ndarray, labels: jnp.ndarray, + is_training: bool) -> jnp.ndarray: + """Runs (sequence) model on inputs. + + Args: + inputs: Inputs of size [batch x sequence_len x feature_size] + labels: One-hot labels of size [batch x sequence_len x num_classes] + is_training: Boolean for toggling dropout. + + Returns: + Prediction of size [batch x sequence_len x num_classes] + """ + + if inputs.ndim == 5: + shape = inputs.shape + shift = shape[2] * shape[3] + inputs = inputs.reshape((shape[0], shape[1] * shift, shape[4])) + labels = jnp.repeat(labels, repeats=shift, axis=1) + else: + shift = 1 + + # Shift labels by one and zero out last one. + # Useful for maximizing log-likelihood at every step. + # Alternatively could just optimize for last prediction. + labels: jnp.ndarray = jnp.roll(labels, shift=shift, axis=1) + labels = labels.at[:, :shift].multiply(0) + + # Project down to model size + x = jnp.concatenate([inputs, labels], axis=-1) + embedding = hk.Linear(self._model_size)(x) + + if self._transformer_type != 'dm_xl': + # Positional embeddings + # TODO(lkirsch) There is probably a better way to create pos embeddings. + # We want invariance to order apart from last input. + seq_length = inputs.shape[1] + embed_init = hk.initializers.TruncatedNormal(stddev=self._pos_embed_std) + positional_embeddings = hk.get_parameter( + 'pos_embs', [seq_length, self._model_size], init=embed_init) + embedding += positional_embeddings + + if self._transformer_type == 'haiku': + model = transformer.Transformer( + num_heads=self._num_heads, + num_layers=self._num_layers, + key_size=self._key_size, + widening_factor=self._ff_widening_factor, + dropout_rate=self._dropout_rate) + else: + raise ValueError(f'Invalid transformer type {self._transformer_type}') + + hidden = model(h=embedding, mask=None, is_training=is_training) + num_classes = labels.shape[-1] + prediction = jax.nn.log_softmax(hk.Linear(num_classes)(hidden)) + if shift > 1: + return prediction[:, shift - 1::shift] + return prediction + + +@gin.configurable() +class VisionTransformer(Model): + """VisionTransformer model.""" + + def __init__(self, + dummy_inputs, + dummy_labels, + name: str = 'ViT-B_16', + **kwargs): + self._dummy_inputs = dummy_inputs + self._dummy_labels = dummy_labels + num_classes = dummy_labels.shape[-1] + self._model = models.get_model(name, num_classes=num_classes, **kwargs) + + def create_model(self, key: chex.PRNGKey) -> chex.ArrayTree: + inp = self._dummy_inputs.squeeze(1) + params = self._model.init(key, inp, train=True) + return params + + def __call__(self, params: chex.ArrayTree, key: chex.PRNGKey, + inputs: jnp.ndarray, labels: jnp.ndarray, + is_training: bool) -> jnp.ndarray: + """Runs (sequence) model on inputs. + + Args: + params: Parameters of model from create_model(). + key: Jax random key. + inputs: Inputs of size [batch x sequence_len x feature_size] + labels: One-hot labels of size [batch x sequence_len x num_classes] + is_training: Boolean for toggling dropout. + + Returns: + Prediction of size [batch x sequence_len x num_classes] + """ + x = inputs.squeeze(1) + out = self._model.apply(params, x, train=is_training) + prediction = jax.nn.log_softmax(out)[:, None] + return prediction + + +class LayerState(NamedTuple): + lstm_state: hk.LSTMState = None + fwd_msg: jnp.ndarray = None + bwd_msg: jnp.ndarray = None + + +@gin.configurable() +class VSMLLayer(hk.Module): + """A recurrent VSML layer with self-messaging.""" + + def __init__(self, + input_size: int, + output_size: int, + msg_size: int = 8, + hidden_size: int = 16, + micro_ticks: int = 2, + self_msg: bool = True): + super().__init__() + self.input_size = input_size + self.output_size = output_size + self.micro_ticks = micro_ticks + self.self_msg = self_msg + self._lstm = hk.LSTM(hidden_size) + self._fwd_messenger = hk.Linear(msg_size) + self._bwd_messenger = hk.Linear(msg_size) + self._tick = hk.vmap( + hk.vmap(self._tick, (0, None, 0, None), split_rng=False), + (0, 0, None, None), + split_rng=False) + + def _tick(self, lstm_state: hk.LSTMState, fwd_msg: jnp.ndarray, + bwd_msg: jnp.ndarray, aux: Optional[jnp.ndarray]): + if aux is not None: + inp = jnp.concatenate([fwd_msg, bwd_msg, aux]) + else: + inp = jnp.concatenate([fwd_msg, bwd_msg]) + out, lstm_state = self._lstm(inp, lstm_state) + return out, lstm_state + + def create_state(self) -> LayerState: + lstm_state_shape = (2, self.input_size, self.output_size, + self._lstm.hidden_size) + lstm_state = jnp.zeros(lstm_state_shape) + lstm_state = hk.LSTMState(hidden=lstm_state[0], cell=lstm_state[1]) + + fwd_msg_shape = (self.output_size, self._fwd_messenger.output_size) + fwd_msg = jnp.zeros(fwd_msg_shape) + + bwd_msg_shape = (self.input_size, self._bwd_messenger.output_size) + bwd_msg = jnp.zeros(bwd_msg_shape) + + return LayerState(lstm_state, fwd_msg, bwd_msg) + + def __call__(self, + state: LayerState, + fwd_msg: jnp.ndarray, + bwd_msg: jnp.ndarray, + aux: Optional[jnp.ndarray] = None): + + if self.self_msg: + lstm_state, self_fwd_msg, self_bwd_msg = state + fwd_msg = jnp.concatenate([self_bwd_msg, fwd_msg], axis=-1) + bwd_msg = jnp.concatenate([self_fwd_msg, bwd_msg], axis=-1) + else: + lstm_state, self_fwd_msg, self_bwd_msg = state + if fwd_msg.shape != state.bwd_msg.shape: + diff = state.bwd_msg.shape[-1] - fwd_msg.shape[-1] + fwd_msg = jnp.pad(fwd_msg, ((0, 0), (0, diff))) + if bwd_msg.shape != state.fwd_msg.shape: + diff = state.fwd_msg.shape[-1] - bwd_msg.shape[-1] + bwd_msg = jnp.pad(bwd_msg, ((0, 0), (0, diff))) + fwd_msg = jnp.concatenate([self_bwd_msg, fwd_msg], axis=-1) + bwd_msg = jnp.concatenate([self_fwd_msg, bwd_msg], axis=-1) + + # Update state + for _ in range(self.micro_ticks): + out, lstm_state = self._tick(lstm_state, fwd_msg, bwd_msg, aux) + + # Update forward messages + out_fwd_msg = self._fwd_messenger(out).mean(axis=0) + # Update backward messages + out_bwd_msg = self._bwd_messenger(out).mean(axis=1) + + return out_fwd_msg, LayerState(lstm_state, out_fwd_msg, out_bwd_msg) + + +class BiSequential(hk.Module): + """Runs the given VSMLLayers first forward, then optionally backward. + + Attributes: + layers: The list of VSMLLayers. + """ + + def __init__(self, + layers: List[VSMLLayer], + backward: bool = False, + name: Optional[str] = None): + super().__init__(name) + self.layers = layers + self._backward = backward + + def create_state(self) -> List[LayerState]: + return [layer.create_state() for layer in self.layers] + + def __call__(self, + states: List[LayerState], + inp: jnp.ndarray, + inp_end: jnp.ndarray, + aux: Optional[jnp.ndarray] = None): + if len(states) != len(self.layers): + raise ValueError('Number of states must equal number of layers') + start_s = LayerState(fwd_msg=inp) + if self._backward: + # Do not include input at end until backward pass + end_s = LayerState(bwd_msg=jnp.zeros_like(inp_end)) + else: + end_s = LayerState(bwd_msg=inp_end) + + new_states = [start_s] + states + [end_s] + for i in range(len(states)): + prev_s, state, next_s = new_states[i:i + 3] + layer = self.layers[i] + out, new_states[i + 1] = layer(state, prev_s.fwd_msg, next_s.bwd_msg, aux) + + if self._backward: + new_states[-1] = LayerState(bwd_msg=inp_end) + for i in reversed(range(len(states))): + prev_s, state, next_s = new_states[i:i + 3] + layer = self.layers[i] + _, new_states[i + 1] = layer(state, prev_s.fwd_msg, next_s.bwd_msg, aux) + + return out, new_states[1:-1] + + +@gin.configurable() +class NoSymVSML(Model): + """VSML model without symmetries.""" + + def __init__(self, dummy_inputs, dummy_labels, size: int = 8): + self._dummy_inputs = dummy_inputs + self._dummy_labels = dummy_labels + self._size = size + + def create_model(self, key: chex.PRNGKey) -> chex.ArrayTree: + transformed = hk.transform(self.hk_forward) + params = transformed.init(key, self._dummy_inputs, self._dummy_labels) + return params + + def __call__(self, params: chex.ArrayTree, key: chex.PRNGKey, + inputs: jnp.ndarray, labels: jnp.ndarray, + is_training: bool) -> jnp.ndarray: + transformed = hk.transform(self.hk_forward) + return transformed.apply(params, key, inputs, labels) + + def hk_forward(self, inputs: jnp.ndarray, labels: jnp.ndarray) -> jnp.ndarray: + """Runs (sequential) model on inputs. + + Args: + inputs: Inputs of size [batch x sequence_len x feature_size] + labels: One-hot labels of size [batch x sequence_len x num_classes] + + Returns: + Prediction of size [batch x sequence_len x num_classes] + """ + + # Shift labels by one and zero out last one. + # Useful for maximizing log-likelihood at every step. + # Alternatively could just optimize for last prediction. + labels: jnp.ndarray = jnp.roll(labels, shift=1, axis=1) + labels = labels.at[:, 0].multiply(0) + + batch_size, _, num_classes = labels.shape + + # Put sequence axis first + inputs = jnp.transpose(inputs, (1, 0, 2)) + labels = jnp.transpose(labels, (1, 0, 2)) + + feature_layer = hk.Sequential([ + hk.Linear(self._size * 16, name='feature_layer'), + hk.Reshape([self._size, 16]) + ]) + label_layer = hk.Sequential([ + hk.Linear(self._size * 4, name='label_layer'), + hk.Reshape([self._size, 4]) + ]) + layer = VSMLLayer(self._size, self._size) + out_layer = hk.Sequential( + [hk.Flatten(), hk.Linear(num_classes, name='out_layer')]) + + v_layer = hk.vmap(layer, split_rng=False) + state = layer.create_state() + state = jax.tree_util.tree_map(lambda x: jnp.stack([x] * batch_size), state) + + def scan_tick(state, x): + inp, label = x + h_label = label_layer(label) + h_inp = feature_layer(inp) + h_out, new_state = v_layer(state, h_inp, h_label) + out = out_layer(h_out) + return new_state, out + + _, out = hk.scan(scan_tick, state, (inputs, labels)) + + # Put sequence axis second + out = jnp.transpose(out, (1, 0, 2)) + logits = jax.nn.log_softmax(out) + + return logits + + +@gin.configurable() +class VSML(Model): + """VSML model.""" + + def __init__(self, + dummy_inputs, + dummy_labels, + fast_hidden_size: int = 0, + num_layers: int = 1): + self._dummy_inputs = dummy_inputs + self._dummy_labels = dummy_labels + self._fast_hidden_size = fast_hidden_size + self._num_layers = num_layers + + def create_model(self, key: chex.PRNGKey) -> chex.ArrayTree: + transformed = hk.transform(self.hk_forward) + params = transformed.init(key, self._dummy_inputs, self._dummy_labels) + return params + + def __call__(self, params: chex.ArrayTree, key: chex.PRNGKey, + inputs: jnp.ndarray, labels: jnp.ndarray, + is_training: bool) -> jnp.ndarray: + transformed = hk.transform(self.hk_forward) + return transformed.apply(params, key, inputs, labels) + + def hk_forward(self, inputs: jnp.ndarray, labels: jnp.ndarray) -> jnp.ndarray: + """Runs (sequential) model on inputs. + + Args: + inputs: Inputs of size [batch x sequence_len x feature_size] + labels: One-hot labels of size [batch x sequence_len x num_classes] + + Returns: + Prediction of size [batch x sequence_len x num_classes] + """ + + # Shift labels by one and zero out last one. + # Useful for maximizing log-likelihood at every step. + # Alternatively could just optimize for last prediction. + labels: jnp.ndarray = jnp.roll(labels, shift=1, axis=1) + labels = labels.at[:, 0].multiply(0) + + batch_size, _, feature_size = inputs.shape + num_classes = labels.shape[-1] + + # Put sequence axis first + inputs = jnp.transpose(inputs, (1, 0, 2)) + labels = jnp.transpose(labels, (1, 0, 2)) + + if self._fast_hidden_size > 0: + sizes = ([feature_size] + [self._fast_hidden_size] * + (self._num_layers - 1) + [num_classes]) + layers = BiSequential([ + VSMLLayer(in_size, out_size, self_msg=False) + for in_size, out_size in zip(sizes[:-1], sizes[1:]) + ]) + else: + layers = VSMLLayer(feature_size, num_classes) + v_layers = hk.vmap(layers, split_rng=False) + state = layers.create_state() + state = jax.tree_util.tree_map(lambda x: jnp.stack([x] * batch_size), state) + + def scan_tick(state, x): + inp, label = x + out_fwd_msg, new_state = v_layers(state, inp[:, :, None], label[:, :, + None]) + # Read out (unnormalized) logits + out = out_fwd_msg[:, :, 0] + return new_state, out + + _, out = hk.scan(scan_tick, state, (inputs, labels)) + + # Put sequence axis second + out = jnp.transpose(out, (1, 0, 2)) + logits = jax.nn.log_softmax(out) + + return logits + + +@gin.configurable() +class SGD(Model): + """SGD model.""" + + def __init__(self, + dummy_inputs, + dummy_labels, + num_layers: int = 2, + hidden_size: int = 128, + optimizer: str = 'adam', + learning_rate: float = 1e-3, + use_maml: bool = False): + self._dummy_inputs = dummy_inputs + self._dummy_labels = dummy_labels + self._num_layers = num_layers + self._hidden_size = hidden_size + self._use_maml = use_maml + + self._grad_func = jax.grad(self._loss, has_aux=True) + self._network = hk.without_apply_rng(hk.transform(self._network)) + self._opt = getattr(optax, optimizer)(learning_rate) + + def create_model(self, key: chex.PRNGKey) -> chex.ArrayTree: + transformed = hk.transform( + hk.vmap(self.hk_forward, split_rng=not self._use_maml)) + params = transformed.init(key, self._dummy_inputs, self._dummy_labels) + return params + + def __call__(self, params: chex.ArrayTree, key: chex.PRNGKey, + inputs: jnp.ndarray, labels: jnp.ndarray, + is_training: bool) -> jnp.ndarray: + transformed = hk.transform( + hk.vmap(self.hk_forward, split_rng=not self._use_maml)) + return transformed.apply(params, key, inputs, labels) + + def _network(self, x: jnp.ndarray): + output_size = self._dummy_labels.shape[-1] + x = hk.Flatten(preserve_dims=1)(x) + for _ in range(self._num_layers - 1): + x = hk.Linear(self._hidden_size)(x) + x = jax.nn.relu(x) + x = hk.Linear(output_size)(x) + x = jax.nn.log_softmax(x) + return x + + def _loss(self, params, x, labels): + logits = self._network.apply(params, x) + loss = optax.softmax_cross_entropy(logits, labels) + return loss, logits + + def hk_forward(self, inputs: jnp.ndarray, labels: jnp.ndarray) -> jnp.ndarray: + """Runs (sequential) model on inputs. + + Args: + inputs: Inputs of size [sequence_len x feature_size] + labels: One-hot labels of size [sequence_len x num_classes] + + Returns: + Prediction of size [sequence_len x num_classes] + """ + + dummy_inp = inputs[0] + if self._use_maml: + key = hk.next_rng_key() if hk.running_init() else None + params = hk.lift(self._network.init, name='maml_lift')(key, dummy_inp) + else: + key = hk.next_rng_key() + params = self._network.init(key, dummy_inp) + opt_state = self._opt.init(params) + + def scan_tick(carry, x): + params, opt_state = carry + grads, out = self._grad_func(params, *x) + updates, opt_state = self._opt.update(grads, opt_state, params=params) + params = optax.apply_updates(params, updates) + return (params, opt_state), out + + _, outputs = jax.lax.scan(scan_tick, (params, opt_state), (inputs, labels)) + return outputs + + +@gin.configurable() +class FWMemory(Model): + """A fast weight memory model.""" + + def __init__(self, + dummy_inputs, + dummy_labels, + slow_size: int = 64, + memory_size: int = 16): + self._dummy_inputs = dummy_inputs + self._dummy_labels = dummy_labels + self._slow_size = slow_size + self._memory_size = memory_size + + def create_model(self, key: chex.PRNGKey) -> chex.ArrayTree: + transformed = hk.transform(self.hk_forward) + params = transformed.init(key, self._dummy_inputs, self._dummy_labels) + return params + + def __call__(self, params: chex.ArrayTree, key: chex.PRNGKey, + inputs: jnp.ndarray, labels: jnp.ndarray, + is_training: bool) -> jnp.ndarray: + transformed = hk.transform(self.hk_forward) + return transformed.apply(params, key, inputs, labels) + + def hk_forward(self, inputs: jnp.ndarray, labels: jnp.ndarray) -> jnp.ndarray: + """Runs a fw memory model on inputs. + + Args: + inputs: Inputs of size [batch x sequence_len x feature_size] + labels: One-hot labels of size [batch x sequence_len x num_classes] + + Returns: + Prediction of size [batch x sequence_len x num_classes] + """ + + # Shift labels by one and zero out last one. + # Useful for maximizing log-likelihood at every step. + # Alternatively could just optimize for last prediction. + labels: jnp.ndarray = jnp.roll(labels, shift=1, axis=1) + labels = labels.at[:, 0].multiply(0) + + batch_size, _, num_classes = labels.shape + lstm = hk.LSTM(self._slow_size) + output_proj = hk.Linear(num_classes) + write_head = hk.Linear(3 * self._memory_size + 1) + read_head = hk.Linear(2 * self._memory_size) + read_proj = hk.Linear(self._slow_size) + layer_norm = hk.LayerNorm(-1, create_scale=True, create_offset=False) + + lstm_state = lstm.initial_state(batch_size) + init_memory = jnp.zeros( + (batch_size, self._memory_size, self._memory_size**2)) + + # Put sequence axis first + inputs = jnp.transpose(inputs, (1, 0, 2)) + labels = jnp.transpose(labels, (1, 0, 2)) + + def scan_tick(carry, x): + lstm_state, memory = carry + inp, label = x + + inputs = jnp.concatenate([inp, label], axis=-1) + + out, lstm_state = lstm(inputs, lstm_state) + + # Write + write = write_head(out) + beta = jax.nn.sigmoid(write[:, -1]) + k1, k2, v = jnp.split(jax.nn.tanh(write[:, :-1]), 3, axis=-1) + key = jnp.einsum('bi,bj->bij', k1, k2).reshape((batch_size, -1)) + v_old = jnp.einsum('bmn,bn->bm', memory, key) + v_write = jnp.einsum('bm,bn->bmn', v - v_old, key) + memory += beta[:, None, None] * v_write + + # Read + # TODO(lkirsch) optionally add multiple readouts + k1, k2 = jnp.split(jax.nn.tanh(read_head(out)), 2, axis=-1) + key = jnp.einsum('bi,bj->bij', k1, k2).reshape((batch_size, -1)) + v_read = jnp.einsum('bmn,bn->bm', memory, key) + readout = read_proj(layer_norm(v_read)) + out += readout + + out = output_proj(out) + return (lstm_state, memory), out + + _, out = hk.scan(scan_tick, (lstm_state, init_memory), (inputs, labels)) + + # Put sequence axis second + out = jnp.transpose(out, (1, 0, 2)) + logits = jax.nn.log_softmax(out) + + return logits diff --git a/learned_optimization/research/data_driven/resnet.py b/learned_optimization/research/data_driven/resnet.py new file mode 100644 index 0000000..f6a1bb5 --- /dev/null +++ b/learned_optimization/research/data_driven/resnet.py @@ -0,0 +1,177 @@ +# coding=utf-8 +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""ResNet pre-trained networks.""" + +# pylint: disable=g-importing-member +from functools import partial +from typing import Any, Callable, Sequence, Tuple + +from flax import linen as nn +from flax.training import checkpoints +import jax.numpy as jnp + +ModuleDef = Any + + +class ResNetBlock(nn.Module): + """ResNet block.""" + + filters: int + conv: ModuleDef + norm: ModuleDef + act: Callable[[jnp.ndarray], jnp.ndarray] + strides: Tuple[int, int] = (1, 1) + + @nn.compact + def __call__( + self, + x, + ): + residual = x + y = self.conv(self.filters, (3, 3), self.strides)(x) + y = self.norm()(y) + y = self.act(y) + y = self.conv(self.filters, (3, 3))(y) + y = self.norm(scale_init=nn.initializers.zeros)(y) + + if residual.shape != y.shape: + residual = self.conv( + self.filters, (1, 1), self.strides, name='conv_proj' + )(residual) + residual = self.norm(name='norm_proj')(residual) + + return self.act(residual + y) + + +class BottleneckResNetBlock(nn.Module): + """Bottleneck ResNet block.""" + + filters: int + conv: ModuleDef + norm: ModuleDef + act: Callable[[jnp.ndarray], jnp.ndarray] + strides: Tuple[int, int] = (1, 1) + + @nn.compact + def __call__(self, x): + residual = x + y = self.conv(self.filters, (1, 1))(x) + y = self.norm()(y) + y = self.act(y) + y = self.conv(self.filters, (3, 3), self.strides)(y) + y = self.norm()(y) + y = self.act(y) + y = self.conv(self.filters * 4, (1, 1))(y) + y = self.norm(scale_init=nn.initializers.zeros)(y) + + if residual.shape != y.shape: + residual = self.conv( + self.filters * 4, (1, 1), self.strides, name='conv_proj' + )(residual) + residual = self.norm(name='norm_proj')(residual) + + return self.act(residual + y) + + +class ResNet(nn.Module): + """ResNetV1.""" + + stage_sizes: Sequence[int] + block_cls: ModuleDef + num_classes: int + num_filters: int = 64 + dtype: Any = jnp.float32 + act: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu + conv: ModuleDef = nn.Conv + + @nn.compact + def __call__(self, x, train: bool = True): + conv = partial(self.conv, use_bias=False, dtype=self.dtype) + norm = partial( + nn.BatchNorm, + use_running_average=not train, + momentum=0.9, + epsilon=1e-5, + dtype=self.dtype, + ) + + x = conv( + self.num_filters, + (7, 7), + (2, 2), + padding=[(3, 3), (3, 3)], + name='conv_init', + )(x) + x = norm(name='bn_init')(x) + x = nn.relu(x) + x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME') + for i, block_size in enumerate(self.stage_sizes): + for j in range(block_size): + strides = (2, 2) if i > 0 and j == 0 else (1, 1) + x = self.block_cls( + self.num_filters * 2**i, + strides=strides, + conv=conv, + norm=norm, + act=self.act, + )(x) + x = jnp.mean(x, axis=(1, 2)) + # x = nn.Dense(self.num_classes, dtype=self.dtype)(x) + x = jnp.asarray(x, self.dtype) + return x + + +ResNet18 = partial(ResNet, stage_sizes=[2, 2, 2, 2], block_cls=ResNetBlock) +ResNet34 = partial(ResNet, stage_sizes=[3, 4, 6, 3], block_cls=ResNetBlock) +ResNet50 = partial( + ResNet, stage_sizes=[3, 4, 6, 3], block_cls=BottleneckResNetBlock +) +ResNet101 = partial( + ResNet, stage_sizes=[3, 4, 23, 3], block_cls=BottleneckResNetBlock +) +ResNet152 = partial( + ResNet, stage_sizes=[3, 8, 36, 3], block_cls=BottleneckResNetBlock +) +ResNet200 = partial( + ResNet, stage_sizes=[3, 24, 36, 3], block_cls=BottleneckResNetBlock +) + + +ResNet18Local = partial( + ResNet, stage_sizes=[2, 2, 2, 2], block_cls=ResNetBlock, conv=nn.ConvLocal +) + + +# Load model checkpoint from cloud. +def load_params(config_name='v100_x8'): + """Loads the paramteters of a pre-trained ResNet50 or based on the given config_name. + + Args: + config_name: The trained ResNet to load. + + Returns: + The parameters to be used with the respective model. + """ + + chkp = checkpoints.restore_checkpoint(pretrained_path, target=None) + state = {'params': chkp['params'], 'batch_stats': chkp['batch_stats']} + return state + + +def embed(params, x, model_cls=ResNet50, num_classes=1000): + model = model_cls(num_classes=num_classes) + y = model.apply(params, x, train=False) + return y diff --git a/learned_optimization/research/data_driven/run_mnist_projections.py b/learned_optimization/research/data_driven/run_mnist_projections.py new file mode 100644 index 0000000..6a27f80 --- /dev/null +++ b/learned_optimization/research/data_driven/run_mnist_projections.py @@ -0,0 +1,58 @@ +# coding=utf-8 +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Runs MNIST projection experiment with gin configuration.""" + +from absl import app +import jax +from learned_optimization import filesystem +from learned_optimization import setup_experiment +from learned_optimization.research.data_driven import mnist_projections +import numpy as np +import yaml + + +def main(_) -> None: + rank = jax.process_index() + train_log_dir = setup_experiment.setup_experiment(make_dir=(rank == 0)) + train(train_log_dir) + + +def train(training_log_directory: str): + """Runs a projection experiment. + + Args: + training_log_directory: Directory to store log data to. + """ + + experiment = mnist_projections.ProjectionExperiment(training_log_directory) + log_dict = experiment.run() + + if jax.process_index() == 0: + yaml_file_name = f'{training_log_directory}/results.yaml' + with filesystem.file_open(yaml_file_name, 'w') as f: + yaml.dump( + { + k: np.asarray(v).item() + for k, v in log_dict.items() + if np.asarray(v).size == 1 + }, f) + np_file_name = f'{training_log_directory}/results.npy' + with filesystem.file_open(np_file_name, 'wb') as f: + np.save(f, log_dict) + + +if __name__ == '__main__': + app.run(main) diff --git a/learned_optimization/research/data_driven/summary.py b/learned_optimization/research/data_driven/summary.py new file mode 100644 index 0000000..9ed2e51 --- /dev/null +++ b/learned_optimization/research/data_driven/summary.py @@ -0,0 +1,74 @@ +# coding=utf-8 +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utils for logging of summary statistics.""" + +from concurrent import futures +import jax + +from learned_optimization import summary as lo_summary +from learned_optimization.baselines import utils +import numpy as np + + +def only_first_rank(func): + """Only runs the function on rank 0.""" + + def wrappee(*args, **kwargs): + if jax.process_index() == 0: + return func(*args, **kwargs) + return None + + return wrappee + + +class DictSummaryWriter(lo_summary.SummaryWriterBase): + """A summary writer than stores entire dicts as scalars and npy files.""" + + FLUSH_TIMEOUT_SECS = 10 + + def __init__(self, base_writer: lo_summary.SummaryWriterBase, log_dir: str): + self._writer = base_writer + self._log_dir = log_dir + self._thread_pool = futures.ThreadPoolExecutor(max_workers=2) + self._pending_dict_writes = [] + + @only_first_rank + def scalar(self, name, value, step): + return self._writer.scalar(name, value, step) + + @only_first_rank + def histogram(self, name, value, step): + return self._writer.histogram(name, value, step) + + @only_first_rank + def flush(self): + futures.wait(self._pending_dict_writes, timeout=self.FLUSH_TIMEOUT_SECS) + self._pending_dict_writes.clear() + return self._writer.flush() + + @only_first_rank + def dict(self, dict_value, step): + # Store scalars in tf summary + for k, v in dict_value.items(): + if v is None: + continue + if np.isscalar(v) or v.size == 1: + self.scalar(k, v, step) + + # Store entire dictionary as npy file + file_name = f'{self._log_dir}/summary_{step}.npy' + task = self._thread_pool.submit(utils.write_npz, file_name, dict_value) + self._pending_dict_writes.append(task) diff --git a/learned_optimization/research/data_driven/transformer.py b/learned_optimization/research/data_driven/transformer.py new file mode 100644 index 0000000..df3adfa --- /dev/null +++ b/learned_optimization/research/data_driven/transformer.py @@ -0,0 +1,148 @@ +# coding=utf-8 +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Transformer model components.""" + +from typing import Optional + +import haiku as hk +import jax +import jax.numpy as jnp +import numpy as np + + +class CausalSelfAttention(hk.MultiHeadAttention): + """Self attention with a causal mask applied.""" + + def __call__( + self, + query: jnp.ndarray, + key: Optional[jnp.ndarray] = None, + value: Optional[jnp.ndarray] = None, + mask: Optional[jnp.ndarray] = None, + ) -> jnp.ndarray: + key = key if key is not None else query + value = value if value is not None else query + + if query.ndim != 3: + raise ValueError('Expect queries of shape [B, T, D].') + + seq_len = query.shape[1] + causal_mask = np.tril(np.ones((1, 1, seq_len, seq_len))) + mask = mask * causal_mask if mask is not None else causal_mask + + return super().__call__(query, key, value, mask) + + +class DenseBlock(hk.Module): + """A 2-layer MLP which widens then narrows the input.""" + + def __init__(self, + init_scale: float, + widening_factor: float = 4., + name: Optional[str] = None): + super().__init__(name=name) + self._init_scale = init_scale + self._widening_factor = widening_factor + + def __call__(self, x: jnp.ndarray) -> jnp.ndarray: + hiddens = x.shape[-1] + initializer = hk.initializers.VarianceScaling(self._init_scale) + x = hk.Linear(int(self._widening_factor * hiddens), w_init=initializer)(x) + x = jax.nn.gelu(x) + return hk.Linear(hiddens, w_init=initializer)(x) + + +class Transformer(hk.Module): + """A transformer stack.""" + + def __init__(self, + num_heads: int, + num_layers: int, + key_size: int, + dropout_rate: float, + widening_factor: float, + name: Optional[str] = None): + super().__init__(name=name) + self._num_layers = num_layers + self._num_heads = num_heads + self._key_size = key_size + self._widening_factor = widening_factor + self._dropout_rate = dropout_rate + + def __call__(self, + h: jnp.ndarray, + mask: Optional[jnp.ndarray], + is_training: bool) -> jnp.ndarray: + """Connects the transformer. + + Args: + h: Inputs, [B, T, D]. + mask: Padding mask, [B, T]. + is_training: Whether we're training or not. + + Returns: + Array of shape [B, T, D]. + """ + + init_scale = 2. / self._num_layers + dropout_rate = self._dropout_rate if is_training else 0. + if mask is not None: + mask = mask[:, None, None, :] + + # Note: names chosen to approximately match those used in the GPT-2 code; + # see https://github.com/openai/gpt-2/blob/master/src/model.py. + for i in range(self._num_layers): + h_norm = layer_norm(h, name=f'h{i}_ln_1') + h_attn = CausalSelfAttention( + num_heads=self._num_heads, + key_size=self._key_size, + model_size=h.shape[-1], + w_init_scale=init_scale, + name=f'h{i}_attn')( + h_norm, mask=mask) + h_attn = hk.dropout(hk.next_rng_key(), dropout_rate, h_attn) + h = h + h_attn + h_norm = layer_norm(h, name=f'h{i}_ln_2') + h_dense = DenseBlock( + init_scale, self._widening_factor, name=f'h{i}_mlp')( + h_norm) + h_dense = hk.dropout(hk.next_rng_key(), dropout_rate, h_dense) + h = h + h_dense + h = layer_norm(h, name='ln_f') + + return h + + +def layer_norm(x: jnp.ndarray, name: Optional[str] = None) -> jnp.ndarray: + """Apply a unique LayerNorm to x with default settings.""" + return hk.LayerNorm(axis=-1, + create_scale=True, + create_offset=True, + name=name)(x)