Skip to content

Commit

Permalink
Merge pull request #176 from haoransh/master
Browse files Browse the repository at this point in the history
Update SinusoidsPositionEmbedder and bleu_tool

* SinusoidsPositionEmbedder supports infinite indexes
* bleu_tool.py works when ratio==0

Fix #174
Fix #175
  • Loading branch information
ZhitingHu authored Jul 2, 2019
2 parents 821c7db + 192d67a commit 5a8fb32
Show file tree
Hide file tree
Showing 4 changed files with 143 additions and 66 deletions.
6 changes: 4 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,15 @@
* Allow passing a Tensor to `output_layer` of decoders' constructors -- used for weight tie b/w the output layer and input embedding matrix. ([#126](https://github.com/asyml/texar/pull/126))
* `TransformerDecoder` constructor interface made exact the same with `RNN decoders` constructor interfaces. ([#126](https://github.com/asyml/texar/pull/126))
* Refactor decoder `Helper`s to allow two-argument `embedding_fn` (supporting for position embedding). ([#126](https://github.com/asyml/texar/pull/126))
* Refactor `SinusoidsPositionEmbedder` to enable infinite large or negative position indexes. ([#176](https://github.com/asyml/texar/pull/176))

### Fixes

* Fix `texar.losses.reduce_batch_time` when `sequence` has dtype other than `tf.float32`. ([#143](https://github.com/asyml/texar/issues/143))
* Fix `texar.losses.reduce_dimensions` when `average_axes` or `sum_axes` is `int`. ([#141](https://github.com/asyml/texar/pull/141))
* Fix [GPT-2](https://github.com/asyml/texar/tree/master/examples/gpt-2) tokenization loading path. ([165](https://github.com/asyml/texar/pull/165))
* Fix [examples/vae_text](https://github.com/asyml/texar/tree/master/examples/vae_text) EOS bug. ([168](https://github.com/asyml/texar/pull/168))
* Fix [GPT-2](https://github.com/asyml/texar/tree/master/examples/gpt-2) tokenization loading path. ([#165](https://github.com/asyml/texar/pull/165))
* Fix [examples/vae_text](https://github.com/asyml/texar/tree/master/examples/vae_text) EOS bug. ([#168](https://github.com/asyml/texar/pull/168))
* Fix transformer [bleu_tool.py](https://github.com/asyml/texar/blob/master/examples/transformer/bleu_tool.py) when `translation_length` is 0. ([#176](https://github.com/asyml/texar/pull/176))

## [v0.2.0](https://github.com/asyml/texar/releases/tag/v0.2.0) (2019-04-09)

Expand Down
7 changes: 5 additions & 2 deletions examples/transformer/bleu_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,9 +132,12 @@ def compute_bleu(reference_corpus,

if use_bp:
ratio = translation_length / reference_length
if ratio == 0:
if ratio <= 0:
bp = 0
bp = math.exp(1 - 1. / ratio) if ratio < 1.0 else 1.0
elif ratio < 1.0:
bp = math.exp(1 - 1. / ratio)
else:
bp = 1.0
bleu = geo_mean * bp
return np.float32(bleu)

Expand Down
50 changes: 32 additions & 18 deletions examples/transformer/utils/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,27 @@

# pylint: disable=no-member


def load_data_numpy(input_dir, prefix):
train_data = np.load(os.path.join(input_dir,\
prefix + 'train.npy'), encoding='latin1').tolist()
dev_data = np.load(os.path.join(input_dir,\
prefix + 'valid.npy'), encoding='latin1').tolist()
test_data = np.load(os.path.join(input_dir,\
prefix + 'test.npy'), encoding='latin1').tolist()
print('train data size:{}'.format(len(train_data)))
train_data = np.load(
os.path.join(input_dir, prefix + "train.npy"),
encoding="latin1",
allow_pickle=True,
).tolist()
dev_data = np.load(
os.path.join(input_dir, prefix + "valid.npy"),
encoding="latin1",
allow_pickle=True,
).tolist()
test_data = np.load(
os.path.join(input_dir, prefix + "test.npy"),
encoding="latin1",
allow_pickle=True,
).tolist()
print("train data size:{}".format(len(train_data)))
return train_data, dev_data, test_data


def seq2seq_pad_concat_convert(xy_batch, eos_id=2, bos_id=1):
"""
Args:
Expand All @@ -55,36 +66,39 @@ def seq2seq_pad_concat_convert(xy_batch, eos_id=2, bos_id=1):
y_block = _concat_examples(y_seqs, padding=0)

# Add EOS
x_block = np.pad(x_block, ((0, 0), (0, 1)), 'constant',
constant_values=0)
x_block = np.pad(x_block, ((0, 0), (0, 1)), "constant", constant_values=0)
for i_batch, seq in enumerate(x_seqs):
x_block[i_batch, len(seq)] = eos_id

y_out_block = np.pad(y_block, ((0, 0), (0, 1)), 'constant',
constant_values=0)
y_out_block = np.pad(
y_block, ((0, 0), (0, 1)), "constant", constant_values=0
)
for i_batch, seq in enumerate(y_seqs):
y_out_block[i_batch, len(seq)] = eos_id

# Add BOS in target language
y_in_block = np.pad(y_block, ((0, 0), (1, 0)), 'constant',
constant_values=bos_id)
y_in_block = np.pad(
y_block, ((0, 0), (1, 0)), "constant", constant_values=bos_id
)
return x_block, y_in_block, y_out_block


def source_pad_concat_convert(x_seqs, eos_id=2, bos_id=1):
"""
This function is used when testing the model without target input.
"""
x_block = _concat_examples(x_seqs, padding=0)

# add EOS
x_block = np.pad(x_block, ((0, 0), (0, 1)), 'constant', constant_values=0)
x_block = np.pad(x_block, ((0, 0), (0, 1)), "constant", constant_values=0)
for i_batch, seq in enumerate(x_seqs):
x_block[i_batch, len(seq)] = eos_id
return x_block


def _concat_examples(arrays, padding=0):
if len(arrays) == 0:
raise ValueError('batch is empty')
raise ValueError("batch is empty")

first_elem = arrays[0]
assert isinstance(first_elem, np.ndarray)
Expand All @@ -102,8 +116,8 @@ def _concat_examples(arrays, padding=0):
result[(i,) + slices] = src
return result


def write_words(words_list, filename):
with codecs.open(filename, 'w+', 'utf-8') as myfile:
with codecs.open(filename, "w+", "utf-8") as myfile:
for words in words_list:
myfile.write(' '.join(words) + '\n')

myfile.write(" ".join(words) + "\n")
146 changes: 102 additions & 44 deletions texar/modules/embedders/position_embedders.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,16 @@
from texar.modules.embedders import embedder_utils
from texar.utils.mode import is_train_mode
from texar.utils.shapes import mask_sequences
from texar.utils.shapes import shape_list

# pylint: disable=arguments-differ, invalid-name

__all__ = [
"PositionEmbedder",
"SinusoidsPositionEmbedder",
"SinusoidsPositionEmbedder"
]


class PositionEmbedder(EmbedderBase):
"""Simple position embedder that maps position indexes into embeddings
via lookup.
Expand Down Expand Up @@ -68,18 +70,21 @@ def __init__(self, init_value=None, position_size=None, hparams=None):

if init_value is None and position_size is None:
raise ValueError(
"Either `init_value` or `position_size` is required.")
"Either `init_value` or `position_size` is required."
)

self._init_parameterized_embedding(init_value, position_size,
self._hparams)
self._init_parameterized_embedding(
init_value, position_size, self._hparams
)

self._position_size = position_size
if position_size is None:
self._position_size = self._num_embeds
if self._position_size != self._num_embeds:
raise ValueError(
'position_size must equal to init_value.shape[0].'
'Got %d and %d' % (self._position_size, self._num_embeds))
"position_size must equal to init_value.shape[0]."
"Got %d and %d" % (self._position_size, self._num_embeds)
)

self._built = True

Expand Down Expand Up @@ -148,11 +153,13 @@ def _build(self, positions=None, sequence_length=None, mode=None, **kwargs):
A `Tensor` of shape `shape(inputs) + embedding dimension`.
"""
# Gets embedder inputs
# pylint:disable=too-many-locals
inputs = positions
if positions is None:
if sequence_length is None:
raise ValueError(
'Either `positions` or `sequence_length` is required.')
"Either `positions` or `sequence_length` is required."
)
max_length = tf.reduce_max(sequence_length)
single_inputs = tf.range(start=0, limit=max_length, dtype=tf.int32)
# Expands `single_inputs` to have shape [batch_size, max_length]
Expand All @@ -166,38 +173,46 @@ def _build(self, positions=None, sequence_length=None, mode=None, **kwargs):

# Gets dropout strategy
st = self._hparams.dropout_strategy
if positions is None and st == 'item':
if positions is None and st == "item":
# If `inputs` is based on `sequence_length`, then dropout
# strategies 'item' and 'item_type' have the same effect, we
# use 'item_type' to avoid unknown noise_shape in the 'item'
# strategy
st = 'item_type'
st = "item_type"

# Dropouts as 'item_type' before embedding
if st == 'item_type':
if st == "item_type":
dropout_layer = self._get_dropout_layer(
self._hparams, dropout_strategy=st)
self._hparams, dropout_strategy=st
)
if dropout_layer:
embedding = dropout_layer.apply(inputs=embedding,
training=is_training)
embedding = dropout_layer.apply(
inputs=embedding, training=is_training
)

# Embeds
outputs = tf.nn.embedding_lookup(embedding, inputs, **kwargs)

# Dropouts as 'item' or 'elements' after embedding
if st != 'item_type':
if st != "item_type":
dropout_layer = self._get_dropout_layer(
self._hparams, ids_rank=ids_rank, dropout_input=outputs,
dropout_strategy=st)
self._hparams,
ids_rank=ids_rank,
dropout_input=outputs,
dropout_strategy=st,
)
if dropout_layer:
outputs = dropout_layer.apply(inputs=outputs,
training=is_training)
outputs = dropout_layer.apply(
inputs=outputs, training=is_training
)

# Optionally masks
if sequence_length is not None:
outputs = mask_sequences(
outputs, sequence_length,
tensor_rank=len(inputs.shape.dims) + self._dim_rank)
outputs,
sequence_length,
tensor_rank=len(inputs.shape.dims) + self._dim_rank,
)

return outputs

Expand Down Expand Up @@ -243,32 +258,45 @@ class SinusoidsPositionEmbedder(EmbedderBase):
Args:
position_size (int): The number of possible positions, e.g., the maximum
sequence length.
sequence length. Set `position_size=None` and
`hparams['cache_embeddings']=False` to enable infinite large or
negative position indexes.
.. document private functions
.. automethod:: _build
"""

def __init__(self, position_size, hparams=None):
EmbedderBase.__init__(self, hparams=hparams)

dim = self._hparams.dim
num_timescales = dim // 2
self._num_embeds = position_size
self._dim = self._hparams.dim
self._cache_embeddings = self._hparams.cache_embeddings

num_timescales = self._dim // 2
min_timescale = self._hparams.min_timescale
max_timescale = self._hparams.max_timescale

positions = tf.to_float(tf.range(position_size, dtype=tf.int32))
log_timescale_increment = (
math.log(float(max_timescale) / float(min_timescale)) /
(tf.to_float(num_timescales) - 1))
log_timescale_increment = math.log(
float(max_timescale) / float(min_timescale)
) / (tf.to_float(num_timescales) - 1)
inv_timescales = min_timescale * tf.exp(
tf.to_float(tf.range(num_timescales)) * -log_timescale_increment)
scaled_time = tf.expand_dims(positions, 1) \
* tf.expand_dims(inv_timescales, 0)
signal = tf.concat([tf.sin(scaled_time), tf.cos(scaled_time)], axis=1)
signal = tf.pad(signal, [[0, 0], [0, tf.mod(dim, 2)]])
self.signal = signal

def default_hparams(self):
tf.to_float(tf.range(num_timescales)) * -log_timescale_increment
)
self.inv_timescales = inv_timescales

if self._cache_embeddings:
if position_size is None:
raise ValueError(
"'position_size' must not be None when "
"'cache_embeddings' is set to True"
)
positions = tf.to_float(tf.range(position_size, dtype=tf.int32))
signal = self._compute_embeddings(positions)
self.signal = signal

@staticmethod
def default_hparams():
"""Returns a dictionary of hyperparameters with default values
We use a geometric sequence of timescales starting with
min_timescale and ending with max_timescale. The number of different
Expand All @@ -280,17 +308,42 @@ def default_hparams(self):
'min_timescale': 1.0,
'max_timescale': 10000.0,
'dim': 512,
'cache_embeddings': True,
'name':'sinusoid_posisiton_embedder',
}
Here:
`"cache_embeddings"`: bool
If `True`, precompute embeddings for positions in range
`[0, position_size - 1]`. This leads to faster lookup but
requires lookup indices to be within this range.
If `False`, embeddings are computed on-the-fly during lookup.
Set to `False` if your application needs to handle sequences
of arbitrary length, or requires embeddings at negative
positions.
"""
hparams = {
'min_timescale': 1.0,
'max_timescale': 1.0e4,
'dim': 512,
'name':'sinusoid_posisiton_embedder',
"min_timescale": 1.0,
"max_timescale": 1.0e4,
"dim": 512,
"cache_embeddings": True,
"name": "sinusoid_posisiton_embedder",
}
return hparams

def _compute_embeddings(self, positions):
inv_timescales = self.inv_timescales
scaled_time = tf.reshape(tf.cast(positions, inv_timescales.dtype),
(-1, 1)) * tf.expand_dims(inv_timescales, 0)
signal = tf.concat(
[tf.sin(scaled_time), tf.cos(scaled_time)], axis=1
)
signal = tf.pad(signal, [[0, 0], [0, tf.mod(self._dim, 2)]])
signal = tf.reshape(signal, shape_list(positions) + [self._dim])
return signal

def _build(self, positions=None, sequence_length=None):
"""Embeds.
Either :attr:`positions` or :attr:`sequence_length` is required:
Expand All @@ -312,18 +365,23 @@ def _build(self, positions=None, sequence_length=None):
Returns:
A `Tensor` of shape `[batch_size, max_time, dim]`.
"""
inputs = positions

if positions is None:
if sequence_length is None:
raise ValueError(
'Either `positions` or `sequence_length` is required.')
"Either `positions` or `sequence_length` is required."
)
max_length = tf.reduce_max(sequence_length)
single_inputs = tf.range(start=0, limit=max_length, dtype=tf.int32)
# Expands `single_inputs` to have shape [batch_size, max_length]
expander = tf.expand_dims(tf.ones_like(sequence_length), -1)
inputs = expander * tf.expand_dims(single_inputs, 0)
else:
inputs = positions

embedding = self.signal
outputs = tf.nn.embedding_lookup(embedding, inputs)
return outputs
if self._cache_embeddings:
outputs = tf.nn.embedding_lookup(self.signal, inputs)
else:
outputs = self._compute_embeddings(inputs)

return outputs

0 comments on commit 5a8fb32

Please sign in to comment.