Skip to content
This repository has been archived by the owner on Jul 7, 2023. It is now read-only.

Commit

Permalink
Merge of PR #1773
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 316746422
  • Loading branch information
AgoloCuongHoang authored and copybara-github committed Jun 16, 2020
1 parent 94a3c0e commit ceba665
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 81 deletions.
181 changes: 104 additions & 77 deletions tensor2tensor/utils/multistep_with_adamoptimizer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2019 The Tensor2Tensor Authors.
# Copyright 2020 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -13,6 +13,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# Copyright 2019 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Multi-step optimizers simulating large batches.
Optimizer variants which make it possible to use very large batch sizes with
Expand All @@ -26,26 +39,26 @@
from __future__ import division
from __future__ import print_function

import tensorflow as tf
from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
import tensorflow.compat.v1 as tf
# pylint: disable=g-direct-tensorflow-import
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.training import optimizer
from tensorflow.python.training import training_ops
from tensorflow.python.util.tf_export import tf_export
from tensorflow.keras import backend as K
# pylint: enable=g-direct-tensorflow-import


class MultistepAdamOptimizer(optimizer.Optimizer):
class MultistepAdamOptimizer(tf.train.Optimizer):
"""Adam with SGD updates every n steps with accumulated gradients."""

def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8,
use_locking=False, name="Adam", n=1):
super(MultistepAdamOptimizer, self).__init__(use_locking=use_locking, name=name)
def __init__(self,
learning_rate=0.001,
beta1=0.9,
beta2=0.999,
epsilon=1e-8,
use_locking=False,
name="Adam",
n=1):
super(MultistepAdamOptimizer, self).__init__(
use_locking=use_locking, name=name)
self._lr = learning_rate
self._beta1 = beta1
self._beta2 = beta2
Expand All @@ -59,43 +72,46 @@ def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8,
self._n_t = None # n as tensor

def _get_beta_accumulators(self):
with ops.init_scope():
if context.executing_eagerly():
with tf.init_scope():
if tf.executing_eagerly():
graph = None
else:
graph = ops.get_default_graph()
graph = tf.get_default_graph()
return (self._get_non_slot_variable("beta1_power", graph=graph),
self._get_non_slot_variable("beta2_power", graph=graph))

def _create_slots(self, var_list):
"""Create slot variables for Adam with accumulated gradients."""
first_var = min(var_list, key=lambda x: x.name)
self._create_non_slot_variable(initial_value=self._beta1, name="beta1_power", colocate_with=first_var)
self._create_non_slot_variable(initial_value=self._beta2, name="beta2_power", colocate_with=first_var)
#if iter is initialized as an int32, this optimizer could not run
#with tensorflow_hub with a tensorflow-gpu version
self._create_non_slot_variable(initial_value=0.0 if self._n == 1 else 1.0, name="iter", colocate_with=first_var)
self._create_non_slot_variable(
initial_value=self._beta1, name="beta1_power", colocate_with=first_var)
self._create_non_slot_variable(
initial_value=self._beta2, name="beta2_power", colocate_with=first_var)
# if iter is initialized as an int32, this optimizer could not run
# with tensorflow_hub with a tensorflow-gpu version
self._create_non_slot_variable(
initial_value=0.0 if self._n == 1 else 1.0,
name="iter",
colocate_with=first_var)
# Create slots for the first and second moments, as well as grad_acc.
for v in var_list:
self._zeros_slot(v, "m", self._name)
self._zeros_slot(v, "v", self._name)
self._zeros_slot(v, "grad_acc", self._name)


def _get_iter_variable(self):
graph = (
None if tf.executing_eagerly() else tf.get_default_graph())
graph = (None if tf.executing_eagerly() else tf.get_default_graph())
return self._get_non_slot_variable("iter", graph=graph)

def _prepare(self):
lr = self._call_if_callable(self._lr)
beta1 = self._call_if_callable(self._beta1)
beta2 = self._call_if_callable(self._beta2)
epsilon = self._call_if_callable(self._epsilon)
self._beta1_t = ops.convert_to_tensor(beta1, name="beta1")
self._beta2_t = ops.convert_to_tensor(beta2, name="beta2")
self._lr_t = ops.convert_to_tensor(lr, name="learning_rate")
self._epsilon_t = ops.convert_to_tensor(epsilon, name="epsilon")
self._beta1_t = tf.convert_to_tensor(beta1, name="beta1")
self._beta2_t = tf.convert_to_tensor(beta2, name="beta2")
self._lr_t = tf.convert_to_tensor(lr, name="learning_rate")
self._epsilon_t = tf.convert_to_tensor(epsilon, name="epsilon")
self._n_t = tf.convert_to_tensor(self._n, name="n")

