diff --git a/bertalign/aligner.py b/bertalign/aligner.py index 6c9f916..b3678ad 100644 --- a/bertalign/aligner.py +++ b/bertalign/aligner.py @@ -1,52 +1,89 @@ import numpy as np + from bertalign import model from bertalign.corelib import * from bertalign.utils import * class Bertalign: def __init__(self, - src, - tgt, + src_raw, + tgt_raw, max_align=5, top_k=3, win=5, skip=-0.1, margin=True, len_penalty=True, - is_split=False, + input_type='raw', + src_lang=None, + tgt_lang=None, ): - + self.max_align = max_align self.top_k = top_k self.win = win self.skip = skip self.margin = margin self.len_penalty = len_penalty - - src = clean_text(src) - tgt = clean_text(tgt) - src_lang = detect_lang(src) - tgt_lang = detect_lang(tgt) - - if is_split: + + input_types = ['raw', 'lines', 'tokenized'] + if input_type not in input_types: + raise ValueError("Invalid input type '%s'. Expected one of: %s" % (input_type, input_types)) + + if input_type == 'lines': + # need to split + src = clean_text(src_raw) + tgt = clean_text(tgt_raw) src_sents = src.splitlines() tgt_sents = tgt.splitlines() - else: + + if not src_lang: + src_lang = detect_lang(src) + if not tgt_lang: + tgt_lang = detect_lang(tgt) + + + elif input_type == 'raw': + src = clean_text(src_raw) + tgt = clean_text(tgt_raw) + + if not src_lang: + src_lang = detect_lang(src) + if not tgt_lang: + tgt_lang = detect_lang(tgt) + src_sents = split_sents(src, src_lang) tgt_sents = split_sents(tgt, tgt_lang) - + + elif input_type == 'tokenized': + + if not src_lang: + src_lang = detect_lang(src) + if not tgt_lang: + tgt_lang = detect_lang(tgt) + + src_sents = src_raw + tgt_sents = tgt_raw + + if not src_lang: + src_lang = detect_lang(' '.join(src_sents)) + if not tgt_lang: + tgt_lang = detect_lang(' '.join(tgt_sents)) + + src_num = len(src_sents) tgt_num = len(tgt_sents) - + src_lang = LANG.ISO[src_lang] tgt_lang = LANG.ISO[tgt_lang] - + print("Source language: {}, Number of sentences: {}".format(src_lang, src_num)) print("Target language: {}, Number of sentences: {}".format(tgt_lang, tgt_num)) - print("Embedding source and target text using {} ...".format(model.model_name)) + print("Embedding source text using {} ...".format(model.model_name)) src_vecs, src_lens = model.transform(src_sents, max_align - 1) + print("Embedding target text using {} ...".format(model.model_name)) tgt_vecs, tgt_lens = model.transform(tgt_sents, max_align - 1) char_ratio = np.sum(src_lens[0,]) / np.sum(tgt_lens[0,]) @@ -62,7 +99,7 @@ def __init__(self, self.char_ratio = char_ratio self.src_vecs = src_vecs self.tgt_vecs = tgt_vecs - + def align_sents(self): print("Performing first-step alignment ...") @@ -71,7 +108,7 @@ def align_sents(self): first_w, first_path = find_first_search_path(self.src_num, self.tgt_num) first_pointers = first_pass_align(self.src_num, self.tgt_num, first_w, first_path, first_alignment_types, D, I) first_alignment = first_back_track(self.src_num, self.tgt_num, first_pointers, first_path, first_alignment_types) - + print("Performing second-step alignment ...") second_alignment_types = get_alignment_types(self.max_align) second_w, second_path = find_second_search_path(first_alignment, self.win, self.src_num, self.tgt_num) @@ -79,10 +116,10 @@ def align_sents(self): second_w, second_path, second_alignment_types, self.char_ratio, self.skip, margin=self.margin, len_penalty=self.len_penalty) second_alignment = second_back_track(self.src_num, self.tgt_num, second_pointers, second_path, second_alignment_types) - + print("Finished! Successfully aligning {} {} sentences to {} {} sentences\n".format(self.src_num, self.src_lang, self.tgt_num, self.tgt_lang)) self.result = second_alignment - + def print_sents(self): for bead in (self.result): src_line = self._get_line(bead[0], self.src_sents) diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..8801e6f --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,37 @@ +import pytest +import json +import os + +def load_json(fpath): + with open(fpath) as json_file: + data = json.load(json_file) + return data + + +@pytest.fixture +def text_and_berg_expected_results(): + """Fixture for the Text und Berg expected result.""" + + cur_dir = os.path.dirname(os.path.realpath(__file__)) + fname = 'gold_standard_text_und_berg.json' + fpath = os.path.join(cur_dir, fname) + data = load_json(fpath) + yield data + + + +@pytest.fixture +def text_and_berg_inputs(): + r"""Input data for Text and Berg.""" + + src_dir = 'text+berg/de' + tgt_dir = 'text+berg/fr' + gold_dir = 'text+berg/gold' + + data = [] + for file in os.listdir(src_dir): + src_file = os.path.join(src_dir, file).replace("\\","/") + tgt_file = os.path.join(tgt_dir, file).replace("\\","/") + data.append((file, src_file, tgt_file, gold_dir)) + + yield data diff --git a/tests/gold_standard_text_und_berg.json b/tests/gold_standard_text_und_berg.json new file mode 100644 index 0000000..7a85923 --- /dev/null +++ b/tests/gold_standard_text_und_berg.json @@ -0,0 +1,58 @@ +{ + "002": { + "recall_strict": 0.9588477366255144, + "recall_lax": 0.9917695473251029, + "precision_strict": 0.9505703422053232, + "precision_lax": 0.9847908745247148, + "f1_strict": 0.9546910980176844, + "f1_lax": 0.9882678910702977 + }, + "006": { + "recall_strict": 0.9694444444444444, + "recall_lax": 0.9944444444444445, + "precision_strict": 0.9607329842931938, + "precision_lax": 0.9869109947643979, + "f1_strict": 0.9650690556740179, + "f1_lax": 0.9906633978772443 + }, + "001": { + "recall_strict": 0.9553191489361702, + "recall_lax": 0.9957446808510638, + "precision_strict": 0.9496981891348089, + "precision_lax": 0.9879275653923542, + "f1_strict": 0.9525003764104154, + "f1_lax": 0.991820720553515 + }, + "005": { + "recall_strict": 0.9502982107355865, + "recall_lax": 0.9960238568588469, + "precision_strict": 0.9453860640301318, + "precision_lax": 0.9887005649717514, + "f1_strict": 0.9478357731413087, + "f1_lax": 0.9923487000713064 + }, + "007": { + "recall_strict": 0.937592867756315, + "recall_lax": 0.9910846953937593, + "precision_strict": 0.9265536723163842, + "precision_lax": 0.9830508474576272, + "f1_strict": 0.9320405838088075, + "f1_lax": 0.9870514243433223 + }, + "004": { + "recall_strict": 0.9404145077720207, + "recall_lax": 0.9896373056994818, + "precision_strict": 0.9320148331273177, + "precision_lax": 0.9851668726823238, + "f1_strict": 0.9361958300767388, + "f1_lax": 0.9873970292534215 + }, + "003": { + "recall_strict": 0.9405594405594405, + "recall_lax": 0.9906759906759907, + "precision_strict": 0.9319955406911928, + "precision_lax": 0.9866220735785953, + "f1_strict": 0.9362579076540054, + "f1_lax": 0.9886448763947483 + } +} diff --git a/tests/requirements.txt b/tests/requirements.txt new file mode 100644 index 0000000..e079f8a --- /dev/null +++ b/tests/requirements.txt @@ -0,0 +1 @@ +pytest diff --git a/tests/test_results.py b/tests/test_results.py new file mode 100644 index 0000000..730dc27 --- /dev/null +++ b/tests/test_results.py @@ -0,0 +1,78 @@ +import os + +import pytest + +from bertalign import Bertalign +from bertalign.eval import read_alignments +from bertalign.eval import score_multiple +from bertalign.eval import log_final_scores + + +def align_text_and_berg(filespec, aligner_spec): + r"""Align Text and Berg using the original aligner.""" + + test_alignments = [] + gold_alignments = [] + + results = {} + + for test_data in filespec: + + file, src_file, tgt_file, gold_dir = test_data + src = open(src_file, "rt", encoding="utf-8").read() + tgt = open(tgt_file, "rt", encoding="utf-8").read() + + print("Start aligning {} to {}".format(src_file, tgt_file)) + # aligner = Bertalign(src, tgt, is_split=True) + aligner = Bertalign(src, tgt, **aligner_spec) + aligner.align_sents() + test_alignments.append(aligner.result) + + gold_file = os.path.join(gold_dir, file) + gold_alignments.append(read_alignments(gold_file)) + + scores = score_multiple(gold_list=gold_alignments, test_list=test_alignments) + log_final_scores(scores) + results[file] = scores + return results + + +@pytest.mark.skip(reason="is_split is removed at the moment.") +def test_aligner_original(text_and_berg_expected_results, text_and_berg_inputs): + r"""Test results for the original aligner using is_split.""" + + aligner_spec = {"is_split": True} + result = align_text_and_berg(text_and_berg_inputs, aligner_spec) + + for file in result: + expected = text_and_berg_expected_results[file] + calculated = result[file] + for metric in expected: + assert expected[metric] == calculated[metric], "Result mismatch" + + +aligner_spec_explicit = { + "input_type": "lines", + "src_lang": "de", + "tgt_lang": "fr", +} + + +aligner_spec_detect = { + "input_type": "lines", +} + +# @pytest.mark.parametrize("aligner_spec", [aligner_spec_detect]) +@pytest.mark.parametrize("aligner_spec", [aligner_spec_explicit]) +def test_aligner_altered_parametrization( + text_and_berg_expected_results, text_and_berg_inputs, aligner_spec +): + r"""Test results for the aligner using input_type and languages.""" + + result = align_text_and_berg(text_and_berg_inputs, aligner_spec) + + for file in result: + expected = text_and_berg_expected_results[file] + calculated = result[file] + for metric in expected: + assert expected[metric] == calculated[metric], "Result mismatch"