Skip to content

Commit

Permalink
rfctr(part): add new decorator to replace four (#3650)
Browse files Browse the repository at this point in the history
**Summary**
In preparation for pluggable auto-partitioners, add a new metadata
decorator to replace the four existing ones.

**Additional Context**
"Global" metadata items, those applied to all element on all
partitioners, are applied using a decorator.

Currently there are four decorators where there only needs to be one.
Consolidate those into a single metadata decorator.
One or two additional behaviors of the new decorator will allow us to
remove decorators from delegating partitioners which is a prerequisite
for pluggable auto-partitioners.
  • Loading branch information
scanny authored Sep 25, 2024
1 parent 44bad21 commit 50d75c4
Show file tree
Hide file tree
Showing 19 changed files with 388 additions and 69 deletions.
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
## 0.15.14-dev3
## 0.15.14-dev4

### Enhancements

### Features

* **Add (but do not install) a new post-partitioning decorator to handle metadata added for all file-types, like `.filename`, `.filetype` and `.languages`.** This will be installed in a closely following PR to replace the four currently being used for this purpose.

### Fixes

* **Update Python SDK usage in `partition_via_api`.** Make a minor syntax change to ensure forward compatibility with the upcoming 0.26.0 Python SDK.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@

import os
import pathlib
from typing import Union

import pytest

from test_unstructured.unit_utils import LogCaptureFixture
from unstructured.documents.elements import (
NarrativeText,
PageBreak,
)
from unstructured.partition.lang import (
from unstructured.partition.common.lang import (
_clean_ocr_languages_arg,
_convert_language_code_to_pytesseract_lang_code,
apply_lang_metadata,
Expand Down Expand Up @@ -61,13 +61,13 @@ def test_prepare_languages_for_tesseract_with_multiple_languages():
assert prepare_languages_for_tesseract(languages) == "jpn+jpn_vert+afr+eng+equ"


def test_prepare_languages_for_tesseract_warns_nonstandard_language(caplog):
def test_prepare_languages_for_tesseract_warns_nonstandard_language(caplog: LogCaptureFixture):
languages = ["zzz", "chi"]
assert prepare_languages_for_tesseract(languages) == "chi_sim+chi_sim_vert+chi_tra+chi_tra_vert"
assert "not a valid standard language code" in caplog.text


def test_prepare_languages_for_tesseract_warns_non_tesseract_language(caplog):
def test_prepare_languages_for_tesseract_warns_non_tesseract_language(caplog: LogCaptureFixture):
languages = ["kbd", "eng"]
assert prepare_languages_for_tesseract(languages) == "eng"
assert "not a language supported by Tesseract" in caplog.text
Expand All @@ -79,7 +79,7 @@ def test_prepare_languages_for_tesseract_None_languages():
prepare_languages_for_tesseract(languages)


def test_prepare_languages_for_tesseract_no_valid_languages(caplog):
def test_prepare_languages_for_tesseract_no_valid_languages(caplog: LogCaptureFixture):
languages = [""]
assert prepare_languages_for_tesseract(languages) == "eng"
assert "Failed to find any valid standard language code from languages" in caplog.text
Expand All @@ -96,11 +96,11 @@ def test_prepare_languages_for_tesseract_no_valid_languages(caplog):
("kor", "korean"),
],
)
def test_tesseract_to_paddle_language_valid_codes(tesseract_lang, expected_lang):
def test_tesseract_to_paddle_language_valid_codes(tesseract_lang: str, expected_lang: str):
assert expected_lang == tesseract_to_paddle_language(tesseract_lang)


def test_tesseract_to_paddle_language_invalid_codes(caplog):
def test_tesseract_to_paddle_language_invalid_codes(caplog: LogCaptureFixture):
tesseract_lang = "unsupported_lang"
assert tesseract_to_paddle_language(tesseract_lang) == "en"
assert "unsupported_lang is not a language code supported by PaddleOCR," in caplog.text
Expand All @@ -114,7 +114,7 @@ def test_tesseract_to_paddle_language_invalid_codes(caplog):
("DEU", "german"),
],
)
def test_tesseract_to_paddle_language_case_sensitivity(tesseract_lang, expected_lang):
def test_tesseract_to_paddle_language_case_sensitivity(tesseract_lang: str, expected_lang: str):
assert expected_lang == tesseract_to_paddle_language(tesseract_lang)


Expand All @@ -139,7 +139,7 @@ def test_detect_languages_gets_multiple_languages():
assert detect_languages(text) == ["ces", "pol", "slk"]


