From aae8d9fbfd1c1c1b478c6c6991bfa567b4bda8d4 Mon Sep 17 00:00:00 2001 From: Alexander Goscinski Date: Tue, 19 Nov 2024 19:22:28 +0100 Subject: [PATCH] WIP: implement nodes download endpoint --- aiida_restapi/routers/nodes.py | 51 +++++++++++++++++++++++++++++++ docs/source/user_guide/graphql.md | 2 +- tests/test_nodes.py | 8 +++++ 3 files changed, 60 insertions(+), 1 deletion(-) diff --git a/aiida_restapi/routers/nodes.py b/aiida_restapi/routers/nodes.py index 80a3637..103e843 100644 --- a/aiida_restapi/routers/nodes.py +++ b/aiida_restapi/routers/nodes.py @@ -7,6 +7,7 @@ from aiida.common.exceptions import EntryPointError from aiida.plugins.entry_point import load_entry_point from fastapi import APIRouter, Depends, File, HTTPException +from fastapi.responses import StreamingResponse from aiida_restapi import models @@ -38,6 +39,56 @@ async def read_node(nodes_id: int) -> Optional[models.Node]: return qbobj.dict()[0]['node'] +@router.get('/nodes/{nodes_id}/download?download_format={download_format}', response_model=StreamingResponse) +@with_dbenv() +async def download_node(nodes_id: int, download_format: str | None = None) -> Optional[models.Node]: + """Get nodes by id.""" + qbobj = orm.QueryBuilder() + qbobj.append(orm.Node, filters={'id': nodes_id}, project='**', tag='node').limit(1) + node = qbobj.dict()[0]['node'] + + # TODO that is the behavior in the old REST API but I think we should rather send error message + # since it does not agree with the response_model + if len(node) == 0: + return {} + # raise HTTPException(status_code=404, detail=f"Could no find any node with id {nodes_id}") + + if download_format is None: + raise RestInputValidationError( + 'Please specify the download format. ' + 'The available download formats can be ' + 'queried using the /nodes/download_formats/ endpoint.' + ) + + elif download_format in node.get_export_formats(): + # byteobj, dict with {filename: filecontent} + import io + + bytes, metadata = node._exportcontent(download_format)[0] + return StreamingResponse( + io.BytesIO(bytes) + ) # , media_type=next(iter(medata.values()))) # Don. know what metadata contains, but could be maybe used to define media_type + # try: + # response['data'] = + # response['status'] = 200 + # try: + # response['filename'] = node.filename + # except AttributeError: + # response['filename'] = f'{node.uuid}.{download_format}' + # except LicensingException as exc: + # response['status'] = 500 + # response['data'] = str(exc) + + # return response + + else: + raise RestInputValidationError( + 'The format {} is not supported. ' + 'The available download formats can be ' + 'queried using the /nodes/download_formats/ endpoint.'.format(download_format) + ) + + @router.post('/nodes', response_model=models.Node) @with_dbenv() async def create_node( diff --git a/docs/source/user_guide/graphql.md b/docs/source/user_guide/graphql.md index 18214a5..4748674 100644 --- a/docs/source/user_guide/graphql.md +++ b/docs/source/user_guide/graphql.md @@ -377,7 +377,7 @@ NOT YET SPECIFICALLY IMPLEMENTED http://localhost:5000/api/v4/nodes/download_formats ``` -NOT YET IMPLEMENTED +Not implemented for GraphQL, please adapt the URL for this use case. ```html diff --git a/tests/test_nodes.py b/tests/test_nodes.py index 5552bd7..5d856b6 100644 --- a/tests/test_nodes.py +++ b/tests/test_nodes.py @@ -285,3 +285,11 @@ def test_create_bool_with_extra(client, authenticate): # pylint: disable=unused assert check_response.status_code == 200, response.content assert check_response.json()['extras']['extra_one'] == 'value_1' assert check_response.json()['extras']['extra_two'] == 'value_2' + + + def test_get_download_node(default_nodes, client): + """Test get projectable properites for nodes.""" + + for nodes_id in default_nodes: + response = client.get(f'/nodes/{nodes_id}/download?download_format=cif') #TODO download_format is wrong need to check this + assert response.status_code == 200