Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Converting Embedded image from Documents #158

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 112 additions & 50 deletions src/markitdown/_markitdown.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# type: ignore
from io import BytesIO
import base64
import binascii
import copy
Expand Down Expand Up @@ -66,12 +67,15 @@ class _CustomMarkdownify(markdownify.MarkdownConverter):

- Altering the default heading style to use '#', '##', etc.
- Removing javascript hyperlinks.
- Truncating images with large data:uri sources.
- Using mlm for transcription the images, otherwise, truncation images with large data:uri sources.
- Ensuring URIs are properly escaped, and do not conflict with Markdown syntax
"""

def __init__(self, **options: Any):
options["heading_style"] = options.get("heading_style", markdownify.ATX)

self.mlm_client = options.get("mlm_client")
self.mlm_model = options.get("mlm_model")
# Explicitly cast options to the expected type if necessary
super().__init__(**options)

Expand Down Expand Up @@ -124,7 +128,8 @@ def convert_img(self, el: Any, text: str, convert_as_inline: bool) -> str:

alt = el.attrs.get("alt", None) or ""
src = el.attrs.get("src", None) or ""
title = el.attrs.get("title", None) or ""
title = el.attrs.get("title", None) or ""

title_part = ' "%s"' % title.replace('"', r"\"") if title else ""
if (
convert_as_inline
Expand All @@ -133,8 +138,13 @@ def convert_img(self, el: Any, text: str, convert_as_inline: bool) -> str:
return alt

# Remove dataURIs
if src.startswith("data:"):
src = src.split(",")[0] + "..."
if src.startswith("data:image/"):
if self.mlm_client is not None and self.mlm_model is not None:
md = ImageConverter()
result = md._convert(src, mlm_client=self.mlm_client, mlm_model=self.mlm_model)
src = result.text_content if result is not None else src.split(",")[0] + "..."
else:
src = src.split(",")[0] + "..."

return "![%s](%s%s)" % (alt, src, title_part)

Expand Down Expand Up @@ -199,11 +209,11 @@ def convert(

result = None
with open(local_path, "rt", encoding="utf-8") as fh:
result = self._convert(fh.read())
result = self._convert(fh.read(), **kwargs)

return result

def _convert(self, html_content: str) -> Union[None, DocumentConverterResult]:
def _convert(self, html_content: str, **kwargs) -> Union[None, DocumentConverterResult]:
"""Helper function that converts and HTML string."""

# Parse the string
Expand All @@ -216,10 +226,14 @@ def _convert(self, html_content: str) -> Union[None, DocumentConverterResult]:
# Print only the main content
body_elm = soup.find("body")
webpage_text = ""

# add mlm_client and mlm_model to the options
#options = copy.deepcopy(kwargs)

if body_elm:
webpage_text = _CustomMarkdownify().convert_soup(body_elm)
webpage_text = _CustomMarkdownify(**kwargs).convert_soup(body_elm)
else:
webpage_text = _CustomMarkdownify().convert_soup(soup)
webpage_text = _CustomMarkdownify(**kwargs).convert_soup(soup)

assert isinstance(webpage_text, str)

Expand Down Expand Up @@ -713,7 +727,7 @@ def convert(self, local_path, **kwargs) -> Union[None, DocumentConverterResult]:

result = mammoth.convert_to_html(docx_file, style_map=style_map)
html_content = result.value
result = self._convert(html_content)
result = self._convert(html_content, **kwargs)

return result

Expand Down Expand Up @@ -778,7 +792,9 @@ def convert(self, local_path, **kwargs) -> Union[None, DocumentConverterResult]:
return None

md_content = ""

self._mlm_client = kwargs.get("mlm_client")
self._mlm_model = kwargs.get("mlm_model")

presentation = pptx.Presentation(local_path)
slide_num = 0
for slide in presentation.slides:
Expand All @@ -795,8 +811,8 @@ def convert(self, local_path, **kwargs) -> Union[None, DocumentConverterResult]:
try:
alt_text = shape._element._nvXxPr.cNvPr.attrib.get("descr", "")
except Exception:
pass

pass
# A placeholder name
filename = re.sub(r"\W", "", shape.name) + ".jpg"
md_content += (
Expand All @@ -806,6 +822,7 @@ def convert(self, local_path, **kwargs) -> Union[None, DocumentConverterResult]:
+ filename
+ ")\n"
)
md_content += self._convert_image_to_markdown(shape)

# Tables
if self._is_table(shape):
Expand Down Expand Up @@ -850,6 +867,29 @@ def convert(self, local_path, **kwargs) -> Union[None, DocumentConverterResult]:
text_content=md_content.strip(),
)

def _convert_image_to_markdown(self, shape) -> str:
if not self._is_picture(shape):
return ""

image_converter = ImageConverter() if (self._mlm_client is not None) and (self._mlm_model is not None) else None

if image_converter is not None:
image = shape.image
content_type = image.content_type
blob = image.blob

try:
ext = f"data:{content_type};base64"
image_base64_uri = f"{ext},{base64.b64encode(blob).decode('utf-8')}"
image_description = image_converter._convert(image_base64_uri, mlm_client=self._mlm_client, mlm_model=self._mlm_model)

return ("\n" + image_description.text_content.strip() + "\n")
except Exception as e:
print("Error converting image to markdown")
sys.stderr.write(f"Error converting image to markdown: {e}")

return ""

def _is_picture(self, shape):
if shape.shape_type == pptx.enum.shapes.MSO_SHAPE_TYPE.PICTURE:
return True
Expand Down Expand Up @@ -1037,7 +1077,37 @@ class ImageConverter(MediaConverter):
"""
Converts images to markdown via extraction of metadata (if `exiftool` is installed), OCR (if `easyocr` is installed), and description via a multimodal LLM (if an llm_client is configured).
"""
def _convert(self, data_base64_uri, **kwargs) -> Union[None, DocumentConverterResult]:
# Bail if not an image
try:
content_type = data_base64_uri.split(",")[0].split(";")[0]
if content_type.lower() not in ["data:image/jpg", "data:image/jpeg", "data:image/png"]:
return None
except Exception:
return None

# Try describing the image with GPTV
mlm_client = kwargs.get("mlm_client")
mlm_model = kwargs.get("mlm_model")
md_content = ""

if mlm_client is not None and mlm_model is not None:
md_content = (
"\n# Image Description:\n"
+ self._get_mlm_description(
data_base64_uri,
mlm_client,
mlm_model,
prompt=kwargs.get("mlm_prompt"),
).strip()
+ "\n"
)

return DocumentConverterResult(
title=None,
text_content=md_content,
)

def convert(self, local_path, **kwargs) -> Union[None, DocumentConverterResult]:
# Bail if not an image
extension = kwargs.get("file_extension", "")
Expand All @@ -1064,39 +1134,29 @@ def convert(self, local_path, **kwargs) -> Union[None, DocumentConverterResult]:
if f in metadata:
md_content += f"{f}: {metadata[f]}\n"

# Try describing the image with GPTV
llm_client = kwargs.get("llm_client")
llm_model = kwargs.get("llm_model")
if llm_client is not None and llm_model is not None:
md_content += (
"\n# Description:\n"
+ self._get_llm_description(
local_path,
extension,
llm_client,
llm_model,
prompt=kwargs.get("llm_prompt"),
).strip()
+ "\n"
)
image_base64_uri = self._get_image_base64(local_path, extension)
md_content += self._convert(image_base64_uri, **kwargs).text_content

return DocumentConverterResult(
title=None,
text_content=md_content,
)

def _get_llm_description(self, local_path, extension, client, model, prompt=None):
if prompt is None or prompt.strip() == "":
prompt = "Write a detailed caption for this image."

data_uri = ""

def _get_image_base64(self, local_path, extension):
with open(local_path, "rb") as image_file:
content_type, encoding = mimetypes.guess_type("_dummy" + extension)
if content_type is None:
content_type = "image/jpeg"
image_base64 = base64.b64encode(image_file.read()).decode("utf-8")
data_uri = f"data:{content_type};base64,{image_base64}"

return f"data:{content_type};base64,{image_base64}"

def _get_mlm_description(self, data_base64_uri, client, model, prompt=None):
if prompt is None or prompt.strip() == "":
prompt = "Write a detailed caption for this image."

sys.stderr.write(f"MLM Prompt:\n{prompt}\n")

messages = [
{
"role": "user",
Expand All @@ -1105,7 +1165,7 @@ def _get_llm_description(self, local_path, extension, client, model, prompt=None
{
"type": "image_url",
"image_url": {
"url": data_uri,
"url": data_base64_uri,
},
},
],
Expand All @@ -1115,7 +1175,6 @@ def _get_llm_description(self, local_path, extension, client, model, prompt=None
response = client.chat.completions.create(model=model, messages=messages)
return response.choices[0].message.content


class OutlookMsgConverter(DocumentConverter):
"""Converts Outlook .msg files to markdown by extracting email metadata and content.

