Skip to content

Commit

Permalink
Merge pull request cylc#5330 from MetRonnie/db-2
Browse files Browse the repository at this point in the history
DB interface & safety improvements
  • Loading branch information
hjoliver authored Jan 27, 2023
2 parents c122bec + 445ae35 commit 8c855b4
Show file tree
Hide file tree
Showing 13 changed files with 196 additions and 166 deletions.
10 changes: 4 additions & 6 deletions cylc/flow/network/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,14 +545,12 @@ def _callback(_, entry):

# NOTE: use the public DB for reading
# (only the scheduler process/thread should access the private database)
db_file = Path(get_workflow_run_dir(flow['name'], 'log', 'db'))
db_file = Path(get_workflow_run_dir(
flow['name'], WorkflowFiles.LogDir.DIRNAME, WorkflowFiles.LogDir.DB
))
if db_file.exists():
dao = CylcWorkflowDAO(db_file, is_public=False)
try:
dao.connect()
with CylcWorkflowDAO(db_file, is_public=True) as dao:
dao.select_workflow_params(_callback)
flow['workflow_params'] = params
finally:
dao.close()

return flow
39 changes: 30 additions & 9 deletions cylc/flow/rundb.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,15 @@
from pprint import pformat
import sqlite3
import traceback
from typing import List, Tuple
from typing import TYPE_CHECKING, List, Optional, Tuple, Union

from cylc.flow import LOG
from cylc.flow.util import deserialise
import cylc.flow.flags

if TYPE_CHECKING:
from pathlib import Path


@dataclass
class CylcWorkflowDAOTableColumn:
Expand Down Expand Up @@ -318,26 +321,44 @@ class CylcWorkflowDAO:
],
}

def __init__(self, db_file_name, is_public=False):
def __init__(
self,
db_file_name: Union['Path', str],
is_public: bool = False,
create_tables: bool = False
):
"""Initialise database access object.
An instance of this class can also be opened as a context manager
which will automatically close the DB connection.
Args:
db_file_name (str): Path to the database file.
is_public (bool): If True, allow retries, etc.
db_file_name: Path to the database file.
is_public: If True, allow retries.
create_tables: If True, create the tables if they
don't already exist.
"""
self.db_file_name = expandvars(db_file_name)
self.is_public = is_public
self.conn = None
self.conn: Optional[sqlite3.Connection] = None
self.n_tries = 0

self.tables = {}
for name, attrs in sorted(self.TABLES_ATTRS.items()):
self.tables[name] = CylcWorkflowDAOTable(name, attrs)
self.tables = {
name: CylcWorkflowDAOTable(name, attrs)
for name, attrs in sorted(self.TABLES_ATTRS.items())
}

if not self.is_public:
if create_tables:
self.create_tables()

def __enter__(self):
return self

def __exit__(self, exc_type, exc_value, traceback):
"""Close DB connection when leaving context manager."""
self.close()

def add_delete_item(self, table_name, where_args=None):
"""Queue a DELETE item for a given table.
Expand Down
5 changes: 3 additions & 2 deletions cylc/flow/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1035,8 +1035,9 @@ def command_reload_workflow(self) -> None:
LOG.info("Reloading the workflow definition.")
old_tasks = set(self.config.get_task_name_list())
# Things that can't change on workflow reload:
pri_dao = self.workflow_db_mgr._get_pri_dao()
pri_dao.select_workflow_params(self._load_workflow_params)
self.workflow_db_mgr.pri_dao.select_workflow_params(
self._load_workflow_params
)

try:
self.load_flow_file(is_reload=True)
Expand Down
8 changes: 3 additions & 5 deletions cylc/flow/scheduler_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,9 +456,8 @@ def _version_check(
if not db_file.is_file():
# not a restart
return True
wdbm = WorkflowDatabaseManager(db_file.parent)
this_version = parse_version(__version__)
last_run_version = wdbm.check_workflow_db_compatibility()
last_run_version = WorkflowDatabaseManager.check_db_compatibility(db_file)

for itt, (this, that) in enumerate(zip_longest(
this_version.release,
Expand Down Expand Up @@ -524,16 +523,15 @@ def _version_check(
return True


def _upgrade_database(db_file):
def _upgrade_database(db_file: Path) -> None:
"""Upgrade the workflow database if needed.
Note:
Do this after the user has confirmed that they want to upgrade!
"""
if db_file.is_file():
wdbm = WorkflowDatabaseManager(db_file.parent)
wdbm.upgrade()
WorkflowDatabaseManager.upgrade(db_file)


def _print_startup_message(options):
Expand Down
8 changes: 4 additions & 4 deletions cylc/flow/scripts/cat_log.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,10 +265,10 @@ def get_task_job_attrs(workflow_id, point, task, submit_num):
live_job_id is the job ID if job is running, else None.
"""
workflow_dao = CylcWorkflowDAO(
get_workflow_run_pub_db_path(workflow_id), is_public=True)
task_job_data = workflow_dao.select_task_job(point, task, submit_num)
workflow_dao.close()
with CylcWorkflowDAO(
get_workflow_run_pub_db_path(workflow_id), is_public=True
) as dao:
task_job_data = dao.select_task_job(point, task, submit_num)
if task_job_data is None:
return (None, None, None)
job_runner_name = task_job_data["job_runner_name"]
Expand Down
4 changes: 2 additions & 2 deletions cylc/flow/scripts/report_timings.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,8 @@ def main(parser: COP, options: 'Values', workflow_id: str) -> None:
# No output specified - choose summary by default
options.show_summary = True

