diff --git a/aiida_restapi/config.py b/aiida_restapi/config.py index fead108..9f877de 100644 --- a/aiida_restapi/config.py +++ b/aiida_restapi/config.py @@ -17,3 +17,6 @@ 'disabled': False, } } + +# The chunks size for streaming data for download +DOWNLOAD_CHUNK_SIZE = 1024 diff --git a/aiida_restapi/routers/nodes.py b/aiida_restapi/routers/nodes.py index 7565dd3..a34ce8f 100644 --- a/aiida_restapi/routers/nodes.py +++ b/aiida_restapi/routers/nodes.py @@ -4,17 +4,18 @@ import os import tempfile from pathlib import Path -from typing import Any, List, Optional +from typing import Any, Generator, List, Optional from aiida import orm from aiida.cmdline.utils.decorators import with_dbenv -from aiida.common.exceptions import EntryPointError +from aiida.common.exceptions import EntryPointError, LicensingException, NotExistent from aiida.plugins.entry_point import load_entry_point from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile -from pydantic import ValidationError from fastapi.responses import StreamingResponse +from pydantic import ValidationError from aiida_restapi import models +from aiida_restapi.config import DOWNLOAD_CHUNK_SIZE from .auth import get_current_active_user @@ -42,65 +43,59 @@ async def get_nodes_download_formats() -> dict[str, Any]: return models.Node.get_all_download_formats() -@router.get('/nodes/{nodes_id}', response_model=models.Node) +@router.get('/nodes/{nodes_id}/download') @with_dbenv() -async def read_node(nodes_id: int) -> Optional[models.Node]: +async def download_node(nodes_id: int, download_format: Optional[str] = None) -> StreamingResponse: """Get nodes by id.""" - qbobj = orm.QueryBuilder() - qbobj.append(orm.Node, filters={'id': nodes_id}, project='**', tag='node').limit(1) - return qbobj.dict()[0]['node'] - + from aiida.orm import load_node -@router.get('/nodes/{nodes_id}/download?download_format={download_format}', response_model=StreamingResponse) -@with_dbenv() -async def download_node(nodes_id: int, download_format: str | None = None) -> Optional[models.Node]: - """Get nodes by id.""" - qbobj = orm.QueryBuilder() - qbobj.append(orm.Node, filters={'id': nodes_id}, project='**', tag='node').limit(1) - node = qbobj.dict()[0]['node'] - - # TODO that is the behavior in the old REST API but I think we should rather send error message - # since it does not agree with the response_model - if len(node) == 0: - return {} - # raise HTTPException(status_code=404, detail=f"Could no find any node with id {nodes_id}") + try: + node = load_node(nodes_id) + except NotExistent: + raise HTTPException(status_code=404, detail=f'Could no find any node with id {nodes_id}') if download_format is None: - raise RestInputValidationError( - 'Please specify the download format. ' + raise HTTPException( + status_code=422, + detail='Please specify the download format. ' 'The available download formats can be ' - 'queried using the /nodes/download_formats/ endpoint.' + 'queried using the /nodes/download_formats/ endpoint.', ) elif download_format in node.get_export_formats(): # byteobj, dict with {filename: filecontent} import io - bytes, metadata = node._exportcontent(download_format)[0] - return StreamingResponse( - io.BytesIO(bytes) - ) # , media_type=next(iter(medata.values()))) # Don. know what metadata contains, but could be maybe used to define media_type - # try: - # response['data'] = - # response['status'] = 200 - # try: - # response['filename'] = node.filename - # except AttributeError: - # response['filename'] = f'{node.uuid}.{download_format}' - # except LicensingException as exc: - # response['status'] = 500 - # response['data'] = str(exc) - - # return response + try: + exported_bytes, _ = node._exportcontent(download_format) + except LicensingException as exc: + raise HTTPException(status_code=500, detail=str(exc)) + + def stream() -> Generator[bytes, None, None]: + with io.BytesIO(exported_bytes) as handler: + while chunk := handler.read(DOWNLOAD_CHUNK_SIZE): + yield chunk + + return StreamingResponse(stream(), media_type=f'application/{download_format}') else: - raise RestInputValidationError( - 'The format {} is not supported. ' + raise HTTPException( + status_code=422, + detail='The format {} is not supported. ' 'The available download formats can be ' - 'queried using the /nodes/download_formats/ endpoint.'.format(download_format) + 'queried using the /nodes/download_formats/ endpoint.'.format(download_format), ) +@router.get('/nodes/{nodes_id}', response_model=models.Node) +@with_dbenv() +async def read_node(nodes_id: int) -> Optional[models.Node]: + """Get nodes by id.""" + qbobj = orm.QueryBuilder() + qbobj.append(orm.Node, filters={'id': nodes_id}, project='**', tag='node').limit(1) + return qbobj.dict()[0]['node'] + + @router.post('/nodes', response_model=models.Node) @with_dbenv() async def create_node( diff --git a/pyproject.toml b/pyproject.toml index 5d1ac28..c23cef3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,7 +60,8 @@ testing = [ 'pytest-regressions', 'pytest-cov', 'requests', - 'httpx' + 'httpx', + 'numpy~=1.21' ] [project.urls] diff --git a/tests/conftest.py b/tests/conftest.py index c39c403..b680eba 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,6 +5,7 @@ from datetime import datetime from typing import Any, Callable, Mapping, MutableMapping, Optional, Union +import numpy as np import pytest import pytz from aiida import orm @@ -164,6 +165,17 @@ def default_nodes(): return [node_1.pk, node_2.pk, node_3.pk, node_4.pk] +@pytest.fixture(scope='function') +def array_data_node(): + """Populate database with downloadable node (implmenting a _prepare_* function). + For testing the chunking of the streaming we create an array that needs to be splitted int two chunks.""" + + from aiida_restapi.config import DOWNLOAD_CHUNK_SIZE + + nb_elements = DOWNLOAD_CHUNK_SIZE // 64 + 1 + return orm.ArrayData(np.arange(nb_elements, dtype=np.int64)).store() + + @pytest.fixture(scope='function') def authenticate(): """Authenticate user. diff --git a/tests/test_nodes.py b/tests/test_nodes.py index 5d652f3..72c8444 100644 --- a/tests/test_nodes.py +++ b/tests/test_nodes.py @@ -332,9 +332,22 @@ def test_create_bool_with_extra(client, authenticate): # pylint: disable=unused assert check_response.json()['extras']['extra_two'] == 'value_2' - def test_get_download_node(default_nodes, client): - """Test get projectable properites for nodes.""" +@pytest.mark.anyio +async def test_get_download_node(array_data_node, async_client): + """Test download node /nodes/{nodes_id}/download. + The async client is needed to avoid an error caused by an I/O operation on closed file""" - for nodes_id in default_nodes: - response = client.get(f'/nodes/{nodes_id}/download?download_format=cif') #TODO download_format is wrong need to check this - assert response.status_code == 200 + # Test that array is correctly downloaded as json + response = await async_client.get(f'/nodes/{array_data_node.pk}/download?download_format=json') + assert response.status_code == 200, response.json() + assert response.json().get('default', None) == array_data_node.get_array().tolist() + + # Test exception when wrong download format given + response = await async_client.get(f'/nodes/{array_data_node.pk}/download?download_format=cif') + assert response.status_code == 422, response.json() + assert 'format cif is not supported' in response.json()['detail'] + + # Test exception when no download format given + response = await async_client.get(f'/nodes/{array_data_node.pk}/download') + assert response.status_code == 422, response.json() + assert 'Please specify the download format' in response.json()['detail']