This repository has been archived by the owner on Jul 7, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #436 from martinpopel/bleu
Proper BLEU evaluation
- Loading branch information
Showing
10 changed files
with
285 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,200 @@ | ||
#!/usr/bin/env python | ||
# coding=utf-8 | ||
# Copyright 2017 The Tensor2Tensor Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""Evaluate BLEU score for all checkpoints in a given directory. | ||
This script can be used in two ways. | ||
To evaluate an already translated file: | ||
`t2t-bleu --translation=my-wmt13.de --reference=wmt13_deen.de` | ||
To evaluate all checkpoints in a given directory: | ||
`t2t-bleu | ||
--model_dir=t2t_train | ||
--data_dir=t2t_data | ||
--translations_dir=my-translations | ||
--problems=translate_ende_wmt32k | ||
--hparams_set=transformer_big_single_gpu | ||
--source=wmt13_deen.en | ||
--reference=wmt13_deen.de` | ||
In addition to the above-mentioned compulsory parameters, | ||
there are optional parameters: | ||
* bleu_variant: cased (case-sensitive), uncased, both (default). | ||
* translations_dir: Where to store the translated files? Default="translations". | ||
* even_subdir: Where in the model_dir to store the even file? Default="", | ||
which means TensorBoard will show it as the same run as the training, but it will warn | ||
about "more than one metagraph event per run". event_subdir can be used e.g. if running | ||
this script several times with different `--decode_hparams="beam_size=$BEAM_SIZE,alpha=$ALPHA"`. | ||
* tag_suffix: Default="", so the tags will be BLEU_cased and BLEU_uncased. Again, tag_suffix | ||
can be used e.g. for different beam sizes if these should be plotted in different graphs. | ||
* min_steps: Don't evaluate checkpoints with less steps. | ||
Default=-1 means check the `last_evaluated_step.txt` file, which contains the number of steps | ||
of the last successfully evaluated checkpoint. | ||
* report_zero: Store BLEU=0 and guess its time based on flags.txt. Default=True. | ||
This is useful, so TensorBoard reports correct relative time for the remaining checkpoints. | ||
This flag is set to False if min_steps is > 0. | ||
* wait_secs: Wait upto N seconds for a new checkpoint. Default=0. | ||
This is useful for continuous evaluation of a running training, | ||
in which case this should be equal to save_checkpoints_secs plus some reserve. | ||
""" | ||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
import os | ||
import time | ||
from collections import namedtuple | ||
from tensor2tensor.utils import decoding | ||
from tensor2tensor.utils import trainer_utils | ||
from tensor2tensor.utils import usr_dir | ||
from tensor2tensor.utils import bleu_hook | ||
import tensorflow as tf | ||
|
||
flags = tf.flags | ||
FLAGS = flags.FLAGS | ||
|
||
# t2t-bleu specific options | ||
flags.DEFINE_string("bleu_variant", "both", "Possible values: cased(case-sensitive), uncased, both(default).") | ||
flags.DEFINE_string("model_dir", "", "Directory to load model checkpoints from.") | ||
flags.DEFINE_string("translation", None, "Path to the MT system translation file") | ||
flags.DEFINE_string("source", None, "Path to the source-language file to be translated") | ||
flags.DEFINE_string("reference", None, "Path to the reference translation file") | ||
flags.DEFINE_string("translations_dir", "translations", "Where to store the translated files") | ||
flags.DEFINE_string("event_subdir", "", "Where in model_dir to store the event file") | ||
flags.DEFINE_string("tag_suffix", "", "What to add to BLEU_cased and BLEU_uncased tags. Default=''.") | ||
flags.DEFINE_integer("min_steps", -1, "Don't evaluate checkpoints with less steps.") | ||
flags.DEFINE_integer("wait_secs", 0, "Wait upto N seconds for a new checkpoint, cf. save_checkpoints_secs.") | ||
flags.DEFINE_bool("report_zero", None, "Store BLEU=0 and guess its time based on flags.txt") | ||
|
||
# options derived from t2t-decode | ||
flags.DEFINE_integer("decode_shards", 1, "Number of decoding replicas.") | ||
flags.DEFINE_string("t2t_usr_dir", "", | ||
"Path to a Python module that will be imported. The " | ||
"__init__.py file should include the necessary imports. " | ||
"The imported files should contain registrations, " | ||
"e.g. @registry.register_model calls, that will then be " | ||
"available to the t2t-decoder.") | ||
flags.DEFINE_string("master", "", "Address of TensorFlow master.") | ||
flags.DEFINE_string("schedule", "train_and_evaluate", | ||
"Must be train_and_evaluate for decoding.") | ||
|
||
Model = namedtuple('Model', 'filename time steps') | ||
|
||
|
||
def read_checkpoints_list(model_dir, min_steps): | ||
models = [Model(x[:-6], os.path.getctime(x), int(x[:-6].rsplit('-')[-1])) | ||
for x in tf.gfile.Glob(os.path.join(model_dir, 'model.ckpt-*.index'))] | ||
return sorted((x for x in models if x.steps > min_steps), key=lambda x: x.steps) | ||
|
||
def main(_): | ||
tf.logging.set_verbosity(tf.logging.INFO) | ||
if FLAGS.translation: | ||
if FLAGS.model_dir: | ||
raise ValueError('Cannot specify both --translation and --model_dir.') | ||
if FLAGS.bleu_variant in ('uncased', 'both'): | ||
bleu = 100 * bleu_hook.bleu_wrapper(FLAGS.reference, FLAGS.translation, case_sensitive=False) | ||
print("BLEU_uncased = %6.2f" % bleu) | ||
if FLAGS.bleu_variant in ('cased', 'both'): | ||
bleu = 100 * bleu_hook.bleu_wrapper(FLAGS.reference, FLAGS.translation, case_sensitive=True) | ||
print("BLEU_cased = %6.2f" % bleu) | ||
return | ||
|
||
usr_dir.import_usr_dir(FLAGS.t2t_usr_dir) | ||
FLAGS.model = FLAGS.model or 'transformer' | ||
FLAGS.output_dir = FLAGS.model_dir | ||
trainer_utils.log_registry() | ||
trainer_utils.validate_flags() | ||
assert FLAGS.schedule == "train_and_evaluate" | ||
data_dir = os.path.expanduser(FLAGS.data_dir) | ||
model_dir = os.path.expanduser(FLAGS.model_dir) | ||
|
||
hparams = trainer_utils.create_hparams( | ||
FLAGS.hparams_set, data_dir, passed_hparams=FLAGS.hparams) | ||
trainer_utils.add_problem_hparams(hparams, FLAGS.problems) | ||
estimator, _ = trainer_utils.create_experiment_components( | ||
data_dir=data_dir, | ||
model_name=FLAGS.model, | ||
hparams=hparams, | ||
run_config=trainer_utils.create_run_config(model_dir)) | ||
|
||
decode_hp = decoding.decode_hparams(FLAGS.decode_hparams) | ||
decode_hp.add_hparam("shards", FLAGS.decode_shards) | ||
decode_hp.add_hparam("shard_id", FLAGS.worker_id) | ||
|
||
os.makedirs(FLAGS.translations_dir, exist_ok=True) | ||
translated_base_file = os.path.join(FLAGS.translations_dir, FLAGS.problems) | ||
event_dir = os.path.join(FLAGS.model_dir, FLAGS.event_subdir) | ||
last_step_file = os.path.join(event_dir, 'last_evaluated_step.txt') | ||
if FLAGS.min_steps == -1: | ||
try: | ||
with open(last_step_file) as ls_file: | ||
FLAGS.min_steps = int(ls_file.read()) | ||
except FileNotFoundError: | ||
FLAGS.min_steps = 0 | ||
if FLAGS.report_zero is None: | ||
FLAGS.report_zero = FLAGS.min_steps == 0 | ||
|
||
models = read_checkpoints_list(model_dir, FLAGS.min_steps) | ||
tf.logging.info("Found %d models with steps: %s" % (len(models), ", ".join(str(x.steps) for x in models))) | ||
|
||
writer = tf.summary.FileWriter(event_dir) | ||
if FLAGS.report_zero: | ||
start_time = os.path.getctime(os.path.join(model_dir, 'flags.txt')) | ||
values = [] | ||
if FLAGS.bleu_variant in ('uncased', 'both'): | ||
values.append(tf.Summary.Value(tag='BLEU_uncased' + FLAGS.tag_suffix, simple_value=0)) | ||
if FLAGS.bleu_variant in ('cased', 'both'): | ||
values.append(tf.Summary.Value(tag='BLEU_cased' + FLAGS.tag_suffix, simple_value=0)) | ||
writer.add_event(tf.summary.Event(summary=tf.Summary(value=values), wall_time=start_time, step=0)) | ||
|
||
exit_time = time.time() + FLAGS.wait_secs | ||
min_steps = FLAGS.min_steps | ||
while True: | ||
if not models and FLAGS.wait_secs: | ||
tf.logging.info('All checkpoints evaluated. Waiting till %s if a new checkpoint appears' % time.asctime(time.localtime(exit_time))) | ||
while True: | ||
time.sleep(10) | ||
models = read_checkpoints_list(model_dir, min_steps) | ||
if models or time.time() > exit_time: | ||
break | ||
if not models: | ||
return | ||
|
||
model = models.pop(0) | ||
exit_time, min_steps = model.time + FLAGS.wait_secs, model.steps | ||
tf.logging.info("Evaluating " + model.filename) | ||
out_file = translated_base_file + '-' + str(model.steps) | ||
tf.logging.set_verbosity(tf.logging.ERROR) # decode_from_file logs all the translations as INFO | ||
decoding.decode_from_file(estimator, FLAGS.source, decode_hp, out_file, checkpoint_path=model.filename) | ||
tf.logging.set_verbosity(tf.logging.INFO) | ||
values = [] | ||
if FLAGS.bleu_variant in ('uncased', 'both'): | ||
bleu = 100 * bleu_hook.bleu_wrapper(FLAGS.reference, out_file, case_sensitive=False) | ||
values.append(tf.Summary.Value(tag='BLEU_uncased' + FLAGS.tag_suffix, simple_value=bleu)) | ||
tf.logging.info("%s: BLEU_uncased = %6.2f" % (model.filename, bleu)) | ||
if FLAGS.bleu_variant in ('cased', 'both'): | ||
bleu = 100 * bleu_hook.bleu_wrapper(FLAGS.reference, out_file, case_sensitive=True) | ||
values.append(tf.Summary.Value(tag='BLEU_cased' + FLAGS.tag_suffix, simple_value=bleu)) | ||
tf.logging.info("%s: BLEU_cased = %6.2f" % (model.filename, bleu)) | ||
writer.add_event(tf.summary.Event(summary=tf.Summary(value=values), wall_time=model.time, step=model.steps)) | ||
writer.flush() | ||
with open(last_step_file, 'w') as ls_file: | ||
ls_file.write(str(model.steps) + '\n') | ||
|
||
|
||
if __name__ == "__main__": | ||
tf.app.run() |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters