Skip to content

Commit

Permalink
Config: Switch from jsonschema to pydantic (#6117)
Browse files Browse the repository at this point in the history
The configuration of an AiiDA instance is written in JSON format to the
`config.json` file. The schema is defined using `jsonschema` to take
care of validation, however, some validation, for example of the config
options was still happening manually.

Other parts of the code want to start using `pydantic` for model
definition and configuration purposes, which has become the de-facto
standard for these use-cases in the Python ecosystem. Before introducing
another dependency, the existing `jsonschema` approach is replaced by
`pydantic` in current code base first.
  • Loading branch information
sphuber authored Oct 25, 2023
1 parent d16792f commit 4203f16
Show file tree
Hide file tree
Showing 15 changed files with 246 additions and 129 deletions.
7 changes: 4 additions & 3 deletions aiida/cmdline/commands/cmd_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,8 @@ def verdi_config_set(ctx, option, value, globally, append, remove):
List values are split by whitespace, e.g. "a b" becomes ["a", "b"].
"""
from aiida.manage.configuration import Config, ConfigValidationError, Profile
from aiida.common.exceptions import ConfigurationError
from aiida.manage.configuration import Config, Profile

if append and remove:
echo.echo_critical('Cannot flag both append and remove')
Expand All @@ -137,7 +138,7 @@ def verdi_config_set(ctx, option, value, globally, append, remove):
if append or remove:
try:
current = config.get_option(option.name, scope=scope)
except ConfigValidationError as error:
except ConfigurationError as error:
echo.echo_critical(str(error))
if not isinstance(current, list):
echo.echo_critical(f'cannot append/remove to value: {current}')
Expand All @@ -149,7 +150,7 @@ def verdi_config_set(ctx, option, value, globally, append, remove):
# Set the specified option
try:
value = config.set_option(option.name, value, scope=scope)
except ConfigValidationError as error:
except ConfigurationError as error:
echo.echo_critical(str(error))

config.store()
Expand Down
3 changes: 3 additions & 0 deletions aiida/common/log.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import collections
import contextlib
import enum
import logging
import types
from typing import cast
Expand Down Expand Up @@ -52,6 +53,8 @@ def report(self, msg: str, *args, **kwargs) -> None:
logging.getLevelName(logging.CRITICAL): logging.CRITICAL,
}

LogLevels = enum.Enum('LogLevels', {key: key for key in LOG_LEVELS}) # type: ignore[misc]

AIIDA_LOGGER = cast(AiidaLoggerType, logging.getLogger('aiida'))

CLI_ACTIVE: bool | None = None
Expand Down
2 changes: 0 additions & 2 deletions aiida/manage/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
'BROKER_DEFAULTS',
'CURRENT_CONFIG_VERSION',
'Config',
'ConfigValidationError',
'MIGRATIONS',
'ManagementApiConnectionError',
'OLDEST_COMPATIBLE_CONFIG_VERSION',
Expand All @@ -43,7 +42,6 @@
'RabbitmqManagementClient',
'check_and_migrate_config',
'config_needs_migrating',
'config_schema',
'disable_caching',
'downgrade_config',
'enable_caching',
Expand Down
2 changes: 0 additions & 2 deletions aiida/manage/configuration/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,12 @@
__all__ = (
'CURRENT_CONFIG_VERSION',
'Config',
'ConfigValidationError',
'MIGRATIONS',
'OLDEST_COMPATIBLE_CONFIG_VERSION',
'Option',
'Profile',
'check_and_migrate_config',
'config_needs_migrating',
'config_schema',
'downgrade_config',
'get_current_version',
'get_option',
Expand Down
208 changes: 175 additions & 33 deletions aiida/manage/configuration/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,49 +7,194 @@
# For further information on the license, see the LICENSE.txt file #
# For further information please visit http://www.aiida.net #
###########################################################################
"""Module that defines the configuration file of an AiiDA instance and functions to create and load it."""
"""Module that defines the configuration file of an AiiDA instance and functions to create and load it.
Despite the import of the annotations backport below which enables postponed type annotation evaluation as implemented
with PEP 563 (https://peps.python.org/pep-0563/), this is not compatible with ``pydantic`` for Python 3.9 and older (
See https://github.com/pydantic/pydantic/issues/2678 for details).
"""
from __future__ import annotations

import codecs
from functools import cache
import json
import os
from typing import Any, Dict, Optional, Sequence, Tuple
from typing import Any, Dict, List, Optional, Tuple
import uuid

from pydantic import ( # pylint: disable=no-name-in-module
BaseModel,
ConfigDict,
Field,
ValidationError,
field_serializer,
field_validator,
)

from aiida.common.exceptions import ConfigurationError
from aiida.common.log import LogLevels

from . import schema as schema_module
from .options import Option, get_option, get_option_names, parse_option
from .profile import Profile

__all__ = ('Config', 'config_schema', 'ConfigValidationError')
__all__ = ('Config',)


class ConfigVersionSchema(BaseModel, defer_build=True):
"""Schema for the version configuration of an AiiDA instance."""

CURRENT: int
OLDEST_COMPATIBLE: int


class ProfileOptionsSchema(BaseModel, defer_build=True):
"""Schema for the options of an AiiDA profile."""

model_config = ConfigDict(use_enum_values=True)

runner__poll__interval: int = Field(60, description='Polling interval in seconds to be used by process runners.')
daemon__default_workers: int = Field(
1, description='Default number of workers to be launched by `verdi daemon start`.'
)
daemon__timeout: int = Field(
2,
description=
'Used to set default timeout in the :class:`aiida.engine.daemon.client.DaemonClient` for calls to the daemon.'
)
daemon__worker_process_slots: int = Field(
200, description='Maximum number of concurrent process tasks that each daemon worker can handle.'
)
daemon__recursion_limit: int = Field(3000, description='Maximum recursion depth for the daemon workers.')
db__batch_size: int = Field(
100000,
description='Batch size for bulk CREATE operations in the database. Avoids hitting MaxAllocSize of PostgreSQL '
'(1GB) when creating large numbers of database records in one go.'
)
verdi__shell__auto_import: str = Field(
':',
description='Additional modules/functions/classes to be automatically loaded in `verdi shell`, split by `:`.'
)
logging__aiida_loglevel: LogLevels = Field(
'REPORT', description='Minimum level to log to daemon log and the `DbLog` table for the `aiida` logger.'
)
logging__verdi_loglevel: LogLevels = Field(
'REPORT', description='Minimum level to log to console when running a `verdi` command.'
)
logging__db_loglevel: LogLevels = Field('REPORT', description='Minimum level to log to the DbLog table.')
logging__plumpy_loglevel: LogLevels = Field(
'WARNING', description='Minimum level to log to daemon log and the `DbLog` table for the `plumpy` logger.'
)
logging__kiwipy_loglevel: LogLevels = Field(
'WARNING', description='Minimum level to log to daemon log and the `DbLog` table for the `kiwipy` logger'
)
logging__paramiko_loglevel: LogLevels = Field(
'WARNING', description='Minimum level to log to daemon log and the `DbLog` table for the `paramiko` logger'
)
logging__alembic_loglevel: LogLevels = Field(
'WARNING', description='Minimum level to log to daemon log and the `DbLog` table for the `alembic` logger'
)
logging__sqlalchemy_loglevel: LogLevels = Field(
'WARNING', description='Minimum level to log to daemon log and the `DbLog` table for the `sqlalchemy` logger'
)
logging__circus_loglevel: LogLevels = Field(
'INFO', description='Minimum level to log to daemon log and the `DbLog` table for the `circus` logger'
)
logging__aiopika_loglevel: LogLevels = Field(
'WARNING', description='Minimum level to log to daemon log and the `DbLog` table for the `aiopika` logger'
)
warnings__showdeprecations: bool = Field(True, description='Whether to print AiiDA deprecation warnings.')
warnings__rabbitmq_version: bool = Field(
True, description='Whether to print a warning when an incompatible version of RabbitMQ is configured.'
)
transport__task_retry_initial_interval: int = Field(
20, description='Initial time interval for the exponential backoff mechanism.'
)
transport__task_maximum_attempts: int = Field(
5, description='Maximum number of transport task attempts before a Process is Paused.'
)
rmq__task_timeout: int = Field(10, description='Timeout in seconds for communications with RabbitMQ.')
storage__sandbox: Optional[str] = Field(
None, description='Absolute path to the directory to store sandbox folders.'
)
caching__default_enabled: bool = Field(False, description='Enable calculation caching by default.')
caching__enabled_for: List[str] = Field([], description='Calculation entry points to enable caching on.')
caching__disabled_for: List[str] = Field([], description='Calculation entry points to disable caching on.')

@field_validator('caching__enabled_for', 'caching__disabled_for')
@classmethod
def validate_caching_identifier_pattern(cls, value: List[str]) -> List[str]:
"""Validate the caching identifier patterns."""
from aiida.manage.caching import _validate_identifier_pattern
for identifier in value:
_validate_identifier_pattern(identifier=identifier)

return value

SCHEMA_FILE = 'config-v9.schema.json'

class GlobalOptionsSchema(ProfileOptionsSchema):
"""Schema for the global options of an AiiDA instance."""
autofill__user__email: Optional[str] = Field(
None, description='Default user email to use when creating new profiles.'
)
autofill__user__first_name: Optional[str] = Field(
None, description='Default user first name to use when creating new profiles.'
)
autofill__user__last_name: Optional[str] = Field(
None, description='Default user last name to use when creating new profiles.'
)
autofill__user__institution: Optional[str] = Field(
None, description='Default user institution to use when creating new profiles.'
)
rest_api__profile_switching: bool = Field(
False, description='Toggle whether the profile can be specified in requests submitted to the REST API.'
)
warnings__development_version: bool = Field(
True,
description='Whether to print a warning when a profile is loaded while a development version is installed.'
)

@cache
def config_schema() -> Dict[str, Any]:
"""Return the configuration schema."""
from importlib.resources import files

return json.loads(files(schema_module).joinpath(SCHEMA_FILE).read_text(encoding='utf8'))
class ProfileStorageConfig(BaseModel, defer_build=True):
"""Schema for the storage backend configuration of an AiiDA profile."""

backend: str
config: Dict[str, Any]

class ConfigValidationError(ConfigurationError):
"""Configuration error raised when the file contents fails validation."""

def __init__(
self, message: str, keypath: Sequence[Any] = (), schema: Optional[dict] = None, filepath: Optional[str] = None
):
super().__init__(message)
self._message = message
self._keypath = keypath
self._filepath = filepath
self._schema = schema
class ProcessControlConfig(BaseModel, defer_build=True):
"""Schema for the process control configuration of an AiiDA profile."""

broker_protocol: str = Field('amqp', description='Protocol for connecting to the message broker.')
broker_username: str = Field('guest', description='Username for message broker authentication.')
broker_password: str = Field('guest', description='Password for message broker.')
broker_host: str = Field('127.0.0.1', description='Hostname of the message broker.')
broker_port: int = Field(5432, description='Port of the message broker.')
broker_virtual_host: str = Field('', description='Virtual host to use for the message broker.')
broker_parameters: dict[
str, Any] = Field(default_factory=dict, description='Arguments to be encoded as query parameters.')

def __str__(self) -> str:
prefix = f'{self._filepath}:' if self._filepath else ''
path = '/' + '/'.join(str(k) for k in self._keypath) + ': ' if self._keypath else ''
schema = f'\n schema:\n {self._schema}' if self._schema else ''
return f'Validation Error: {prefix}{path}{self._message}{schema}'

class ProfileSchema(BaseModel, defer_build=True):
"""Schema for the configuration of an AiiDA profile."""

uuid: str = Field(description='A UUID that uniquely identifies the profile.', default_factory=uuid.uuid4)
storage: ProfileStorageConfig
process_control: ProcessControlConfig
default_user_email: Optional[str] = None
test_profile: bool = False
options: Optional[ProfileOptionsSchema] = None

@field_serializer('uuid')
def serialize_dt(self, value: uuid.UUID, _info):
return str(value)


class ConfigSchema(BaseModel, defer_build=True):
"""Schema for the configuration of an AiiDA instance."""

CONFIG_VERSION: Optional[ConfigVersionSchema] = None
profiles: Optional[dict[str, ProfileSchema]] = None
options: Optional[GlobalOptionsSchema] = None
default_profile: Optional[str] = None


class Config: # pylint: disable=too-many-public-methods
Expand Down Expand Up @@ -125,13 +270,10 @@ def _backup(cls, filepath):
@staticmethod
def validate(config: dict, filepath: Optional[str] = None):
"""Validate a configuration dictionary."""
import jsonschema
try:
jsonschema.validate(instance=config, schema=config_schema())
except jsonschema.ValidationError as error:
raise ConfigValidationError(
message=error.message, keypath=error.path, schema=error.schema, filepath=filepath
)
ConfigSchema(**config)
except ValidationError as exception:
raise ConfigurationError(f'invalid config schema: {filepath}: {str(exception)}')

def __init__(self, filepath: str, config: dict, validate: bool = True):
"""Instantiate a configuration object from a configuration dictionary and its filepath.
Expand Down Expand Up @@ -470,7 +612,7 @@ def get_options(self, scope: Optional[str] = None) -> Dict[str, Tuple[Option, st
elif name in self.options:
value = self.options.get(name)
source = 'global'
elif 'default' in option.schema:
elif option.default is not None:
value = option.default
source = 'default'
else:
Expand Down
Loading

0 comments on commit 4203f16

Please sign in to comment.