run_db = _get_dao(workflow_id)
row_buf = format_rows(*run_db.select_task_times())
with _get_dao(workflow_id) as dao:
row_buf = format_rows(*dao.select_task_times())
with smart_open(options.output_filename) as output:
if options.show_raw:
output.write(row_buf.getvalue())
Expand Down
20 changes: 13 additions & 7 deletions cylc/flow/templatevars.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,26 @@

from ast import literal_eval
from optparse import Values
from typing import Any, Dict
from typing import TYPE_CHECKING, Any, Dict

from cylc.flow.exceptions import InputError


from cylc.flow.rundb import CylcWorkflowDAO
from cylc.flow.workflow_files import WorkflowFiles

if TYPE_CHECKING:
from pathlib import Path


def get_template_vars_from_db(run_dir):
def get_template_vars_from_db(run_dir: 'Path') -> dict:
"""Get template vars stored in a workflow run database.
"""
template_vars = {}
if (run_dir / 'log/db').exists():
dao = CylcWorkflowDAO(str(run_dir / 'log/db'))
pub_db_file = (
run_dir / WorkflowFiles.LogDir.DIRNAME / WorkflowFiles.LogDir.DB
)
template_vars: dict = {}
if not pub_db_file.exists():
return template_vars
with CylcWorkflowDAO(pub_db_file, is_public=True) as dao:
dao.select_workflow_template_vars(
lambda _, row: template_vars.__setitem__(row[0], eval_var(row[1]))
)
Expand Down
114 changes: 53 additions & 61 deletions cylc/flow/workflow_db_mgr.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,14 @@
* Manage existing run database files on restart.
"""

from contextlib import contextmanager
import json
import os
from pkg_resources import parse_version
from shutil import copy, rmtree
from sqlite3 import OperationalError
from tempfile import mkstemp
from typing import (
Any, AnyStr, Dict, Generator, List, Set, TYPE_CHECKING, Tuple, Union
Any, AnyStr, Dict, List, Set, TYPE_CHECKING, Tuple, Union
)

from cylc.flow import LOG
Expand All @@ -42,12 +42,14 @@
from cylc.flow.util import serialise

if TYPE_CHECKING:
from pathlib import Path
from cylc.flow.cycling import PointBase
from cylc.flow.scheduler import Scheduler
from cylc.flow.task_pool import TaskPool

# # TODO: narrow down Any (should be str | int) after implementing type
# # annotations in cylc.flow.task_state.TaskState
Version = Any
# TODO: narrow down Any (should be str | int) after implementing type
# annotations in cylc.flow.task_state.TaskState
DbArgDict = Dict[str, Any]
DbUpdateTuple = Tuple[DbArgDict, DbArgDict]

Expand Down Expand Up @@ -194,23 +196,13 @@ def delete_workflow_stop_task(self):
"""Delete workflow stop task from workflow_params table."""
self.delete_workflow_params(self.KEY_STOP_TASK)

def _get_pri_dao(self) -> CylcWorkflowDAO:
def get_pri_dao(self) -> CylcWorkflowDAO:
"""Return the primary DAO.
Note: the DAO should be closed after use. It is better to use the
context manager method below, which handles this for you.
NOTE: the DAO should be closed after use. You can use this function as
a context manager, which handles this for you.
"""
return CylcWorkflowDAO(self.pri_path)

@contextmanager
def get_pri_dao(self) -> Generator[CylcWorkflowDAO, None, None]:
"""Return the primary DAO and close it after the context manager
exits."""
pri_dao = self._get_pri_dao()
try:
yield pri_dao
finally:
pri_dao.close()
return CylcWorkflowDAO(self.pri_path, create_tables=True)

@staticmethod
def _namedtuple2json(obj):
Expand All @@ -227,7 +219,7 @@ def _namedtuple2json(obj):
else:
return json.dumps([type(obj).__name__, obj.__getnewargs__()])

def on_workflow_start(self, is_restart):
def on_workflow_start(self, is_restart: bool) -> None:
"""Initialise data access objects.
Ensure that:
Expand All @@ -243,7 +235,7 @@ def on_workflow_start(self, is_restart):
# ... however, in case there is a directory at the path for
# some bizarre reason:
rmtree(self.pri_path, ignore_errors=True)
self.pri_dao = self._get_pri_dao()
self.pri_dao = self.get_pri_dao()
os.chmod(self.pri_path, PERM_PRIVATE)
self.pub_dao = CylcWorkflowDAO(self.pub_path, is_public=True)
self.copy_pri_to_pub()
Expand Down Expand Up @@ -693,20 +685,32 @@ def restart_check(self) -> None:
self.put_workflow_params_1(self.KEY_RESTART_COUNT, self.n_restart)
self.process_queued_ops()

