From 6617d1af051217015dded28bbb6173463f479445 Mon Sep 17 00:00:00 2001 From: Ronnie Dutta <61982285+MetRonnie@users.noreply.github.com> Date: Tue, 24 Jan 2023 17:45:49 +0000 Subject: [PATCH 1/4] Tidy --- cylc/flow/workflow_db_mgr.py | 58 +++++++++++++++++------------------- 1 file changed, 28 insertions(+), 30 deletions(-) diff --git a/cylc/flow/workflow_db_mgr.py b/cylc/flow/workflow_db_mgr.py index ec8cb03d11d..1d11a258326 100644 --- a/cylc/flow/workflow_db_mgr.py +++ b/cylc/flow/workflow_db_mgr.py @@ -46,8 +46,9 @@ from cylc.flow.scheduler import Scheduler from cylc.flow.task_pool import TaskPool -# # TODO: narrow down Any (should be str | int) after implementing type -# # annotations in cylc.flow.task_state.TaskState +Version = Any +# TODO: narrow down Any (should be str | int) after implementing type +# annotations in cylc.flow.task_state.TaskState DbArgDict = Dict[str, Any] DbUpdateTuple = Tuple[DbArgDict, DbArgDict] @@ -693,18 +694,28 @@ def restart_check(self) -> None: self.put_workflow_params_1(self.KEY_RESTART_COUNT, self.n_restart) self.process_queued_ops() - def _get_last_run_version(self, pri_dao: CylcWorkflowDAO) -> str: - return pri_dao.connect().execute( - rf''' - SELECT - value - FROM - {self.TABLE_WORKFLOW_PARAMS} - WHERE - key == ? - ''', # nosec (table name is a code constant) - [self.KEY_CYLC_VERSION] - ).fetchone()[0] + def _get_last_run_version(self, pri_dao: CylcWorkflowDAO) -> Version: + """Return the version of Cylc this DB was last run with. + + Args: + pri_dao: Open private database connection object. + + """ + try: + last_run_ver = pri_dao.connect().execute( + rf''' + SELECT + value + FROM + {self.TABLE_WORKFLOW_PARAMS} + WHERE + key == ? + ''', # nosec (table name is a code constant) + [self.KEY_CYLC_VERSION] + ).fetchone()[0] + except TypeError: + raise ServiceFileError(f"{INCOMPAT_MSG}, or is corrupted.") + return parse_version(last_run_ver) def upgrade_pre_803(self, pri_dao: CylcWorkflowDAO) -> None: """Upgrade on restart from a pre-8.0.3 database. @@ -760,30 +771,17 @@ def upgrade_pre_810(pri_dao: CylcWorkflowDAO) -> None: ) conn.commit() - def _get_last_run_ver(self, pri_dao): - """Return the version of Cylc this DB was last run with. - - Args: - pri_dao: Open private database connection object. - - """ - try: - last_run_ver = self._get_last_run_version(pri_dao) - except TypeError: - raise ServiceFileError(f"{INCOMPAT_MSG}, or is corrupted.") - return parse_version(last_run_ver) - def upgrade(self): """Upgrade this database to this Cylc version. """ with self.get_pri_dao() as pri_dao: - last_run_ver = self._get_last_run_ver(pri_dao) + last_run_ver = self._get_last_run_version(pri_dao) if last_run_ver < parse_version("8.0.3.dev"): self.upgrade_pre_803(pri_dao) if last_run_ver < parse_version("8.1.0.dev"): self.upgrade_pre_810(pri_dao) - def check_workflow_db_compatibility(self): + def check_workflow_db_compatibility(self) -> Version: """Check this DB is compatible with this Cylc version. Raises: @@ -796,7 +794,7 @@ def check_workflow_db_compatibility(self): raise FileNotFoundError(self.pri_path) with self.get_pri_dao() as pri_dao: - last_run_ver = self._get_last_run_ver(pri_dao) + last_run_ver = self._get_last_run_version(pri_dao) # WARNING: Do no upgrade the DB here restart_incompat_ver = parse_version( From a125310f86010387b8cad2503863c813960cef4f Mon Sep 17 00:00:00 2001 From: Ronnie Dutta <61982285+MetRonnie@users.noreply.github.com> Date: Wed, 25 Jan 2023 18:22:23 +0000 Subject: [PATCH 2/4] Improve CylcWorkflowDAO interface - Do not create tables by default, only if specified via kwarg - Provide context manager functionality to automatically close DB connection - A WorkflowDatabaseManager should only be created by Scheduler; just use CylcWorkflowDAO elsewhere --- cylc/flow/network/scan.py | 10 ++--- cylc/flow/rundb.py | 39 +++++++++++++---- cylc/flow/scheduler.py | 5 ++- cylc/flow/scheduler_cli.py | 8 ++-- cylc/flow/scripts/cat_log.py | 8 ++-- cylc/flow/scripts/report_timings.py | 4 +- cylc/flow/templatevars.py | 20 ++++++--- cylc/flow/workflow_db_mgr.py | 68 +++++++++++++---------------- cylc/flow/workflow_files.py | 15 ++++--- tests/unit/test_scheduler_cli.py | 11 ++--- tests/unit/test_templatevars.py | 48 ++++++++------------ tests/unit/test_workflow_db_mgr.py | 18 ++++---- 12 files changed, 130 insertions(+), 124 deletions(-) diff --git a/cylc/flow/network/scan.py b/cylc/flow/network/scan.py index 9b7fdc45a1d..4dfe5510848 100644 --- a/cylc/flow/network/scan.py +++ b/cylc/flow/network/scan.py @@ -545,14 +545,12 @@ def _callback(_, entry): # NOTE: use the public DB for reading # (only the scheduler process/thread should access the private database) - db_file = Path(get_workflow_run_dir(flow['name'], 'log', 'db')) + db_file = Path(get_workflow_run_dir( + flow['name'], WorkflowFiles.LogDir.DIRNAME, WorkflowFiles.LogDir.DB + )) if db_file.exists(): - dao = CylcWorkflowDAO(db_file, is_public=False) - try: - dao.connect() + with CylcWorkflowDAO(db_file, is_public=True) as dao: dao.select_workflow_params(_callback) flow['workflow_params'] = params - finally: - dao.close() return flow diff --git a/cylc/flow/rundb.py b/cylc/flow/rundb.py index 8dd60d6422b..ceb4c2c3626 100644 --- a/cylc/flow/rundb.py +++ b/cylc/flow/rundb.py @@ -21,12 +21,15 @@ from pprint import pformat import sqlite3 import traceback -from typing import List, Tuple +from typing import TYPE_CHECKING, List, Optional, Tuple, Union from cylc.flow import LOG from cylc.flow.util import deserialise import cylc.flow.flags +if TYPE_CHECKING: + from pathlib import Path + @dataclass class CylcWorkflowDAOTableColumn: @@ -318,26 +321,44 @@ class CylcWorkflowDAO: ], } - def __init__(self, db_file_name, is_public=False): + def __init__( + self, + db_file_name: Union['Path', str], + is_public: bool = False, + create_tables: bool = False + ): """Initialise database access object. + An instance of this class can also be opened as a context manager + which will automatically close the DB connection. + Args: - db_file_name (str): Path to the database file. - is_public (bool): If True, allow retries, etc. + db_file_name: Path to the database file. + is_public: If True, allow retries, etc. + create_tables: If True, create the tables if they + don't already exist. """ self.db_file_name = expandvars(db_file_name) self.is_public = is_public - self.conn = None + self.conn: Optional[sqlite3.Connection] = None self.n_tries = 0 - self.tables = {} - for name, attrs in sorted(self.TABLES_ATTRS.items()): - self.tables[name] = CylcWorkflowDAOTable(name, attrs) + self.tables = { + name: CylcWorkflowDAOTable(name, attrs) + for name, attrs in sorted(self.TABLES_ATTRS.items()) + } - if not self.is_public: + if create_tables: self.create_tables() + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + """Close DB connection when leaving context manager.""" + self.close() + def add_delete_item(self, table_name, where_args=None): """Queue a DELETE item for a given table. diff --git a/cylc/flow/scheduler.py b/cylc/flow/scheduler.py index f2e5a6bb05c..9acac8bd283 100644 --- a/cylc/flow/scheduler.py +++ b/cylc/flow/scheduler.py @@ -1035,8 +1035,9 @@ def command_reload_workflow(self) -> None: LOG.info("Reloading the workflow definition.") old_tasks = set(self.config.get_task_name_list()) # Things that can't change on workflow reload: - pri_dao = self.workflow_db_mgr._get_pri_dao() - pri_dao.select_workflow_params(self._load_workflow_params) + self.workflow_db_mgr.pri_dao.select_workflow_params( + self._load_workflow_params + ) try: self.load_flow_file(is_reload=True) diff --git a/cylc/flow/scheduler_cli.py b/cylc/flow/scheduler_cli.py index ec5ac2071e3..7b2afc690f2 100644 --- a/cylc/flow/scheduler_cli.py +++ b/cylc/flow/scheduler_cli.py @@ -456,9 +456,8 @@ def _version_check( if not db_file.is_file(): # not a restart return True - wdbm = WorkflowDatabaseManager(db_file.parent) this_version = parse_version(__version__) - last_run_version = wdbm.check_workflow_db_compatibility() + last_run_version = WorkflowDatabaseManager.check_db_compatibility(db_file) for itt, (this, that) in enumerate(zip_longest( this_version.release, @@ -524,7 +523,7 @@ def _version_check( return True -def _upgrade_database(db_file): +def _upgrade_database(db_file: Path) -> None: """Upgrade the workflow database if needed. Note: @@ -532,8 +531,7 @@ def _upgrade_database(db_file): """ if db_file.is_file(): - wdbm = WorkflowDatabaseManager(db_file.parent) - wdbm.upgrade() + WorkflowDatabaseManager.upgrade(db_file) def _print_startup_message(options): diff --git a/cylc/flow/scripts/cat_log.py b/cylc/flow/scripts/cat_log.py index 09ca8d915ba..24fdc6fbb38 100755 --- a/cylc/flow/scripts/cat_log.py +++ b/cylc/flow/scripts/cat_log.py @@ -265,10 +265,10 @@ def get_task_job_attrs(workflow_id, point, task, submit_num): live_job_id is the job ID if job is running, else None. """ - workflow_dao = CylcWorkflowDAO( - get_workflow_run_pub_db_path(workflow_id), is_public=True) - task_job_data = workflow_dao.select_task_job(point, task, submit_num) - workflow_dao.close() + with CylcWorkflowDAO( + get_workflow_run_pub_db_path(workflow_id), is_public=True + ) as dao: + task_job_data = dao.select_task_job(point, task, submit_num) if task_job_data is None: return (None, None, None) job_runner_name = task_job_data["job_runner_name"] diff --git a/cylc/flow/scripts/report_timings.py b/cylc/flow/scripts/report_timings.py index f2e9dcf8cd4..37026d1f48c 100755 --- a/cylc/flow/scripts/report_timings.py +++ b/cylc/flow/scripts/report_timings.py @@ -132,8 +132,8 @@ def main(parser: COP, options: 'Values', workflow_id: str) -> None: # No output specified - choose summary by default options.show_summary = True - run_db = _get_dao(workflow_id) - row_buf = format_rows(*run_db.select_task_times()) + with _get_dao(workflow_id) as dao: + row_buf = format_rows(*dao.select_task_times()) with smart_open(options.output_filename) as output: if options.show_raw: output.write(row_buf.getvalue()) diff --git a/cylc/flow/templatevars.py b/cylc/flow/templatevars.py index 97469285d03..5288185f3c3 100644 --- a/cylc/flow/templatevars.py +++ b/cylc/flow/templatevars.py @@ -17,20 +17,26 @@ from ast import literal_eval from optparse import Values -from typing import Any, Dict +from typing import TYPE_CHECKING, Any, Dict from cylc.flow.exceptions import InputError - - from cylc.flow.rundb import CylcWorkflowDAO +from cylc.flow.workflow_files import WorkflowFiles + +if TYPE_CHECKING: + from pathlib import Path -def get_template_vars_from_db(run_dir): +def get_template_vars_from_db(run_dir: 'Path') -> dict: """Get template vars stored in a workflow run database. """ - template_vars = {} - if (run_dir / 'log/db').exists(): - dao = CylcWorkflowDAO(str(run_dir / 'log/db')) + pub_db_file = ( + run_dir / WorkflowFiles.LogDir.DIRNAME / WorkflowFiles.LogDir.DB + ) + template_vars: dict = {} + if not pub_db_file.exists(): + return template_vars + with CylcWorkflowDAO(pub_db_file, is_public=True) as dao: dao.select_workflow_template_vars( lambda _, row: template_vars.__setitem__(row[0], eval_var(row[1])) ) diff --git a/cylc/flow/workflow_db_mgr.py b/cylc/flow/workflow_db_mgr.py index 1d11a258326..ad46214e561 100644 --- a/cylc/flow/workflow_db_mgr.py +++ b/cylc/flow/workflow_db_mgr.py @@ -23,14 +23,14 @@ * Manage existing run database files on restart. """ -from contextlib import contextmanager import json import os from pkg_resources import parse_version from shutil import copy, rmtree +from sqlite3 import OperationalError from tempfile import mkstemp from typing import ( - Any, AnyStr, Dict, Generator, List, Set, TYPE_CHECKING, Tuple, Union + Any, AnyStr, Dict, List, Set, TYPE_CHECKING, Tuple, Union ) from cylc.flow import LOG @@ -42,6 +42,7 @@ from cylc.flow.util import serialise if TYPE_CHECKING: + from pathlib import Path from cylc.flow.cycling import PointBase from cylc.flow.scheduler import Scheduler from cylc.flow.task_pool import TaskPool @@ -195,23 +196,13 @@ def delete_workflow_stop_task(self): """Delete workflow stop task from workflow_params table.""" self.delete_workflow_params(self.KEY_STOP_TASK) - def _get_pri_dao(self) -> CylcWorkflowDAO: + def get_pri_dao(self) -> CylcWorkflowDAO: """Return the primary DAO. - Note: the DAO should be closed after use. It is better to use the - context manager method below, which handles this for you. + NOTE: the DAO should be closed after use. You can use this function as + a context manager, which handles this for you. """ - return CylcWorkflowDAO(self.pri_path) - - @contextmanager - def get_pri_dao(self) -> Generator[CylcWorkflowDAO, None, None]: - """Return the primary DAO and close it after the context manager - exits.""" - pri_dao = self._get_pri_dao() - try: - yield pri_dao - finally: - pri_dao.close() + return CylcWorkflowDAO(self.pri_path, create_tables=True) @staticmethod def _namedtuple2json(obj): @@ -228,7 +219,7 @@ def _namedtuple2json(obj): else: return json.dumps([type(obj).__name__, obj.__getnewargs__()]) - def on_workflow_start(self, is_restart): + def on_workflow_start(self, is_restart: bool) -> None: """Initialise data access objects. Ensure that: @@ -244,7 +235,7 @@ def on_workflow_start(self, is_restart): # ... however, in case there is a directory at the path for # some bizarre reason: rmtree(self.pri_path, ignore_errors=True) - self.pri_dao = self._get_pri_dao() + self.pri_dao = self.get_pri_dao() os.chmod(self.pri_path, PERM_PRIVATE) self.pub_dao = CylcWorkflowDAO(self.pub_path, is_public=True) self.copy_pri_to_pub() @@ -694,7 +685,8 @@ def restart_check(self) -> None: self.put_workflow_params_1(self.KEY_RESTART_COUNT, self.n_restart) self.process_queued_ops() - def _get_last_run_version(self, pri_dao: CylcWorkflowDAO) -> Version: + @classmethod + def _get_last_run_version(cls, pri_dao: CylcWorkflowDAO) -> Version: """Return the version of Cylc this DB was last run with. Args: @@ -707,17 +699,18 @@ def _get_last_run_version(self, pri_dao: CylcWorkflowDAO) -> Version: SELECT value FROM - {self.TABLE_WORKFLOW_PARAMS} + {cls.TABLE_WORKFLOW_PARAMS} WHERE key == ? ''', # nosec (table name is a code constant) - [self.KEY_CYLC_VERSION] + [cls.KEY_CYLC_VERSION] ).fetchone()[0] - except TypeError: + except (TypeError, OperationalError): raise ServiceFileError(f"{INCOMPAT_MSG}, or is corrupted.") return parse_version(last_run_ver) - def upgrade_pre_803(self, pri_dao: CylcWorkflowDAO) -> None: + @classmethod + def upgrade_pre_803(cls, pri_dao: CylcWorkflowDAO) -> None: """Upgrade on restart from a pre-8.0.3 database. Add "is_manual_submit" column to the task states table. @@ -727,10 +720,10 @@ def upgrade_pre_803(self, pri_dao: CylcWorkflowDAO) -> None: c_name = "is_manual_submit" LOG.info( f"DB upgrade (pre-8.0.3): " - f"add {c_name} column to {self.TABLE_TASK_STATES}" + f"add {c_name} column to {cls.TABLE_TASK_STATES}" ) conn.execute( - rf"ALTER TABLE {self.TABLE_TASK_STATES} " + rf"ALTER TABLE {cls.TABLE_TASK_STATES} " rf"ADD COLUMN {c_name} INTEGER " r"DEFAULT 0 NOT NULL" ) @@ -771,17 +764,19 @@ def upgrade_pre_810(pri_dao: CylcWorkflowDAO) -> None: ) conn.commit() - def upgrade(self): + @classmethod + def upgrade(cls, db_file: Union['Path', str]) -> None: """Upgrade this database to this Cylc version. """ - with self.get_pri_dao() as pri_dao: - last_run_ver = self._get_last_run_version(pri_dao) + with CylcWorkflowDAO(db_file, create_tables=True) as pri_dao: + last_run_ver = cls._get_last_run_version(pri_dao) if last_run_ver < parse_version("8.0.3.dev"): - self.upgrade_pre_803(pri_dao) + cls.upgrade_pre_803(pri_dao) if last_run_ver < parse_version("8.1.0.dev"): - self.upgrade_pre_810(pri_dao) + cls.upgrade_pre_810(pri_dao) - def check_workflow_db_compatibility(self) -> Version: + @classmethod + def check_db_compatibility(cls, db_file: Union['Path', str]) -> Version: """Check this DB is compatible with this Cylc version. Raises: @@ -790,11 +785,11 @@ def check_workflow_db_compatibility(self) -> Version: current version of Cylc. """ - if not os.path.isfile(self.pri_path): - raise FileNotFoundError(self.pri_path) + if not os.path.isfile(db_file): + raise FileNotFoundError(db_file) - with self.get_pri_dao() as pri_dao: - last_run_ver = self._get_last_run_version(pri_dao) + with CylcWorkflowDAO(db_file) as dao: + last_run_ver = cls._get_last_run_version(dao) # WARNING: Do no upgrade the DB here restart_incompat_ver = parse_version( @@ -802,7 +797,6 @@ def check_workflow_db_compatibility(self) -> Version: ) if last_run_ver <= restart_incompat_ver: raise ServiceFileError( - f"{INCOMPAT_MSG} (workflow last run with " - f"Cylc {last_run_ver})." + f"{INCOMPAT_MSG} (workflow last run with Cylc {last_run_ver})." ) return last_run_ver diff --git a/cylc/flow/workflow_files.py b/cylc/flow/workflow_files.py index f32ddc0afaa..e51036da327 100644 --- a/cylc/flow/workflow_files.py +++ b/cylc/flow/workflow_files.py @@ -82,10 +82,10 @@ construct_cylc_server_ssh_cmd, construct_ssh_cmd, ) +from cylc.flow.rundb import CylcWorkflowDAO from cylc.flow.terminal import parse_dirty_json from cylc.flow.unicode_rules import WorkflowNameValidator from cylc.flow.util import cli_format -from cylc.flow.workflow_db_mgr import WorkflowDatabaseManager if TYPE_CHECKING: from optparse import Values @@ -203,6 +203,9 @@ class LogDir: VERSION = 'version' """Version control log dir""" + DB = 'db' + """The public database""" + SHARE_DIR = 'share' """Workflow share directory.""" @@ -1200,7 +1203,7 @@ def get_workflow_title(reg): return title -def get_platforms_from_db(run_dir): +def get_platforms_from_db(run_dir: Path) -> Set[str]: """Return the set of names of platforms (that jobs ran on) from the DB. Warning: @@ -1213,16 +1216,16 @@ def get_platforms_from_db(run_dir): compatibility. We can't apply upgraders which don't exist yet. Args: - run_dir (str): The workflow run directory. + run_dir: The workflow run directory. Raises: sqlite3.OperationalError: in the event the table/field required for cleaning is not present. """ - workflow_db_mgr = WorkflowDatabaseManager( - os.path.join(run_dir, WorkflowFiles.Service.DIRNAME)) - with workflow_db_mgr.get_pri_dao() as pri_dao: + with CylcWorkflowDAO( + run_dir / WorkflowFiles.Service.DIRNAME / WorkflowFiles.Service.DB + ) as pri_dao: platform_names = pri_dao.select_task_job_platforms() return platform_names diff --git a/tests/unit/test_scheduler_cli.py b/tests/unit/test_scheduler_cli.py index 6e5b70b3478..34a32e8c195 100644 --- a/tests/unit/test_scheduler_cli.py +++ b/tests/unit/test_scheduler_cli.py @@ -26,19 +26,14 @@ @pytest.fixture -def stopped_workflow_db(tmp_path, monkeypatch): - """Returns a workflow DB with the `cylc_version` set to the provided string. +def stopped_workflow_db(tmp_path): + """Returns a workflow DB with the `cylc_version` set to the provided + string. def test_x(stopped_workflow_db): db_file = stopped_workflow_db(version) """ - # disable workflow DB upgraders - monkeypatch.setattr( - 'cylc.flow.workflow_files.WorkflowDatabaseManager.upgrade', - lambda x, y, z: None, - ) - def _stopped_workflow_db(version): nonlocal tmp_path db_file = tmp_path / 'db' diff --git a/tests/unit/test_templatevars.py b/tests/unit/test_templatevars.py index 96821098499..fa5283b43c6 100644 --- a/tests/unit/test_templatevars.py +++ b/tests/unit/test_templatevars.py @@ -14,20 +14,17 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . +from pathlib import Path import pytest -import sqlite3 import tempfile import unittest -from types import SimpleNamespace - - -from cylc.flow.exceptions import PluginError +from cylc.flow.rundb import CylcWorkflowDAO from cylc.flow.templatevars import ( get_template_vars_from_db, - get_template_vars, load_template_vars ) +from cylc.flow.workflow_files import WorkflowFiles class TestTemplatevars(unittest.TestCase): @@ -112,31 +109,22 @@ def test_load_template_vars_from_string_and_file_2(self): @pytest.fixture(scope='module') def _setup_db(tmp_path_factory): - tmp_path = tmp_path_factory.mktemp('test_get_old_tvars') - logfolder = tmp_path / "log/" + tmp_path: Path = tmp_path_factory.mktemp('test_get_old_tvars') + logfolder = tmp_path / WorkflowFiles.LogDir.DIRNAME logfolder.mkdir() - db_path = logfolder / 'db' - conn = sqlite3.connect(db_path) - conn.execute( - r''' - CREATE TABLE workflow_template_vars ( - key, - value - ) - ''' - ) - conn.execute( - r''' - INSERT INTO workflow_template_vars - VALUES - ("FOO", "42"), - ("BAR", "'hello world'"), - ("BAZ", "'foo', 'bar', 48"), - ("QUX", "['foo', 'bar', 21]") - ''' - ) - conn.commit() - conn.close() + db_path = logfolder / WorkflowFiles.LogDir.DB + with CylcWorkflowDAO(db_path, create_tables=True) as dao: + dao.connect().execute( + r''' + INSERT INTO workflow_template_vars + VALUES + ("FOO", "42"), + ("BAR", "'hello world'"), + ("BAZ", "'foo', 'bar', 48"), + ("QUX", "['foo', 'bar', 21]") + ''' + ) + dao.connect().commit() yield get_template_vars_from_db(tmp_path) diff --git a/tests/unit/test_workflow_db_mgr.py b/tests/unit/test_workflow_db_mgr.py index ffd2d978e52..faae984c640 100644 --- a/tests/unit/test_workflow_db_mgr.py +++ b/tests/unit/test_workflow_db_mgr.py @@ -30,8 +30,10 @@ @pytest.fixture def _setup_db(tmp_path): + """Fixture to create old DB.""" def _inner(values): db_file = tmp_path / 'sql.db' + # Note: cannot use CylcWorkflowDAO here as creating outdated DB conn = sqlite3.connect(str(db_file)) conn.execute(( r'CREATE TABLE task_states(name TEXT, cycle TEXT, flow_nums TEXT,' @@ -57,6 +59,7 @@ def _inner(values): r" 4377)" )) conn.commit() + conn.close() return db_file return _inner @@ -69,12 +72,11 @@ def test_upgrade_pre_810_fails_on_multiple_flows(_setup_db): r" '2022-12-05T14:46:40Z', 1, 'succeeded', 0, 0)" ) db_file_name = _setup_db(values) - pri_dao = CylcWorkflowDAO(db_file_name) - with pytest.raises( + with CylcWorkflowDAO(db_file_name) as dao, pytest.raises( CylcError, match='^Cannot .* 8.0.x to 8.1.0 .* used.$' ): - WorkflowDatabaseManager.upgrade_pre_810(pri_dao) + WorkflowDatabaseManager.upgrade_pre_810(dao) def test_upgrade_pre_810_pass_on_single_flow(_setup_db): @@ -85,9 +87,9 @@ def test_upgrade_pre_810_pass_on_single_flow(_setup_db): r" '2022-12-05T14:46:40Z', 1, 'succeeded', 0, 0)" ) db_file_name = _setup_db(values) - pri_dao = CylcWorkflowDAO(db_file_name) - WorkflowDatabaseManager.upgrade_pre_810(pri_dao) - conn = sqlite3.connect(db_file_name) - result = conn.execute( - 'SELECT DISTINCT flow_nums FROM task_jobs;').fetchall()[0][0] + with CylcWorkflowDAO(db_file_name) as dao: + WorkflowDatabaseManager.upgrade_pre_810(dao) + result = dao.connect().execute( + 'SELECT DISTINCT flow_nums FROM task_jobs;' + ).fetchall()[0][0] assert result == '[1]' From a6031214d90cfe552626c4347f58ecb3f30c2e99 Mon Sep 17 00:00:00 2001 From: Ronnie Dutta <61982285+MetRonnie@users.noreply.github.com> Date: Thu, 26 Jan 2023 12:14:00 +0000 Subject: [PATCH 3/4] Add unit tests --- tests/unit/test_rundb.py | 62 ++++++++++++++++++++++++++++------------ 1 file changed, 44 insertions(+), 18 deletions(-) diff --git a/tests/unit/test_rundb.py b/tests/unit/test_rundb.py index 88de617849d..c17e9955513 100644 --- a/tests/unit/test_rundb.py +++ b/tests/unit/test_rundb.py @@ -16,6 +16,7 @@ import contextlib import os +from pathlib import Path import sqlite3 import unittest from unittest import mock @@ -119,11 +120,9 @@ def test_remove_columns(): conn.commit() conn.close() - dao = CylcWorkflowDAO(temp_db) - dao.remove_columns('foo', ['bar', 'baz']) - - conn = dao.connect() - data = list(conn.execute(r'SELECT * from foo')) + with CylcWorkflowDAO(temp_db) as dao: + dao.remove_columns('foo', ['bar', 'baz']) + data = list(dao.connect().execute(r'SELECT * from foo')) assert data == [('PUB',)] @@ -131,22 +130,21 @@ def test_operational_error(monkeypatch, tmp_path, caplog): """Test logging on operational error.""" # create a db object db_file = tmp_path / 'db' - dao = CylcWorkflowDAO(db_file) - - # stage some stuff - dao.add_delete_item(CylcWorkflowDAO.TABLE_TASK_JOBS) - dao.add_insert_item(CylcWorkflowDAO.TABLE_TASK_JOBS, ['pub']) - dao.add_update_item(CylcWorkflowDAO.TABLE_TASK_JOBS, ['pub']) + with CylcWorkflowDAO(db_file) as dao: + # stage some stuff + dao.add_delete_item(CylcWorkflowDAO.TABLE_TASK_JOBS) + dao.add_insert_item(CylcWorkflowDAO.TABLE_TASK_JOBS, ['pub']) + dao.add_update_item(CylcWorkflowDAO.TABLE_TASK_JOBS, ['pub']) - # connect the to DB - dao.connect() + # connect the to DB + dao.connect() - # then delete the file - this will result in an OperationalError - db_file.unlink() + # then delete the file - this will result in an OperationalError + db_file.unlink() - # execute & commit the staged items - with pytest.raises(sqlite3.OperationalError): - dao.execute_queued_items() + # execute & commit the staged items + with pytest.raises(sqlite3.OperationalError): + dao.execute_queued_items() # ensure that the failed transaction is logged for debug purposes assert len(caplog.messages) == 1 @@ -155,3 +153,31 @@ def test_operational_error(monkeypatch, tmp_path, caplog): assert 'DELETE FROM task_jobs' in message assert 'INSERT OR REPLACE INTO task_jobs' in message assert 'UPDATE task_jobs' in message + + +def test_table_creation(tmp_path: Path): + """Test tables are NOT created by default.""" + db_file = tmp_path / 'db' + stmt = "SELECT name FROM sqlite_master WHERE type='table'" + with CylcWorkflowDAO(db_file) as dao: + tables = list(dao.connect().execute(stmt)) + assert tables == [] + with CylcWorkflowDAO(db_file, create_tables=True) as dao: + tables = [i[0] for i in dao.connect().execute(stmt)] + assert CylcWorkflowDAO.TABLE_WORKFLOW_PARAMS in tables + + +def test_context_manager_exit( + monkeypatch: pytest.MonkeyPatch, tmp_path: Path +): + """Test connection is closed even if an exception occurs somewhere.""" + db_file = tmp_path / 'db' + mock_close = mock.Mock() + with monkeypatch.context() as mp: + mp.setattr(CylcWorkflowDAO, 'close', mock_close) + with CylcWorkflowDAO(db_file) as dao, pytest.raises(RuntimeError): + mock_close.assert_not_called() + raise RuntimeError('mock err') + mock_close.assert_called_once() + # Close connection for real: + dao.close() From 445ae358d13e3bbd4d6b75c5f5b4947d3caa6066 Mon Sep 17 00:00:00 2001 From: Hilary James Oliver Date: Fri, 27 Jan 2023 13:00:53 +1300 Subject: [PATCH 4/4] Update cylc/flow/rundb.py docstring [skip ci] Co-authored-by: Tim Pillinger <26465611+wxtim@users.noreply.github.com> --- cylc/flow/rundb.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cylc/flow/rundb.py b/cylc/flow/rundb.py index ceb4c2c3626..0ef237b41bf 100644 --- a/cylc/flow/rundb.py +++ b/cylc/flow/rundb.py @@ -334,7 +334,7 @@ def __init__( Args: db_file_name: Path to the database file. - is_public: If True, allow retries, etc. + is_public: If True, allow retries. create_tables: If True, create the tables if they don't already exist.