Skip to content

Commit

Permalink
feat: add exponential retries to parallel mode calls (#216)
Browse files Browse the repository at this point in the history
Replace our parallel mode retry logic with the `backoff` library. This
gives us exponential backoff, retryable error codes, etc with just a
decorator, which really cleans up the code.

Changes:
* Refactor `partition_file_via_api` and move the request with backoff to
`call_api`
* Add `backoff` as a dependency and `pip compile`
* Make sure we don't dump api parameters on every parallel call
* Don't allow internal calls to bypass the 503 low memory gate (Should
be handle in the retries like everything else)

To test this, try adding an HTTPException to the code.

Add a non-retryable exception in `partition_pdf_splits`:
```
    # If it's small enough, just process locally
    # (Some kwargs need to be renamed for local partition)
    if len(pdf_pages) <= pages_per_pdf:
        raise HTTPException(status_code=400)

```

When you run this and send a file, you'll get the 400 back immediately:
```
export UNSTRUCTURED_PARALLEL_MODE_ENABLED=true
export UNSTRUCTURED_PARALLEL_MODE_URL=http://localhost:8000/general/v0/general
export UNSTRUCTURED_PARALLEL_NUM_THREADS=1

make run-web-app


curl -X POST 'http://localhost:8000/general/v0/general' --form files=@sample-docs/layout-parser-paper.pdf
{"detail":"Bad Request"}
```

Now, return a 500 error instead and run again. In this case you'll get a
server error, but in the logs you should see that the retries happened:
```
Giving up call_api(...) after 3 tries (fastapi.exceptions.HTTPException)
```
  • Loading branch information
awalker4 authored Sep 13, 2023
1 parent 91c0617 commit ee82bb7
Show file tree
Hide file tree
Showing 7 changed files with 93 additions and 100 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
## 0.0.44-dev0
## 0.0.44-dev1

* Bump unstructured to 0.10.14
* Improve parallel mode retry handling

## 0.0.43

Expand Down
3 changes: 1 addition & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -274,8 +274,7 @@ As mentioned above, processing a pdf using `hi_res` is currently a slow operatio
* `UNSTRUCTURED_PARALLEL_MODE_URL` - the location to send pdf page asynchronously, no default setting at the moment.
* `UNSTRUCTURED_PARALLEL_MODE_THREADS` - the number of threads making requests at once, default is `3`.
* `UNSTRUCTURED_PARALLEL_MODE_SPLIT_SIZE` - the number of pages to be processed in one request, default is `1`.
* `UNSTRUCTURED_PARALLEL_RETRY_ATTEMPTS` - the number of retry attempts, default is `1`.
* `UNSTRUCTURED_PARALLEL_RETRY_BACKOFF_TIME` - the backoff time in seconds for each retry attempt, default is `1.0`.
* `UNSTRUCTURED_PARALLEL_RETRY_ATTEMPTS` - the number of retry attempts on a retryable error, default is `2`. (i.e. 3 attempts are made in total)

### Generating Python files from the pipeline notebooks

Expand Down
141 changes: 73 additions & 68 deletions prepline_general/api/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@
from unstructured.staging.base import convert_to_isd, convert_to_dataframe, elements_from_json
import psutil
import requests
import time
from unstructured_inference.models.chipper import MODEL_TYPES as CHIPPER_MODEL_TYPES
import backoff
import logging
from unstructured_inference.models.chipper import MODEL_TYPES as CHIPPER_MODEL_TYPES


app = FastAPI()
Expand All @@ -39,7 +39,7 @@ def is_expected_response_type(media_type, response_type):
return False


# pipeline-api
logger = logging.getLogger("unstructured_api")


