Skip to content

Commit

Permalink
fix: bring back embedded images in pdf (#198)
Browse files Browse the repository at this point in the history
### Summary
- Fix inferred layout visualization
- Add functionality to check if the extracted image is a full-page image

### Testing
from unstructured_inference.inference.layout import DocumentLayout
doc = DocumentLayout.from_file("sample-docs/embedded-images.pdf")

### Evaluation
The Python script (or Jupyter Notebook) for "layout analysis" can be
used to evaluate the feature implemented in this branch.

PYTHONPATH=. python examples/layout_analysis/visualization.py sample-docs/embedded-images.pdf
  • Loading branch information
christinestraub authored Sep 7, 2023
1 parent 5c295c5 commit bdee102
Show file tree
Hide file tree
Showing 9 changed files with 201 additions and 15 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
## 0.5.23

* Add functionality to bring back embedded images in PDF

## 0.5.22

* Add object-detection classification probabilities to LayoutElement for all currently implemented object detection models
Expand Down
Binary file added sample-docs/embedded-images.pdf
Binary file not shown.
61 changes: 58 additions & 3 deletions test_unstructured_inference/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import pytest
from PIL import Image

from unstructured_inference.inference.elements import EmbeddedTextRegion
from unstructured_inference.inference.elements import EmbeddedTextRegion, Rectangle, TextRegion
from unstructured_inference.inference.layoutelement import LayoutElement


@pytest.fixture()
Expand All @@ -15,9 +16,23 @@ def mock_numpy_image():
return np.zeros((50, 50, 3), np.uint8)


# TODO(alan): Make a better test layout
@pytest.fixture()
def sample_layout():
def mock_rectangle():
return Rectangle(100, 100, 300, 300)


@pytest.fixture()
def mock_text_region():
return TextRegion(100, 100, 300, 300, text="Sample text")


@pytest.fixture()
def mock_layout_element():
return LayoutElement(100, 100, 300, 300, text="Sample text", type="Text")


@pytest.fixture()
def mock_embedded_text_regions():
return [
EmbeddedTextRegion(
x1=453.00277777777774,
Expand Down Expand Up @@ -90,3 +105,43 @@ def sample_layout():
text="Image",
),
]


@pytest.fixture()
def mock_ocr_regions():
return [
EmbeddedTextRegion(10, 10, 90, 90, "0"),
EmbeddedTextRegion(200, 200, 300, 300, "1"),
EmbeddedTextRegion(500, 320, 600, 350, "3"),
]


# TODO(alan): Make a better test layout
@pytest.fixture()
def mock_layout(mock_embedded_text_regions):
return [
LayoutElement(
r.x1,
r.y1,
r.x2,
r.y2,
text=r.text,
type="UncategorizedText",
)
for r in mock_embedded_text_regions
]


@pytest.fixture()
def mock_inferred_layout(mock_embedded_text_regions):
return [
LayoutElement(
r.x1,
r.y1,
r.x2,
r.y2,
text=None,
type="Text",
)
for r in mock_embedded_text_regions
]
120 changes: 116 additions & 4 deletions test_unstructured_inference/inference/test_layout_element.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
import pytest
from layoutparser.elements import TextBlock
from layoutparser.elements.layout_elements import Rectangle as LPRectangle

from unstructured_inference.constants import SUBREGION_THRESHOLD_FOR_OCR
from unstructured_inference.inference.elements import TextRegion
from unstructured_inference.inference.layoutelement import (
LayoutElement,
aggregate_ocr_text_by_block,
get_elements_from_ocr_regions,
merge_inferred_layout_with_ocr_layout,
merge_text_regions,
supplement_layout_with_ocr_elements,
)


Expand All @@ -21,7 +28,7 @@ def test_aggregate_ocr_text_by_block():
assert text == expected


def test_merge_text_regions(sample_layout):
def test_merge_text_regions(mock_embedded_text_regions):
expected = TextRegion(
x1=437.83888888888885,
y1=317.319341111111,
Expand All @@ -30,11 +37,11 @@ def test_merge_text_regions(sample_layout):
text="LayoutParser: A Unified Toolkit for Deep Learning Based Document Image",
)

merged_text_region = merge_text_regions(sample_layout)
merged_text_region = merge_text_regions(mock_embedded_text_regions)
assert merged_text_region == expected


def test_get_elements_from_ocr_regions(sample_layout):
def test_get_elements_from_ocr_regions(mock_embedded_text_regions):
expected = [
LayoutElement(
x1=437.83888888888885,
Expand All @@ -46,5 +53,110 @@ def test_get_elements_from_ocr_regions(sample_layout):
),
]

elements = get_elements_from_ocr_regions(sample_layout)
elements = get_elements_from_ocr_regions(mock_embedded_text_regions)
assert elements == expected


def test_supplement_layout_with_ocr_elements(mock_layout, mock_ocr_regions):
ocr_elements = [
LayoutElement(
r.x1,
r.y1,
r.x2,
r.y2,
text=r.text,
type="UncategorizedText",
)
for r in mock_ocr_regions
]

final_layout = supplement_layout_with_ocr_elements(mock_layout, mock_ocr_regions)

# Check if the final layout contains the original layout elements
for element in mock_layout:
assert element in final_layout

# Check if the final layout contains the OCR-derived elements
assert any(ocr_element in final_layout for ocr_element in ocr_elements)

# Check if the OCR-derived elements that are subregions of layout elements are removed
for element in mock_layout:
for ocr_element in ocr_elements:
if ocr_element.is_almost_subregion_of(element, SUBREGION_THRESHOLD_FOR_OCR):
assert ocr_element not in final_layout


def test_merge_inferred_layout_with_ocr_layout(mock_inferred_layout, mock_ocr_regions):
ocr_elements = [
LayoutElement(
r.x1,
r.y1,
r.x2,
r.y2,
text=r.text,
type="UncategorizedText",
)
for r in mock_ocr_regions
]

final_layout = merge_inferred_layout_with_ocr_layout(mock_inferred_layout, mock_ocr_regions)

# Check if the inferred layout's text attribute is updated with aggregated OCR text
assert final_layout[0].text == mock_ocr_regions[2].text

# Check if the final layout contains both original elements and OCR-derived elements
assert all(element in final_layout for element in mock_inferred_layout)
assert any(element in final_layout for element in ocr_elements)


@pytest.mark.parametrize("is_table", [False, True])
def test_layout_element_extract_text(
mock_layout_element,
mock_text_region,
mock_pil_image,
is_table,
):
if is_table:
mock_layout_element.type = "Table"

extracted_text = mock_layout_element.extract_text(
objects=[mock_text_region],
image=mock_pil_image,
extract_tables=True,
)

assert isinstance(extracted_text, str)
assert "Sample text" in extracted_text

if mock_layout_element.type == "Table":
assert hasattr(mock_layout_element, "text_as_html")


def test_layout_element_do_dict(mock_layout_element):
expected = {
"coordinates": ((100, 100), (100, 300), (300, 300), (300, 100)),
"text": "Sample text",
"type": "Text",
"prob": None,
}

assert mock_layout_element.to_dict() == expected


def test_layout_element_from_region(mock_rectangle):
expected = LayoutElement(100, 100, 300, 300, None, None)

assert LayoutElement.from_region(mock_rectangle) == expected


def test_layout_element_from_lp_textblock():
mock_text_block = TextBlock(
block=LPRectangle(100, 100, 300, 300),
text="Sample Text",
type="Text",
score=0.99,
)

expected = LayoutElement(100, 100, 300, 300, "Sample Text", "Text", 0.99)

assert LayoutElement.from_lp_textblock(mock_text_block) == expected
4 changes: 2 additions & 2 deletions test_unstructured_inference/test_elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@ def test_minimal_containing_rect():
assert rect2.is_in(big_rect)


def test_partition_groups_from_regions(sample_layout):
words = sample_layout
def test_partition_groups_from_regions(mock_embedded_text_regions):
words = mock_embedded_text_regions
groups = elements.partition_groups_from_regions(words)
assert len(groups) == 1
sorted_groups = sorted(groups, key=lambda group: group[0].y1)
Expand Down
2 changes: 1 addition & 1 deletion unstructured_inference/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.5.22" # pragma: no cover
__version__ = "0.5.23" # pragma: no cover
1 change: 1 addition & 0 deletions unstructured_inference/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ class AnnotationResult(Enum):


SUBREGION_THRESHOLD_FOR_OCR = 0.5
FULL_PAGE_REGION_THRESHOLD = 0.99
9 changes: 6 additions & 3 deletions unstructured_inference/inference/layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,9 +285,10 @@ def get_elements_with_detection_model(
and "R_50" not in self.detection_model.model_path
):
threshold_kwargs = {"same_region_threshold": 0.5, "subregion_threshold": 0.5}
inferred_layout = merge_inferred_layout_with_extracted_layout(
merged_layout = merge_inferred_layout_with_extracted_layout(
inferred_layout=inferred_layout,
extracted_layout=self.layout,
page_image_size=self.image.size,
ocr_layout=ocr_layout,
supplement_with_ocr_elements=self.supplement_with_ocr_elements,
**threshold_kwargs,
Expand All @@ -301,14 +302,16 @@ def get_elements_with_detection_model(
and "R_50" not in self.detection_model.model_path
):
threshold_kwargs = {"subregion_threshold": 0.3}
inferred_layout = merge_inferred_layout_with_ocr_layout(
merged_layout = merge_inferred_layout_with_ocr_layout(
inferred_layout=inferred_layout,
ocr_layout=ocr_layout,
supplement_with_ocr_elements=self.supplement_with_ocr_elements,
**threshold_kwargs,
)
else:
merged_layout = inferred_layout

elements = self.get_elements_from_layout(cast(List[TextRegion], inferred_layout))
elements = self.get_elements_from_layout(cast(List[TextRegion], merged_layout))

if self.analysis:
self.inferred_layout = inferred_layout
Expand Down
15 changes: 13 additions & 2 deletions unstructured_inference/inference/layoutelement.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from layoutparser.elements.layout import TextBlock
from PIL import Image

from unstructured_inference.constants import SUBREGION_THRESHOLD_FOR_OCR
from unstructured_inference.constants import FULL_PAGE_REGION_THRESHOLD, SUBREGION_THRESHOLD_FOR_OCR
from unstructured_inference.inference.elements import (
ImageTextRegion,
Rectangle,
Expand Down Expand Up @@ -84,6 +84,7 @@ def interpret_table_block(text_block: TextRegion, image: Image.Image) -> str:
def merge_inferred_layout_with_extracted_layout(
inferred_layout: Collection[LayoutElement],
extracted_layout: Collection[TextRegion],
page_image_size: tuple,
ocr_layout: Optional[List[TextRegion]] = None,
supplement_with_ocr_elements: bool = True,
same_region_threshold: float = 0.75,
Expand All @@ -92,11 +93,21 @@ def merge_inferred_layout_with_extracted_layout(
"""Merge two layouts to produce a single layout."""
extracted_elements_to_add: List[TextRegion] = []
inferred_regions_to_remove = []
w, h = page_image_size
full_page_region = Rectangle(0, 0, w, h)
for extracted_region in extracted_layout:
if isinstance(extracted_region, ImageTextRegion):
# Skip extracted images for this purpose, we don't have the text from them and they
# don't provide good text bounding boxes.
continue

is_full_page_image = region_bounding_boxes_are_almost_the_same(
extracted_region,
full_page_region,
FULL_PAGE_REGION_THRESHOLD,
)

if is_full_page_image:
continue
region_matched = False
for inferred_region in inferred_layout:
if inferred_region.intersects(extracted_region):
Expand Down

0 comments on commit bdee102

Please sign in to comment.