Skip to content
This repository has been archived by the owner on Jul 7, 2023. It is now read-only.

Commit

Permalink
Merge pull request #436 from martinpopel/bleu
Browse files Browse the repository at this point in the history
Proper BLEU evaluation
  • Loading branch information
lukaszkaiser authored Dec 1, 2017
2 parents e3cd447 + 7ba78a2 commit bb1173a
Show file tree
Hide file tree
Showing 10 changed files with 285 additions and 10 deletions.
7 changes: 5 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -126,10 +126,13 @@ t2t-decoder \
--output_dir=$TRAIN_DIR \
--decode_hparams="beam_size=$BEAM_SIZE,alpha=$ALPHA" \
--decode_from_file=$DECODE_FILE
cat $DECODE_FILE.$MODEL.$HPARAMS.beam$BEAM_SIZE.alpha$ALPHA.decodes
--decode_to_file=translation.en
```

# Eval BLEU

t2t-bleu --translation=translation.en --reference=ref-translation.de

---

## Installation
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
'tensor2tensor/bin/t2t-datagen',
'tensor2tensor/bin/t2t-decoder',
'tensor2tensor/bin/t2t-make-tf-configs',
'tensor2tensor/bin/t2t-bleu',
],
install_requires=[
'bz2file',
Expand Down
200 changes: 200 additions & 0 deletions tensor2tensor/bin/t2t-bleu
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2017 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Evaluate BLEU score for all checkpoints in a given directory.
This script can be used in two ways.
To evaluate an already translated file:
`t2t-bleu --translation=my-wmt13.de --reference=wmt13_deen.de`
To evaluate all checkpoints in a given directory:
`t2t-bleu
--model_dir=t2t_train
--data_dir=t2t_data
--translations_dir=my-translations
--problems=translate_ende_wmt32k
--hparams_set=transformer_big_single_gpu
--source=wmt13_deen.en
--reference=wmt13_deen.de`
In addition to the above-mentioned compulsory parameters,
there are optional parameters:
* bleu_variant: cased (case-sensitive), uncased, both (default).
* translations_dir: Where to store the translated files? Default="translations".
* even_subdir: Where in the model_dir to store the even file? Default="",
which means TensorBoard will show it as the same run as the training, but it will warn
about "more than one metagraph event per run". event_subdir can be used e.g. if running
this script several times with different `--decode_hparams="beam_size=$BEAM_SIZE,alpha=$ALPHA"`.
* tag_suffix: Default="", so the tags will be BLEU_cased and BLEU_uncased. Again, tag_suffix
can be used e.g. for different beam sizes if these should be plotted in different graphs.
* min_steps: Don't evaluate checkpoints with less steps.
Default=-1 means check the `last_evaluated_step.txt` file, which contains the number of steps
of the last successfully evaluated checkpoint.
* report_zero: Store BLEU=0 and guess its time based on flags.txt. Default=True.
This is useful, so TensorBoard reports correct relative time for the remaining checkpoints.
This flag is set to False if min_steps is > 0.
* wait_secs: Wait upto N seconds for a new checkpoint. Default=0.
This is useful for continuous evaluation of a running training,
in which case this should be equal to save_checkpoints_secs plus some reserve.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import time
from collections import namedtuple
from tensor2tensor.utils import decoding
from tensor2tensor.utils import trainer_utils
from tensor2tensor.utils import usr_dir
from tensor2tensor.utils import bleu_hook
import tensorflow as tf

flags = tf.flags
FLAGS = flags.FLAGS

# t2t-bleu specific options
flags.DEFINE_string("bleu_variant", "both", "Possible values: cased(case-sensitive), uncased, both(default).")
flags.DEFINE_string("model_dir", "", "Directory to load model checkpoints from.")
flags.DEFINE_string("translation", None, "Path to the MT system translation file")
flags.DEFINE_string("source", None, "Path to the source-language file to be translated")
flags.DEFINE_string("reference", None, "Path to the reference translation file")
flags.DEFINE_string("translations_dir", "translations", "Where to store the translated files")
flags.DEFINE_string("event_subdir", "", "Where in model_dir to store the event file")
flags.DEFINE_string("tag_suffix", "", "What to add to BLEU_cased and BLEU_uncased tags. Default=''.")
flags.DEFINE_integer("min_steps", -1, "Don't evaluate checkpoints with less steps.")
flags.DEFINE_integer("wait_secs", 0, "Wait upto N seconds for a new checkpoint, cf. save_checkpoints_secs.")
flags.DEFINE_bool("report_zero", None, "Store BLEU=0 and guess its time based on flags.txt")

# options derived from t2t-decode
flags.DEFINE_integer("decode_shards", 1, "Number of decoding replicas.")
flags.DEFINE_string("t2t_usr_dir", "",
"Path to a Python module that will be imported. The "
"__init__.py file should include the necessary imports. "
"The imported files should contain registrations, "
"e.g. @registry.register_model calls, that will then be "
"available to the t2t-decoder.")
flags.DEFINE_string("master", "", "Address of TensorFlow master.")
flags.DEFINE_string("schedule", "train_and_evaluate",
"Must be train_and_evaluate for decoding.")

Model = namedtuple('Model', 'filename time steps')


def read_checkpoints_list(model_dir, min_steps):
models = [Model(x[:-6], os.path.getctime(x), int(x[:-6].rsplit('-')[-1]))
for x in tf.gfile.Glob(os.path.join(model_dir, 'model.ckpt-*.index'))]
return sorted((x for x in models if x.steps > min_steps), key=lambda x: x.steps)

def main(_):
tf.logging.set_verbosity(tf.logging.INFO)
if FLAGS.translation:
if FLAGS.model_dir:
raise ValueError('Cannot specify both --translation and --model_dir.')
if FLAGS.bleu_variant in ('uncased', 'both'):
bleu = 100 * bleu_hook.bleu_wrapper(FLAGS.reference, FLAGS.translation, case_sensitive=False)
print("BLEU_uncased = %6.2f" % bleu)
if FLAGS.bleu_variant in ('cased', 'both'):
bleu = 100 * bleu_hook.bleu_wrapper(FLAGS.reference, FLAGS.translation, case_sensitive=True)
print("BLEU_cased = %6.2f" % bleu)
return

usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
FLAGS.model = FLAGS.model or 'transformer'
FLAGS.output_dir = FLAGS.model_dir
trainer_utils.log_registry()
trainer_utils.validate_flags()
assert FLAGS.schedule == "train_and_evaluate"
data_dir = os.path.expanduser(FLAGS.data_dir)
model_dir = os.path.expanduser(FLAGS.model_dir)

hparams = trainer_utils.create_hparams(
FLAGS.hparams_set, data_dir, passed_hparams=FLAGS.hparams)
trainer_utils.add_problem_hparams(hparams, FLAGS.problems)
estimator, _ = trainer_utils.create_experiment_components(
data_dir=data_dir,
model_name=FLAGS.model,
hparams=hparams,
run_config=trainer_utils.create_run_config(model_dir))

decode_hp = decoding.decode_hparams(FLAGS.decode_hparams)
decode_hp.add_hparam("shards", FLAGS.decode_shards)
decode_hp.add_hparam("shard_id", FLAGS.worker_id)

os.makedirs(FLAGS.translations_dir, exist_ok=True)
translated_base_file = os.path.join(FLAGS.translations_dir, FLAGS.problems)
event_dir = os.path.join(FLAGS.model_dir, FLAGS.event_subdir)
last_step_file = os.path.join(event_dir, 'last_evaluated_step.txt')
if FLAGS.min_steps == -1:
try:
with open(last_step_file) as ls_file:
FLAGS.min_steps = int(ls_file.read())
except FileNotFoundError:
FLAGS.min_steps = 0
if FLAGS.report_zero is None:
FLAGS.report_zero = FLAGS.min_steps == 0

models = read_checkpoints_list(model_dir, FLAGS.min_steps)
tf.logging.info("Found %d models with steps: %s" % (len(models), ", ".join(str(x.steps) for x in models)))

writer = tf.summary.FileWriter(event_dir)
if FLAGS.report_zero:
start_time = os.path.getctime(os.path.join(model_dir, 'flags.txt'))
values = []
if FLAGS.bleu_variant in ('uncased', 'both'):
values.append(tf.Summary.Value(tag='BLEU_uncased' + FLAGS.tag_suffix, simple_value=0))
if FLAGS.bleu_variant in ('cased', 'both'):
values.append(tf.Summary.Value(tag='BLEU_cased' + FLAGS.tag_suffix, simple_value=0))
writer.add_event(tf.summary.Event(summary=tf.Summary(value=values), wall_time=start_time, step=0))

exit_time = time.time() + FLAGS.wait_secs
min_steps = FLAGS.min_steps
while True:
if not models and FLAGS.wait_secs:
tf.logging.info('All checkpoints evaluated. Waiting till %s if a new checkpoint appears' % time.asctime(time.localtime(exit_time)))
while True:
time.sleep(10)
models = read_checkpoints_list(model_dir, min_steps)
if models or time.time() > exit_time:
break
if not models:
return

model = models.pop(0)
exit_time, min_steps = model.time + FLAGS.wait_secs, model.steps
tf.logging.info("Evaluating " + model.filename)
out_file = translated_base_file + '-' + str(model.steps)
tf.logging.set_verbosity(tf.logging.ERROR) # decode_from_file logs all the translations as INFO
decoding.decode_from_file(estimator, FLAGS.source, decode_hp, out_file, checkpoint_path=model.filename)
tf.logging.set_verbosity(tf.logging.INFO)
values = []
if FLAGS.bleu_variant in ('uncased', 'both'):
bleu = 100 * bleu_hook.bleu_wrapper(FLAGS.reference, out_file, case_sensitive=False)
values.append(tf.Summary.Value(tag='BLEU_uncased' + FLAGS.tag_suffix, simple_value=bleu))
tf.logging.info("%s: BLEU_uncased = %6.2f" % (model.filename, bleu))
if FLAGS.bleu_variant in ('cased', 'both'):
bleu = 100 * bleu_hook.bleu_wrapper(FLAGS.reference, out_file, case_sensitive=True)
values.append(tf.Summary.Value(tag='BLEU_cased' + FLAGS.tag_suffix, simple_value=bleu))
tf.logging.info("%s: BLEU_cased = %6.2f" % (model.filename, bleu))
writer.add_event(tf.summary.Event(summary=tf.Summary(value=values), wall_time=model.time, step=model.steps))
writer.flush()
with open(last_step_file, 'w') as ls_file:
ls_file.write(str(model.steps) + '\n')


if __name__ == "__main__":
tf.app.run()
Empty file modified tensor2tensor/bin/t2t-datagen
100644 → 100755
Empty file.
7 changes: 5 additions & 2 deletions tensor2tensor/bin/t2t-decoder
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,10 @@ import tensorflow as tf
flags = tf.flags
FLAGS = flags.FLAGS

flags.DEFINE_string("output_dir", "", "Training directory to load from.")
flags.DEFINE_string("output_dir", "",
"Training directory where the latest checkpoint is used.")
flags.DEFINE_string("checkpoint_path", None,
"Path to the model checkpoint. Overrides output_dir.")
flags.DEFINE_string("decode_from_file", None,
"Path to the source file for decoding")
flags.DEFINE_string("decode_to_file", None,
Expand Down Expand Up @@ -90,7 +93,7 @@ def main(_):
decoding.decode_interactively(estimator, decode_hp)
elif FLAGS.decode_from_file:
decoding.decode_from_file(estimator, FLAGS.decode_from_file, decode_hp,
FLAGS.decode_to_file)
FLAGS.decode_to_file, checkpoint_path=FLAGS.checkpoint_path)
else:
decoding.decode_from_dataset(
estimator,
Expand Down
Empty file modified tensor2tensor/bin/t2t-make-tf-configs
100644 → 100755
Empty file.
Empty file modified tensor2tensor/bin/t2t-trainer
100644 → 100755
Empty file.
68 changes: 67 additions & 1 deletion tensor2tensor/utils/bleu_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,17 @@

import collections
import math
import re
import sys
import unicodedata

# Dependency imports

import numpy as np
# pylint: disable=redefined-builtin
from six.moves import xrange
from six.moves import zip
import six
# pylint: enable=redefined-builtin

import tensorflow as tf
Expand Down Expand Up @@ -92,10 +96,17 @@ def compute_bleu(reference_corpus,
matches_by_order[len(ngram) - 1] += overlap[ngram]
for ngram in translation_ngram_counts:
possible_matches_by_order[len(ngram)-1] += translation_ngram_counts[ngram]
assert reference_length, "no reference provided"
assert translation_length, "no translation provided"
precisions = [0] * max_order
smooth = 1.0
for i in xrange(0, max_order):
if possible_matches_by_order[i] > 0:
precisions[i] = matches_by_order[i] / possible_matches_by_order[i]
if matches_by_order[i] > 0:
precisions[i] = matches_by_order[i] / possible_matches_by_order[i]
else:
smooth *= 2
precisions[i] = 1.0 / (smooth * possible_matches_by_order[i])
else:
precisions[i] = 0.0

Expand Down Expand Up @@ -131,3 +142,58 @@ def bleu_score(predictions, labels, **unused_kwargs):

bleu = tf.py_func(compute_bleu, (labels, outputs), tf.float32)
return bleu, tf.constant(1.0)


class UnicodeRegex:
"""Ad-hoc hack to recognize all punctuation and symbols.
without dependening on https://pypi.python.org/pypi/regex/."""
def _property_chars(prefix):
return ''.join(six.unichr(x) for x in range(sys.maxunicode)
if unicodedata.category(six.unichr(x)).startswith(prefix))
punctuation = _property_chars('P')
nondigit_punct_re = re.compile(r'([^\d])([' + punctuation + r'])')
punct_nondigit_re = re.compile(r'([' + punctuation + r'])([^\d])')
symbol_re = re.compile('([' + _property_chars('S') + '])')


def bleu_tokenize(string):
r"""Tokenize a string following the official BLEU implementation.
See https://github.com/moses-smt/mosesdecoder/blob/master/scripts/generic/mteval-v14.pl#L954-L983
In our case, the input string is expected to be just one line
and no HTML entities de-escaping is needed.
So we just tokenize on punctuation and symbols,
except when a punctuation is preceded and followed by a digit
(e.g. a comma/dot as a thousand/decimal separator).
Note that a numer (e.g. a year) followed by a dot at the end of sentence is NOT tokenized,
i.e. the dot stays with the number because `s/(\p{P})(\P{N})/ $1 $2/g`
does not match this case (unless we add a space after each sentence).
However, this error is already in the original mteval-v14.pl
and we want to be consistent with it.
Args:
string: the input string
Returns:
a list of tokens
"""
string = UnicodeRegex.nondigit_punct_re.sub(r'\1 \2 ', string)
string = UnicodeRegex.punct_nondigit_re.sub(r' \1 \2', string)
string = UnicodeRegex.symbol_re.sub(r' \1 ', string)
return string.split()


def bleu_wrapper(ref_filename, hyp_filename, case_sensitive=False):
"""Compute BLEU for two files (reference and hypothesis translation)."""
# TODO: Does anyone care about Python2 compatibility?
ref_lines = open(ref_filename, 'rt', encoding='utf-8').read().splitlines()
hyp_lines = open(hyp_filename, 'rt', encoding='utf-8').read().splitlines()
assert len(ref_lines) == len(hyp_lines)
if not case_sensitive:
ref_lines = [x.lower() for x in ref_lines]
hyp_lines = [x.lower() for x in hyp_lines]
ref_tokens = [bleu_tokenize(x) for x in ref_lines]
hyp_tokens = [bleu_tokenize(x) for x in hyp_lines]
return compute_bleu(ref_tokens, hyp_tokens)
8 changes: 5 additions & 3 deletions tensor2tensor/utils/bleu_hook_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,9 @@ def testComputeNotEqual(self):
translation_corpus = [[1, 2, 3, 4]]
reference_corpus = [[5, 6, 7, 8]]
bleu = bleu_hook.compute_bleu(reference_corpus, translation_corpus)
actual_bleu = 0.0
self.assertEqual(bleu, actual_bleu)
# The smoothing prevents 0 for small corpora
actual_bleu = 0.0798679
self.assertAllClose(bleu, actual_bleu, atol=1e-03)

def testComputeMultipleBatch(self):
translation_corpus = [[1, 2, 3, 4], [5, 6, 7, 0]]
Expand All @@ -53,8 +54,9 @@ def testComputeMultipleNgrams(self):
reference_corpus = [[1, 2, 1, 13], [12, 6, 7, 4, 8, 9, 10]]
translation_corpus = [[1, 2, 1, 3], [5, 6, 7, 4]]
bleu = bleu_hook.compute_bleu(reference_corpus, translation_corpus)
actual_bleu = 0.486
actual_bleu = 0.3436
self.assertAllClose(bleu, actual_bleu, atol=1e-03)


if __name__ == '__main__':
tf.test.main()
4 changes: 2 additions & 2 deletions tensor2tensor/utils/decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def decode_from_dataset(estimator,
tf.logging.info("Completed inference on %d samples." % num_predictions) # pylint: disable=undefined-loop-variable


def decode_from_file(estimator, filename, decode_hp, decode_to_file=None):
def decode_from_file(estimator, filename, decode_hp, decode_to_file=None, checkpoint_path=None):
"""Compute predictions on entries in filename and write them out."""
if not decode_hp.batch_size:
decode_hp.batch_size = 32
Expand Down Expand Up @@ -230,7 +230,7 @@ def input_fn():
return _decode_input_tensor_to_features_dict(example, hparams)

decodes = []
result_iter = estimator.predict(input_fn)
result_iter = estimator.predict(input_fn, checkpoint_path=checkpoint_path)
for result in result_iter:
if decode_hp.return_beams:
beam_decodes = []
Expand Down

0 comments on commit bb1173a

Please sign in to comment.