def _apply_cond(self, apply_fn, grad, var, *args, **kwargs):
Expand All @@ -106,8 +122,8 @@ def apply_adam(grad_acc, apply_fn, grad, var, *args, **kwargs):
total_grad = (grad_acc + grad) / tf.cast(self._n_t, grad.dtype)
adam_op = apply_fn(total_grad, var, *args, **kwargs)
with tf.control_dependencies([adam_op]):
grad_acc_to_zero_op = grad_acc.assign(tf.zeros_like(grad_acc),
use_locking=self._use_locking)
grad_acc_to_zero_op = grad_acc.assign(
tf.zeros_like(grad_acc), use_locking=self._use_locking)
return tf.group(adam_op, grad_acc_to_zero_op)

def accumulate_gradient(grad_acc, grad):
Expand All @@ -126,14 +142,17 @@ def _apply_dense_in_action(self, grad, var):
m = self.get_slot(var, "m")
v = self.get_slot(var, "v")
beta1_power, beta2_power = self._get_beta_accumulators()
return training_ops.apply_adam(var, m, v,
math_ops.cast(beta1_power, var.dtype.base_dtype),
math_ops.cast(beta2_power, var.dtype.base_dtype),
math_ops.cast(self._lr_t, var.dtype.base_dtype),
math_ops.cast(self._beta1_t, var.dtype.base_dtype),
math_ops.cast(self._beta2_t, var.dtype.base_dtype),
math_ops.cast(self._epsilon_t, var.dtype.base_dtype),
grad,
return training_ops.apply_adam(
var,
m,
v,
tf.cast(beta1_power, var.dtype.base_dtype),
tf.cast(beta2_power, var.dtype.base_dtype),
tf.cast(self._lr_t, var.dtype.base_dtype),
tf.cast(self._beta1_t, var.dtype.base_dtype),
tf.cast(self._beta2_t, var.dtype.base_dtype),
tf.cast(self._epsilon_t, var.dtype.base_dtype),
grad,
use_locking=self._use_locking).op

def _resource_apply_dense(self, grad, var):
Expand All @@ -143,41 +162,44 @@ def _resource_apply_dense_in_action(self, grad, var):
m = self.get_slot(var, "m")
v = self.get_slot(var, "v")
beta1_power, beta2_power = self._get_beta_accumulators()
return training_ops.resource_apply_adam(var.handle,
m.handle,
return training_ops.resource_apply_adam(
var.handle,
m.handle,
v.handle,
math_ops.cast(beta1_power, grad.dtype.base_dtype),
math_ops.cast(beta2_power, grad.dtype.base_dtype),
math_ops.cast(self._lr_t, var.dtype.base_dtype),
math_ops.cast(self._beta1_t, grad.dtype.base_dtype),
math_ops.cast(self._beta2_t, grad.dtype.base_dtype),
math_ops.cast(self._epsilon_t, grad.dtype.base_dtype),
grad, use_locking=self._use_locking)
tf.cast(beta1_power, grad.dtype.base_dtype),
tf.cast(beta2_power, grad.dtype.base_dtype),
tf.cast(self._lr_t, var.dtype.base_dtype),
tf.cast(self._beta1_t, grad.dtype.base_dtype),
tf.cast(self._beta2_t, grad.dtype.base_dtype),
tf.cast(self._epsilon_t, grad.dtype.base_dtype),
grad,
use_locking=self._use_locking)

def _apply_sparse_shared(self, grad, var, indices, scatter_add):
beta1_power, beta2_power = self._get_beta_accumulators()
beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype)
beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype)
lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype)
beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype)
beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype)
epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype)
lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power))
beta1_power = tf.cast(beta1_power, var.dtype.base_dtype)
beta2_power = tf.cast(beta2_power, var.dtype.base_dtype)
lr_t = tf.cast(self._lr_t, var.dtype.base_dtype)
beta1_t = tf.cast(self._beta1_t, var.dtype.base_dtype)
beta2_t = tf.cast(self._beta2_t, var.dtype.base_dtype)
epsilon_t = tf.cast(self._epsilon_t, var.dtype.base_dtype)
lr = (lr_t * tf.sqrt(1 - beta2_power) / (1 - beta1_power))
# m_t = beta1 * m + (1 - beta1) * g_t
m = self.get_slot(var, "m")
m_scaled_g_values = grad * (1 - beta1_t)
m_t = state_ops.assign(m, m * beta1_t, use_locking=self._use_locking)
with ops.control_dependencies([m_t]):
m_t = tf.assign(m, m * beta1_t, use_locking=self._use_locking)
with tf.control_dependencies([m_t]):
m_t = scatter_add(m, indices, m_scaled_g_values)
# v_t = beta2 * v + (1 - beta2) * (g_t * g_t)
v = self.get_slot(var, "v")
v_scaled_g_values = (grad * grad) * (1 - beta2_t)
v_t = state_ops.assign(v, v * beta2_t, use_locking=self._use_locking)
with ops.control_dependencies([v_t]):
v_t = tf.assign(v, v * beta2_t, use_locking=self._use_locking)
with tf.control_dependencies([v_t]):
v_t = scatter_add(v, indices, v_scaled_g_values)
v_sqrt = math_ops.sqrt(v_t)
var_update = state_ops.assign_sub(var, lr * m_t / (v_sqrt + epsilon_t), use_locking=self._use_locking)
return control_flow_ops.group(*[var_update, m_t, v_t])
v_sqrt = tf.sqrt(v_t)
var_update = tf.assign_sub(
var, lr * m_t / (v_sqrt + epsilon_t), use_locking=self._use_locking)
return tf.group(*[var_update, m_t, v_t])