Expand Down Expand Up @@ -1477,6 +1536,9 @@ def convert_stream(

# Convert
result = self._convert(temp_path, extensions, **kwargs)
except Exception as e:
sys.stderr.write(f"Error converting stream to markdown: {e}")
pass
# Clean up
finally:
try:
Expand Down Expand Up @@ -1548,22 +1610,22 @@ def _convert(
) -> DocumentConverterResult:
error_trace = ""
for ext in extensions + [None]: # Try last with no extension
for converter in self._page_converters:
_kwargs = copy.deepcopy(kwargs)

# Overwrite file_extension appropriately
if ext is None:
if "file_extension" in _kwargs:
del _kwargs["file_extension"]
else:
_kwargs.update({"file_extension": ext})
_kwargs = copy.deepcopy(kwargs)
# Overwrite file_extension appropriately
if ext is None:
if "file_extension" in _kwargs:
del _kwargs["file_extension"]
else:
_kwargs.update({"file_extension": ext})

# Copy any additional global options
if "llm_client" not in _kwargs and self._llm_client is not None:
_kwargs["llm_client"] = self._llm_client
# Copy any additional global options
if "mlm_client" not in _kwargs and self._llm_client is not None:
_kwargs["mlm_client"] = self._llm_client

if "llm_model" not in _kwargs and self._llm_model is not None:
_kwargs["llm_model"] = self._llm_model
if "mlm_model" not in _kwargs and self._llm_model is not None:
_kwargs["mlm_model"] = self._llm_model

for converter in self._page_converters:

if "style_map" not in _kwargs and self._style_map is not None:
_kwargs["style_map"] = self._style_map
Expand Down
Binary file modified tests/test_files/test.docx
Binary file not shown.
Binary file modified tests/test_files/test.pptx
Binary file not shown.
23 changes: 19 additions & 4 deletions tests/test_markitdown.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
#!/usr/bin/env python3 -m pytest
import io
import os
from dotenv import load_dotenv
import shutil

from openai import OpenAI, AzureOpenAI
import pytest
import requests

Expand Down Expand Up @@ -134,6 +135,7 @@
"data:image/svg+xml,%3Csvg%20width%3D",
]


