Skip to content

Commit

Permalink
1.4.0 - add option to use visual document extraction models
Browse files Browse the repository at this point in the history
  • Loading branch information
Benjoyo committed Apr 4, 2024
1 parent a315f17 commit cdfd4b0
Show file tree
Hide file tree
Showing 4 changed files with 170 additions and 34 deletions.
62 changes: 30 additions & 32 deletions bpm-ai/bpm_ai/extract/extract.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,23 @@
import itertools
import re
from typing import Callable, Any

from bpm_ai_core.classification.zero_shot_classifier import ZeroShotClassifier
from bpm_ai_core.llm.common.blob import Blob
from bpm_ai_core.question_answering.question_answering import QuestionAnswering
from bpm_ai_core.llm.common.llm import LLM
from bpm_ai_core.ocr.ocr import OCR
from bpm_ai_core.prompt.prompt import Prompt
from bpm_ai_core.speech_recognition.asr import ASRModel
from bpm_ai_core.token_classification.zero_shot_token_classifier import ZeroShotTokenClassifier
from bpm_ai_core.tracing.decorators import trace
from bpm_ai_core.util.file import is_supported_img_file
from bpm_ai_core.util.json_schema import expand_simplified_json_schema
from bpm_ai_core.util.markdown import dict_to_md

from bpm_ai.common.errors import MissingParameterError
from bpm_ai.common.multimodal import transcribe_audio, prepare_images_for_llm_prompt, ocr_documents
from bpm_ai.extract.util import merge_dicts, strip_non_numeric_chars, create_json_object


@trace("bpm-ai-extract", ["llm"])
Expand Down Expand Up @@ -70,18 +74,23 @@ def empty_to_none(v):
async def extract_qa(
qa: QuestionAnswering,
classifier: ZeroShotClassifier,
token_classifier: ZeroShotTokenClassifier,
input_data: dict[str, str | dict | None],
output_schema: dict[str, str | dict],
multiple: bool = False,
multiple_description: str = "",
ocr: OCR | None = None,
vqa: QuestionAnswering | None = None,
token_classifier: ZeroShotTokenClassifier | None = None,
asr: ASRModel | None = None
) -> dict | list[dict]:
if all(value is None for value in input_data.values()):
return input_data

input_data = await ocr_documents(input_data, ocr)
if vqa:
input_img_data = {k: v for k, v in input_data.items() if (isinstance(v, str) and is_supported_img_file(v))}
input_data = {k: v for k, v in input_data.items() if k not in input_img_data.keys()}
else:
input_data = await ocr_documents(input_data, ocr)
input_data = await transcribe_audio(input_data, asr)

