Skip to content

Commit

Permalink
finish implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
agoscinski committed Nov 21, 2024
1 parent 50e8d2e commit 0d4202a
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 50 deletions.
3 changes: 3 additions & 0 deletions aiida_restapi/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,6 @@
'disabled': False,
}
}

# The chunks size for streaming data for download
DOWNLOAD_CHUNK_SIZE = 1024
83 changes: 39 additions & 44 deletions aiida_restapi/routers/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,18 @@
import os
import tempfile
from pathlib import Path
from typing import Any, List, Optional
from typing import Any, Generator, List, Optional

from aiida import orm
from aiida.cmdline.utils.decorators import with_dbenv
from aiida.common.exceptions import EntryPointError
from aiida.common.exceptions import EntryPointError, LicensingException, NotExistent
from aiida.plugins.entry_point import load_entry_point
from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile
from pydantic import ValidationError
from fastapi.responses import StreamingResponse
from pydantic import ValidationError

from aiida_restapi import models
from aiida_restapi.config import DOWNLOAD_CHUNK_SIZE

from .auth import get_current_active_user

Expand Down Expand Up @@ -42,65 +43,59 @@ async def get_nodes_download_formats() -> dict[str, Any]:
return models.Node.get_all_download_formats()


@router.get('/nodes/{nodes_id}', response_model=models.Node)
@router.get('/nodes/{nodes_id}/download')
@with_dbenv()
async def read_node(nodes_id: int) -> Optional[models.Node]:
async def download_node(nodes_id: int, download_format: Optional[str] = None) -> StreamingResponse:
"""Get nodes by id."""
qbobj = orm.QueryBuilder()
qbobj.append(orm.Node, filters={'id': nodes_id}, project='**', tag='node').limit(1)
return qbobj.dict()[0]['node']

from aiida.orm import load_node

@router.get('/nodes/{nodes_id}/download?download_format={download_format}', response_model=StreamingResponse)
@with_dbenv()
async def download_node(nodes_id: int, download_format: str | None = None) -> Optional[models.Node]:
"""Get nodes by id."""
qbobj = orm.QueryBuilder()
qbobj.append(orm.Node, filters={'id': nodes_id}, project='**', tag='node').limit(1)
node = qbobj.dict()[0]['node']

# TODO that is the behavior in the old REST API but I think we should rather send error message
# since it does not agree with the response_model
if len(node) == 0:
return {}
# raise HTTPException(status_code=404, detail=f"Could no find any node with id {nodes_id}")
try:
node = load_node(nodes_id)
except NotExistent:
raise HTTPException(status_code=404, detail=f'Could no find any node with id {nodes_id}')

if download_format is None:
raise RestInputValidationError(
'Please specify the download format. '
raise HTTPException(
status_code=422,
detail='Please specify the download format. '
'The available download formats can be '
'queried using the /nodes/download_formats/ endpoint.'
'queried using the /nodes/download_formats/ endpoint.',
)

elif download_format in node.get_export_formats():
# byteobj, dict with {filename: filecontent}
import io

bytes, metadata = node._exportcontent(download_format)[0]
return StreamingResponse(
io.BytesIO(bytes)
) # , media_type=next(iter(medata.values()))) # Don. know what metadata contains, but could be maybe used to define media_type
# try:
# response['data'] =
# response['status'] = 200
# try:
# response['filename'] = node.filename
# except AttributeError:
# response['filename'] = f'{node.uuid}.{download_format}'
# except LicensingException as exc:
# response['status'] = 500
# response['data'] = str(exc)

# return response
try:
exported_bytes, _ = node._exportcontent(download_format)
except LicensingException as exc:
raise HTTPException(status_code=500, detail=str(exc))

def stream() -> Generator[bytes, None, None]:
with io.BytesIO(exported_bytes) as handler:
while chunk := handler.read(DOWNLOAD_CHUNK_SIZE):
yield chunk

return StreamingResponse(stream(), media_type=f'application/{download_format}')

else:
raise RestInputValidationError(
'The format {} is not supported. '
raise HTTPException(
status_code=422,
detail='The format {} is not supported. '
'The available download formats can be '
'queried using the /nodes/download_formats/ endpoint.'.format(download_format)
'queried using the /nodes/download_formats/ endpoint.'.format(download_format),
)


@router.get('/nodes/{nodes_id}', response_model=models.Node)
@with_dbenv()
async def read_node(nodes_id: int) -> Optional[models.Node]:
"""Get nodes by id."""
qbobj = orm.QueryBuilder()
qbobj.append(orm.Node, filters={'id': nodes_id}, project='**', tag='node').limit(1)
return qbobj.dict()[0]['node']


@router.post('/nodes', response_model=models.Node)
@with_dbenv()
async def create_node(
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ testing = [
'pytest-regressions',
'pytest-cov',
'requests',
'httpx'
'httpx',
'numpy~=1.21'
]

[project.urls]
Expand Down
12 changes: 12 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from datetime import datetime
from typing import Any, Callable, Mapping, MutableMapping, Optional, Union

import numpy as np
import pytest
import pytz
from aiida import orm
Expand Down Expand Up @@ -164,6 +165,17 @@ def default_nodes():
return [node_1.pk, node_2.pk, node_3.pk, node_4.pk]


@pytest.fixture(scope='function')
def array_data_node():
"""Populate database with downloadable node (implmenting a _prepare_* function).
For testing the chunking of the streaming we create an array that needs to be splitted int two chunks."""

from aiida_restapi.config import DOWNLOAD_CHUNK_SIZE

nb_elements = DOWNLOAD_CHUNK_SIZE // 64 + 1
return orm.ArrayData(np.arange(nb_elements, dtype=np.int64)).store()


@pytest.fixture(scope='function')
def authenticate():
"""Authenticate user.
Expand Down
23 changes: 18 additions & 5 deletions tests/test_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,9 +332,22 @@ def test_create_bool_with_extra(client, authenticate): # pylint: disable=unused
assert check_response.json()['extras']['extra_two'] == 'value_2'


def test_get_download_node(default_nodes, client):
"""Test get projectable properites for nodes."""
@pytest.mark.anyio
async def test_get_download_node(array_data_node, async_client):
"""Test download node /nodes/{nodes_id}/download.
The async client is needed to avoid an error caused by an I/O operation on closed file"""

for nodes_id in default_nodes:
response = client.get(f'/nodes/{nodes_id}/download?download_format=cif') #TODO download_format is wrong need to check this
assert response.status_code == 200
# Test that array is correctly downloaded as json
response = await async_client.get(f'/nodes/{array_data_node.pk}/download?download_format=json')
assert response.status_code == 200, response.json()
assert response.json().get('default', None) == array_data_node.get_array().tolist()

# Test exception when wrong download format given
response = await async_client.get(f'/nodes/{array_data_node.pk}/download?download_format=cif')
assert response.status_code == 422, response.json()
assert 'format cif is not supported' in response.json()['detail']

# Test exception when no download format given
response = await async_client.get(f'/nodes/{array_data_node.pk}/download')
assert response.status_code == 422, response.json()
assert 'Please specify the download format' in response.json()['detail']

0 comments on commit 0d4202a

Please sign in to comment.