Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dependencies: Update to sqlalchemy~=2.0 #6146

Merged
merged 1 commit into from
Nov 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .github/workflows/ci-code.yml
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,6 @@ jobs:
- name: Run test suite
env:
AIIDA_WARN_v3: 1
SQLALCHEMY_WARN_20: 1
run:
.github/workflows/tests.sh

Expand Down
1 change: 0 additions & 1 deletion .github/workflows/test-install.yml
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,6 @@ jobs:
- name: Run test suite
env:
AIIDA_WARN_v3: 1
SQLALCHEMY_WARN_20: 1
run:
.github/workflows/tests.sh

Expand Down
15 changes: 14 additions & 1 deletion aiida/cmdline/commands/cmd_data/cmd_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,14 @@ def query(datatype, project, past_days, group_pks, all_users):
n_days_ago = now - datetime.timedelta(days=past_days)
data_filters.update({'ctime': {'>=': n_days_ago}})

# Since the query results are sorted on ``ctime`` it has to be projected on. If it doesn't exist, append it to the
# projections, but make sure to pop it again from the final results since it wasn't part of the original projections
if 'ctime' in project:
pop_ctime = False
else:
project.append('ctime')
pop_ctime = True

qbl.append(datatype, tag='data', with_user='creator', filters=data_filters, project=project)

# If there is a group restriction
Expand All @@ -63,7 +71,12 @@ def query(datatype, project, past_days, group_pks, all_users):
qbl.order_by({datatype: {'ctime': 'asc'}})

object_list = qbl.distinct()
return object_list.all()
results = object_list.all()

if pop_ctime:
return [element[:-1] for element in results]

return results


# pylint: disable=unused-argument,too-many-arguments
Expand Down
4 changes: 2 additions & 2 deletions aiida/orm/nodes/data/array/bands.py
Original file line number Diff line number Diff line change
Expand Up @@ -1840,14 +1840,14 @@ def get_bands_and_parents_structure(args, backend=None):
tag='sdata',
with_descendants='bdata',
# We don't care about the creator of StructureData
project=['id', 'attributes.kinds', 'attributes.sites']
project=['id', 'attributes.kinds', 'attributes.sites', 'ctime']
)

q_build.order_by({orm.StructureData: {'ctime': 'desc'}})

structure_dict = {}
list_data = q_build.distinct().all()
for bid, _, _, _, akinds, asites in list_data:
for bid, _, _, _, akinds, asites, _ in list_data:
structure_dict[bid] = (akinds, asites)

entry_list = []
Expand Down
7 changes: 4 additions & 3 deletions aiida/storage/psql_dos/alembic_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@
"""Simple wrapper around the alembic command line tool that first loads an AiiDA profile."""
from __future__ import annotations

import contextlib

import alembic
import click
from sqlalchemy.util.compat import nullcontext

from aiida.cmdline import is_verbose
from aiida.cmdline.groups.verdi import VerdiCommandGroup
Expand All @@ -38,8 +39,8 @@ def execute_alembic_command(self, command_name, connect=True, **kwargs):
raise click.ClickException('No profile specified')
migrator = PsqlDosMigrator(self.profile)

context = migrator._alembic_connect() if connect else nullcontext(migrator._alembic_config()) # pylint: disable=protected-access
with context as config: # type: ignore[attr-defined]
context = migrator._alembic_connect() if connect else contextlib.nullcontext(migrator._alembic_config()) # pylint: disable=protected-access
edan-bainglass marked this conversation as resolved.
Show resolved Hide resolved
with context as config:
command = getattr(alembic.command, command_name)
config.stdout = click.get_text_stream('stdout')
command(config, **kwargs)
Expand Down
30 changes: 18 additions & 12 deletions aiida/storage/psql_dos/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import pathlib
from typing import TYPE_CHECKING, Iterator, List, Optional, Sequence, Set, Union

from sqlalchemy import column, insert, update
from sqlalchemy.orm import Session, scoped_session, sessionmaker

from aiida.common.exceptions import ClosedStorage, ConfigurationError, IntegrityError
Expand Down Expand Up @@ -188,8 +189,9 @@ def _clear(self) -> None:

