Skip to content

Commit

Permalink
add amazon and azure AI services and perform refactorings
Browse files Browse the repository at this point in the history
  • Loading branch information
Benjoyo committed Mar 26, 2024
1 parent 69a1e17 commit cec8c38
Show file tree
Hide file tree
Showing 42 changed files with 2,093 additions and 499 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from bpm_ai_core.llm.common.tool import Tool
from bpm_ai_core.prompt.prompt import Prompt
from bpm_ai_core.tracing.tracing import Tracing
from bpm_ai_core.util.json import expand_simplified_json_schema
from bpm_ai_core.util.json_schema import expand_simplified_json_schema

logger = logging.getLogger(__name__)

Expand Down
175 changes: 175 additions & 0 deletions bpm-ai-core/bpm_ai_core/llm/common/blob.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
import contextlib
import mimetypes
import os
from io import BytesIO, BufferedReader
from pathlib import PurePath
from typing import Union, Optional, Dict, Any, Self, cast, Generator

import requests
from pydantic import BaseModel, Field, model_validator

from bpm_ai_core.util.storage import read_file_from_azure_blob, read_file_from_s3, is_s3_url, is_azure_blob_url


class Blob(BaseModel):
"""Blob represents raw data by either reference or value.
Provides an interface to materialize the blob in different representations.
Based on:
https://github.com/langchain-ai/langchain/blob/master/libs/core/langchain_core/document_loaders/blob_loaders.py
"""

data: Union[bytes, str, None]
"""Raw data associated with the blob."""

mimetype: Optional[str] = None
"""MimeType not to be confused with a file extension."""

path: Optional[Union[str, PurePath]] = None
"""Location where the original content was found."""

metadata: Dict[str, Any] = Field(default_factory=dict)
"""Metadata about the blob (e.g., source)"""

class Config:
arbitrary_types_allowed = True
frozen = True

@property
def source(self) -> Optional[str]:
"""The source location of the blob as string if known otherwise none.
If a path is associated with the blob, it will default to the path location.
Unless explicitly set via a metadata field called "source", in which
case that value will be used instead.
"""
if self.metadata and "source" in self.metadata:
return cast(Optional[str], self.metadata["source"])
return str(self.path) if self.path else None

@model_validator(mode='after')
def check_blob_is_valid(self) -> Self:
"""Verify that either data or path is provided."""
if not self.data and not self.path:
raise ValueError("Either data or path must be provided")
return self

def is_image(self) -> bool:
return self.mimetype.startswith("image/") if self.mimetype else False

def is_pdf(self) -> bool:
return self.mimetype == "application/pdf" if self.mimetype else False

def is_audio(self) -> bool:
return self.mimetype.startswith("audio/") if self.mimetype else False

def is_video(self) -> bool:
return self.mimetype.startswith("video/") if self.mimetype else False

def is_text(self) -> bool:
app_text_mimetypes = [
'application/json',
'application/javascript',
'application/manifest+json',
'application/xml',
'application/x-sh',
'application/x-python',
]
return (self.mimetype.startswith("text/") or self.mimetype in app_text_mimetypes) if self.mimetype else False

async def as_bytes(self) -> bytes:
"""Read data as bytes."""
if self.data is None and (self.path.startswith('http://') or self.path.startswith('https://')):
response = requests.get(self.path)
return response.content
elif self.data is None and is_s3_url(self.path):
return await read_file_from_s3(self.path)
elif self.data is None and is_azure_blob_url(self.path):
return await read_file_from_azure_blob(self.path)
elif isinstance(self.data, bytes):
return self.data
elif isinstance(self.data, str):
return self.data.encode("utf-8")
elif self.data is None and self.path:
with open(str(self.path), "rb") as f:
return f.read()
else:
raise ValueError(f"Unable to get bytes for blob {self}")

async def as_bytes_io(self) -> BytesIO:
return BytesIO(await self.as_bytes())

@classmethod
def from_path_or_url(
cls,
path: Union[str, PurePath],
*,
mime_type: Optional[str] = None,
guess_type: bool = True,
metadata: Optional[dict] = None,
) -> "Blob":
"""Load the blob from a path like object.
Args:
path: path like object to file to be read
mime_type: if provided, will be set as the mime-type of the data
guess_type: If True, the mimetype will be guessed from the file extension,
if a mime-type was not provided
metadata: Metadata to associate with the blob
Returns:
Blob instance
"""
if mime_type is None and guess_type:
_mimetype = mimetypes.guess_type(path)[0] if guess_type else None
else:
_mimetype = mime_type

# Convert a path to an absolute path
if os.path.isfile(path):
path = os.path.abspath(path)

# We do not load the data immediately, instead we treat the blob as a
# reference to the underlying data.
return cls(
data=None,
mimetype=_mimetype,
path=path,
metadata=metadata if metadata is not None else {},
)

@classmethod
def from_data(
cls,
data: Union[str, bytes],
*,
mime_type: str,
path: Optional[str] = None,
metadata: Optional[dict] = None,
) -> "Blob":
"""Initialize the blob from in-memory data.
Args:
data: the in-memory data associated with the blob
mime_type: if provided, will be set as the mime-type of the data
path: if provided, will be set as the source from which the data came
metadata: Metadata to associate with the blob
Returns:
Blob instance
"""
return cls(
data=data,
mimetype=mime_type,
path=path,
metadata=metadata if metadata is not None else {},
)

def __repr__(self) -> str:
"""Define the blob representation."""
str_repr = f"Blob {id(self)}"
if self.source:
str_repr += f" {self.source}"
return str_repr
6 changes: 3 additions & 3 deletions bpm-ai-core/bpm_ai_core/llm/common/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,19 @@
import json
from typing import Optional, Literal, Any, Union, List

