Skip to content

Commit

Permalink
Merge pull request #56 from Aleph-Alpha/fix-image-loading
Browse files Browse the repository at this point in the history
Fix image loading
  • Loading branch information
ahartel authored Oct 4, 2022
2 parents 262e50e + 263b6a4 commit f391534
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 14 deletions.
6 changes: 6 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
TEST_API_URL=https://test.api.aleph-alpha.com
TEST_MODEL=luminous-extended
TEST_TOKEN=
TEST_CHECKPOINT=
TEST_CHECKPOINT_QA=
TEST_CHECKPOINT_SUMMARIZATION=
4 changes: 4 additions & 0 deletions Changelog.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# Changelog

## 2.4.4

- Fix: ImagePrompt.from_url raises if status-code not OK

## 2.4.3

- Fix: Dependency `urllib` is specified to be at least of version `1.26`.
Expand Down
35 changes: 22 additions & 13 deletions aleph_alpha_client/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ class Cropping:
"""
Describes a quadratic crop of the file.
"""

def __init__(self, upper_left_x: int, upper_left_y: int, size: int):
self.upper_left_x = upper_left_x
self.upper_left_y = upper_left_y
Expand All @@ -18,6 +19,7 @@ class ImagePrompt:
An image send as part of a prompt to a model. The image is represented as
base64.
"""

def __init__(
self,
base_64: str,
Expand All @@ -39,20 +41,20 @@ def from_url(cls, url: str):
Downloads a file and prepare it to be used in a prompt.
The image will be [center cropped](https://pytorch.org/vision/stable/transforms.html#torchvision.transforms.CenterCrop)
"""
bytes = requests.get(url).content
return cls.from_bytes(bytes)
return cls.from_bytes(cls._get_url(url))

@classmethod
def from_url_with_cropping(cls, url: str, upper_left_x: int,
upper_left_y: int, crop_size: int):
def from_url_with_cropping(
cls, url: str, upper_left_x: int, upper_left_y: int, crop_size: int
):
"""
Downloads a file and prepare it to be used in a prompt.
upper_left_x, upper_left_y and crop_size are used to crop the image.
"""
cropping = Cropping(upper_left_x=upper_left_x,
upper_left_y=upper_left_y,
size=crop_size)
bytes = requests.get(url).content
cropping = Cropping(
upper_left_x=upper_left_x, upper_left_y=upper_left_y, size=crop_size
)
bytes = cls._get_url(url)
return cls.from_bytes(bytes, cropping=cropping)

@classmethod
Expand All @@ -66,19 +68,26 @@ def from_file(cls, path: str):
return cls.from_bytes(image)

@classmethod
def from_file_with_cropping(cls, path: str, upper_left_x: int,
upper_left_y: int, crop_size: int):
def from_file_with_cropping(
cls, path: str, upper_left_x: int, upper_left_y: int, crop_size: int
):
"""
Load an image from disk and prepare it to be used in a prompt
upper_left_x, upper_left_y and crop_size are used to crop the image.
"""
cropping = Cropping(upper_left_x=upper_left_x,
upper_left_y=upper_left_y,
size=crop_size)
cropping = Cropping(
upper_left_x=upper_left_x, upper_left_y=upper_left_y, size=crop_size
)
with open(path, "rb") as f:
bytes = f.read()
return cls.from_bytes(bytes, cropping=cropping)

@classmethod
def _get_url(cls, url: str) -> bytes:
response = requests.get(url)
response.raise_for_status()
return response.content

def _to_prompt_item(self) -> Dict[str, Any]:
"""
A dict if serialized to JSON is suitable as a prompt element
Expand Down
2 changes: 1 addition & 1 deletion aleph_alpha_client/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "2.4.3"
__version__ = "2.4.4"
16 changes: 16 additions & 0 deletions tests/test_image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from http import HTTPStatus
from pytest import raises
from pytest_httpserver import HTTPServer
from requests import RequestException

from aleph_alpha_client.image import ImagePrompt


def test_from_url_with_non_OK_response(httpserver: HTTPServer):
path = "/image"
httpserver.expect_request(path).respond_with_data(
"html", status=HTTPStatus.FORBIDDEN
)

with raises(RequestException) as e:
ImagePrompt.from_url(httpserver.url_for(path))

0 comments on commit f391534

Please sign in to comment.