From 2f87d7431b5120486753aa6bd676684b1c70a6d3 Mon Sep 17 00:00:00 2001 From: Sebastiaan Huber Date: Fri, 6 Aug 2021 12:25:05 +0200 Subject: [PATCH] REST API: make the profile configurable as request parameter To make this possible, after parsing the query string but before performing the request, the desired profile needs to be loaded. A new method `load_profile` is added to the `BaseResource` class. All methods that access the storage, such as the `get` methods, need to invoke this method before handing the request. The `load_profile` method will call `load_profile` with `allow_switch` set to True, in order to allow changing the profile if another had already been loaded. The profile that is loaded is determined from the `profile` query parameter specified in the request. If not specified, the profile will be taken that was specified in the `kwargs` of the resources constructor. Note that the parsing of the request path and query parameters had to be refactored a bit to prevent the parsing having to be performed twice, which would result in a performance regression. When the REST API is invoked through the `verdi` CLI, the profile specified by the `-p` option, or the default profile if not specified, is passed to the API, which will be passed to the resource constructors. This guarantees that if `profile` is not specified in the query parameters the profile with which `verdi restapi` was invoked will be loaded. --- aiida/cmdline/commands/cmd_restapi.py | 4 +- aiida/restapi/common/utils.py | 14 +- aiida/restapi/resources.py | 217 +++++++++++++++++++------- aiida/restapi/run_api.py | 4 +- 4 files changed, 179 insertions(+), 60 deletions(-) diff --git a/aiida/cmdline/commands/cmd_restapi.py b/aiida/cmdline/commands/cmd_restapi.py index cb3266be9b..799d6350ab 100644 --- a/aiida/cmdline/commands/cmd_restapi.py +++ b/aiida/cmdline/commands/cmd_restapi.py @@ -45,7 +45,8 @@ help='Enable POST endpoints (currently only /querybuilder).', hidden=True, ) -def restapi(hostname, port, config_dir, debug, wsgi_profile, posting): +@click.pass_context +def restapi(ctx, hostname, port, config_dir, debug, wsgi_profile, posting): """ Run the AiiDA REST API server. @@ -58,6 +59,7 @@ def restapi(hostname, port, config_dir, debug, wsgi_profile, posting): # Invoke the runner try: run_api( + profile=ctx.obj['profile'].name, hostname=hostname, port=port, config=config_dir, diff --git a/aiida/restapi/common/utils.py b/aiida/restapi/common/utils.py index 733673ae26..f3a61a1892 100644 --- a/aiida/restapi/common/utils.py +++ b/aiida/restapi/common/utils.py @@ -486,6 +486,7 @@ def build_translator_parameters(self, field_list): extras = None extras_filter = None full_type = None + profile = None # io tree limit parameters tree_in_limit = None @@ -539,10 +540,17 @@ def build_translator_parameters(self, field_list): raise RestInputValidationError('You cannot specify extras_filter more than once') if 'full_type' in field_counts and field_counts['full_type'] > 1: raise RestInputValidationError('You cannot specify full_type more than once') + if 'profile' in field_counts and field_counts['profile'] > 1: + raise RestInputValidationError('You cannot specify profile more than once') ## Extract results for field in field_list: - if field[0] == 'limit': + if field[0] == 'profile': + if field[1] == '=': + profile = field[2] + else: + raise RestInputValidationError("only assignment operator '=' is permitted after 'profile'") + elif field[0] == 'limit': if field[1] == '=': limit = field[2] else: @@ -658,7 +666,7 @@ def build_translator_parameters(self, field_list): return ( limit, offset, perpage, orderby, filters, download_format, download, filename, tree_in_limit, - tree_out_limit, attributes, attributes_filter, extras, extras_filter, full_type + tree_out_limit, attributes, attributes_filter, extras, extras_filter, full_type, profile ) def parse_query_string(self, query_string): @@ -680,7 +688,7 @@ def parse_query_string(self, query_string): ## Define grammar # key types - key = Word(f'{alphas}_', f'{alphanums}_') + key = Word(f'{alphas}_', f'{alphanums}_-') # operators operator = ( Literal('=like=') | Literal('=ilike=') | Literal('=in=') | Literal('=notin=') | Literal('=') | diff --git a/aiida/restapi/resources.py b/aiida/restapi/resources.py index 2c3e32dec2..c99a64975e 100644 --- a/aiida/restapi/resources.py +++ b/aiida/restapi/resources.py @@ -14,6 +14,7 @@ from flask_restful import Resource from aiida.common.lang import classproperty +from aiida.manage import load_profile from aiida.restapi.common.exceptions import RestInputValidationError from aiida.restapi.common.utils import Utils, close_thread_connection from aiida.restapi.translator.nodes.node import NodeTranslator @@ -33,11 +34,10 @@ def get(self): It returns the general info about the REST API :return: returns current AiiDA version defined in aiida/__init__.py """ - # Decode url parts path = unquote(request.path) url = unquote(request.url) url_root = unquote(request.url_root) - + query_string = unquote(request.query_string.decode('utf-8')) subpath = self.utils.strip_api_prefix(path).strip('/') pathlist = self.utils.split_path(subpath) @@ -81,7 +81,7 @@ def get(self): url=url, url_root=url_root, path=path, - query_string=request.query_string.decode('utf-8'), + query_string=query_string, resource_type='Info', data=response ) @@ -103,8 +103,10 @@ class BaseResource(Resource): ## TODO add the caching support. I cache total count, results, and possibly - def __init__(self, **kwargs): + def __init__(self, profile, **kwargs): + """Construct the resource.""" self.trans = self._translator_class(**kwargs) + self.profile = profile # Configure utils utils_conf_keys = ('PREFIX', 'PERPAGE_DEFAULT', 'LIMIT_DEFAULT') @@ -115,6 +117,26 @@ def __init__(self, **kwargs): if 'get_decorators' in kwargs and isinstance(kwargs['get_decorators'], (tuple, list, set)): self.method_decorators = {'get': list(kwargs['get_decorators'])} + @staticmethod + def unquote_request(): + """Unquote and return various parts of the request. + + :returns: Tuple of the request path, url, url root and query string. + """ + path = unquote(request.path) + url = unquote(request.url) + url_root = unquote(request.url_root) + query_string = unquote(request.query_string.decode('utf-8')) + return path, url, url_root, query_string + + def parse_path(self, path): + """Parse the request path.""" + return self.utils.parse_path(path, parse_pk_uuid=self.parse_pk_uuid) + + def parse_query_string(self, query_string): + """Parse the request query string.""" + return self.utils.parse_query_string(query_string) + @classproperty def parse_pk_uuid(cls): # pylint: disable=no-self-argument return cls._parse_pk_uuid @@ -131,6 +153,14 @@ def _load_and_verify(self, node_id=None): return node + def load_profile(self, profile): + """Load the required profile. + + This will load the profile specified by the ``profile`` keyword in the query parameters, and if not specified it + will default to the profile defined in the constructor. + """ + load_profile(profile, allow_switch=True) + def get(self, id=None, page=None): # pylint: disable=redefined-builtin,invalid-name,unused-argument # pylint: disable=too-many-locals """ @@ -139,21 +169,34 @@ def get(self, id=None, page=None): # pylint: disable=redefined-builtin,invalid- :param page: page no, used for pagination :return: http response """ - - ## Decode url parts - path = unquote(request.path) - query_string = unquote(request.query_string.decode('utf-8')) - url = unquote(request.url) - url_root = unquote(request.url_root) - - ## Parse request - (resource_type, page, node_id, query_type) = self.utils.parse_path(path, parse_pk_uuid=self.parse_pk_uuid) - - # pylint: disable=unused-variable - ( - limit, offset, perpage, orderby, filters, download_format, download, filename, tree_in_limit, - tree_out_limit, attributes, attributes_filter, extras, extras_filter, full_type - ) = self.utils.parse_query_string(query_string) + path, url, url_root, query_string = self.unquote_request() + resource_type, page, node_id, query_type = self.parse_path(path) + parameters = self.parse_query_string(query_string) + limit = parameters[0] + offset = parameters[1] + perpage = parameters[2] + orderby = parameters[3] + filters = parameters[4] + profile = parameters[-1] + + try: + self.load_profile(profile) + except RestInputValidationError as exception: + return self.utils.build_response( + status=400, + headers=self.utils.build_headers(url=request.url, total_count=0), + data={ + 'method': request.method, + 'url': url, + 'url_root': url_root, + 'path': path, + 'query_string': query_string, + 'resource_type': self.__class__.__name__, + 'data': { + 'message': str(exception) + }, + } + ) ## Validate request self.utils.validate_request( @@ -175,6 +218,7 @@ def get(self, id=None, page=None): # pylint: disable=redefined-builtin,invalid- headers = self.utils.build_headers(url=request.url, total_count=1) else: + ## Set the query, and initialize qb object self.trans.set_query(filters=filters, orders=orderby, node_id=node_id) @@ -200,7 +244,7 @@ def get(self, id=None, page=None): # pylint: disable=redefined-builtin,invalid- url_root=url_root, path=request.path, id=node_id, - query_string=request.query_string.decode('utf-8'), + query_string=query_string, resource_type=resource_type, data=results ) @@ -234,16 +278,17 @@ def __init__(self, **kwargs): def get(self): # pylint: disable=arguments-differ """Static return to state information about this endpoint.""" + path, url, url_root, query_string = self.unquote_request() headers = self.utils.build_headers(url=request.url, total_count=1) return self.utils.build_response( status=405, # Method Not Allowed headers=headers, data={ 'method': request.method, - 'url': unquote(request.url), - 'url_root': unquote(request.url_root), - 'path': unquote(request.path), - 'query_string': request.query_string.decode('utf-8'), + 'url': url, + 'url_root': url_root, + 'path': path, + 'query_string': query_string, 'resource_type': self.__class__.__name__, 'data': {'message': self.GET_MESSAGE}, }, @@ -262,6 +307,28 @@ def post(self): # pylint: disable=too-many-branches :return: QueryBuilder result of AiiDA entities in "standard" REST API format. """ # pylint: disable=protected-access + path, url, url_root, query_string = self.unquote_request() + profile = self.parse_query_string(query_string)[-1] + + try: + self.load_profile(profile) + except RestInputValidationError as exception: + return self.utils.build_response( + status=400, + headers=self.utils.build_headers(url=request.url, total_count=0), + data={ + 'method': request.method, + 'url': url, + 'url_root': url_root, + 'path': path, + 'query_string': query_string, + 'resource_type': self.__class__.__name__, + 'data': { + 'message': str(exception) + }, + } + ) + self.trans._query_help = request.get_json(force=True) # While the data may be correct JSON, it MUST be a single JSON Object, # equivalent of a QueryBuilder.as_dict() dictionary. @@ -350,20 +417,31 @@ def get(self, id=None, page=None): # pylint: disable=redefined-builtin,invalid- :param page: page no, used for pagination :return: http response """ - - ## Decode url parts - path = unquote(request.path) - query_string = unquote(request.query_string.decode('utf-8')) - url = unquote(request.url) - url_root = unquote(request.url_root) - - ## Parse request - (resource_type, page, node_id, query_type) = self.utils.parse_path(path, parse_pk_uuid=self.parse_pk_uuid) - + path, url, url_root, query_string = self.unquote_request() + (resource_type, page, node_id, query_type) = self.parse_path(path) ( limit, offset, perpage, orderby, filters, download_format, download, filename, tree_in_limit, - tree_out_limit, attributes, attributes_filter, extras, extras_filter, full_type - ) = self.utils.parse_query_string(query_string) + tree_out_limit, attributes, attributes_filter, extras, extras_filter, full_type, profile + ) = self.parse_query_string(query_string) + + try: + self.load_profile(profile) + except RestInputValidationError as exception: + return self.utils.build_response( + status=400, + headers=self.utils.build_headers(url=request.url, total_count=0), + data={ + 'method': request.method, + 'url': url, + 'url_root': url_root, + 'path': path, + 'query_string': query_string, + 'resource_type': self.__class__.__name__, + 'data': { + 'message': str(exception) + }, + } + ) ## Validate request self.utils.validate_request( @@ -495,7 +573,7 @@ def get(self, id=None, page=None): # pylint: disable=redefined-builtin,invalid- url_root=url_root, path=path, id=node_id, - query_string=request.query_string.decode('utf-8'), + query_string=query_string, resource_type=resource_type, data=results ) @@ -540,14 +618,31 @@ def get(self, id=None, page=None): # pylint: disable=redefined-builtin :param id: node identifier :return: http response """ + path, url, url_root, query_string = self.unquote_request() + resource_type, page, node_id, query_type = self.parse_path(path) + profile = self.parse_query_string(query_string)[-1] + + try: + self.load_profile(profile) + except RestInputValidationError as exception: + return self.utils.build_response( + status=400, + headers=self.utils.build_headers(url=request.url, total_count=0), + data={ + 'method': request.method, + 'url': url, + 'url_root': url_root, + 'path': path, + 'query_string': query_string, + 'resource_type': self.__class__.__name__, + 'data': { + 'message': str(exception) + }, + } + ) headers = self.utils.build_headers(url=request.url, total_count=1) - path = unquote(request.path) - - ## Parse request - (resource_type, page, node_id, query_type) = self.utils.parse_path(path, parse_pk_uuid=self.parse_pk_uuid) - results = None if query_type == 'report': node = self._load_and_verify(node_id) @@ -562,11 +657,11 @@ def get(self, id=None, page=None): # pylint: disable=redefined-builtin ## Build response data = dict( method=request.method, - url=unquote(request.url), - url_root=unquote(request.url_root), + url=url, + url_root=url_root, path=path, id=node_id, - query_string=request.query_string.decode('utf-8'), + query_string=query_string, resource_type=resource_type, data=results ) @@ -587,14 +682,28 @@ def get(self, id=None, page=None): # pylint: disable=redefined-builtin :param id: node identifier :return: http response """ - - ## Decode url parts - path = unquote(request.path) - url = unquote(request.url) - url_root = unquote(request.url_root) - - ## Parse request - (resource_type, page, node_id, query_type) = self.utils.parse_path(path, parse_pk_uuid=self.parse_pk_uuid) + path, url, url_root, query_string = self.unquote_request() + resource_type, page, node_id, query_type = self.parse_path(path) + profile = self.parse_query_string(query_string)[-1] + + try: + self.load_profile(profile) + except RestInputValidationError as exception: + return self.utils.build_response( + status=400, + headers=self.utils.build_headers(url=request.url, total_count=0), + data={ + 'method': request.method, + 'url': url, + 'url_root': url_root, + 'path': path, + 'query_string': query_string, + 'resource_type': self.__class__.__name__, + 'data': { + 'message': str(exception) + }, + } + ) node = self._load_and_verify(node_id) results = None @@ -617,7 +726,7 @@ def get(self, id=None, page=None): # pylint: disable=redefined-builtin url_root=url_root, path=path, id=node_id, - query_string=request.query_string.decode('utf-8'), + query_string=query_string, resource_type=resource_type, data=results ) diff --git a/aiida/restapi/run_api.py b/aiida/restapi/run_api.py index cd1a8b3106..bf884789ec 100755 --- a/aiida/restapi/run_api.py +++ b/aiida/restapi/run_api.py @@ -68,8 +68,8 @@ def configure_api(flask_app=api_classes.App, flask_api=api_classes.AiidaApi, **k :returns: Flask RESTful API :rtype: :py:class:`flask_restful.Api` """ - # Unpack parameters + profile = kwargs.pop('profile', None) config = kwargs.pop('config', CLI_DEFAULTS['CONFIG_DIR']) catch_internal_server = kwargs.pop('catch_internal_server', CLI_DEFAULTS['CATCH_INTERNAL_SERVER']) wsgi_profile = kwargs.pop('wsgi_profile', CLI_DEFAULTS['WSGI_PROFILE']) @@ -106,4 +106,4 @@ def configure_api(flask_app=api_classes.App, flask_api=api_classes.AiidaApi, **k app.wsgi_app = ProfilerMiddleware(app.wsgi_app, restrictions=[30]) # Instantiate and return a Flask RESTful API by associating its app - return flask_api(app, posting=posting, **config_module.API_CONFIG) + return flask_api(app, posting=posting, **config_module.API_CONFIG, profile=profile)