Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ORM: Use pydantic to specify a schema for each ORM entity #6255

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/source/nitpick-exceptions
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,9 @@ py:meth click.Option.get_default
py:meth fail

py:class ComputedFieldInfo
py:class BaseModel
py:class pydantic.fields.Field
py:class pydantic.fields.FieldInfo
py:class pydantic.main.BaseModel
py:class PluggableSchemaValidator

Expand All @@ -157,6 +159,7 @@ py:class frozenset

py:class numpy.bool_
py:class numpy.ndarray
py:class np.ndarray
py:class ndarray

py:class paramiko.proxy.ProxyCommand
Expand Down
11 changes: 11 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,17 @@ requires-python = '>=3.9'
'process.workflow.workchain' = 'aiida.orm.nodes.process.workflow.workchain:WorkChainNode'
'process.workflow.workfunction' = 'aiida.orm.nodes.process.workflow.workfunction:WorkFunctionNode'

[project.entry-points.'aiida.orm']
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it really necessary to make orm pluggable, what is the use case of it?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sphuber correct me if I'm wrong, but this has to do with the use of orm_class, right? Under the implementation section of the AEP, it states that orm_class is a shortcut to model_to_orm used to define a model field that points to an ORM instance. So instead of

class Model(...):
    user = MetadataField(
        ...,
        model_to_orm=lambda pk: User.collection.get(id=pk),
    )

you can do

class Model(...):
    user = MetadataField(
        ...,
        orm_class=User,
    )

The various ORM classes that may be used in orm_class are then defined in project.entry-points.'aiida.orm', so that you can do in Entity.model_to_orm_field_values(...):

if orm_class := get_metadata(field, 'orm_class'):
    if isinstance(orm_class, str):
        orm_class = BaseFactory('aiida.orm', orm_class)
        fields[key] = orm_class.collection.get(id=field_value)

But then I guess my question is, do these not already have entry points? Or is the problem that you can't fetch them all generically using the BaseFactory?

'core.auth_info' = 'aiida.orm.authinfos:AuthInfo'
'core.comment' = 'aiida.orm.comments:Comment'
'core.computer' = 'aiida.orm.computers:Computer'
'core.data' = 'aiida.orm.nodes.data.data:Data'
'core.entity' = 'aiida.orm.entities:Entity'
'core.group' = 'aiida.orm.groups:Group'
'core.log' = 'aiida.orm.logs:Log'
'core.node' = 'aiida.orm.nodes.node:Node'
'core.user' = 'aiida.orm.users:User'

[project.entry-points.'aiida.parsers']
'core.arithmetic.add' = 'aiida.parsers.plugins.arithmetic.add:ArithmeticAddParser'
'core.templatereplacer' = 'aiida.parsers.plugins.templatereplacer.parser:TemplatereplacerParser'
Expand Down
6 changes: 2 additions & 4 deletions src/aiida/cmdline/commands/cmd_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@ def verdi_code():
"""Setup and manage codes."""


def create_code(ctx: click.Context, cls, non_interactive: bool, **kwargs):
def create_code(ctx: click.Context, cls, **kwargs):
"""Create a new `Code` instance."""
try:
instance = cls(**kwargs)
instance = cls.from_model(cls.Model(**kwargs))
except (TypeError, ValueError) as exception:
echo.echo_critical(f'Failed to create instance `{cls}`: {exception}')

Expand Down Expand Up @@ -243,9 +243,7 @@ def show(code):
@with_dbenv()
def export(code, output_file, overwrite, sort):
"""Export code to a yaml file. If no output file is given, default name is created based on the code label."""

other_args = {'sort': sort}

fileformat = 'yaml'