def test_detect_languages_warns_for_auto_and_other_input(caplog):
def test_detect_languages_warns_for_auto_and_other_input(caplog: LogCaptureFixture):
text = "This is another short sentence."
languages = ["en", "auto", "rus"]
assert detect_languages(text, languages) == ["eng"]
Expand All @@ -149,10 +149,10 @@ def test_detect_languages_warns_for_auto_and_other_input(caplog):
def test_detect_languages_raises_TypeError_for_invalid_languages():
with pytest.raises(TypeError):
text = "This is a short sentence."
detect_languages(text, languages="eng") == ["eng"]
detect_languages(text, languages="eng") == ["eng"] # type: ignore


def test_apply_lang_metadata_has_no_warning_for_PageBreak(caplog):
def test_apply_lang_metadata_has_no_warning_for_PageBreak(caplog: LogCaptureFixture):
elements = [NarrativeText("Sample text."), PageBreak("")]
elements = list(
apply_lang_metadata(
Expand All @@ -171,7 +171,7 @@ def test_apply_lang_metadata_has_no_warning_for_PageBreak(caplog):
("fr", "fra"),
],
)
def test_convert_language_code_to_pytesseract_lang_code(lang_in, expected_lang):
def test_convert_language_code_to_pytesseract_lang_code(lang_in: str, expected_lang: str):
assert expected_lang == _convert_language_code_to_pytesseract_lang_code(lang_in)


Expand All @@ -187,7 +187,7 @@ def test_convert_language_code_to_pytesseract_lang_code(lang_in, expected_lang):
("deu+spa", "deu+spa"), # correct input
],
)
def test_clean_ocr_languages_arg(input_ocr_langs, expected):
def test_clean_ocr_languages_arg(input_ocr_langs: str, expected: str):
assert _clean_ocr_languages_arg(input_ocr_langs) == expected


Expand All @@ -209,12 +209,15 @@ def test_detect_languages_handles_spelled_out_languages():
],
)
def test_check_language_args_uses_languages_when_ocr_languages_and_languages_are_both_defined(
languages: Union[list[str], str],
ocr_languages: Union[list[str], str, None],
languages: list[str],
ocr_languages: list[str] | str,
expected_langs: list[str],
caplog,
caplog: LogCaptureFixture,
):
returned_langs = check_language_args(languages=languages, ocr_languages=ocr_languages)
returned_langs = check_language_args(
languages=languages,
ocr_languages=ocr_languages,
)
for lang in returned_langs: # type: ignore
assert lang in expected_langs
assert "ocr_languages" in caplog.text
Expand All @@ -231,10 +234,10 @@ def test_check_language_args_uses_languages_when_ocr_languages_and_languages_are
],
)
def test_check_language_args_uses_ocr_languages_when_languages_is_empty_or_None(
languages: Union[list[str], str],
ocr_languages: Union[list[str], str, None],
languages: list[str],
ocr_languages: str,
expected_langs: list[str],
caplog,
caplog: LogCaptureFixture,
):
returned_langs = check_language_args(languages=languages, ocr_languages=ocr_languages)
for lang in returned_langs: # type: ignore
Expand All @@ -250,19 +253,15 @@ def test_check_language_args_uses_ocr_languages_when_languages_is_empty_or_None(
],
)
def test_check_language_args_returns_None(
languages: Union[list[str], str, None],
ocr_languages: Union[list[str], str, None],
languages: list[str],
ocr_languages: None,
):
returned_langs = check_language_args(languages=languages, ocr_languages=ocr_languages)
assert returned_langs is None


def test_check_language_args_returns_auto(
languages=["eng", "spa", "auto"],
ocr_languages=None,
):
returned_langs = check_language_args(languages=languages, ocr_languages=ocr_languages)
assert returned_langs == ["auto"]
def test_check_language_args_returns_auto():
assert check_language_args(languages=["eng", "spa", "auto"], ocr_languages=None) == ["auto"]


@pytest.mark.parametrize(
Expand All @@ -273,8 +272,11 @@ def test_check_language_args_returns_auto(
],
)
def test_check_language_args_raises_error_when_ocr_languages_contains_auto(
languages: Union[list[str], str, None],
ocr_languages: Union[list[str], str, None],
languages: list[str],
ocr_languages: str | list[str],
):
with pytest.raises(ValueError):
check_language_args(languages=languages, ocr_languages=ocr_languages)
check_language_args(
languages=languages,
ocr_languages=ocr_languages,
)
Loading

0 comments on commit 50d75c4

Please sign in to comment.