From 6617d1af051217015dded28bbb6173463f479445 Mon Sep 17 00:00:00 2001
From: Ronnie Dutta <61982285+MetRonnie@users.noreply.github.com>
Date: Tue, 24 Jan 2023 17:45:49 +0000
Subject: [PATCH 1/4] Tidy
---
cylc/flow/workflow_db_mgr.py | 58 +++++++++++++++++-------------------
1 file changed, 28 insertions(+), 30 deletions(-)
diff --git a/cylc/flow/workflow_db_mgr.py b/cylc/flow/workflow_db_mgr.py
index ec8cb03d11d..1d11a258326 100644
--- a/cylc/flow/workflow_db_mgr.py
+++ b/cylc/flow/workflow_db_mgr.py
@@ -46,8 +46,9 @@
from cylc.flow.scheduler import Scheduler
from cylc.flow.task_pool import TaskPool
-# # TODO: narrow down Any (should be str | int) after implementing type
-# # annotations in cylc.flow.task_state.TaskState
+Version = Any
+# TODO: narrow down Any (should be str | int) after implementing type
+# annotations in cylc.flow.task_state.TaskState
DbArgDict = Dict[str, Any]
DbUpdateTuple = Tuple[DbArgDict, DbArgDict]
@@ -693,18 +694,28 @@ def restart_check(self) -> None:
self.put_workflow_params_1(self.KEY_RESTART_COUNT, self.n_restart)
self.process_queued_ops()
- def _get_last_run_version(self, pri_dao: CylcWorkflowDAO) -> str:
- return pri_dao.connect().execute(
- rf'''
- SELECT
- value
- FROM
- {self.TABLE_WORKFLOW_PARAMS}
- WHERE
- key == ?
- ''', # nosec (table name is a code constant)
- [self.KEY_CYLC_VERSION]
- ).fetchone()[0]
+ def _get_last_run_version(self, pri_dao: CylcWorkflowDAO) -> Version:
+ """Return the version of Cylc this DB was last run with.
+
+ Args:
+ pri_dao: Open private database connection object.
+
+ """
+ try:
+ last_run_ver = pri_dao.connect().execute(
+ rf'''
+ SELECT
+ value
+ FROM
+ {self.TABLE_WORKFLOW_PARAMS}
+ WHERE
+ key == ?
+ ''', # nosec (table name is a code constant)
+ [self.KEY_CYLC_VERSION]
+ ).fetchone()[0]
+ except TypeError:
+ raise ServiceFileError(f"{INCOMPAT_MSG}, or is corrupted.")
+ return parse_version(last_run_ver)
def upgrade_pre_803(self, pri_dao: CylcWorkflowDAO) -> None:
"""Upgrade on restart from a pre-8.0.3 database.
@@ -760,30 +771,17 @@ def upgrade_pre_810(pri_dao: CylcWorkflowDAO) -> None:
)
conn.commit()
- def _get_last_run_ver(self, pri_dao):
- """Return the version of Cylc this DB was last run with.
-
- Args:
- pri_dao: Open private database connection object.
-
- """
- try:
- last_run_ver = self._get_last_run_version(pri_dao)
- except TypeError:
- raise ServiceFileError(f"{INCOMPAT_MSG}, or is corrupted.")
- return parse_version(last_run_ver)
-
def upgrade(self):
"""Upgrade this database to this Cylc version.
"""
with self.get_pri_dao() as pri_dao:
- last_run_ver = self._get_last_run_ver(pri_dao)
+ last_run_ver = self._get_last_run_version(pri_dao)
if last_run_ver < parse_version("8.0.3.dev"):
self.upgrade_pre_803(pri_dao)
if last_run_ver < parse_version("8.1.0.dev"):
self.upgrade_pre_810(pri_dao)
- def check_workflow_db_compatibility(self):
+ def check_workflow_db_compatibility(self) -> Version:
"""Check this DB is compatible with this Cylc version.
Raises:
@@ -796,7 +794,7 @@ def check_workflow_db_compatibility(self):
raise FileNotFoundError(self.pri_path)
with self.get_pri_dao() as pri_dao:
- last_run_ver = self._get_last_run_ver(pri_dao)
+ last_run_ver = self._get_last_run_version(pri_dao)
# WARNING: Do no upgrade the DB here
restart_incompat_ver = parse_version(
From a125310f86010387b8cad2503863c813960cef4f Mon Sep 17 00:00:00 2001
From: Ronnie Dutta <61982285+MetRonnie@users.noreply.github.com>
Date: Wed, 25 Jan 2023 18:22:23 +0000
Subject: [PATCH 2/4] Improve CylcWorkflowDAO interface
- Do not create tables by default, only if specified via kwarg
- Provide context manager functionality to automatically close DB connection
- A WorkflowDatabaseManager should only be created by Scheduler; just use CylcWorkflowDAO elsewhere
---
cylc/flow/network/scan.py | 10 ++---
cylc/flow/rundb.py | 39 +++++++++++++----
cylc/flow/scheduler.py | 5 ++-
cylc/flow/scheduler_cli.py | 8 ++--
cylc/flow/scripts/cat_log.py | 8 ++--
cylc/flow/scripts/report_timings.py | 4 +-
cylc/flow/templatevars.py | 20 ++++++---
cylc/flow/workflow_db_mgr.py | 68 +++++++++++++----------------
cylc/flow/workflow_files.py | 15 ++++---
tests/unit/test_scheduler_cli.py | 11 ++---
tests/unit/test_templatevars.py | 48 ++++++++------------
tests/unit/test_workflow_db_mgr.py | 18 ++++----
12 files changed, 130 insertions(+), 124 deletions(-)
diff --git a/cylc/flow/network/scan.py b/cylc/flow/network/scan.py
index 9b7fdc45a1d..4dfe5510848 100644
--- a/cylc/flow/network/scan.py
+++ b/cylc/flow/network/scan.py
@@ -545,14 +545,12 @@ def _callback(_, entry):
# NOTE: use the public DB for reading
# (only the scheduler process/thread should access the private database)
- db_file = Path(get_workflow_run_dir(flow['name'], 'log', 'db'))
+ db_file = Path(get_workflow_run_dir(
+ flow['name'], WorkflowFiles.LogDir.DIRNAME, WorkflowFiles.LogDir.DB
+ ))
if db_file.exists():
- dao = CylcWorkflowDAO(db_file, is_public=False)
- try:
- dao.connect()
+ with CylcWorkflowDAO(db_file, is_public=True) as dao:
dao.select_workflow_params(_callback)
flow['workflow_params'] = params
- finally:
- dao.close()
return flow
diff --git a/cylc/flow/rundb.py b/cylc/flow/rundb.py
index 8dd60d6422b..ceb4c2c3626 100644
--- a/cylc/flow/rundb.py
+++ b/cylc/flow/rundb.py
@@ -21,12 +21,15 @@
from pprint import pformat
import sqlite3
import traceback
-from typing import List, Tuple
+from typing import TYPE_CHECKING, List, Optional, Tuple, Union
from cylc.flow import LOG
from cylc.flow.util import deserialise
import cylc.flow.flags
+if TYPE_CHECKING:
+ from pathlib import Path
+
@dataclass
class CylcWorkflowDAOTableColumn:
@@ -318,26 +321,44 @@ class CylcWorkflowDAO:
],
}
- def __init__(self, db_file_name, is_public=False):
+ def __init__(
+ self,
+ db_file_name: Union['Path', str],
+ is_public: bool = False,
+ create_tables: bool = False
+ ):
"""Initialise database access object.
+ An instance of this class can also be opened as a context manager
+ which will automatically close the DB connection.
+
Args:
- db_file_name (str): Path to the database file.
- is_public (bool): If True, allow retries, etc.
+ db_file_name: Path to the database file.
+ is_public: If True, allow retries, etc.
+ create_tables: If True, create the tables if they
+ don't already exist.
"""
self.db_file_name = expandvars(db_file_name)
self.is_public = is_public
- self.conn = None
+ self.conn: Optional[sqlite3.Connection] = None
self.n_tries = 0
- self.tables = {}
- for name, attrs in sorted(self.TABLES_ATTRS.items()):
- self.tables[name] = CylcWorkflowDAOTable(name, attrs)
+ self.tables = {
+ name: CylcWorkflowDAOTable(name, attrs)
+ for name, attrs in sorted(self.TABLES_ATTRS.items())
+ }
- if not self.is_public:
+ if create_tables:
self.create_tables()
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ """Close DB connection when leaving context manager."""
+ self.close()
+
def add_delete_item(self, table_name, where_args=None):
"""Queue a DELETE item for a given table.
diff --git a/cylc/flow/scheduler.py b/cylc/flow/scheduler.py
index f2e5a6bb05c..9acac8bd283 100644
--- a/cylc/flow/scheduler.py
+++ b/cylc/flow/scheduler.py
@@ -1035,8 +1035,9 @@ def command_reload_workflow(self) -> None:
LOG.info("Reloading the workflow definition.")
old_tasks = set(self.config.get_task_name_list())
# Things that can't change on workflow reload:
- pri_dao = self.workflow_db_mgr._get_pri_dao()
- pri_dao.select_workflow_params(self._load_workflow_params)
+ self.workflow_db_mgr.pri_dao.select_workflow_params(
+ self._load_workflow_params
+ )
try:
self.load_flow_file(is_reload=True)
diff --git a/cylc/flow/scheduler_cli.py b/cylc/flow/scheduler_cli.py
index ec5ac2071e3..7b2afc690f2 100644
--- a/cylc/flow/scheduler_cli.py
+++ b/cylc/flow/scheduler_cli.py
@@ -456,9 +456,8 @@ def _version_check(
if not db_file.is_file():
# not a restart
return True
- wdbm = WorkflowDatabaseManager(db_file.parent)
this_version = parse_version(__version__)
- last_run_version = wdbm.check_workflow_db_compatibility()
+ last_run_version = WorkflowDatabaseManager.check_db_compatibility(db_file)
for itt, (this, that) in enumerate(zip_longest(
this_version.release,
@@ -524,7 +523,7 @@ def _version_check(
return True
-def _upgrade_database(db_file):
+def _upgrade_database(db_file: Path) -> None:
"""Upgrade the workflow database if needed.
Note:
@@ -532,8 +531,7 @@ def _upgrade_database(db_file):
"""
if db_file.is_file():
- wdbm = WorkflowDatabaseManager(db_file.parent)
- wdbm.upgrade()
+ WorkflowDatabaseManager.upgrade(db_file)
def _print_startup_message(options):
diff --git a/cylc/flow/scripts/cat_log.py b/cylc/flow/scripts/cat_log.py
index 09ca8d915ba..24fdc6fbb38 100755
--- a/cylc/flow/scripts/cat_log.py
+++ b/cylc/flow/scripts/cat_log.py
@@ -265,10 +265,10 @@ def get_task_job_attrs(workflow_id, point, task, submit_num):
live_job_id is the job ID if job is running, else None.
"""
- workflow_dao = CylcWorkflowDAO(
- get_workflow_run_pub_db_path(workflow_id), is_public=True)
- task_job_data = workflow_dao.select_task_job(point, task, submit_num)
- workflow_dao.close()
+ with CylcWorkflowDAO(
+ get_workflow_run_pub_db_path(workflow_id), is_public=True
+ ) as dao:
+ task_job_data = dao.select_task_job(point, task, submit_num)
if task_job_data is None:
return (None, None, None)
job_runner_name = task_job_data["job_runner_name"]
diff --git a/cylc/flow/scripts/report_timings.py b/cylc/flow/scripts/report_timings.py
index f2e9dcf8cd4..37026d1f48c 100755
--- a/cylc/flow/scripts/report_timings.py
+++ b/cylc/flow/scripts/report_timings.py
@@ -132,8 +132,8 @@ def main(parser: COP, options: 'Values', workflow_id: str) -> None:
# No output specified - choose summary by default
options.show_summary = True
- run_db = _get_dao(workflow_id)
- row_buf = format_rows(*run_db.select_task_times())
+ with _get_dao(workflow_id) as dao:
+ row_buf = format_rows(*dao.select_task_times())
with smart_open(options.output_filename) as output:
if options.show_raw:
output.write(row_buf.getvalue())
diff --git a/cylc/flow/templatevars.py b/cylc/flow/templatevars.py
index 97469285d03..5288185f3c3 100644
--- a/cylc/flow/templatevars.py
+++ b/cylc/flow/templatevars.py
@@ -17,20 +17,26 @@
from ast import literal_eval
from optparse import Values
-from typing import Any, Dict
+from typing import TYPE_CHECKING, Any, Dict
from cylc.flow.exceptions import InputError
-
-
from cylc.flow.rundb import CylcWorkflowDAO
+from cylc.flow.workflow_files import WorkflowFiles
+
+if TYPE_CHECKING:
+ from pathlib import Path
-def get_template_vars_from_db(run_dir):
+def get_template_vars_from_db(run_dir: 'Path') -> dict:
"""Get template vars stored in a workflow run database.
"""
- template_vars = {}
- if (run_dir / 'log/db').exists():
- dao = CylcWorkflowDAO(str(run_dir / 'log/db'))
+ pub_db_file = (
+ run_dir / WorkflowFiles.LogDir.DIRNAME / WorkflowFiles.LogDir.DB
+ )
+ template_vars: dict = {}
+ if not pub_db_file.exists():
+ return template_vars
+ with CylcWorkflowDAO(pub_db_file, is_public=True) as dao:
dao.select_workflow_template_vars(
lambda _, row: template_vars.__setitem__(row[0], eval_var(row[1]))
)
diff --git a/cylc/flow/workflow_db_mgr.py b/cylc/flow/workflow_db_mgr.py
index 1d11a258326..ad46214e561 100644
--- a/cylc/flow/workflow_db_mgr.py
+++ b/cylc/flow/workflow_db_mgr.py
@@ -23,14 +23,14 @@
* Manage existing run database files on restart.
"""
-from contextlib import contextmanager
import json
import os
from pkg_resources import parse_version
from shutil import copy, rmtree
+from sqlite3 import OperationalError
from tempfile import mkstemp
from typing import (
- Any, AnyStr, Dict, Generator, List, Set, TYPE_CHECKING, Tuple, Union
+ Any, AnyStr, Dict, List, Set, TYPE_CHECKING, Tuple, Union
)
from cylc.flow import LOG
@@ -42,6 +42,7 @@
from cylc.flow.util import serialise
if TYPE_CHECKING:
+ from pathlib import Path
from cylc.flow.cycling import PointBase
from cylc.flow.scheduler import Scheduler
from cylc.flow.task_pool import TaskPool
@@ -195,23 +196,13 @@ def delete_workflow_stop_task(self):
"""Delete workflow stop task from workflow_params table."""
self.delete_workflow_params(self.KEY_STOP_TASK)
- def _get_pri_dao(self) -> CylcWorkflowDAO:
+ def get_pri_dao(self) -> CylcWorkflowDAO:
"""Return the primary DAO.
- Note: the DAO should be closed after use. It is better to use the
- context manager method below, which handles this for you.
+ NOTE: the DAO should be closed after use. You can use this function as
+ a context manager, which handles this for you.
"""
- return CylcWorkflowDAO(self.pri_path)
-
- @contextmanager
- def get_pri_dao(self) -> Generator[CylcWorkflowDAO, None, None]:
- """Return the primary DAO and close it after the context manager
- exits."""
- pri_dao = self._get_pri_dao()
- try:
- yield pri_dao
- finally:
- pri_dao.close()
+ return CylcWorkflowDAO(self.pri_path, create_tables=True)
@staticmethod
def _namedtuple2json(obj):
@@ -228,7 +219,7 @@ def _namedtuple2json(obj):
else:
return json.dumps([type(obj).__name__, obj.__getnewargs__()])
- def on_workflow_start(self, is_restart):
+ def on_workflow_start(self, is_restart: bool) -> None:
"""Initialise data access objects.
Ensure that:
@@ -244,7 +235,7 @@ def on_workflow_start(self, is_restart):
# ... however, in case there is a directory at the path for
# some bizarre reason:
rmtree(self.pri_path, ignore_errors=True)
- self.pri_dao = self._get_pri_dao()
+ self.pri_dao = self.get_pri_dao()
os.chmod(self.pri_path, PERM_PRIVATE)
self.pub_dao = CylcWorkflowDAO(self.pub_path, is_public=True)
self.copy_pri_to_pub()
@@ -694,7 +685,8 @@ def restart_check(self) -> None:
self.put_workflow_params_1(self.KEY_RESTART_COUNT, self.n_restart)
self.process_queued_ops()
- def _get_last_run_version(self, pri_dao: CylcWorkflowDAO) -> Version:
+ @classmethod
+ def _get_last_run_version(cls, pri_dao: CylcWorkflowDAO) -> Version:
"""Return the version of Cylc this DB was last run with.
Args:
@@ -707,17 +699,18 @@ def _get_last_run_version(self, pri_dao: CylcWorkflowDAO) -> Version:
SELECT
value
FROM
- {self.TABLE_WORKFLOW_PARAMS}
+ {cls.TABLE_WORKFLOW_PARAMS}
WHERE
key == ?
''', # nosec (table name is a code constant)
- [self.KEY_CYLC_VERSION]
+ [cls.KEY_CYLC_VERSION]
).fetchone()[0]
- except TypeError:
+ except (TypeError, OperationalError):
raise ServiceFileError(f"{INCOMPAT_MSG}, or is corrupted.")
return parse_version(last_run_ver)
- def upgrade_pre_803(self, pri_dao: CylcWorkflowDAO) -> None:
+ @classmethod
+ def upgrade_pre_803(cls, pri_dao: CylcWorkflowDAO) -> None:
"""Upgrade on restart from a pre-8.0.3 database.
Add "is_manual_submit" column to the task states table.
@@ -727,10 +720,10 @@ def upgrade_pre_803(self, pri_dao: CylcWorkflowDAO) -> None:
c_name = "is_manual_submit"
LOG.info(
f"DB upgrade (pre-8.0.3): "
- f"add {c_name} column to {self.TABLE_TASK_STATES}"
+ f"add {c_name} column to {cls.TABLE_TASK_STATES}"
)
conn.execute(
- rf"ALTER TABLE {self.TABLE_TASK_STATES} "
+ rf"ALTER TABLE {cls.TABLE_TASK_STATES} "
rf"ADD COLUMN {c_name} INTEGER "
r"DEFAULT 0 NOT NULL"
)
@@ -771,17 +764,19 @@ def upgrade_pre_810(pri_dao: CylcWorkflowDAO) -> None:
)
conn.commit()
- def upgrade(self):
+ @classmethod
+ def upgrade(cls, db_file: Union['Path', str]) -> None:
"""Upgrade this database to this Cylc version.
"""
- with self.get_pri_dao() as pri_dao:
- last_run_ver = self._get_last_run_version(pri_dao)
+ with CylcWorkflowDAO(db_file, create_tables=True) as pri_dao:
+ last_run_ver = cls._get_last_run_version(pri_dao)
if last_run_ver < parse_version("8.0.3.dev"):
- self.upgrade_pre_803(pri_dao)
+ cls.upgrade_pre_803(pri_dao)
if last_run_ver < parse_version("8.1.0.dev"):
- self.upgrade_pre_810(pri_dao)
+ cls.upgrade_pre_810(pri_dao)
- def check_workflow_db_compatibility(self) -> Version:
+ @classmethod
+ def check_db_compatibility(cls, db_file: Union['Path', str]) -> Version:
"""Check this DB is compatible with this Cylc version.
Raises:
@@ -790,11 +785,11 @@ def check_workflow_db_compatibility(self) -> Version:
current version of Cylc.
"""
- if not os.path.isfile(self.pri_path):
- raise FileNotFoundError(self.pri_path)
+ if not os.path.isfile(db_file):
+ raise FileNotFoundError(db_file)
- with self.get_pri_dao() as pri_dao:
- last_run_ver = self._get_last_run_version(pri_dao)
+ with CylcWorkflowDAO(db_file) as dao:
+ last_run_ver = cls._get_last_run_version(dao)
# WARNING: Do no upgrade the DB here
restart_incompat_ver = parse_version(
@@ -802,7 +797,6 @@ def check_workflow_db_compatibility(self) -> Version:
)
if last_run_ver <= restart_incompat_ver:
raise ServiceFileError(
- f"{INCOMPAT_MSG} (workflow last run with "
- f"Cylc {last_run_ver})."
+ f"{INCOMPAT_MSG} (workflow last run with Cylc {last_run_ver})."
)
return last_run_ver
diff --git a/cylc/flow/workflow_files.py b/cylc/flow/workflow_files.py
index f32ddc0afaa..e51036da327 100644
--- a/cylc/flow/workflow_files.py
+++ b/cylc/flow/workflow_files.py
@@ -82,10 +82,10 @@
construct_cylc_server_ssh_cmd,
construct_ssh_cmd,
)
+from cylc.flow.rundb import CylcWorkflowDAO
from cylc.flow.terminal import parse_dirty_json
from cylc.flow.unicode_rules import WorkflowNameValidator
from cylc.flow.util import cli_format
-from cylc.flow.workflow_db_mgr import WorkflowDatabaseManager
if TYPE_CHECKING:
from optparse import Values
@@ -203,6 +203,9 @@ class LogDir:
VERSION = 'version'
"""Version control log dir"""
+ DB = 'db'
+ """The public database"""
+
SHARE_DIR = 'share'
"""Workflow share directory."""
@@ -1200,7 +1203,7 @@ def get_workflow_title(reg):
return title
-def get_platforms_from_db(run_dir):
+def get_platforms_from_db(run_dir: Path) -> Set[str]:
"""Return the set of names of platforms (that jobs ran on) from the DB.
Warning:
@@ -1213,16 +1216,16 @@ def get_platforms_from_db(run_dir):
compatibility. We can't apply upgraders which don't exist yet.
Args:
- run_dir (str): The workflow run directory.
+ run_dir: The workflow run directory.
Raises:
sqlite3.OperationalError: in the event the table/field required for
cleaning is not present.
"""
- workflow_db_mgr = WorkflowDatabaseManager(
- os.path.join(run_dir, WorkflowFiles.Service.DIRNAME))
- with workflow_db_mgr.get_pri_dao() as pri_dao:
+ with CylcWorkflowDAO(
+ run_dir / WorkflowFiles.Service.DIRNAME / WorkflowFiles.Service.DB
+ ) as pri_dao:
platform_names = pri_dao.select_task_job_platforms()
return platform_names
diff --git a/tests/unit/test_scheduler_cli.py b/tests/unit/test_scheduler_cli.py
index 6e5b70b3478..34a32e8c195 100644
--- a/tests/unit/test_scheduler_cli.py
+++ b/tests/unit/test_scheduler_cli.py
@@ -26,19 +26,14 @@
@pytest.fixture
-def stopped_workflow_db(tmp_path, monkeypatch):
- """Returns a workflow DB with the `cylc_version` set to the provided string.
+def stopped_workflow_db(tmp_path):
+ """Returns a workflow DB with the `cylc_version` set to the provided
+ string.
def test_x(stopped_workflow_db):
db_file = stopped_workflow_db(version)
"""
- # disable workflow DB upgraders
- monkeypatch.setattr(
- 'cylc.flow.workflow_files.WorkflowDatabaseManager.upgrade',
- lambda x, y, z: None,
- )
-
def _stopped_workflow_db(version):
nonlocal tmp_path
db_file = tmp_path / 'db'
diff --git a/tests/unit/test_templatevars.py b/tests/unit/test_templatevars.py
index 96821098499..fa5283b43c6 100644
--- a/tests/unit/test_templatevars.py
+++ b/tests/unit/test_templatevars.py
@@ -14,20 +14,17 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see .
+from pathlib import Path
import pytest
-import sqlite3
import tempfile
import unittest
-from types import SimpleNamespace
-
-
-from cylc.flow.exceptions import PluginError
+from cylc.flow.rundb import CylcWorkflowDAO
from cylc.flow.templatevars import (
get_template_vars_from_db,
- get_template_vars,
load_template_vars
)
+from cylc.flow.workflow_files import WorkflowFiles
class TestTemplatevars(unittest.TestCase):
@@ -112,31 +109,22 @@ def test_load_template_vars_from_string_and_file_2(self):
@pytest.fixture(scope='module')
def _setup_db(tmp_path_factory):
- tmp_path = tmp_path_factory.mktemp('test_get_old_tvars')
- logfolder = tmp_path / "log/"
+ tmp_path: Path = tmp_path_factory.mktemp('test_get_old_tvars')
+ logfolder = tmp_path / WorkflowFiles.LogDir.DIRNAME
logfolder.mkdir()
- db_path = logfolder / 'db'
- conn = sqlite3.connect(db_path)
- conn.execute(
- r'''
- CREATE TABLE workflow_template_vars (
- key,
- value
- )
- '''
- )
- conn.execute(
- r'''
- INSERT INTO workflow_template_vars
- VALUES
- ("FOO", "42"),
- ("BAR", "'hello world'"),
- ("BAZ", "'foo', 'bar', 48"),
- ("QUX", "['foo', 'bar', 21]")
- '''
- )
- conn.commit()
- conn.close()
+ db_path = logfolder / WorkflowFiles.LogDir.DB
+ with CylcWorkflowDAO(db_path, create_tables=True) as dao:
+ dao.connect().execute(
+ r'''
+ INSERT INTO workflow_template_vars
+ VALUES
+ ("FOO", "42"),
+ ("BAR", "'hello world'"),
+ ("BAZ", "'foo', 'bar', 48"),
+ ("QUX", "['foo', 'bar', 21]")
+ '''
+ )
+ dao.connect().commit()
yield get_template_vars_from_db(tmp_path)
diff --git a/tests/unit/test_workflow_db_mgr.py b/tests/unit/test_workflow_db_mgr.py
index ffd2d978e52..faae984c640 100644
--- a/tests/unit/test_workflow_db_mgr.py
+++ b/tests/unit/test_workflow_db_mgr.py
@@ -30,8 +30,10 @@
@pytest.fixture
def _setup_db(tmp_path):
+ """Fixture to create old DB."""
def _inner(values):
db_file = tmp_path / 'sql.db'
+ # Note: cannot use CylcWorkflowDAO here as creating outdated DB
conn = sqlite3.connect(str(db_file))
conn.execute((
r'CREATE TABLE task_states(name TEXT, cycle TEXT, flow_nums TEXT,'
@@ -57,6 +59,7 @@ def _inner(values):
r" 4377)"
))
conn.commit()
+ conn.close()
return db_file
return _inner
@@ -69,12 +72,11 @@ def test_upgrade_pre_810_fails_on_multiple_flows(_setup_db):
r" '2022-12-05T14:46:40Z', 1, 'succeeded', 0, 0)"
)
db_file_name = _setup_db(values)
- pri_dao = CylcWorkflowDAO(db_file_name)
- with pytest.raises(
+ with CylcWorkflowDAO(db_file_name) as dao, pytest.raises(
CylcError,
match='^Cannot .* 8.0.x to 8.1.0 .* used.$'
):
- WorkflowDatabaseManager.upgrade_pre_810(pri_dao)
+ WorkflowDatabaseManager.upgrade_pre_810(dao)
def test_upgrade_pre_810_pass_on_single_flow(_setup_db):
@@ -85,9 +87,9 @@ def test_upgrade_pre_810_pass_on_single_flow(_setup_db):
r" '2022-12-05T14:46:40Z', 1, 'succeeded', 0, 0)"
)
db_file_name = _setup_db(values)
- pri_dao = CylcWorkflowDAO(db_file_name)
- WorkflowDatabaseManager.upgrade_pre_810(pri_dao)
- conn = sqlite3.connect(db_file_name)
- result = conn.execute(
- 'SELECT DISTINCT flow_nums FROM task_jobs;').fetchall()[0][0]
+ with CylcWorkflowDAO(db_file_name) as dao:
+ WorkflowDatabaseManager.upgrade_pre_810(dao)
+ result = dao.connect().execute(
+ 'SELECT DISTINCT flow_nums FROM task_jobs;'
+ ).fetchall()[0][0]
assert result == '[1]'
From a6031214d90cfe552626c4347f58ecb3f30c2e99 Mon Sep 17 00:00:00 2001
From: Ronnie Dutta <61982285+MetRonnie@users.noreply.github.com>
Date: Thu, 26 Jan 2023 12:14:00 +0000
Subject: [PATCH 3/4] Add unit tests
---
tests/unit/test_rundb.py | 62 ++++++++++++++++++++++++++++------------
1 file changed, 44 insertions(+), 18 deletions(-)
diff --git a/tests/unit/test_rundb.py b/tests/unit/test_rundb.py
index 88de617849d..c17e9955513 100644
--- a/tests/unit/test_rundb.py
+++ b/tests/unit/test_rundb.py
@@ -16,6 +16,7 @@
import contextlib
import os
+from pathlib import Path
import sqlite3
import unittest
from unittest import mock
@@ -119,11 +120,9 @@ def test_remove_columns():
conn.commit()
conn.close()
- dao = CylcWorkflowDAO(temp_db)
- dao.remove_columns('foo', ['bar', 'baz'])
-
- conn = dao.connect()
- data = list(conn.execute(r'SELECT * from foo'))
+ with CylcWorkflowDAO(temp_db) as dao:
+ dao.remove_columns('foo', ['bar', 'baz'])
+ data = list(dao.connect().execute(r'SELECT * from foo'))
assert data == [('PUB',)]
@@ -131,22 +130,21 @@ def test_operational_error(monkeypatch, tmp_path, caplog):
"""Test logging on operational error."""
# create a db object
db_file = tmp_path / 'db'
- dao = CylcWorkflowDAO(db_file)
-
- # stage some stuff
- dao.add_delete_item(CylcWorkflowDAO.TABLE_TASK_JOBS)
- dao.add_insert_item(CylcWorkflowDAO.TABLE_TASK_JOBS, ['pub'])
- dao.add_update_item(CylcWorkflowDAO.TABLE_TASK_JOBS, ['pub'])
+ with CylcWorkflowDAO(db_file) as dao:
+ # stage some stuff
+ dao.add_delete_item(CylcWorkflowDAO.TABLE_TASK_JOBS)
+ dao.add_insert_item(CylcWorkflowDAO.TABLE_TASK_JOBS, ['pub'])
+ dao.add_update_item(CylcWorkflowDAO.TABLE_TASK_JOBS, ['pub'])
- # connect the to DB
- dao.connect()
+ # connect the to DB
+ dao.connect()
- # then delete the file - this will result in an OperationalError
- db_file.unlink()
+ # then delete the file - this will result in an OperationalError
+ db_file.unlink()
- # execute & commit the staged items
- with pytest.raises(sqlite3.OperationalError):
- dao.execute_queued_items()
+ # execute & commit the staged items
+ with pytest.raises(sqlite3.OperationalError):
+ dao.execute_queued_items()
# ensure that the failed transaction is logged for debug purposes
assert len(caplog.messages) == 1
@@ -155,3 +153,31 @@ def test_operational_error(monkeypatch, tmp_path, caplog):
assert 'DELETE FROM task_jobs' in message
assert 'INSERT OR REPLACE INTO task_jobs' in message
assert 'UPDATE task_jobs' in message
+
+
+def test_table_creation(tmp_path: Path):
+ """Test tables are NOT created by default."""
+ db_file = tmp_path / 'db'
+ stmt = "SELECT name FROM sqlite_master WHERE type='table'"
+ with CylcWorkflowDAO(db_file) as dao:
+ tables = list(dao.connect().execute(stmt))
+ assert tables == []
+ with CylcWorkflowDAO(db_file, create_tables=True) as dao:
+ tables = [i[0] for i in dao.connect().execute(stmt)]
+ assert CylcWorkflowDAO.TABLE_WORKFLOW_PARAMS in tables
+
+
+def test_context_manager_exit(
+ monkeypatch: pytest.MonkeyPatch, tmp_path: Path
+):
+ """Test connection is closed even if an exception occurs somewhere."""
+ db_file = tmp_path / 'db'
+ mock_close = mock.Mock()
+ with monkeypatch.context() as mp:
+ mp.setattr(CylcWorkflowDAO, 'close', mock_close)
+ with CylcWorkflowDAO(db_file) as dao, pytest.raises(RuntimeError):
+ mock_close.assert_not_called()
+ raise RuntimeError('mock err')
+ mock_close.assert_called_once()
+ # Close connection for real:
+ dao.close()
From 445ae358d13e3bbd4d6b75c5f5b4947d3caa6066 Mon Sep 17 00:00:00 2001
From: Hilary James Oliver
Date: Fri, 27 Jan 2023 13:00:53 +1300
Subject: [PATCH 4/4] Update cylc/flow/rundb.py docstring [skip ci]
Co-authored-by: Tim Pillinger <26465611+wxtim@users.noreply.github.com>
---
cylc/flow/rundb.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/cylc/flow/rundb.py b/cylc/flow/rundb.py
index ceb4c2c3626..0ef237b41bf 100644
--- a/cylc/flow/rundb.py
+++ b/cylc/flow/rundb.py
@@ -334,7 +334,7 @@ def __init__(
Args:
db_file_name: Path to the database file.
- is_public: If True, allow retries, etc.
+ is_public: If True, allow retries.
create_tables: If True, create the tables if they
don't already exist.