if output_file is None:
Expand Down
1 change: 0 additions & 1 deletion src/aiida/cmdline/commands/cmd_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ def verdi_profile():
def command_create_profile(
ctx: click.Context,
storage_cls,
non_interactive: bool,
profile: Profile,
set_as_default: bool = True,
email: str | None = None,
Expand Down
20 changes: 11 additions & 9 deletions src/aiida/cmdline/groups/dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,29 +88,25 @@ def get_command(self, ctx: click.Context, cmd_name: str) -> click.Command | None
command = super().get_command(ctx, cmd_name)
return command

def call_command(self, ctx, cls, **kwargs):
def call_command(self, ctx, cls, non_interactive, **kwargs):
edan-bainglass marked this conversation as resolved.
Show resolved Hide resolved
"""Call the ``command`` after validating the provided inputs."""
from pydantic import ValidationError

if hasattr(cls, 'Model'):
# The plugin defines a pydantic model: use it to validate the provided arguments
try:
model = cls.Model(**kwargs)
cls.Model(**kwargs)
except ValidationError as exception:
param_hint = [
f'--{loc.replace("_", "-")}' # type: ignore[union-attr]
for loc in exception.errors()[0]['loc']
]
message = '\n'.join([str(e['ctx']['error']) for e in exception.errors()])
message = '\n'.join([str(e['msg']) for e in exception.errors()])
raise click.BadParameter(
message,
param_hint=param_hint or 'multiple parameters', # type: ignore[arg-type]
param_hint=param_hint or 'one or more parameters', # type: ignore[arg-type]
) from exception

# Update the arguments with the dictionary representation of the model. This will include any type coercions
# that may have been applied with validators defined for the model.
kwargs.update(**model.model_dump())

return self._command(ctx, cls, **kwargs)

def create_command(self, ctx: click.Context, entry_point: str) -> click.Command:
Expand Down Expand Up @@ -154,6 +150,8 @@ def list_options(self, entry_point: str) -> list:
"""
from pydantic_core import PydanticUndefined

from aiida.common.pydantic import get_metadata

cls = self.factory(entry_point)

if not hasattr(cls, 'Model'):
Expand All @@ -170,6 +168,9 @@ def list_options(self, entry_point: str) -> list:
options_spec = {}

for key, field_info in cls.Model.model_fields.items():
if get_metadata(field_info, 'exclude_from_cli'):
continue

default = field_info.default_factory if field_info.default is PydanticUndefined else field_info.default

# If the annotation has the ``__args__`` attribute it is an instance of a type from ``typing`` and the real
Expand All @@ -194,7 +195,8 @@ def list_options(self, entry_point: str) -> list:
}
for metadata in field_info.metadata:
for metadata_key, metadata_value in metadata.items():
options_spec[key][metadata_key] = metadata_value
if metadata_key in ('priority', 'short_name', 'option_cls'):
options_spec[key][metadata_key] = metadata_value

options_ordered = []

Expand Down
58 changes: 55 additions & 3 deletions src/aiida/common/pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,41 @@
import typing as t

from pydantic import Field
from pydantic_core import PydanticUndefined

if t.TYPE_CHECKING:
from pydantic import BaseModel

from aiida.orm import Entity


def get_metadata(field_info, key: str, default: t.Any | None = None):
"""Return a the metadata of the given field for a particular key.

:param field_info: The field from which to retrieve the metadata.
:param key: The metadata name.
:param default: Optional default value to return in case the metadata is not defined on the field.
:returns: The metadata if defined, otherwise the default.
"""
for element in field_info.metadata:
if key in element:
return element[key]
return default


def MetadataField( # noqa: N802
default: t.Any | None = None,
default: t.Any = PydanticUndefined,
*,
priority: int = 0,
short_name: str | None = None,
option_cls: t.Any | None = None,
orm_class: type['Entity'] | str | None = None,
orm_to_model: t.Callable[['Entity'], t.Any] | None = None,
model_to_orm: t.Callable[['BaseModel'], t.Any] | None = None,
exclude_to_orm: bool = False,
exclude_from_cli: bool = False,
is_attribute: bool = True,
is_subscriptable: bool = False,
**kwargs,
):
"""Return a :class:`pydantic.fields.Field` instance with additional metadata.
Expand All @@ -37,11 +64,36 @@ class Model(BaseModel):
:param priority: Used to order the list of all fields in the model. Ordering is done from small to large priority.
:param short_name: Optional short name to use for an option on a command line interface.
:param option_cls: The :class:`click.Option` class to use to construct the option.
:param orm_class: The class, or entry point name thereof, to which the field should be converted. If this field is
defined, the value of this field should acccept an integer which will automatically be converted to an instance
of said ORM class using ``orm_class.collection.get(id={field_value})``. This is useful, for example, where a
field represents an instance of a different entity, such as an instance of ``User``. The serialized data would
store the ``pk`` of the user, but the ORM entity instance would receive the actual ``User`` instance with that
primary key.
:param orm_to_model: Optional callable to convert the value of a field from an ORM instance to a model instance.
:param model_to_orm: Optional callable to convert the value of a field from a model instance to an ORM instance.
:param exclude_to_orm: When set to ``True``, this field value will not be passed to the ORM entity constructor
through ``Entity.from_model``.
:param exclude_to_orm: When set to ``True``, this field value will not be exposed on the CLI command that is
edan-bainglass marked this conversation as resolved.
Show resolved Hide resolved
dynamically generated to create a new instance.
:param is_attribute: Whether the field is stored as an attribute.
:param is_subscriptable: Whether the field can be indexed like a list or dictionary.
"""
field_info = Field(default, **kwargs)

for key, value in (('priority', priority), ('short_name', short_name), ('option_cls', option_cls)):
if value is not None and field_info is not None:
for key, value in (
('priority', priority),
('short_name', short_name),
('option_cls', option_cls),
('orm_class', orm_class),
('orm_to_model', orm_to_model),
('model_to_orm', model_to_orm),
('exclude_to_orm', exclude_to_orm),
('exclude_from_cli', exclude_from_cli),
('is_attribute', is_attribute),
('is_subscriptable', is_subscriptable),
):
if value is not None:
field_info.metadata.append({key: value})

return field_info
100 changes: 66 additions & 34 deletions src/aiida/orm/authinfos.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,20 @@
###########################################################################
"""Module for the `AuthInfo` ORM class."""

from __future__ import annotations

from typing import TYPE_CHECKING, Any, Dict, Optional, Type

from aiida.common import exceptions
from aiida.common.pydantic import MetadataField
from aiida.manage import get_manager
from aiida.plugins import TransportFactory

from . import entities, users
from .fields import add_field
from .computers import Computer
from .users import User

if TYPE_CHECKING:
from aiida.orm import Computer, User
from aiida.orm.implementation import StorageBackend
from aiida.orm.implementation.authinfos import BackendAuthInfo # noqa: F401
from aiida.transports import Transport
Expand All @@ -45,51 +48,60 @@ class AuthInfo(entities.Entity['BackendAuthInfo', AuthInfoCollection]):
"""ORM class that models the authorization information that allows a `User` to connect to a `Computer`."""

_CLS_COLLECTION = AuthInfoCollection
PROPERTY_WORKDIR = 'workdir'

__qb_fields__ = [
add_field(
'enabled',
dtype=bool,
class Model(entities.Entity.Model):
computer: int = MetadataField(
description='The PK of the computer',
is_attribute=False,
doc='Whether the instance is enabled',
),
add_field(
'auth_params',
dtype=Dict[str, Any],
orm_class=Computer,
orm_to_model=lambda auth_info: auth_info.computer.pk, # type: ignore[attr-defined]
)
user: int = MetadataField(
description='The PK of the user',
edan-bainglass marked this conversation as resolved.
Show resolved Hide resolved
is_attribute=False,
doc='Dictionary of authentication parameters',
),
add_field(
'metadata',
dtype=Dict[str, Any],
orm_class=User,
orm_to_model=lambda auth_info: auth_info.user.pk, # type: ignore[attr-defined]
)
enabled: bool = MetadataField(
True,
description='Whether the instance is enabled',
is_attribute=False,
doc='Dictionary of metadata',
),
add_field(
'computer_pk',
dtype=int,
)
auth_params: Dict[str, Any] = MetadataField(
default_factory=dict,
description='Dictionary of authentication parameters',
is_attribute=False,
doc='The PK of the computer',
),
add_field(
'user_pk',
dtype=int,
)
metadata: Dict[str, Any] = MetadataField(
default_factory=dict,
description='Dictionary of metadata',
is_attribute=False,
doc='The PK of the user',
),
]

PROPERTY_WORKDIR = 'workdir'

def __init__(self, computer: 'Computer', user: 'User', backend: Optional['StorageBackend'] = None) -> None:
)

def __init__(
self,
computer: 'Computer',
user: 'User',
enabled: bool = True,
auth_params: Dict[str, Any] | None = None,
metadata: Dict[str, Any] | None = None,
backend: Optional['StorageBackend'] = None,
) -> None:
"""Create an `AuthInfo` instance for the given computer and user.

:param computer: a `Computer` instance
:param user: a `User` instance
:param backend: the backend to use for the instance, or use the default backend if None
"""
backend = backend or get_manager().get_profile_storage()
model = backend.authinfos.create(computer=computer.backend_entity, user=user.backend_entity)
model = backend.authinfos.create(
computer=computer.backend_entity,
user=user.backend_entity,
enabled=enabled,
auth_params=auth_params or {},
metadata=metadata or {},
)
super().__init__(model)

def __str__(self) -> str:
Expand All @@ -98,6 +110,18 @@ def __str__(self) -> str:

return f'AuthInfo for {self.user.email} on {self.computer.label} [DISABLED]'

def __eq__(self, other) -> bool:
if not isinstance(other, AuthInfo):
return False

return (
self.user.pk == other.user.pk
and self.computer.pk == other.computer.pk
and self.enabled == other.enabled
and self.auth_params == other.auth_params
and self.metadata == other.metadata
)

@property
def enabled(self) -> bool:
"""Return whether this instance is enabled.
Expand Down Expand Up @@ -126,6 +150,14 @@ def user(self) -> 'User':
"""Return the user associated with this instance."""
return entities.from_backend_entity(users.User, self._backend_entity.user)

@property
def auth_params(self) -> Dict[str, Any]:
return self._backend_entity.get_auth_params()

@property
def metadata(self) -> Dict[str, Any]:
return self._backend_entity.get_metadata()

def get_auth_params(self) -> Dict[str, Any]:
"""Return the dictionary of authentication parameters

Expand Down
Loading
Loading