from PIL import Image
from pydantic import BaseModel, Field, ConfigDict

from bpm_ai_core.llm.common.blob import Blob
from bpm_ai_core.llm.common.tool import Tool
from bpm_ai_core.tracing.tracing import Tracing


class ChatMessage(BaseModel):
content: Optional[Union[str, dict, List[Union[str, Image.Image]]]] = None
content: Optional[Union[str, dict, List[Union[str, Blob]]]] = None
"""
The contents of the message.
Either a string for normal completions,
or a list of strings and images for multimodal completions,
or a list of strings and blobs for multimodal completions,
or a dict for prediction with output schema.
"""

Expand Down
2 changes: 1 addition & 1 deletion bpm-ai-core/bpm_ai_core/llm/common/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from pydantic import BaseModel, validate_call, create_model

from bpm_ai_core.util.json import expand_simplified_json_schema
from bpm_ai_core.util.json_schema import expand_simplified_json_schema


def _create_subset_model(
Expand Down
138 changes: 138 additions & 0 deletions bpm-ai-core/bpm_ai_core/ocr/amazon_textract.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
import asyncio

from typing_extensions import override

from bpm_ai_core.llm.common.blob import Blob
from bpm_ai_core.ocr.ocr import OCR, OCRResult, OCRPage
from bpm_ai_core.util.image import blob_as_images
from bpm_ai_core.util.storage import is_s3_url, parse_s3_url

try:
from aiobotocore.session import get_session
from textractprettyprinter.t_pretty_print import get_text_from_layout_json

has_textract = True
except ImportError:
has_textract = False

IMAGE_FORMATS = ["png", "jpeg", "tiff"]


class AmazonTextractOCR(OCR):
def __init__(self, region_name: str = None):
if not has_textract:
raise ImportError('aiobotocore and/or amazon-textract-prettyprinter are not installed')
self.region_name = region_name

@override
async def _do_process(
self,
blob: Blob,
language: str = None
) -> OCRResult:
if not (blob.is_pdf() or blob.is_image()):
raise ValueError("Blob must be a PDF or an image")

if is_s3_url(blob.path):
pages = await self._get_pages_async(blob.path)
else:
pages = await self._get_pages_sync(blob)

return OCRResult(pages=pages)

async def _get_pages_sync(self, document: Blob):
if document.is_pdf():
_bytes = await document.as_bytes()
else:
_bytes = (await blob_as_images(document, accept_formats=IMAGE_FORMATS, return_bytes=True))[0]
# Create a document from the image bytes asynchronously
async with get_session().create_client("textract", region_name=self.region_name) as client:
# Call Amazon Textract API asynchronously
response = await client.analyze_document(
Document={"Bytes": _bytes},
FeatureTypes=["TABLES", "FORMS", "LAYOUT"]
)
# Convert Textract response to markdown using amazon-textract-prettyprinter
markdown_pages = get_text_from_layout_json(
textract_json=response,
table_format="github",
generate_markdown=True
)
return self.parse_pages(markdown_pages, response)

async def _get_pages_async(self, s3_url: str):
bucket_name, file_path = await parse_s3_url(s3_url)
# Create a document from the image bytes asynchronously
async with get_session().create_client("textract", region_name=self.region_name) as client:
# Call Amazon Textract API asynchronously using start_document_analysis
response = await client.start_document_analysis(
DocumentLocation={'S3Object': {
'Bucket': bucket_name,
'Name': file_path
}},
FeatureTypes=["TABLES", "FORMS", "LAYOUT"]
)

# Get the job ID from the response
job_id = response["JobId"]

# Wait for the job to complete
while True:
response = await client.get_document_analysis(JobId=job_id)
status = response["JobStatus"]
if status in ["SUCCEEDED", "FAILED"]:
break
await asyncio.sleep(1) # Wait for 1 second before checking the status again

if status == "FAILED":
raise Exception(f"Document analysis failed with error: {response['StatusMessage']}")

# Retrieve the results from the completed job
pages = []
markdown_pages = {}
next_token = None
while True:
if next_token:
response = await client.get_document_analysis(JobId=job_id, NextToken=next_token)
else:
response = await client.get_document_analysis(JobId=job_id)

# Convert Textract response to markdown using amazon-textract-prettyprinter
markdown_pages.update(get_text_from_layout_json(
textract_json=response,
table_format="github",
generate_markdown=True
))
pages.extend(self.parse_pages(markdown_pages, response))

next_token = response.get("NextToken")
if not next_token:
break
return pages

@staticmethod
def parse_pages(markdown_pages, response):
pages = []
page_idx = -1
for block in response["Blocks"]:
if block["BlockType"] == "PAGE":
page_idx += 1
bboxes = []
words = []
for _block in block["Relationships"][0]["Ids"]:
block_data = next(b for b in response["Blocks"] if b["Id"] == _block)
if block_data["BlockType"] == "LINE":
for word_block_id in block_data["Relationships"][0]["Ids"]:
word_block = next(b for b in response["Blocks"] if b["Id"] == word_block_id)
if word_block["BlockType"] == "WORD":
bbox = word_block["Geometry"]["BoundingBox"]
x, y, w, h = bbox["Left"], bbox["Top"], bbox["Width"], bbox["Height"]
bboxes.append((x, y, x + w, y + h))
words.append(word_block["Text"])
page_data = OCRPage(
text=list(markdown_pages.values())[page_idx],
words=words,
bboxes=bboxes
)
pages.append(page_data)
return pages
Loading

0 comments on commit cec8c38

Please sign in to comment.