DEFAULT_MIMETYPES = (
Expand Down Expand Up @@ -92,6 +92,38 @@ def get_pdf_splits(pdf_pages, split_size=1):
return split_pdfs


# Do not retry with these status codes
def is_non_retryable(e):
return 400 <= e.status_code < 500


@backoff.on_exception(
backoff.expo,
HTTPException,
max_tries=int(os.environ.get("UNSTRUCTURED_PARALLEL_RETRY_ATTEMPTS", 2)) + 1,
giveup=is_non_retryable,
logger=logger,
)
def call_api(request_url, api_key, filename, file, content_type, **partition_kwargs):
"""
Call the api with the given request_url.
"""
headers = {"unstructured-api-key": api_key}

response = requests.post(
request_url,
files={"files": (filename, file, content_type)},
data=partition_kwargs,
headers=headers,
)

if response.status_code != 200:
detail = response.json().get("detail") or response.text
raise HTTPException(status_code=response.status_code, detail=detail)

return response.text


def partition_file_via_api(file_tuple, request, filename, content_type, **partition_kwargs):
"""
Send the given file to be partitioned remotely with retry logic,
Expand All @@ -103,40 +135,16 @@ def partition_file_via_api(file_tuple, request, filename, content_type, **partit
filename and content_type are passed in the file form data
partition_kwargs holds any form parameters to be sent on
"""
request_url = os.environ.get("UNSTRUCTURED_PARALLEL_MODE_URL")
file, page_offset = file_tuple

request_url = os.environ.get("UNSTRUCTURED_PARALLEL_MODE_URL")
if not request_url:
raise HTTPException(status_code=500, detail="Parallel mode enabled but no url set!")

file, page_offset = file_tuple

headers = {"unstructured-api-key": request.headers.get("unstructured-api-key")}
api_key = request.headers.get("unstructured-api-key")

# Retry parameters
try_attempts = int(os.environ.get("UNSTRUCTURED_PARALLEL_RETRY_ATTEMPTS", 1)) + 1
retry_backoff_time = float(os.environ.get("UNSTRUCTURED_PARALLEL_RETRY_BACKOFF_TIME", 1.0))

while try_attempts >= 0:
response = requests.post(
request_url,
files={"files": (filename, file, content_type)},
data=partition_kwargs,
headers=headers,
)
try_attempts -= 1
non_retryable_error_codes = [400, 401, 402, 403]
status_code = response.status_code
if status_code != 200:
if try_attempts == 0 or status_code in non_retryable_error_codes:
detail = response.json().get("detail") or response.text
raise HTTPException(status_code=response.status_code, detail=detail)
else:
# Retry after backoff
time.sleep(retry_backoff_time)
else:
break

elements = elements_from_json(text=response.text)
result = call_api(request_url, api_key, filename, file, content_type, **partition_kwargs)
elements = elements_from_json(text=result)

# We need to account for the original page numbers
for element in elements:
Expand Down Expand Up @@ -196,9 +204,6 @@ def partition_pdf_splits(
return results


logger = logging.getLogger("unstructured_api")


def pipeline_api(
file,
request=None,
Expand All @@ -215,47 +220,47 @@ def pipeline_api(
m_strategy=[],
m_xml_keep_tags=[],
):
logger.debug(
"pipeline_api input params: {}".format(
json.dumps(
{
"filename": filename,
"file_content_type": file_content_type,
"response_type": response_type,
"m_coordinates": m_coordinates,
"m_encoding": m_encoding,
"m_hi_res_model_name": m_hi_res_model_name,
"m_include_page_breaks": m_include_page_breaks,
"m_ocr_languages": m_ocr_languages,
"m_pdf_infer_table_structure": m_pdf_infer_table_structure,
"m_skip_infer_table_types": m_skip_infer_table_types,
"m_strategy": m_strategy,
"m_xml_keep_tags": m_xml_keep_tags,
},
default=str,
if filename.endswith(".msg"):
# Note(yuming): convert file type for msg files
# since fast api might sent the wrong one.
file_content_type = "application/x-ole-storage"

# We don't want to keep logging the same params for every parallel call
origin_ip = request.headers.get("X-Forwarded-For") or request.client.host
is_internal_request = origin_ip.startswith("10.")

if not is_internal_request:
logger.debug(
"pipeline_api input params: {}".format(
json.dumps(
{
"filename": filename,
"response_type": response_type,
"m_coordinates": m_coordinates,
"m_encoding": m_encoding,
"m_hi_res_model_name": m_hi_res_model_name,
"m_include_page_breaks": m_include_page_breaks,
"m_ocr_languages": m_ocr_languages,
"m_pdf_infer_table_structure": m_pdf_infer_table_structure,
"m_skip_infer_table_types": m_skip_infer_table_types,
"m_strategy": m_strategy,
"m_xml_keep_tags": m_xml_keep_tags,
},
default=str,
)
)
)
)

logger.debug(f"filetype: {file_content_type}")

# If this var is set, reject traffic when free memory is below minimum
# Allow internal requests - these are parallel calls already in progress
mem = psutil.virtual_memory()
memory_free_minimum = int(os.environ.get("UNSTRUCTURED_MEMORY_FREE_MINIMUM_MB", 0))

if memory_free_minimum > 0 and mem.available <= memory_free_minimum * 1024 * 1024:
# Note(yuming): Use X-Forwarded-For header to find the orginal IP for external API
# requests,since LB forwards requests in AWS
origin_ip = request.headers.get("X-Forwarded-For") or request.client.host

if not origin_ip.startswith("10."):
raise HTTPException(
status_code=503, detail="Server is under heavy load. Please try again later."
)

if filename.endswith(".msg"):
# Note(yuming): convert file type for msg files
# since fast api might sent the wrong one.
file_content_type = "application/x-ole-storage"
raise HTTPException(
status_code=503, detail="Server is under heavy load. Please try again later."
)

if file_content_type == "application/pdf":
try:
Expand Down
1 change: 1 addition & 0 deletions requirements/base.in
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ fastapi
uvicorn
ratelimit
requests
backoff
pypdf
pycryptodome
psutil
Expand Down
2 changes: 2 additions & 0 deletions requirements/base.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ anyio==3.7.1
# via
# fastapi
# starlette
backoff==2.2.1
# via -r requirements/base.in
beautifulsoup4==4.12.2
# via unstructured
certifi==2023.7.22
Expand Down
2 changes: 2 additions & 0 deletions requirements/test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ babel==2.12.1
# via jupyterlab-server
backcall==0.2.0
# via ipython
backoff==2.2.1
# via -r requirements/base.txt
beautifulsoup4==4.12.2
# via
# -r requirements/base.txt
Expand Down
41 changes: 12 additions & 29 deletions test_general/api/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import requests
import pandas as pd
from fastapi.testclient import TestClient
from fastapi import HTTPException
from unittest.mock import Mock, ANY

from prepline_general.api.app import app
Expand Down Expand Up @@ -384,23 +385,6 @@ def test_general_api_returns_503(monkeypatch, mocker):

assert response.status_code == 503

mock_client = mocker.patch("fastapi.Request.client")
mock_client.host = "10.5.0.0"
response = client.post(
MAIN_API_ROUTE,
files=[("files", (str(test_file), open(test_file, "rb")))],
)

assert response.status_code == 200

mock_client.host = "10.4.0.0"
response = client.post(
MAIN_API_ROUTE,
files=[("files", (str(test_file), open(test_file, "rb")))],
)

assert response.status_code == 200


class MockResponse:
def __init__(self, status_code):
Expand Down Expand Up @@ -514,17 +498,14 @@ def test_partition_file_via_api_will_retry(monkeypatch, mocker):
monkeypatch.setenv("UNSTRUCTURED_PARALLEL_MODE_URL", "unused")
monkeypatch.setenv("UNSTRUCTURED_PARALLEL_MODE_THREADS", "1")

monkeypatch.setenv("UNSTRUCTURED_PARALLEL_RETRY_ATTEMPTS", "2")
monkeypatch.setenv("UNSTRUCTURED_PARALLEL_RETRY_BACKOFF_TIME", "0.1")

num_calls = 0

# Return a transient error the first time
# Validate the retry count by returning an error the first 2 times
def mock_response(*args, **kwargs):
nonlocal num_calls
num_calls += 1

if num_calls == 1:
if num_calls <= 2:
return MockResponse(status_code=500)

return MockResponse(status_code=200)
Expand All @@ -549,34 +530,36 @@ def mock_response(*args, **kwargs):
assert response.status_code == 200


def test_partition_file_via_api_no_retryable_error_code(monkeypatch, mocker):
def test_partition_file_via_api_not_retryable_error_code(monkeypatch, mocker):
"""
Verify we didn't retry if the error code is not retryable
"""
monkeypatch.setenv("UNSTRUCTURED_PARALLEL_MODE_ENABLED", "true")
monkeypatch.setenv("UNSTRUCTURED_PARALLEL_MODE_URL", "unused")
monkeypatch.setenv("UNSTRUCTURED_PARALLEL_MODE_THREADS", "1")
monkeypatch.setenv("UNSTRUCTURED_PARALLEL_MODE_RETRY_ATTEMPTS", "3")

monkeypatch.setenv("UNSTRUCTURED_PARALLEL_RETRY_ATTEMPTS", "2")
monkeypatch.setenv("UNSTRUCTURED_PARALLEL_RETRY_BACKOFF_TIME", "0.1")
remote_partition = Mock(side_effect=HTTPException(status_code=401))

monkeypatch.setattr(
requests,
"post",
lambda *args, **kwargs: MockResponse(status_code=401),
remote_partition,
)
mock_sleep = mocker.patch("time.sleep")
client = TestClient(app)
test_file = Path("sample-docs") / "layout-parser-paper.pdf"

response = client.post(
MAIN_API_ROUTE,
files=[("files", (str(test_file), open(test_file, "rb"), "application/pdf"))],
data={"pdf_processing_mode": "parallel"},
)

assert response.status_code == 401
assert mock_sleep.call_count == 0

# Often page 2 will start processing before the page 1 exception is raised.
# So we can't assert called_once, but we can assert the count is less than it
# would have been if we used all retries.
assert remote_partition.call_count < 4


def test_password_protected_pdf():
Expand Down

0 comments on commit ee82bb7

Please sign in to comment.