Skip to content

Commit

Permalink
Merge pull request #179 from TomNong/connector-update
Browse files Browse the repository at this point in the history
Fix *StochasticConnector when `transform=False`
  • Loading branch information
ZhitingHu authored Jul 17, 2019
2 parents 5a8fb32 + 1d71c15 commit 962294b
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 17 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
* Fix [GPT-2](https://github.com/asyml/texar/tree/master/examples/gpt-2) tokenization loading path. ([#165](https://github.com/asyml/texar/pull/165))
* Fix [examples/vae_text](https://github.com/asyml/texar/tree/master/examples/vae_text) EOS bug. ([#168](https://github.com/asyml/texar/pull/168))
* Fix transformer [bleu_tool.py](https://github.com/asyml/texar/blob/master/examples/transformer/bleu_tool.py) when `translation_length` is 0. ([#176](https://github.com/asyml/texar/pull/176))
* Fix `StochasticConnector` and `ReparameterizedStochasticConnector` when `transform=False`. ([#179](https://github.com/asyml/texar/pull/179))

## [v0.2.0](https://github.com/asyml/texar/releases/tag/v0.2.0) (2019-04-09)

Expand Down
1 change: 0 additions & 1 deletion texar/modules/connectors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,3 @@

from texar.modules.connectors.connector_base import *
from texar.modules.connectors.connectors import *

37 changes: 22 additions & 15 deletions texar/modules/connectors/connectors.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2018 The Texar Authors. All Rights Reserved.
# Copyright 2019 The Texar Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -54,7 +54,10 @@ def _assert_same_size(outputs, output_size):
flat_output = nest.flatten(outputs)

for (output, size) in zip(flat_output, flat_output_size):
if output[0].shape != tf.TensorShape(size):
if isinstance(size, tf.TensorShape):
if output.shape == size:
pass
elif output[0].shape != tf.TensorShape(size):
raise ValueError(
"The output size does not match the the required output_size")

Expand Down Expand Up @@ -518,7 +521,8 @@ class instance.
- output: A Tensor or a (nested) tuple of Tensors with the same \
structure and size of :attr:`output_size`. The batch dimension \
equals :attr:`num_samples` if specified, or is determined by the \
distribution dimensionality.
distribution dimensionality. If :attr:`transform` is `False`, \
:attr:`output` will be equal to :attr:`sample`.
- sample: The sample from the distribution, prior to transformation.
Raises:
Expand Down Expand Up @@ -549,9 +553,10 @@ class instance.
fn_modules = ['tensorflow', 'tensorflow.nn', 'texar.custom']
activation_fn = get_function(self.hparams.activation_fn, fn_modules)
output = _mlp_transform(sample, self._output_size, activation_fn)
else:
output = sample

_assert_same_size(output, self._output_size)

if not self._built:
self._add_internal_trainable_variables()
self._built = True
Expand Down Expand Up @@ -616,7 +621,7 @@ def default_hparams():
def _build(self,
distribution='MultivariateNormalDiag',
distribution_kwargs=None,
transform=False,
transform=True,
num_samples=None):
"""Samples from a distribution and optionally performs transformation
with an MLP layer.
Expand Down Expand Up @@ -649,7 +654,8 @@ class instance.
- output: A Tensor or a (nested) tuple of Tensors with the same \
structure and size of :attr:`output_size`. The batch dimension \
equals :attr:`num_samples` if specified, or is determined by the \
distribution dimensionality.
distribution dimensionality. If :attr:`transform` is `False`, \
:attr:`output` will be equal to :attr:`sample`.
- sample: The sample from the distribution, prior to transformation.
Raises:
Expand All @@ -661,31 +667,32 @@ class instance.
"tensorflow.contrib.distributions", "texar.custom"])

if num_samples:
output = dstr.sample(num_samples)
sample = dstr.sample(num_samples)
else:
output = dstr.sample()
sample = dstr.sample()

if dstr.event_shape == []:
output = tf.reshape(output,
output.shape.concatenate(tf.TensorShape(1)))
sample = tf.reshape(sample,
sample.shape.concatenate(tf.TensorShape(1)))

# Disable gradients through samples
output = tf.stop_gradient(output)
sample = tf.stop_gradient(sample)

output = tf.cast(output, tf.float32)
sample = tf.cast(sample, tf.float32)

if transform:
fn_modules = ['tensorflow', 'tensorflow.nn', 'texar.custom']
activation_fn = get_function(self.hparams.activation_fn, fn_modules)
output = _mlp_transform(output, self._output_size, activation_fn)
output = _mlp_transform(sample, self._output_size, activation_fn)
else:
output = sample

_assert_same_size(output, self._output_size)

if not self._built:
self._add_internal_trainable_variables()
self._built = True

return output
return output, sample


#class ConcatConnector(ConnectorBase):
Expand Down
33 changes: 32 additions & 1 deletion texar/modules/connectors/connectors_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
from texar.core import layers
from texar.modules import ConstantConnector
from texar.modules import MLPTransformConnector
from texar.modules import ReparameterizedStochasticConnector
from texar.modules import (ReparameterizedStochasticConnector,
StochasticConnector)
from texar.modules.connectors.connectors import _assert_same_size

# pylint: disable=too-many-locals, invalid-name
Expand Down Expand Up @@ -132,6 +133,36 @@ def test_reparameterized_stochastic_connector(self):
# self.assertAlmostEqual(0, sample_mu[i], delta=0.2)
# self.assertAlmostEqual(1, sample_var[i], delta=0.2)

def test_stochastic_connector(self):
"""Tests the logic of
:class:`~texar.modules.StochasticConnector`.
"""
state_size = (10, 10)
variable_size = 100
state_size_ts = tf.TensorShape([self._batch_size, variable_size])
gauss_connector = StochasticConnector(state_size)
mu = tf.zeros([self._batch_size, variable_size])
var = tf.ones([self._batch_size, variable_size])
gauss_ds = tfpd.MultivariateNormalDiag(loc=mu, scale_diag=var)
output_1, _ = gauss_connector(gauss_ds)

gauss_connector_2 = StochasticConnector(state_size_ts)
output_2, sample2 = gauss_connector_2(
distribution="MultivariateNormalDiag",
distribution_kwargs={"loc": mu, "scale_diag": var}, transform=False)
test_list = [output_1, output_2, sample2]

with self.test_session() as sess:
sess.run(tf.global_variables_initializer())
out_list = sess.run(test_list)
out1 = out_list[0]
out2 = out_list[1]
sample2 = out_list[2]
self.assertEqual(out1[0].shape,
tf.TensorShape([self._batch_size, state_size[0]]))
self.assertEqual(out2.shape, state_size_ts)
self.assertEqual(out2.shape, sample2.shape)

#def test_concat_connector(self): # pylint: disable=too-many-locals
# """Tests the logic of
# :class:`~texar.modules.connectors.ConcatConnector`.
Expand Down

0 comments on commit 962294b

Please sign in to comment.