CSV_CP932_TEST_STRINGS = [
"名前,年齢,住所",
"佐藤太郎,30,東京",
Expand Down Expand Up @@ -189,8 +191,20 @@ def test_markitdown_remote() -> None:
# assert test_string in result.text_content


def test_markitdown_local() -> None:
markitdown = MarkItDown()
def test_markitdown_local(use_mlm = False) -> None:
if (use_mlm):
load_dotenv()
client = AzureOpenAI(
api_key=os.getenv("AZURE_OPENAI_API_KEY"),
api_version=os.getenv("AZURE_OPENAI_API_VERSION"),
azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT")
)
llm_model="gpt-4oModel"

markitdown = MarkItDown(llm_client=client, llm_model=llm_model)
else:
markitdown = MarkItDown()


# Test XLSX processing
result = markitdown.convert(os.path.join(TEST_FILES_DIR, "test.xlsx"))
Expand Down Expand Up @@ -305,7 +319,6 @@ def test_markitdown_exiftool() -> None:
target = f"{key}: {JPG_TEST_EXIFTOOL[key]}"
assert target in result.text_content


def test_markitdown_deprecation() -> None:
try:
with catch_warnings(record=True) as w:
Expand Down Expand Up @@ -361,6 +374,8 @@ def test_markitdown_llm() -> None:

if __name__ == "__main__":
"""Runs this file's tests from the command line."""
test_markitdown_remote()
test_markitdown_local(True)
# test_markitdown_remote()
# test_markitdown_local()
test_markitdown_exiftool()
Expand Down