Skip to content

Commit

Permalink
WIP: implement nodes download endpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
agoscinski committed Nov 19, 2024
1 parent 3b7e343 commit aae8d9f
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 1 deletion.
51 changes: 51 additions & 0 deletions aiida_restapi/routers/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from aiida.common.exceptions import EntryPointError
from aiida.plugins.entry_point import load_entry_point
from fastapi import APIRouter, Depends, File, HTTPException
from fastapi.responses import StreamingResponse

from aiida_restapi import models

Expand Down Expand Up @@ -38,6 +39,56 @@ async def read_node(nodes_id: int) -> Optional[models.Node]:
return qbobj.dict()[0]['node']


@router.get('/nodes/{nodes_id}/download?download_format={download_format}', response_model=StreamingResponse)
@with_dbenv()
async def download_node(nodes_id: int, download_format: str | None = None) -> Optional[models.Node]:
"""Get nodes by id."""
qbobj = orm.QueryBuilder()
qbobj.append(orm.Node, filters={'id': nodes_id}, project='**', tag='node').limit(1)
node = qbobj.dict()[0]['node']

# TODO that is the behavior in the old REST API but I think we should rather send error message
# since it does not agree with the response_model
if len(node) == 0:
return {}
# raise HTTPException(status_code=404, detail=f"Could no find any node with id {nodes_id}")

if download_format is None:
raise RestInputValidationError(
'Please specify the download format. '
'The available download formats can be '
'queried using the /nodes/download_formats/ endpoint.'
)

elif download_format in node.get_export_formats():
# byteobj, dict with {filename: filecontent}
import io

bytes, metadata = node._exportcontent(download_format)[0]
return StreamingResponse(
io.BytesIO(bytes)
) # , media_type=next(iter(medata.values()))) # Don. know what metadata contains, but could be maybe used to define media_type
# try:
# response['data'] =
# response['status'] = 200
# try:
# response['filename'] = node.filename
# except AttributeError:
# response['filename'] = f'{node.uuid}.{download_format}'
# except LicensingException as exc:
# response['status'] = 500
# response['data'] = str(exc)

# return response

else:
raise RestInputValidationError(
'The format {} is not supported. '
'The available download formats can be '
'queried using the /nodes/download_formats/ endpoint.'.format(download_format)
)


@router.post('/nodes', response_model=models.Node)
@with_dbenv()
async def create_node(
Expand Down
2 changes: 1 addition & 1 deletion docs/source/user_guide/graphql.md
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ NOT YET SPECIFICALLY IMPLEMENTED
http://localhost:5000/api/v4/nodes/download_formats
```

NOT YET IMPLEMENTED
Not implemented for GraphQL, please adapt the URL for this use case.


```html
Expand Down
8 changes: 8 additions & 0 deletions tests/test_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,3 +285,11 @@ def test_create_bool_with_extra(client, authenticate): # pylint: disable=unused
assert check_response.status_code == 200, response.content
assert check_response.json()['extras']['extra_one'] == 'value_1'
assert check_response.json()['extras']['extra_two'] == 'value_2'


def test_get_download_node(default_nodes, client):
"""Test get projectable properites for nodes."""

for nodes_id in default_nodes:
response = client.get(f'/nodes/{nodes_id}/download?download_format=cif') #TODO download_format is wrong need to check this
assert response.status_code == 200

0 comments on commit aae8d9f

Please sign in to comment.