def _apply_sparse(self, grad, var):
# TODO(fstahlberg): Implement a sparse version
Expand All @@ -191,39 +213,44 @@ def _resource_apply_sparse_duplicate_indices(self, grad, var, indices):
# correctly (summing them). A real sparse implementation will probably want
# to override _resource_apply_sparse instead so it gets them de-duplicated
# automatically.
dense_grad = tf.convert_to_tensor(tf.IndexedSlices(values=grad,
indices=indices, dense_shape=tf.shape(var)))
return self._apply_cond(self._resource_apply_dense_in_action, dense_grad, var)
dense_grad = tf.convert_to_tensor(
tf.IndexedSlices(
values=grad, indices=indices, dense_shape=tf.shape(var)))
return self._apply_cond(self._resource_apply_dense_in_action, dense_grad,
var)

def _resource_scatter_add(self, x, i, v):
with ops.control_dependencies(
with tf.control_dependencies(
[resource_variable_ops.resource_scatter_add(x.handle, i, v)]):
return x.value()

def _resource_apply_sparse(self, grad, var, indices):
return self._apply_sparse_shared(grad, var, indices, self._resource_scatter_add)
return self._apply_sparse_shared(grad, var, indices,
self._resource_scatter_add)

def _finish(self, update_ops, name_scope):
"""Updates beta_power variables every n batches and incrs counter."""
iter_ = self._get_iter_variable()
beta1_power, beta2_power = self._get_beta_accumulators()
with tf.control_dependencies(update_ops):
with tf.colocate_with(iter_):

def update_beta_op():
update_beta1 = beta1_power.assign(
beta1_power * self._beta1_t,
use_locking=self._use_locking)
beta1_power * self._beta1_t, use_locking=self._use_locking)
update_beta2 = beta2_power.assign(
beta2_power * self._beta2_t,
use_locking=self._use_locking)
beta2_power * self._beta2_t, use_locking=self._use_locking)
return tf.group(update_beta1, update_beta2)

maybe_update_beta = tf.cond(
tf.equal(iter_, 0), update_beta_op, tf.no_op)
with tf.control_dependencies([maybe_update_beta]):
#TODO(Cuong): It is suboptimal here because we have to cast twice (float to int,
#and then int to float)
update_iter = iter_.assign(K.cast(tf.mod(K.cast(iter_ + 1.0, dtype=dtypes.int32), self._n_t), dtype=dtypes.float32),
use_locking=self._use_locking)
# TODO(cuong): It is suboptimal here because we have to cast twice
# (float to int, and then int to float)
update_iter = iter_.assign(
tf.cast(
tf.mod(tf.cast(iter_ + 1.0, tf.int32), self._n_t),
tf.float32),
use_locking=self._use_locking)
return tf.group(
*update_ops + [update_iter, maybe_update_beta], name=name_scope)

22 changes: 18 additions & 4 deletions tensor2tensor/utils/multistep_with_adamoptimizer_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,18 @@
# coding=utf-8
# Copyright 2020 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Copyright 2019 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -19,8 +33,8 @@
from __future__ import print_function

import numpy as np
from tensor2tensor.utils import multistep_optimizer
import tensorflow as tf
from tensor2tensor.utils import multistep_with_adamoptimizer
import tensorflow.compat.v1 as tf


class MultistepAdamOptimizerTest(tf.test.TestCase):
Expand Down Expand Up @@ -56,7 +70,7 @@ def testMultistep(self):

singlestep_opt = tf.train.AdamOptimizer(
beta1=beta1, beta2=beta2, learning_rate=alpha)
multistep_opt = multistep_optimizer.MultistepAdamOptimizer(
multistep_opt = multistep_with_adamoptimizer.MultistepAdamOptimizer(
n=n, beta1=beta1, beta2=beta2, learning_rate=alpha)

singlestep_update = singlestep_opt.apply_gradients([
Expand Down Expand Up @@ -100,7 +114,7 @@ def testResourceVariables(self):
tape.watch([v1, v2])
loss = tf.reduce_sum(tf.gather(params=v1, indices=[0]) + v2)
v1_grad, v2_grad = tape.gradient(loss, [v1, v2])
multistep_opt = multistep_optimizer.MultistepAdamOptimizer(0.1)
multistep_opt = multistep_with_adamoptimizer.MultistepAdamOptimizer(0.1)
multistep_opt.apply_gradients(((v1_grad, v1), (v2_grad, v2)))


Expand Down

0 comments on commit ceba665

Please sign in to comment.