diff --git a/CHANGELOG.md b/CHANGELOG.md index f15235d3..5013cc5f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,7 @@ * Fix [GPT-2](https://github.com/asyml/texar/tree/master/examples/gpt-2) tokenization loading path. ([#165](https://github.com/asyml/texar/pull/165)) * Fix [examples/vae_text](https://github.com/asyml/texar/tree/master/examples/vae_text) EOS bug. ([#168](https://github.com/asyml/texar/pull/168)) * Fix transformer [bleu_tool.py](https://github.com/asyml/texar/blob/master/examples/transformer/bleu_tool.py) when `translation_length` is 0. ([#176](https://github.com/asyml/texar/pull/176)) +* Fix `StochasticConnector` and `ReparameterizedStochasticConnector` when `transform=False`. ([#179](https://github.com/asyml/texar/pull/179)) ## [v0.2.0](https://github.com/asyml/texar/releases/tag/v0.2.0) (2019-04-09) diff --git a/texar/modules/connectors/__init__.py b/texar/modules/connectors/__init__.py index 01dc4c3f..99ec742d 100644 --- a/texar/modules/connectors/__init__.py +++ b/texar/modules/connectors/__init__.py @@ -23,4 +23,3 @@ from texar.modules.connectors.connector_base import * from texar.modules.connectors.connectors import * - diff --git a/texar/modules/connectors/connectors.py b/texar/modules/connectors/connectors.py index 99c336f8..287338f0 100644 --- a/texar/modules/connectors/connectors.py +++ b/texar/modules/connectors/connectors.py @@ -1,4 +1,4 @@ -# Copyright 2018 The Texar Authors. All Rights Reserved. +# Copyright 2019 The Texar Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -54,7 +54,10 @@ def _assert_same_size(outputs, output_size): flat_output = nest.flatten(outputs) for (output, size) in zip(flat_output, flat_output_size): - if output[0].shape != tf.TensorShape(size): + if isinstance(size, tf.TensorShape): + if output.shape == size: + pass + elif output[0].shape != tf.TensorShape(size): raise ValueError( "The output size does not match the the required output_size") @@ -518,7 +521,8 @@ class instance. - output: A Tensor or a (nested) tuple of Tensors with the same \ structure and size of :attr:`output_size`. The batch dimension \ equals :attr:`num_samples` if specified, or is determined by the \ - distribution dimensionality. + distribution dimensionality. If :attr:`transform` is `False`, \ + :attr:`output` will be equal to :attr:`sample`. - sample: The sample from the distribution, prior to transformation. Raises: @@ -549,9 +553,10 @@ class instance. fn_modules = ['tensorflow', 'tensorflow.nn', 'texar.custom'] activation_fn = get_function(self.hparams.activation_fn, fn_modules) output = _mlp_transform(sample, self._output_size, activation_fn) + else: + output = sample _assert_same_size(output, self._output_size) - if not self._built: self._add_internal_trainable_variables() self._built = True @@ -616,7 +621,7 @@ def default_hparams(): def _build(self, distribution='MultivariateNormalDiag', distribution_kwargs=None, - transform=False, + transform=True, num_samples=None): """Samples from a distribution and optionally performs transformation with an MLP layer. @@ -649,7 +654,8 @@ class instance. - output: A Tensor or a (nested) tuple of Tensors with the same \ structure and size of :attr:`output_size`. The batch dimension \ equals :attr:`num_samples` if specified, or is determined by the \ - distribution dimensionality. + distribution dimensionality. If :attr:`transform` is `False`, \ + :attr:`output` will be equal to :attr:`sample`. - sample: The sample from the distribution, prior to transformation. Raises: @@ -661,31 +667,32 @@ class instance. "tensorflow.contrib.distributions", "texar.custom"]) if num_samples: - output = dstr.sample(num_samples) + sample = dstr.sample(num_samples) else: - output = dstr.sample() + sample = dstr.sample() if dstr.event_shape == []: - output = tf.reshape(output, - output.shape.concatenate(tf.TensorShape(1))) + sample = tf.reshape(sample, + sample.shape.concatenate(tf.TensorShape(1))) # Disable gradients through samples - output = tf.stop_gradient(output) + sample = tf.stop_gradient(sample) - output = tf.cast(output, tf.float32) + sample = tf.cast(sample, tf.float32) if transform: fn_modules = ['tensorflow', 'tensorflow.nn', 'texar.custom'] activation_fn = get_function(self.hparams.activation_fn, fn_modules) - output = _mlp_transform(output, self._output_size, activation_fn) + output = _mlp_transform(sample, self._output_size, activation_fn) + else: + output = sample _assert_same_size(output, self._output_size) - if not self._built: self._add_internal_trainable_variables() self._built = True - return output + return output, sample #class ConcatConnector(ConnectorBase): diff --git a/texar/modules/connectors/connectors_test.py b/texar/modules/connectors/connectors_test.py index 4049dfe8..0f837ddd 100644 --- a/texar/modules/connectors/connectors_test.py +++ b/texar/modules/connectors/connectors_test.py @@ -15,7 +15,8 @@ from texar.core import layers from texar.modules import ConstantConnector from texar.modules import MLPTransformConnector -from texar.modules import ReparameterizedStochasticConnector +from texar.modules import (ReparameterizedStochasticConnector, + StochasticConnector) from texar.modules.connectors.connectors import _assert_same_size # pylint: disable=too-many-locals, invalid-name @@ -132,6 +133,36 @@ def test_reparameterized_stochastic_connector(self): # self.assertAlmostEqual(0, sample_mu[i], delta=0.2) # self.assertAlmostEqual(1, sample_var[i], delta=0.2) + def test_stochastic_connector(self): + """Tests the logic of + :class:`~texar.modules.StochasticConnector`. + """ + state_size = (10, 10) + variable_size = 100 + state_size_ts = tf.TensorShape([self._batch_size, variable_size]) + gauss_connector = StochasticConnector(state_size) + mu = tf.zeros([self._batch_size, variable_size]) + var = tf.ones([self._batch_size, variable_size]) + gauss_ds = tfpd.MultivariateNormalDiag(loc=mu, scale_diag=var) + output_1, _ = gauss_connector(gauss_ds) + + gauss_connector_2 = StochasticConnector(state_size_ts) + output_2, sample2 = gauss_connector_2( + distribution="MultivariateNormalDiag", + distribution_kwargs={"loc": mu, "scale_diag": var}, transform=False) + test_list = [output_1, output_2, sample2] + + with self.test_session() as sess: + sess.run(tf.global_variables_initializer()) + out_list = sess.run(test_list) + out1 = out_list[0] + out2 = out_list[1] + sample2 = out_list[2] + self.assertEqual(out1[0].shape, + tf.TensorShape([self._batch_size, state_size[0]])) + self.assertEqual(out2.shape, state_size_ts) + self.assertEqual(out2.shape, sample2.shape) + #def test_concat_connector(self): # pylint: disable=too-many-locals # """Tests the logic of # :class:`~texar.modules.connectors.ConcatConnector`.