-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add amazon and azure AI services and perform refactorings
- Loading branch information
Showing
42 changed files
with
2,093 additions
and
499 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,175 @@ | ||
import contextlib | ||
import mimetypes | ||
import os | ||
from io import BytesIO, BufferedReader | ||
from pathlib import PurePath | ||
from typing import Union, Optional, Dict, Any, Self, cast, Generator | ||
|
||
import requests | ||
from pydantic import BaseModel, Field, model_validator | ||
|
||
from bpm_ai_core.util.storage import read_file_from_azure_blob, read_file_from_s3, is_s3_url, is_azure_blob_url | ||
|
||
|
||
class Blob(BaseModel): | ||
"""Blob represents raw data by either reference or value. | ||
Provides an interface to materialize the blob in different representations. | ||
Based on: | ||
https://github.com/langchain-ai/langchain/blob/master/libs/core/langchain_core/document_loaders/blob_loaders.py | ||
""" | ||
|
||
data: Union[bytes, str, None] | ||
"""Raw data associated with the blob.""" | ||
|
||
mimetype: Optional[str] = None | ||
"""MimeType not to be confused with a file extension.""" | ||
|
||
path: Optional[Union[str, PurePath]] = None | ||
"""Location where the original content was found.""" | ||
|
||
metadata: Dict[str, Any] = Field(default_factory=dict) | ||
"""Metadata about the blob (e.g., source)""" | ||
|
||
class Config: | ||
arbitrary_types_allowed = True | ||
frozen = True | ||
|
||
@property | ||
def source(self) -> Optional[str]: | ||
"""The source location of the blob as string if known otherwise none. | ||
If a path is associated with the blob, it will default to the path location. | ||
Unless explicitly set via a metadata field called "source", in which | ||
case that value will be used instead. | ||
""" | ||
if self.metadata and "source" in self.metadata: | ||
return cast(Optional[str], self.metadata["source"]) | ||
return str(self.path) if self.path else None | ||
|
||
@model_validator(mode='after') | ||
def check_blob_is_valid(self) -> Self: | ||
"""Verify that either data or path is provided.""" | ||
if not self.data and not self.path: | ||
raise ValueError("Either data or path must be provided") | ||
return self | ||
|
||
def is_image(self) -> bool: | ||
return self.mimetype.startswith("image/") if self.mimetype else False | ||
|
||
def is_pdf(self) -> bool: | ||
return self.mimetype == "application/pdf" if self.mimetype else False | ||
|
||
def is_audio(self) -> bool: | ||
return self.mimetype.startswith("audio/") if self.mimetype else False | ||
|
||
def is_video(self) -> bool: | ||
return self.mimetype.startswith("video/") if self.mimetype else False | ||
|
||
def is_text(self) -> bool: | ||
app_text_mimetypes = [ | ||
'application/json', | ||
'application/javascript', | ||
'application/manifest+json', | ||
'application/xml', | ||
'application/x-sh', | ||
'application/x-python', | ||
] | ||
return (self.mimetype.startswith("text/") or self.mimetype in app_text_mimetypes) if self.mimetype else False | ||
|
||
async def as_bytes(self) -> bytes: | ||
"""Read data as bytes.""" | ||
if self.data is None and (self.path.startswith('http://') or self.path.startswith('https://')): | ||
response = requests.get(self.path) | ||
return response.content | ||
elif self.data is None and is_s3_url(self.path): | ||
return await read_file_from_s3(self.path) | ||
elif self.data is None and is_azure_blob_url(self.path): | ||
return await read_file_from_azure_blob(self.path) | ||
elif isinstance(self.data, bytes): | ||
return self.data | ||
elif isinstance(self.data, str): | ||
return self.data.encode("utf-8") | ||
elif self.data is None and self.path: | ||
with open(str(self.path), "rb") as f: | ||
return f.read() | ||
else: | ||
raise ValueError(f"Unable to get bytes for blob {self}") | ||
|
||
async def as_bytes_io(self) -> BytesIO: | ||
return BytesIO(await self.as_bytes()) | ||
|
||
@classmethod | ||
def from_path_or_url( | ||
cls, | ||
path: Union[str, PurePath], | ||
*, | ||
mime_type: Optional[str] = None, | ||
guess_type: bool = True, | ||
metadata: Optional[dict] = None, | ||
) -> "Blob": | ||
"""Load the blob from a path like object. | ||
Args: | ||
path: path like object to file to be read | ||
mime_type: if provided, will be set as the mime-type of the data | ||
guess_type: If True, the mimetype will be guessed from the file extension, | ||
if a mime-type was not provided | ||
metadata: Metadata to associate with the blob | ||
Returns: | ||
Blob instance | ||
""" | ||
if mime_type is None and guess_type: | ||
_mimetype = mimetypes.guess_type(path)[0] if guess_type else None | ||
else: | ||
_mimetype = mime_type | ||
|
||
# Convert a path to an absolute path | ||
if os.path.isfile(path): | ||
path = os.path.abspath(path) | ||
|
||
# We do not load the data immediately, instead we treat the blob as a | ||
# reference to the underlying data. | ||
return cls( | ||
data=None, | ||
mimetype=_mimetype, | ||
path=path, | ||
metadata=metadata if metadata is not None else {}, | ||
) | ||
|
||
@classmethod | ||
def from_data( | ||
cls, | ||
data: Union[str, bytes], | ||
*, | ||
mime_type: str, | ||
path: Optional[str] = None, | ||
metadata: Optional[dict] = None, | ||
) -> "Blob": | ||
"""Initialize the blob from in-memory data. | ||
Args: | ||
data: the in-memory data associated with the blob | ||
mime_type: if provided, will be set as the mime-type of the data | ||
path: if provided, will be set as the source from which the data came | ||
metadata: Metadata to associate with the blob | ||
Returns: | ||
Blob instance | ||
""" | ||
return cls( | ||
data=data, | ||
mimetype=mime_type, | ||
path=path, | ||
metadata=metadata if metadata is not None else {}, | ||
) | ||
|
||
def __repr__(self) -> str: | ||
"""Define the blob representation.""" | ||
str_repr = f"Blob {id(self)}" | ||
if self.source: | ||
str_repr += f" {self.source}" | ||
return str_repr |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
import asyncio | ||
|
||
from typing_extensions import override | ||
|
||
from bpm_ai_core.llm.common.blob import Blob | ||
from bpm_ai_core.ocr.ocr import OCR, OCRResult, OCRPage | ||
from bpm_ai_core.util.image import blob_as_images | ||
from bpm_ai_core.util.storage import is_s3_url, parse_s3_url | ||
|
||
try: | ||
from aiobotocore.session import get_session | ||
from textractprettyprinter.t_pretty_print import get_text_from_layout_json | ||
|
||
has_textract = True | ||
except ImportError: | ||
has_textract = False | ||
|
||
IMAGE_FORMATS = ["png", "jpeg", "tiff"] | ||
|
||
|
||
class AmazonTextractOCR(OCR): | ||
def __init__(self, region_name: str = None): | ||
if not has_textract: | ||
raise ImportError('aiobotocore and/or amazon-textract-prettyprinter are not installed') | ||
self.region_name = region_name | ||
|
||
@override | ||
async def _do_process( | ||
self, | ||
blob: Blob, | ||
language: str = None | ||
) -> OCRResult: | ||
if not (blob.is_pdf() or blob.is_image()): | ||
raise ValueError("Blob must be a PDF or an image") | ||
|
||
if is_s3_url(blob.path): | ||
pages = await self._get_pages_async(blob.path) | ||
else: | ||
pages = await self._get_pages_sync(blob) | ||
|
||
return OCRResult(pages=pages) | ||
|
||
async def _get_pages_sync(self, document: Blob): | ||
if document.is_pdf(): | ||
_bytes = await document.as_bytes() | ||
else: | ||
_bytes = (await blob_as_images(document, accept_formats=IMAGE_FORMATS, return_bytes=True))[0] | ||
# Create a document from the image bytes asynchronously | ||
async with get_session().create_client("textract", region_name=self.region_name) as client: | ||
# Call Amazon Textract API asynchronously | ||
response = await client.analyze_document( | ||
Document={"Bytes": _bytes}, | ||
FeatureTypes=["TABLES", "FORMS", "LAYOUT"] | ||
) | ||
# Convert Textract response to markdown using amazon-textract-prettyprinter | ||
markdown_pages = get_text_from_layout_json( | ||
textract_json=response, | ||
table_format="github", | ||
generate_markdown=True | ||
) | ||
return self.parse_pages(markdown_pages, response) | ||
|
||
async def _get_pages_async(self, s3_url: str): | ||
bucket_name, file_path = await parse_s3_url(s3_url) | ||
# Create a document from the image bytes asynchronously | ||
async with get_session().create_client("textract", region_name=self.region_name) as client: | ||
# Call Amazon Textract API asynchronously using start_document_analysis | ||
response = await client.start_document_analysis( | ||
DocumentLocation={'S3Object': { | ||
'Bucket': bucket_name, | ||
'Name': file_path | ||
}}, | ||
FeatureTypes=["TABLES", "FORMS", "LAYOUT"] | ||
) | ||
|
||
# Get the job ID from the response | ||
job_id = response["JobId"] | ||
|
||
# Wait for the job to complete | ||
while True: | ||
response = await client.get_document_analysis(JobId=job_id) | ||
status = response["JobStatus"] | ||
if status in ["SUCCEEDED", "FAILED"]: | ||
break | ||
await asyncio.sleep(1) # Wait for 1 second before checking the status again | ||
|
||
if status == "FAILED": | ||
raise Exception(f"Document analysis failed with error: {response['StatusMessage']}") | ||
|
||
# Retrieve the results from the completed job | ||
pages = [] | ||
markdown_pages = {} | ||
next_token = None | ||
while True: | ||
if next_token: | ||
response = await client.get_document_analysis(JobId=job_id, NextToken=next_token) | ||
else: | ||
response = await client.get_document_analysis(JobId=job_id) | ||
|
||
# Convert Textract response to markdown using amazon-textract-prettyprinter | ||
markdown_pages.update(get_text_from_layout_json( | ||
textract_json=response, | ||
table_format="github", | ||
generate_markdown=True | ||
)) | ||
pages.extend(self.parse_pages(markdown_pages, response)) | ||
|
||
next_token = response.get("NextToken") | ||
if not next_token: | ||
break | ||
return pages | ||
|
||
@staticmethod | ||
def parse_pages(markdown_pages, response): | ||
pages = [] | ||
page_idx = -1 | ||
for block in response["Blocks"]: | ||
if block["BlockType"] == "PAGE": | ||
page_idx += 1 | ||
bboxes = [] | ||
words = [] | ||
for _block in block["Relationships"][0]["Ids"]: | ||
block_data = next(b for b in response["Blocks"] if b["Id"] == _block) | ||
if block_data["BlockType"] == "LINE": | ||
for word_block_id in block_data["Relationships"][0]["Ids"]: | ||
word_block = next(b for b in response["Blocks"] if b["Id"] == word_block_id) | ||
if word_block["BlockType"] == "WORD": | ||
bbox = word_block["Geometry"]["BoundingBox"] | ||
x, y, w, h = bbox["Left"], bbox["Top"], bbox["Width"], bbox["Height"] | ||
bboxes.append((x, y, x + w, y + h)) | ||
words.append(word_block["Text"]) | ||
page_data = OCRPage( | ||
text=list(markdown_pages.values())[page_idx], | ||
words=words, | ||
bboxes=bboxes | ||
) | ||
pages.append(page_data) | ||
return pages |
Oops, something went wrong.