Skip to content

Commit

Permalink
Add the "download image" command (#97)
Browse files Browse the repository at this point in the history
  • Loading branch information
Stannislav authored Nov 9, 2021
1 parent 3824394 commit 8ac7513
Show file tree
Hide file tree
Showing 2 changed files with 176 additions and 1 deletion.
81 changes: 81 additions & 0 deletions src/atldld/cli/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,3 +131,84 @@ def download_dataset(

with metadata_path.open("w") as f:
json.dump(metadata, f, indent=4)


@download_cmd.command("image", help="Download a section image.")
@click.argument("image_id", type=str)
@click.argument(
"output_folder",
type=click.Path(exists=False, dir_okay=True, path_type=pathlib.Path),
)
@click.option(
"--downsample-img",
type=int,
default=0,
show_default=True,
help="Downsampling coefficient for the image download.",
)
@click.option(
"-e",
"--include-expression",
is_flag=True,
help="Include expression image.",
)
def download_image(
image_id,
output_folder,
downsample_img,
include_expression,
):
"""Download a section image."""
import textwrap

import requests
from PIL import Image

from atldld.utils import get_image

cli_input = f"""
Image ID : {image_id}
Downsample image : {downsample_img}
Include expression : {include_expression}
Output folder : {output_folder.resolve()}
"""
click.secho(textwrap.dedent(cli_input).strip(), fg="blue")

# Download the image and the expression
try:
click.secho("Downloading image...", fg="green")
image = get_image(image_id, downsample=downsample_img)
if include_expression:
click.secho("Downloading expression...", fg="green")
expression = get_image(image_id, expression=True, downsample=downsample_img)
else:
expression = None
except requests.ConnectionError:
raise click.ClickException(
"no network connection; check your network or try again later."
)
except requests.HTTPError as exc:
raise click.ClickException(
"the server responded with an error: "
f"{exc.response.reason} ({exc.response.status_code})"
)

# Prepare paths
if not output_folder.exists():
output_folder.mkdir(parents=True)
file_name = f"{image_id}-{downsample_img}"

# Save the image to disk
click.secho("Saving the image...", fg="green")
image_path = output_folder / f"{file_name}.png"
Image.fromarray(image, mode="RGB").save(image_path)
click.secho(f"Image saved to {image_path.resolve().as_uri()}", fg="green")

# Save the expression to disk
if expression is not None:
click.secho("Saving the expression...", fg="green")
expression_path = output_folder / f"{file_name}-expression.png"
Image.fromarray(expression, mode="RGB").save(expression_path)
click.secho(
f"Expression saved to {expression_path.resolve().as_uri()}", fg="green"
)
96 changes: 95 additions & 1 deletion tests/cli/test_download_cmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,18 @@
# You should have received a copy of the GNU Lesser General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
import json
import re
from collections import defaultdict

import numpy as np
import pytest
import requests
import responses
from click.testing import CliRunner
from PIL import Image

from atldld.base import DisplacementField
from atldld.cli.download import download_cmd, download_dataset
from atldld.cli.download import download_cmd, download_dataset, download_image
from atldld.sync import DatasetNotFoundError


Expand Down Expand Up @@ -130,3 +134,93 @@ def fake_run():
"section_coordinate",
"section_coordinate_scaled",
} == set(image_metadata.keys())


class TestDownloadImage:
@responses.activate
def test_connection_errors_are_caught(self):
responses.add(responses.GET, re.compile(r""), requests.ConnectionError())

# Testing
runner = CliRunner()
result = runner.invoke(download_image, ["0", "out"])
assert result.exit_code != 0 # should exit with an error code
assert "Error: no network connection" in result.output

@responses.activate
def test_nonexistent_image_is_handled(self):
response = requests.Response()
response.reason = "some reason"
response.status_code = 999
responses.add(
responses.GET, re.compile(r""), requests.HTTPError(response=response)
)

# Testing
runner = CliRunner()
result = runner.invoke(download_image, ["0", "out"])
assert result.exit_code != 0 # should exit with an error code
assert "Error:" in result.output
assert response.reason in result.output
assert str(response.status_code) in result.output

@responses.activate
def test_normal_download(self, mocker, tmp_path):
responses.add(responses.GET, re.compile(r""))

image_arr = np.random.randint(0, 255, size=(5, 5, 3), dtype=np.uint8)
image_id = 12345
output_dir = tmp_path / "out"
get_image = mocker.patch("atldld.utils.get_image", return_value=image_arr)

runner = CliRunner()
result = runner.invoke(download_image, [str(image_id), str(output_dir)])

assert result.exit_code == 0
get_image.assert_called_once()

image_path = output_dir / f"{image_id}-0.png"
assert image_path.exists()
with Image.open(image_path) as lazy_img:
saved_image = np.asarray(lazy_img)
assert np.array_equal(image_arr, saved_image)

@responses.activate
def test_download_with_expression(self, mocker, tmp_path):
responses.add(responses.GET, re.compile(r""))

image_arr = np.random.randint(0, 255, size=(5, 5, 3), dtype=np.uint8)
expression_arr = np.random.randint(0, 255, size=(5, 5, 3), dtype=np.uint8)
image_id = 12345
downsample_img = 5

def get_image(_image_id, folder=None, expression=False, downsample=0):
if expression:
return expression_arr
else:
return image_arr

mocker.patch("atldld.utils.get_image", new=get_image)

runner = CliRunner()
args = [
str(image_id),
str(tmp_path),
"--include-expression",
"--downsample-img",
str(downsample_img),
]
result = runner.invoke(download_image, args)

assert result.exit_code == 0

image_path = tmp_path / f"{image_id}-{downsample_img}.png"
expression_path = tmp_path / f"{image_id}-{downsample_img}-expression.png"
assert image_path.exists()
assert expression_path.exists()
with Image.open(image_path) as lazy_img:
saved_image = np.asarray(lazy_img)
with Image.open(expression_path) as lazy_img:
saved_expression = np.asarray(lazy_img)
assert np.array_equal(image_arr, saved_image)
assert np.array_equal(expression_arr, saved_expression)

0 comments on commit 8ac7513

Please sign in to comment.