diff --git a/nmt/inference.py b/nmt/inference.py index 6f589337a..8ef712a94 100644 --- a/nmt/inference.py +++ b/nmt/inference.py @@ -129,6 +129,18 @@ def single_worker_inference(infer_model, # Read data infer_data = load_data(inference_input_file, hparams) + infer_data_feed = infer_data + + #sort the input file if no hparams.inference_indices is defined + index_pair = {} + new_input =[] + if hparams.inference_indices is None: + input_length = [(len(line.split()), i) for i, line in enumerate(infer_data)] + sorted_input_bylens = sorted(input_length) + for ni, (_, oi) in enumerate(sorted_input_bylens): + new_input.append(infer_data[oi]) + index_pair[oi] = ni + infer_data_feed = new_input with tf.Session( graph=infer_model.graph, config=utils.get_config_proto()) as sess: @@ -137,7 +149,7 @@ def single_worker_inference(infer_model, sess.run( infer_model.iterator.initializer, feed_dict={ - infer_model.src_placeholder: infer_data, + infer_model.src_placeholder: infer_data_feed, infer_model.batch_size_placeholder: hparams.infer_batch_size }) # Decode @@ -162,7 +174,8 @@ def single_worker_inference(infer_model, subword_option=hparams.subword_option, beam_width=hparams.beam_width, tgt_eos=hparams.eos, - num_translations_per_input=hparams.num_translations_per_input) + num_translations_per_input=hparams.num_translations_per_input, + index_pair=index_pair) def multi_worker_inference(infer_model, diff --git a/nmt/utils/nmt_utils.py b/nmt/utils/nmt_utils.py index 72f71b5c2..ee9976947 100644 --- a/nmt/utils/nmt_utils.py +++ b/nmt/utils/nmt_utils.py @@ -37,7 +37,8 @@ def decode_and_evaluate(name, beam_width, tgt_eos, num_translations_per_input=1, - decode=True): + decode=True, + index_pair=[]): """Decode a test set and compute a score according to the evaluation task.""" # Decode if decode: @@ -51,6 +52,7 @@ def decode_and_evaluate(name, num_translations_per_input = max( min(num_translations_per_input, beam_width), 1) + translation = [] while True: try: nmt_outputs, _ = model.decode(sess) @@ -62,17 +64,22 @@ def decode_and_evaluate(name, for sent_id in range(batch_size): for beam_id in range(num_translations_per_input): - translation = get_translation( + translation.append(get_translation( nmt_outputs[beam_id], sent_id, tgt_eos=tgt_eos, - subword_option=subword_option) - trans_f.write((translation + b"\n").decode("utf-8")) + subword_option=subword_option)) except tf.errors.OutOfRangeError: utils.print_time( " done, num sentences %d, num translations per input %d" % (num_sentences, num_translations_per_input), start_time) break + if len(index_pair) is 0: + for sentence in translation: + trans_f.write(sentence + b"\n").decode("utf-8") + else: + for i in index_pair: + trans_f.write((translation[index_pair[i]] + b"\n").decode("utf-8")) # Evaluation evaluation_scores = {}