if not output_schema:
Expand All @@ -90,33 +99,6 @@ async def extract_qa(
input_md = dict_to_md(input_data).strip()
output_schema = expand_simplified_json_schema(output_schema)["properties"]

def strip_non_numeric_chars(s):
while len(s) > 0 and not s[0].isdigit():
s = s[1:]
while len(s) > 0 and not s[-1].isdigit():
s = s[:-1]
return s

async def create_json_object(target: str, schema, get_value: Callable, current_obj=None, root_obj=None, parent_key='', prefix=''):
if current_obj is None:
current_obj = {}
root_obj = {}

for name, properties in schema.items():
full_key = f'{parent_key}.{name}' if parent_key else name
if properties['type'] == 'object':
current_obj[name] = await create_json_object(target, properties['properties'], get_value, {}, root_obj, full_key)
else:
description = properties.get('description')
enum = properties.get('enum', None)
if prefix:
description = prefix + (description[:1].lower() + description[1:])
value = await get_value(target, full_key, properties['type'], description, enum, root_obj)
current_obj[name] = value
root_obj[full_key] = value

return current_obj

async def extract_value(text: str, field_name: str, field_type: str, description: str, enum: list, existing_values: dict) -> Any:
"""
Extract value of type `field_type` from `text` based on `description`.
Expand All @@ -131,9 +113,16 @@ async def extract_value(text: str, field_name: str, field_type: str, description
question = question.format(**existing_values)
question = question[:1].upper() + question[1:] # capitalize first word

qa_result = await qa.answer(text, question, confidence_threshold=0.01)
if vqa and is_supported_img_file(text):
model = vqa
context = Blob.from_path_or_url(text)
else:
model = qa
context = text

qa_result = await model.answer(context, question, confidence_threshold=0.01 if not vqa else 0.1)

if qa_result is None:
if qa_result is None or qa_result.answer is None:
return None

if field_type == "integer":
Expand All @@ -150,7 +139,16 @@ async def extract_value(text: str, field_name: str, field_type: str, description
return qa_result.answer.strip(" .,;:!?")

if not multiple:
return await create_json_object(input_md, output_schema, extract_value)
result_dict = await create_json_object(input_md, output_schema, extract_value)
if vqa:
img_result_dicts = [
await create_json_object(img, output_schema, extract_value) for img in input_img_data.values()
]
# visual models can't process text and text models can't process documents, so if both modalities
# are present we use crude merging of multiple result dicts, giving precedence to visual results
return merge_dicts([result_dict], precedence_dicts=img_result_dicts)
else:
return result_dict
else:
if not multiple_description or multiple_description.isspace():
raise MissingParameterError("Description for entity type is required.")
Expand Down
71 changes: 71 additions & 0 deletions bpm-ai/bpm_ai/extract/util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from typing import Callable


def merge_dicts(dicts, precedence_dicts):
"""
Merge multiple dictionaries with identical keys and structure into one dictionary,
giving precedence to a list of dictionaries when a value is present or not None.
Args:
dicts (list): A list of dictionaries to merge.
precedence_dicts (list): A list of dictionaries to give precedence to when a value is present or not None.
Returns:
dict: The merged dictionary.
"""
if not dicts:
return {}

merged_dict = {}

for key in dicts[0].keys():
values = [d[key] for d in dicts if key in d]
precedence_values = [d[key] for d in precedence_dicts if key in d and d[key] is not None]

if isinstance(values[0], dict):
# Recursive call for nested dictionaries
merged_dict[key] = merge_dicts(
[v for v in values if v is not None],
[v for v in precedence_values if isinstance(v, dict)]
)
else:
# Find the first non-None value from precedence dictionaries
if precedence_values:
merged_dict[key] = precedence_values[0]
else:
# Find the first non-None value from non-precedence dictionaries
for value in values:
if value is not None:
merged_dict[key] = value
break

return merged_dict


async def create_json_object(target: str, schema, get_value: Callable, current_obj=None, root_obj=None, parent_key='', prefix=''):
if current_obj is None:
current_obj = {}
root_obj = {}

for name, properties in schema.items():
full_key = f'{parent_key}.{name}' if parent_key else name
if properties['type'] == 'object':
current_obj[name] = await create_json_object(target, properties['properties'], get_value, {}, root_obj, full_key)
else:
description = properties.get('description')
enum = properties.get('enum', None)
if prefix:
description = prefix + (description[:1].lower() + description[1:])
value = await get_value(target, full_key, properties['type'], description, enum, root_obj) if target else None
current_obj[name] = value
root_obj[full_key] = value

return current_obj


def strip_non_numeric_chars(s):
while len(s) > 0 and not s[0].isdigit():
s = s[1:]
while len(s) > 0 and not s[-1].isdigit():
s = s[:-1]
return s
2 changes: 1 addition & 1 deletion bpm-ai/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "bpm-ai"
version = "1.3.1"
version = "1.4.0"
description = "AI task automation for BPM engines."
authors = ["Bennet Krause <[email protected]>"]
repository = "https://github.com/holunda-io/bpm-ai"
Expand Down
69 changes: 68 additions & 1 deletion bpm-ai/tests/test_extract.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from bpm_ai_core.llm.common.message import AssistantMessage
from bpm_ai_inference.classification.transformers_classifier import TransformersClassifier
from bpm_ai_inference.ocr.tesseract import TesseractOCR
from bpm_ai_inference.question_answering.transformers_docvqa import TransformersDocVQA
from bpm_ai_inference.question_answering.transformers_qa import TransformersExtractiveQA
from bpm_ai_inference.speech_recognition.faster_whisper import FasterWhisperASR
from bpm_ai_core.testing.fake_llm import FakeLLM
Expand Down Expand Up @@ -253,4 +254,70 @@ async def test_extract_qa_multiple():
multiple=True,
multiple_description="Meal Order"
)
assert actual == [{'product': 'Pizza', 'price_eur': 10.99}, {'product': 'Steak', 'price_eur': 28.89}]
assert actual == [{'product': 'Pizza', 'price_eur': 10.99}, {'product': 'Steak', 'price_eur': 28.89}]


async def test_extract_qa_vqa():
qa = TransformersExtractiveQA()
classifier = TransformersClassifier()
vqa = TransformersDocVQA()
actual = await extract_qa(
qa=qa,
vqa=vqa,
classifier=classifier,
input_data={
"email": "Hey it's me, John Meier. You can find an invoice attached, please pay asap.",
"document": "invoice-simple.webp"
},
output_schema={
"lastname": "What is the family name (not forename)?",
"firstname": "What is the forename of the person named {lastname}?",
"invoice_number": {
"type": "integer",
"description": "What is the invoice number?"
},
"total": "What is the total?"
}
)
assert actual == {'lastname': 'Meier', 'firstname': 'John', 'invoice_number': 102, 'total': '$300.00'}


async def test_extract_vqa():
qa = TransformersExtractiveQA()
classifier = TransformersClassifier()
vqa = TransformersDocVQA()
actual = await extract_qa(
qa=qa,
vqa=vqa,
classifier=classifier,
input_data={
"document": "invoice-simple.webp"
},
output_schema={
"invoice_number": {
"type": "integer",
"description": "What is the invoice number?"
},
"total": "What is the total?"
}
)
assert actual == {'invoice_number': 102, 'total': '$300.00'}


async def test_extract_qa_vqa_no_image():
qa = TransformersExtractiveQA()
classifier = TransformersClassifier()
vqa = TransformersDocVQA()
actual = await extract_qa(
qa=qa,
vqa=vqa,
classifier=classifier,
input_data={
"email": "Hey it's me, John Meier. You can find an invoice attached, please pay asap.",
},
output_schema={
"lastname": "What is the family name (not forename)?",
"firstname": "What is the forename of the person named {lastname}?",
}
)
assert actual == {'lastname': 'Meier', 'firstname': 'John'}

0 comments on commit cdfd4b0

Please sign in to comment.