def _get_last_run_version(self, pri_dao: CylcWorkflowDAO) -> str:
return pri_dao.connect().execute(
rf'''
SELECT
value
FROM
{self.TABLE_WORKFLOW_PARAMS}
WHERE
key == ?
''', # nosec (table name is a code constant)
[self.KEY_CYLC_VERSION]
).fetchone()[0]

def upgrade_pre_803(self, pri_dao: CylcWorkflowDAO) -> None:
@classmethod
def _get_last_run_version(cls, pri_dao: CylcWorkflowDAO) -> Version:
"""Return the version of Cylc this DB was last run with.
Args:
pri_dao: Open private database connection object.
"""
try:
last_run_ver = pri_dao.connect().execute(
rf'''
SELECT
value
FROM
{cls.TABLE_WORKFLOW_PARAMS}
WHERE
key == ?
''', # nosec (table name is a code constant)
[cls.KEY_CYLC_VERSION]
).fetchone()[0]
except (TypeError, OperationalError):
raise ServiceFileError(f"{INCOMPAT_MSG}, or is corrupted.")
return parse_version(last_run_ver)

@classmethod
def upgrade_pre_803(cls, pri_dao: CylcWorkflowDAO) -> None:
"""Upgrade on restart from a pre-8.0.3 database.
Add "is_manual_submit" column to the task states table.
Expand All @@ -716,10 +720,10 @@ def upgrade_pre_803(self, pri_dao: CylcWorkflowDAO) -> None:
c_name = "is_manual_submit"
LOG.info(
f"DB upgrade (pre-8.0.3): "
f"add {c_name} column to {self.TABLE_TASK_STATES}"
f"add {c_name} column to {cls.TABLE_TASK_STATES}"
)
conn.execute(
rf"ALTER TABLE {self.TABLE_TASK_STATES} "
rf"ALTER TABLE {cls.TABLE_TASK_STATES} "
rf"ADD COLUMN {c_name} INTEGER "
r"DEFAULT 0 NOT NULL"
)
Expand Down Expand Up @@ -760,30 +764,19 @@ def upgrade_pre_810(pri_dao: CylcWorkflowDAO) -> None:
)
conn.commit()

def _get_last_run_ver(self, pri_dao):
"""Return the version of Cylc this DB was last run with.
Args:
pri_dao: Open private database connection object.
"""
try:
last_run_ver = self._get_last_run_version(pri_dao)
except TypeError:
raise ServiceFileError(f"{INCOMPAT_MSG}, or is corrupted.")
return parse_version(last_run_ver)

def upgrade(self):
@classmethod
def upgrade(cls, db_file: Union['Path', str]) -> None:
"""Upgrade this database to this Cylc version.
"""
with self.get_pri_dao() as pri_dao:
last_run_ver = self._get_last_run_ver(pri_dao)
with CylcWorkflowDAO(db_file, create_tables=True) as pri_dao:
last_run_ver = cls._get_last_run_version(pri_dao)
if last_run_ver < parse_version("8.0.3.dev"):
self.upgrade_pre_803(pri_dao)
cls.upgrade_pre_803(pri_dao)
if last_run_ver < parse_version("8.1.0.dev"):
self.upgrade_pre_810(pri_dao)
cls.upgrade_pre_810(pri_dao)

def check_workflow_db_compatibility(self):
@classmethod
def check_db_compatibility(cls, db_file: Union['Path', str]) -> Version:
"""Check this DB is compatible with this Cylc version.
Raises:
Expand All @@ -792,19 +785,18 @@ def check_workflow_db_compatibility(self):
current version of Cylc.
"""
if not os.path.isfile(self.pri_path):
raise FileNotFoundError(self.pri_path)
if not os.path.isfile(db_file):
raise FileNotFoundError(db_file)

with self.get_pri_dao() as pri_dao:
last_run_ver = self._get_last_run_ver(pri_dao)
with CylcWorkflowDAO(db_file) as dao:
last_run_ver = cls._get_last_run_version(dao)
# WARNING: Do no upgrade the DB here

restart_incompat_ver = parse_version(
CylcWorkflowDAO.RESTART_INCOMPAT_VERSION
)
if last_run_ver <= restart_incompat_ver:
raise ServiceFileError(
f"{INCOMPAT_MSG} (workflow last run with "
f"Cylc {last_run_ver})."
f"{INCOMPAT_MSG} (workflow last run with Cylc {last_run_ver})."
)
return last_run_ver
Loading

0 comments on commit 8c855b4

Please sign in to comment.