with self.transaction() as session:
session.execute(
DbSetting.__table__.update().where(DbSetting.key == REPOSITORY_UUID_KEY
).values(val=repository_uuid)
DbSetting.__table__.update().where(
DbSetting.key == REPOSITORY_UUID_KEY # type: ignore[attr-defined]
).values(val=repository_uuid)
)

def get_repository(self) -> 'DiskObjectStoreRepositoryBackend':
Expand Down Expand Up @@ -305,8 +307,8 @@ def bulk_insert(self, entity_type: EntityTypes, rows: List[dict], allow_defaults
# by contrast, in sqlite, bulk_insert is faster: https://docs.sqlalchemy.org/en/14/faq/performance.html
session = self.get_session()
with (nullcontext() if self.in_transaction else self.transaction()):
session.bulk_insert_mappings(mapper, rows, render_nulls=True, return_defaults=True)
return [row['id'] for row in rows]
result = session.execute(insert(mapper).returning(mapper, column('id')), rows).fetchall()
return [row.id for row in result]

def bulk_update(self, entity_type: EntityTypes, rows: List[dict]) -> None:
mapper, keys = self._get_mapper_from_entity(entity_type, True)
Expand All @@ -319,7 +321,7 @@ def bulk_update(self, entity_type: EntityTypes, rows: List[dict]) -> None:
raise IntegrityError(f'Incorrect fields given for {entity_type}: {set(row)} not subset of {keys}')
session = self.get_session()
with (nullcontext() if self.in_transaction else self.transaction()):
session.bulk_update_mappings(mapper, rows)
session.execute(update(mapper), rows)

def delete_nodes_and_connections(self, pks_to_delete: Sequence[int]) -> None:
# pylint: disable=no-value-for-parameter
Expand All @@ -331,14 +333,17 @@ def delete_nodes_and_connections(self, pks_to_delete: Sequence[int]) -> None:

session = self.get_session()
# Delete the membership of these nodes to groups.
session.query(DbGroupNode).filter(DbGroupNode.dbnode_id.in_(list(pks_to_delete))
session.query(DbGroupNode).filter(DbGroupNode.dbnode_id.in_(list(pks_to_delete)) # type: ignore[attr-defined]
edan-bainglass marked this conversation as resolved.
Show resolved Hide resolved
).delete(synchronize_session='fetch')
# Delete the links coming out of the nodes marked for deletion.
session.query(DbLink).filter(DbLink.input_id.in_(list(pks_to_delete))).delete(synchronize_session='fetch')
session.query(DbLink).filter(DbLink.input_id.in_(list(pks_to_delete))
).delete(synchronize_session='fetch') # type: ignore[attr-defined]
# Delete the links pointing to the nodes marked for deletion.
session.query(DbLink).filter(DbLink.output_id.in_(list(pks_to_delete))).delete(synchronize_session='fetch')
session.query(DbLink).filter(DbLink.output_id.in_(list(pks_to_delete))
).delete(synchronize_session='fetch') # type: ignore[attr-defined]
# Delete the actual nodes
session.query(DbNode).filter(DbNode.id.in_(list(pks_to_delete))).delete(synchronize_session='fetch')
session.query(DbNode).filter(DbNode.id.in_(list(pks_to_delete))
).delete(synchronize_session='fetch') # type: ignore[attr-defined]

def get_backend_entity(self, model: base.Base) -> BackendEntity:
"""
Expand All @@ -356,9 +361,10 @@ def set_global_variable(

session = self.get_session()
with (nullcontext() if self.in_transaction else self.transaction()):
if session.query(DbSetting).filter(DbSetting.key == key).count():
if session.query(DbSetting).filter(DbSetting.key == key).count(): # type: ignore[attr-defined]
if overwrite:
session.query(DbSetting).filter(DbSetting.key == key).update(dict(val=value))
session.query(DbSetting).filter(DbSetting.key == key
).update(dict(val=value)) # type: ignore[attr-defined]
else:
raise ValueError(f'The setting {key} already exists')
else:
Expand All @@ -369,7 +375,7 @@ def get_global_variable(self, key: str) -> Union[None, str, int, float]:

session = self.get_session()
with (nullcontext() if self.in_transaction else self.transaction()):
setting = session.query(DbSetting).filter(DbSetting.key == key).one_or_none()
setting = session.query(DbSetting).filter(DbSetting.key == key).one_or_none() # type: ignore[attr-defined]
if setting is None:
raise KeyError(f'No setting found with key {key}')
return setting.val
Expand Down
6 changes: 3 additions & 3 deletions aiida/storage/psql_dos/migrations/utils/legacy_workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,9 @@ def export_workflow_data(connection, profile):
DbWorkflowData = table('db_dbworkflowdata')
DbWorkflowStep = table('db_dbworkflowstep')

count_workflow = connection.execute(select(func.count()).select_from(DbWorkflow)).scalar()
count_workflow_data = connection.execute(select(func.count()).select_from(DbWorkflowData)).scalar()
count_workflow_step = connection.execute(select(func.count()).select_from(DbWorkflowStep)).scalar()
count_workflow = connection.execute(select(func.count()).select_from(DbWorkflow)).scalar() # pylint: disable=not-callable
edan-bainglass marked this conversation as resolved.
Show resolved Hide resolved
count_workflow_data = connection.execute(select(func.count()).select_from(DbWorkflowData)).scalar() # pylint: disable=not-callable
count_workflow_step = connection.execute(select(func.count()).select_from(DbWorkflowStep)).scalar() # pylint: disable=not-callable

# Nothing to do if all tables are empty
if count_workflow == 0 and count_workflow_data == 0 and count_workflow_step == 0:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def migrate_repository(connection, profile):
column('repository_metadata', JSONB),
)

node_count = connection.execute(select(func.count()).select_from(DbNode)).scalar()
node_count = connection.execute(select(func.count()).select_from(DbNode)).scalar() # pylint: disable=not-callable
missing_repo_folder = []
shard_count = 256

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,9 @@ def migrate_infer_calculation_entry_point(alembic_op):
fallback_cases.append([uuid, type_string, entry_point_string])

connection.execute(
DbNode.update().where(DbNode.c.type == alembic_op.inline_literal(type_string)
).values(process_type=alembic_op.inline_literal(entry_point_string))
DbNode.update().where(
DbNode.c.type == alembic_op.inline_literal(type_string) # type: ignore[attr-defined]
).values(process_type=alembic_op.inline_literal(entry_point_string))
)

if fallback_cases:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,15 @@ def upgrade():
column('attributes', JSONB),
)

nodes = connection.execute(
nodes = connection.execute( # type: ignore[var-annotated]
select(DbNode.c.id,
DbNode.c.uuid).where(DbNode.c.type == op.inline_literal('node.data.array.trajectory.TrajectoryData.'))
).fetchall()

for pk, uuid in nodes:
symbols = load_numpy_array_from_repository(repo_path, uuid, 'symbols').tolist()
connection.execute(
DbNode.update().where(DbNode.c.id == pk).values(
DbNode.update().where(DbNode.c.id == pk).values( # type: ignore[attr-defined]
attributes=func.jsonb_set(DbNode.c.attributes, op.inline_literal('{"symbols"}'), cast(symbols, JSONB))
)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def upgrade():
column('attributes', JSONB),
)

nodes = connection.execute(
nodes = connection.execute( # type: ignore[var-annotated]
select(DbNode.c.id,
DbNode.c.uuid).where(DbNode.c.type == op.inline_literal('node.data.array.trajectory.TrajectoryData.'))
).fetchall()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def upgrade():
sa.column('type', sa.String),
)

nodes = connection.execute(
nodes = connection.execute( # type: ignore[var-annotated]
sa.select(node_model.c.id, node_model.c.uuid).where(
node_model.c.type == op.inline_literal('node.data.array.trajectory.TrajectoryData.')
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def upgrade():
# sa.column('attributes', JSONB),
)

nodes = connection.execute(
nodes = connection.execute( # type: ignore[var-annotated]
sa.select(node_tbl.c.id, node_tbl.c.uuid).where(
node_tbl.c.type == op.inline_literal('node.data.array.trajectory.TrajectoryData.')
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,16 +88,22 @@ def upgrade():
op.add_column('db_dbnode', sa.Column('extras', postgresql.JSONB(astext_type=sa.Text()), nullable=True))

# transition attributes and extras to node
node_count = conn.execute(select(func.count()).select_from(node_tbl)).scalar()
node_count = conn.execute(select(func.count()).select_from(node_tbl)).scalar() # pylint: disable=not-callable
if node_count:
with get_progress_reporter()(total=node_count, desc='Updating attributes and extras') as progress:
for node in conn.execute(select(node_tbl)).all():
attr_list = conn.execute(select(attr_tbl).where(attr_tbl.c.dbnode_id == node.id)).all()
attr_list = conn.execute( # type: ignore[var-annotated]
select(attr_tbl).where(attr_tbl.c.dbnode_id == node.id)
).all()
attributes, _ = attributes_to_dict(sorted(attr_list, key=lambda a: a.key))
extra_list = conn.execute(select(extra_tbl).where(extra_tbl.c.dbnode_id == node.id)).all()
extra_list = conn.execute( # type: ignore[var-annotated]
select(extra_tbl).where(extra_tbl.c.dbnode_id == node.id)
).all()
extras, _ = attributes_to_dict(sorted(extra_list, key=lambda a: a.key))
conn.execute(
node_tbl.update().where(node_tbl.c.id == node.id).values(attributes=attributes, extras=extras)
node_tbl.update().where( # type: ignore[attr-defined]
node_tbl.c.id == node.id
).values(attributes=attributes, extras=extras)
)
progress.update()

Expand All @@ -107,7 +113,7 @@ def upgrade():
op.add_column('db_dbsetting', sa.Column('val', postgresql.JSONB(astext_type=sa.Text()), nullable=True))

# transition settings
setting_count = conn.execute(select(func.count()).select_from(setting_tbl)).scalar()
setting_count = conn.execute(select(func.count()).select_from(setting_tbl)).scalar() # pylint: disable=not-callable
if setting_count:
with get_progress_reporter()(total=setting_count, desc='Updating settings') as progress:
for setting in conn.execute(select(setting_tbl)).all():
Expand All @@ -129,8 +135,9 @@ def upgrade():
else:
val = setting.dval
conn.execute(
setting_tbl.update().where(setting_tbl.c.id == setting.id
).values(val=cast(val, postgresql.JSONB(astext_type=sa.Text())))
setting_tbl.update().where( # type: ignore[attr-defined]
setting_tbl.c.id == setting.id
).values(val=cast(val, postgresql.JSONB(astext_type=sa.Text())))
)
progress.update()

Expand Down
10 changes: 5 additions & 5 deletions aiida/storage/psql_dos/orm/querybuilder/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -760,10 +760,10 @@ def get_creation_statistics(self, user_pk: Optional[int] = None) -> Dict[str, An
retdict: Dict[Any, Any] = {}

total_query = session.query(self.Node)
types_query = session.query(self.Node.node_type.label('typestring'), sa_func.count(self.Node.id)) # pylint: disable=no-member
types_query = session.query(self.Node.node_type.label('typestring'), sa_func.count(self.Node.id)) # pylint: disable=no-member,not-callable
stat_query = session.query(
sa_func.date_trunc('day', self.Node.ctime).label('cday'), # pylint: disable=no-member
sa_func.count(self.Node.id), # pylint: disable=no-member
sa_func.count(self.Node.id), # pylint: disable=no-member,not-callable
)

if user_pk is not None:
Expand Down Expand Up @@ -1088,11 +1088,11 @@ def _get_projection(
if func is None:
pass
elif func == 'max':
entity_to_project = sa_func.max(entity_to_project)
entity_to_project = sa_func.max(entity_to_project) # pylint: disable=not-callable
elif func == 'min':
entity_to_project = sa_func.max(entity_to_project)
entity_to_project = sa_func.max(entity_to_project) # pylint: disable=not-callable
elif func == 'count':
entity_to_project = sa_func.count(entity_to_project)
entity_to_project = sa_func.count(entity_to_project) # pylint: disable=not-callable
else:
raise ValueError(f'\nInvalid function specification {func}')

Expand Down
6 changes: 2 additions & 4 deletions aiida/storage/psql_dos/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,22 +45,20 @@ def create_sqlalchemy_engine(config: PsqlConfig):
password=config['database_password'],
hostname=hostname,
port=config['database_port'],
name=config['database_name']
name=config['database_name'],
)
return create_engine(
engine_url,
json_serializer=json.dumps,
json_deserializer=json.loads,
future=True,
encoding='utf-8',
edan-bainglass marked this conversation as resolved.
Show resolved Hide resolved
**config.get('engine_kwargs', {}),
)


def create_scoped_session_factory(engine, **kwargs):
"""Create scoped SQLAlchemy session factory"""
from sqlalchemy.orm import scoped_session, sessionmaker
return scoped_session(sessionmaker(bind=engine, future=True, **kwargs))
return scoped_session(sessionmaker(bind=engine, **kwargs))


def flag_modified(instance, key):
Expand Down
7 changes: 4 additions & 3 deletions aiida/storage/sqlite_temp/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import shutil
from typing import Any, BinaryIO, Iterator, Sequence

from sqlalchemy import column, insert, update
from sqlalchemy.orm import Session

from aiida.common.exceptions import ClosedStorage, IntegrityError
Expand Down Expand Up @@ -260,8 +261,8 @@ def bulk_insert(self, entity_type: EntityTypes, rows: list[dict], allow_defaults
raise IntegrityError(f'Incorrect fields given for {entity_type}: {set(row)} != {keys}')
session = self.get_session()
with (nullcontext() if self.in_transaction else self.transaction()):
session.bulk_insert_mappings(mapper, rows, render_nulls=True, return_defaults=True)
return [row['id'] for row in rows]
result = session.execute(insert(mapper).returning(mapper, column('id')), rows).fetchall()
return [row.id for row in result]

def bulk_update(self, entity_type: EntityTypes, rows: list[dict]) -> None:
mapper, keys = self._get_mapper_from_entity(entity_type, True)
Expand All @@ -274,7 +275,7 @@ def bulk_update(self, entity_type: EntityTypes, rows: list[dict]) -> None:
raise IntegrityError(f'Incorrect fields given for {entity_type}: {set(row)} not subset of {keys}')
session = self.get_session()
with (nullcontext() if self.in_transaction else self.transaction()):
session.bulk_update_mappings(mapper, rows)
session.execute(update(mapper), rows)

def delete_nodes_and_connections(self, pks_to_delete: Sequence[int]):
raise NotImplementedError
Expand Down
4 changes: 2 additions & 2 deletions aiida/storage/sqlite_zip/migrations/legacy_to_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def _json_to_sqlite( # pylint: disable=too-many-branches,too-many-locals

# get mapping of node IDs to node UUIDs
node_uuid_map = { # pylint: disable=unnecessary-comprehension
uuid: pk for uuid, pk in connection.execute(select(v1_schema.DbNode.uuid, v1_schema.DbNode.id))
uuid: pk for uuid, pk in connection.execute(select(v1_schema.DbNode.uuid, v1_schema.DbNode.id)) # pylint: disable=not-an-iterable
}

# links
Expand Down Expand Up @@ -211,7 +211,7 @@ def _transform_link(link_row):
if data['groups_uuid']:
# get mapping of node IDs to node UUIDs
group_uuid_map = { # pylint: disable=unnecessary-comprehension
uuid: pk for uuid, pk in connection.execute(select(v1_schema.DbGroup.uuid, v1_schema.DbGroup.id))
uuid: pk for uuid, pk in connection.execute(select(v1_schema.DbGroup.uuid, v1_schema.DbGroup.id)) # pylint: disable=not-an-iterable
}
length = sum(len(uuids) for uuids in data['groups_uuid'].values())
unknown_nodes: Dict[str, set] = {}
Expand Down
Loading
Loading