From 97024f75a337281d7b218487033b292ce63fbd61 Mon Sep 17 00:00:00 2001 From: David C Ellis Date: Mon, 2 Dec 2024 08:33:23 +0000 Subject: [PATCH 1/8] Remove unreachable if branch --- src/ducktools/env/__main__.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/ducktools/env/__main__.py b/src/ducktools/env/__main__.py index 397c942..4b9ef5f 100644 --- a/src/ducktools/env/__main__.py +++ b/src/ducktools/env/__main__.py @@ -434,9 +434,7 @@ def delete_env_command(manager, args): def main_command() -> int: executable_name = os.path.splitext(os.path.basename(sys.executable))[0] - if zipapp_path := globals().get("zipapp_path"): - command = f"{executable_name} {zipapp_path}" - elif __name__ == "__main__": + if __name__ == "__main__": command = f"{executable_name} -m ducktools.env" else: command = os.path.basename(sys.argv[0]) From a945ec2d96f1bf8f456ca943eda322569ea9a996 Mon Sep 17 00:00:00 2001 From: David C Ellis Date: Mon, 2 Dec 2024 17:36:08 +0000 Subject: [PATCH 2/8] Start writing direct sqlclasses tests --- tests/test_sql_classes.py | 203 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 203 insertions(+) create mode 100644 tests/test_sql_classes.py diff --git a/tests/test_sql_classes.py b/tests/test_sql_classes.py new file mode 100644 index 0000000..a12c21e --- /dev/null +++ b/tests/test_sql_classes.py @@ -0,0 +1,203 @@ +# ducktools.env +# MIT License +# +# Copyright (c) 2024 David C Ellis +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +from unittest import mock +import typing + +import pytest + +from ducktools.env._sqlclasses import ( + _laz, + TYPE_MAP, + MAPPED_TYPES, + SQLContext, + SQLAttribute, + SQLClass, + + get_sql_fields, + flatten_list, + separate_list, + caps_to_snake, +) + + +def test_type_map(): + # Check that the MAPPED_TYPES matches the union of types in TYPE_MAP + mapped_type_construct = typing.Union[*TYPE_MAP.keys()] + assert MAPPED_TYPES == mapped_type_construct + + +class TestListFlattenSeparate: + def test_flatten(self): + l = ['a', 'b', 'c'] + assert flatten_list(l) == "a;b;c" + + def test_separate(self): + l = "a;b;c" + assert separate_list(l) == ['a', 'b', 'c'] + + +def test_caps_to_snake(): + assert caps_to_snake("CapsNamedClass") == "caps_named_class" + + +def test_sql_context(): + with mock.patch.object(_laz.sql, "connect") as sql_connect: + connection_mock = mock.MagicMock() + sql_connect.return_value = connection_mock + + with SQLContext("FakeDB") as con: + assert con is connection_mock + + sql_connect.assert_called_once_with("FakeDB") + connection_mock.close.assert_called() + + +def test_sql_attribute(): + attrib = SQLAttribute(primary_key=True, unique=False, internal=False, computed=None) + assert attrib.primary_key is True + assert attrib.unique is False + assert attrib.internal is False + assert attrib.computed is None + + with pytest.raises(AttributeError): + # This currently raises an error to avoid double specifying + attrib = SQLAttribute(primary_key=True, unique=True) + + +class TestWithExample: + @property + def example_class(self): + class ExampleClass(SQLClass): + uid: int = SQLAttribute(primary_key=True) + name: str = SQLAttribute(unique=True) + age: int = SQLAttribute(internal=True) + height_m: float + height_feet: float = SQLAttribute(computed="height_m * 3.28084") + friends: list[str] = SQLAttribute(default_factory=list) + some_bool: bool + + return ExampleClass + + @property + def field_dict(self): + return { + "uid": SQLAttribute(primary_key=True, type=int), + "name": SQLAttribute(unique=True, type=str), + "age": SQLAttribute(internal=True, type=int), + "height_m": SQLAttribute(type=float), + "height_feet": SQLAttribute(computed="height_m * 3.28084", type=float), + "friends": SQLAttribute(default_factory=list, type=list[str]), + "some_bool": SQLAttribute(type=bool), + } + + def test_table_features(self): + ex_cls = self.example_class + assert ex_cls.PRIMARY_KEY == "uid" + assert ex_cls.TABLE_NAME == "example_class" + + def test_get_sql_fields(self): + fields = get_sql_fields(self.example_class) + assert fields == self.field_dict + + def test_valid_fields(self): + valid_fields = self.field_dict + valid_fields.pop("age") # Internal only field should be excluded + assert valid_fields == self.example_class.VALID_FIELDS + + def test_computed_fields(self): + assert self.example_class.COMPUTED_FIELDS == {"height_feet"} + + def test_str_list_columns(self): + assert self.example_class.STR_LIST_COLUMNS == {"friends"} + + def test_bool_columns(self): + assert self.example_class.BOOL_COLUMNS == {"some_bool"} + + def test_create_table(self): + mock_con = mock.MagicMock() + self.example_class.create_table(mock_con) + + mock_con.execute.assert_called_with( + "CREATE TABLE IF NOT EXISTS example_class(" + "uid INTEGER PRIMARY KEY, " + "name TEXT UNIQUE, " + "height_m REAL, " + "height_feet REAL GENERATED ALWAYS AS (height_m * 3.28084), " + "friends TEXT, " # list[str] is converted to TEXT + "some_bool INTEGER" # Bools are converted to INTEGERS + ")" + ) + + def test_drop_table(self): + mock_con = mock.MagicMock() + self.example_class.drop_table(mock_con) + + mock_con.execute.assert_called_with("DROP TABLE IF EXISTS example_class") + + def test_select_rows_no_filters(self): + mock_con = mock.MagicMock() + mock_cursor = mock.MagicMock() + mock_rows = mock.MagicMock() + mock_fetchall = mock.MagicMock() + + mock_con.cursor.return_value = mock_cursor + mock_cursor.execute.return_value = mock_rows + mock_rows.fetchall.return_value = mock_fetchall + + row_out = self.example_class.select_rows(mock_con) + assert row_out is mock_fetchall + + mock_rows.fetchall.assert_called_once() + mock_cursor.execute.assert_called_once_with( + "SELECT * FROM example_class ", + {} + ) + mock_con.cursor.assert_called_once() + mock_cursor.close.assert_called_once() + + + + +def test_failed_class_pk(): + with pytest.raises(AttributeError): + class ExampleClass(SQLClass): + name: str = SQLAttribute(unique=True) + age: int = SQLAttribute(internal=True) + height_m: float + height_feet: float = SQLAttribute(computed="height_m * 3.28084") + friends: list[str] = SQLAttribute(default_factory=list) + some_bool: bool + + +def test_failed_class_double_pk(): + with pytest.raises(AttributeError): + class ExampleClass(SQLClass): + uid: int = SQLAttribute(primary_key=True) + ununiqueid: int = SQLAttribute(primary_key=True) + name: str = SQLAttribute(unique=True) + age: int = SQLAttribute(internal=True) + height_m: float + height_feet: float = SQLAttribute(computed="height_m * 3.28084") + friends: list[str] = SQLAttribute(default_factory=list) + some_bool: bool + From 2ea621cae2447f61cb80a83ac9e29070d599af0b Mon Sep 17 00:00:00 2001 From: David C Ellis Date: Mon, 2 Dec 2024 17:45:29 +0000 Subject: [PATCH 3/8] remove unnecessary trailing space, more tests --- src/ducktools/env/_sqlclasses.py | 4 ++- tests/test_sql_classes.py | 42 +++++++++++++++++++++++++++++++- 2 files changed, 44 insertions(+), 2 deletions(-) diff --git a/src/ducktools/env/_sqlclasses.py b/src/ducktools/env/_sqlclasses.py index cfae8aa..2420409 100644 --- a/src/ducktools/env/_sqlclasses.py +++ b/src/ducktools/env/_sqlclasses.py @@ -23,6 +23,8 @@ # This is a minimal object/database wrapper for ducktools.classbuilder # Execute the class to see examples of the methods that will be generated +# There are a lot of features that would be needed for a *general* version of this +# This only implements the required features for ducktools-env's use case import itertools @@ -256,7 +258,7 @@ def _select_query(cls, cursor, filters: dict[str, MAPPED_TYPES] | None = None): search_condition = "" cursor.row_factory = cls.row_factory - result = cursor.execute(f"SELECT * FROM {cls.TABLE_NAME} {search_condition}", filters) + result = cursor.execute(f"SELECT * FROM {cls.TABLE_NAME}{search_condition}", filters) return result @classmethod diff --git a/tests/test_sql_classes.py b/tests/test_sql_classes.py index a12c21e..911d081 100644 --- a/tests/test_sql_classes.py +++ b/tests/test_sql_classes.py @@ -169,13 +169,53 @@ def test_select_rows_no_filters(self): mock_rows.fetchall.assert_called_once() mock_cursor.execute.assert_called_once_with( - "SELECT * FROM example_class ", + "SELECT * FROM example_class", {} ) mock_con.cursor.assert_called_once() mock_cursor.close.assert_called_once() + def test_select_row_no_filters(self): + mock_con = mock.MagicMock() + mock_cursor = mock.MagicMock() + mock_rows = mock.MagicMock() + mock_fetchone = mock.MagicMock() + + mock_con.cursor.return_value = mock_cursor + mock_cursor.execute.return_value = mock_rows + mock_rows.fetchone.return_value = mock_fetchone + row_out = self.example_class.select_row(mock_con) + assert row_out is mock_fetchone + + mock_rows.fetchone.assert_called_once() + mock_cursor.execute.assert_called_once_with( + "SELECT * FROM example_class", + {} + ) + mock_con.cursor.assert_called_once() + mock_cursor.close.assert_called_once() + + def test_select_rows_filters(self): + mock_con = mock.MagicMock() + mock_cursor = mock.MagicMock() + mock_rows = mock.MagicMock() + mock_fetchall = mock.MagicMock() + + mock_con.cursor.return_value = mock_cursor + mock_cursor.execute.return_value = mock_rows + mock_rows.fetchall.return_value = mock_fetchall + + row_out = self.example_class.select_rows(mock_con, {"name": "John"}) + assert row_out is mock_fetchall + + mock_rows.fetchall.assert_called_once() + mock_cursor.execute.assert_called_once_with( + "SELECT * FROM example_class WHERE name = :name", + {"name": "John"} + ) + mock_con.cursor.assert_called_once() + mock_cursor.close.assert_called_once() def test_failed_class_pk(): From a12d07b1fde6558e2e11971a7d4679be0efa6499 Mon Sep 17 00:00:00 2001 From: David C Ellis Date: Mon, 2 Dec 2024 17:48:54 +0000 Subject: [PATCH 4/8] Missed one --- src/ducktools/env/_sqlclasses.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ducktools/env/_sqlclasses.py b/src/ducktools/env/_sqlclasses.py index 2420409..fa8eba0 100644 --- a/src/ducktools/env/_sqlclasses.py +++ b/src/ducktools/env/_sqlclasses.py @@ -304,7 +304,7 @@ def select_like(cls, con, filters: dict[str, MAPPED_TYPES] | None = None): try: cursor.row_factory = cls.row_factory result = cursor.execute( - f"SELECT * FROM {cls.TABLE_NAME} {search_condition}", + f"SELECT * FROM {cls.TABLE_NAME}{search_condition}", filters ) rows = result.fetchall() From 5ab08560714c72b2b01187e02196bf5a0ec5a0e3 Mon Sep 17 00:00:00 2001 From: David C Ellis Date: Mon, 2 Dec 2024 17:49:15 +0000 Subject: [PATCH 5/8] Test for select like --- tests/test_sql_classes.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/tests/test_sql_classes.py b/tests/test_sql_classes.py index 911d081..9531b20 100644 --- a/tests/test_sql_classes.py +++ b/tests/test_sql_classes.py @@ -217,6 +217,27 @@ def test_select_rows_filters(self): mock_con.cursor.assert_called_once() mock_cursor.close.assert_called_once() + def test_select_rows_like(self): + mock_con = mock.MagicMock() + mock_cursor = mock.MagicMock() + mock_rows = mock.MagicMock() + mock_fetchall = mock.MagicMock() + + mock_con.cursor.return_value = mock_cursor + mock_cursor.execute.return_value = mock_rows + mock_rows.fetchall.return_value = mock_fetchall + + row_out = self.example_class.select_like(mock_con, {"name": "John"}) + assert row_out is mock_fetchall + + mock_rows.fetchall.assert_called_once() + mock_cursor.execute.assert_called_once_with( + "SELECT * FROM example_class WHERE name LIKE :name", + {"name": "John"} + ) + mock_con.cursor.assert_called_once() + mock_cursor.close.assert_called_once() + def test_failed_class_pk(): with pytest.raises(AttributeError): From 002319de7f64f460bda7e032adc83856b48d7dc9 Mon Sep 17 00:00:00 2001 From: David C Ellis Date: Thu, 5 Dec 2024 13:10:04 +0000 Subject: [PATCH 6/8] Bugfix, incorrectly checked if primary key was set on class instead of instance. Change attribute name to PK_NAME to avoid this confusion and add a `primary_key` property that will get the value of the primary key no matter what it is named in the row. --- src/ducktools/env/_sqlclasses.py | 49 +++++--- tests/test_sql_classes.py | 200 +++++++++++++++++++++++++++---- 2 files changed, 207 insertions(+), 42 deletions(-) diff --git a/src/ducktools/env/_sqlclasses.py b/src/ducktools/env/_sqlclasses.py index fa8eba0..74ac054 100644 --- a/src/ducktools/env/_sqlclasses.py +++ b/src/ducktools/env/_sqlclasses.py @@ -134,7 +134,7 @@ class SQLMeta(SlotMakerMeta): TABLE_NAME: str VALID_FIELDS: dict[str, SQLAttribute] COMPUTED_FIELDS: set[str] - PRIMARY_KEY: str + PK_NAME: str STR_LIST_COLUMNS: set[str] BOOL_COLUMNS: set[str] @@ -183,20 +183,25 @@ def __init_subclass__( primary_key = None for name, field in fields.items(): if field.primary_key: + if primary_key is not None: + raise AttributeError("sqlclass *must* have **only** one primary key") primary_key = name - break if primary_key is None: raise AttributeError("sqlclass *must* have one primary key") - if sum(1 for f in fields.values() if f.primary_key) > 1: - raise AttributeError("sqlclass *must* have **only** one primary key") - - cls.PRIMARY_KEY = primary_key + cls.PK_NAME = primary_key cls.TABLE_NAME = caps_to_snake(cls.__name__) super().__init_subclass__(**kwargs) + @property + def primary_key(self): + """ + Get the actual value of the primary key on an instance. + """ + return getattr(self, self.PK_NAME) + @classmethod def create_table(cls, con): sql_field_list = [] @@ -315,13 +320,13 @@ def select_like(cls, con, filters: dict[str, MAPPED_TYPES] | None = None): @classmethod def max_pk(cls, con): - statement = f"SELECT MAX({cls.PRIMARY_KEY}) from {cls.TABLE_NAME}" + statement = f"SELECT MAX({cls.PK_NAME}) FROM {cls.TABLE_NAME}" result = con.execute(statement) return result.fetchone()[0] @classmethod def row_from_pk(cls, con, pk_value): - return cls.select_row(con, filters={cls.PRIMARY_KEY: pk_value}) + return cls.select_row(con, filters={cls.PK_NAME: pk_value}) def insert_row(self, con): columns = ", ".join( @@ -340,8 +345,8 @@ def insert_row(self, con): with con: result = con.execute(sql_statement, processed_values) - if getattr(self, self.PRIMARY_KEY) is None: - setattr(self, self.PRIMARY_KEY, result.lastrowid) + if getattr(self, self.PK_NAME) is None: + setattr(self, self.PK_NAME, result.lastrowid) if self.COMPUTED_FIELDS: row = self.row_from_pk(con, result.lastrowid) @@ -349,7 +354,13 @@ def insert_row(self, con): setattr(self, field, getattr(row, field)) def update_row(self, con, columns: list[str]): - if self.PRIMARY_KEY is None: + """ + Update the values in the database for this 'row' + + :param con: SQLContext + :param columns: list of the columns to update from this class. + """ + if self.primary_key is None: raise AttributeError("Primary key has not yet been set") if invalid_columns := (set(columns) - self.VALID_FIELDS.keys()): @@ -362,22 +373,28 @@ def update_row(self, con, columns: list[str]): } set_columns = ", ".join(f"{name} = :{name}" for name in columns) - search_condition = f"{self.PRIMARY_KEY} = :{self.PRIMARY_KEY}" + search_condition = f"{self.PK_NAME} = :{self.PK_NAME}" with con: - con.execute( + result = con.execute( f"UPDATE {self.TABLE_NAME} SET {set_columns} WHERE {search_condition}", processed_values, ) + # Computed rows may need to be updated + if self.COMPUTED_FIELDS: + row = self.row_from_pk(con, self.primary_key) + for field in self.COMPUTED_FIELDS: + setattr(self, field, getattr(row, field)) + def delete_row(self, con): - if self.PRIMARY_KEY is None: + if self.primary_key is None: raise AttributeError("Primary key has not yet been set") - pk_filter = {self.PRIMARY_KEY: getattr(self, self.PRIMARY_KEY)} + pk_filter = {self.PK_NAME: self.primary_key} with con: con.execute( - f"DELETE FROM {self.TABLE_NAME} WHERE {self.PRIMARY_KEY} = :{self.PRIMARY_KEY}", + f"DELETE FROM {self.TABLE_NAME} WHERE {self.PK_NAME} = :{self.PK_NAME}", pk_filter, ) diff --git a/tests/test_sql_classes.py b/tests/test_sql_classes.py index 9531b20..6e8d0db 100644 --- a/tests/test_sql_classes.py +++ b/tests/test_sql_classes.py @@ -22,9 +22,11 @@ # SOFTWARE. from unittest import mock import typing +import types import pytest +# noinspection PyProtectedMember from ducktools.env._sqlclasses import ( _laz, TYPE_MAP, @@ -84,15 +86,15 @@ def test_sql_attribute(): attrib = SQLAttribute(primary_key=True, unique=True) -class TestWithExample: +class SharedExample: @property def example_class(self): class ExampleClass(SQLClass): - uid: int = SQLAttribute(primary_key=True) + uid: int = SQLAttribute(default=None, primary_key=True) name: str = SQLAttribute(unique=True) age: int = SQLAttribute(internal=True) height_m: float - height_feet: float = SQLAttribute(computed="height_m * 3.28084") + height_feet: float = SQLAttribute(default=None, computed="height_m * 3.28084") friends: list[str] = SQLAttribute(default_factory=list) some_bool: bool @@ -101,18 +103,24 @@ class ExampleClass(SQLClass): @property def field_dict(self): return { - "uid": SQLAttribute(primary_key=True, type=int), + "uid": SQLAttribute(default=None, primary_key=True, type=int), "name": SQLAttribute(unique=True, type=str), "age": SQLAttribute(internal=True, type=int), "height_m": SQLAttribute(type=float), - "height_feet": SQLAttribute(computed="height_m * 3.28084", type=float), + "height_feet": SQLAttribute(default=None, computed="height_m * 3.28084", type=float), "friends": SQLAttribute(default_factory=list, type=list[str]), "some_bool": SQLAttribute(type=bool), } + +class TestClassConstruction(SharedExample): + """ + Test that the basic class features are built correctly + """ + def test_table_features(self): ex_cls = self.example_class - assert ex_cls.PRIMARY_KEY == "uid" + assert ex_cls.PK_NAME == "uid" assert ex_cls.TABLE_NAME == "example_class" def test_get_sql_fields(self): @@ -133,6 +141,11 @@ def test_str_list_columns(self): def test_bool_columns(self): assert self.example_class.BOOL_COLUMNS == {"some_bool"} + +class TestSQLGeneration(SharedExample): + """ + Test that the generated SQL looks correct + """ def test_create_table(self): mock_con = mock.MagicMock() self.example_class.create_table(mock_con) @@ -238,27 +251,162 @@ def test_select_rows_like(self): mock_con.cursor.assert_called_once() mock_cursor.close.assert_called_once() + def test_max_pk(self): + mock_con = mock.MagicMock() + mock_result = mock.MagicMock() + mock_con.execute.return_value = mock_result -def test_failed_class_pk(): - with pytest.raises(AttributeError): - class ExampleClass(SQLClass): - name: str = SQLAttribute(unique=True) - age: int = SQLAttribute(internal=True) - height_m: float - height_feet: float = SQLAttribute(computed="height_m * 3.28084") - friends: list[str] = SQLAttribute(default_factory=list) - some_bool: bool + max_pk = self.example_class.max_pk(mock_con) + mock_con.execute.assert_called_with("SELECT MAX(uid) FROM example_class") + mock_result.fetchone.assert_called() -def test_failed_class_double_pk(): - with pytest.raises(AttributeError): - class ExampleClass(SQLClass): - uid: int = SQLAttribute(primary_key=True) - ununiqueid: int = SQLAttribute(primary_key=True) - name: str = SQLAttribute(unique=True) - age: int = SQLAttribute(internal=True) - height_m: float - height_feet: float = SQLAttribute(computed="height_m * 3.28084") - friends: list[str] = SQLAttribute(default_factory=list) - some_bool: bool + def test_row_from_pk(self): + mock_con = mock.MagicMock() + mock_cursor = mock.MagicMock() + mock_con.cursor.return_value = mock_cursor + + row = self.example_class.row_from_pk(mock_con, 42) + + mock_cursor.execute.assert_called_once_with( + "SELECT * FROM example_class WHERE uid = :uid", + {"uid": 42}, + ) + mock_cursor.close.assert_called_once() + + def test_insert_row(self): + mock_con = mock.MagicMock() + + result_row = mock.MagicMock() + mock_con.execute.return_value = result_row + result_row.lastrowid = 100 + + ExampleClass = self.example_class + ex = ExampleClass( + name="John", + age=42, + height_m=1.0, + some_bool=False, + ) + + assert ex.uid is None + assert ex.height_feet is None + + with mock.patch.object(ExampleClass, "row_from_pk") as computed_check: + return_row = types.SimpleNamespace(height_feet=6.0) + computed_check.return_value = return_row + + ex.insert_row(mock_con) + + # Check the values were correctly updated + assert ex.uid == ex.primary_key == 100 + assert ex.height_feet == 6.0 + + # Check the call + mock_con.execute.assert_called_with( + "INSERT INTO example_class VALUES(:uid, :name, :height_m, :friends, :some_bool)", + { + "uid": None, + "name": "John", + "height_m": 1.0, + "friends": "", + "some_bool": False, + } + ) + + def test_update_row(self): + ExampleClass = self.example_class + ex = ExampleClass( + uid=1, + name="John", + age=42, + height_m=1.0, + some_bool=True, + ) + + mock_con = mock.MagicMock() + + with mock.patch.object(ExampleClass, "row_from_pk") as computed_check: + return_row = types.SimpleNamespace(height_feet=6.0) + computed_check.return_value = return_row + + ex.update_row(mock_con, ["some_bool"]) + + assert ex.height_feet == 6.0 + + mock_con.execute.assert_called_with( + "UPDATE example_class SET some_bool = :some_bool WHERE uid = :uid", + { + "uid": 1, + "name": "John", + "height_m": 1.0, + "friends": "", + "some_bool": True, + } + ) + + def test_update_row_fail_no_pk(self): + ExampleClass = self.example_class + ex = ExampleClass( + uid=None, + name="John", + age=42, + height_m=1.0, + some_bool=True, + ) + + mock_con = mock.MagicMock() + + with pytest.raises(AttributeError): + ex.update_row(mock_con, ["some_bool"]) + + def test_delete_row(self): + ExampleClass = self.example_class + ex = ExampleClass( + uid=1, + name="John", + age=42, + height_m=1.0, + some_bool=True, + ) + + mock_con = mock.MagicMock() + + ex.delete_row(mock_con) + + mock_con.execute.assert_called_with( + "DELETE FROM example_class WHERE uid = :uid", + {"uid": 1}, + ) + + + +class TestSQLExecution: + """ + Test that the generated SQL actually does what we expect. + """ + + +class TestIncorrectConstruction: + def test_failed_class_pk(self): + with pytest.raises(AttributeError): + class ExampleClass(SQLClass): + name: str = SQLAttribute(unique=True) + age: int = SQLAttribute(internal=True) + height_m: float + height_feet: float = SQLAttribute(computed="height_m * 3.28084") + friends: list[str] = SQLAttribute(default_factory=list) + some_bool: bool + + def test_failed_class_double_pk(self): + with pytest.raises(AttributeError): + class ExampleClass(SQLClass): + uid: int = SQLAttribute(primary_key=True) + ununiqueid: int = SQLAttribute(primary_key=True) + name: str = SQLAttribute(unique=True) + age: int = SQLAttribute(internal=True) + height_m: float + height_feet: float = SQLAttribute(computed="height_m * 3.28084") + friends: list[str] = SQLAttribute(default_factory=list) + some_bool: bool From b9c82770c5d8c8407afc22341f501a9d6e579853 Mon Sep 17 00:00:00 2001 From: David C Ellis Date: Thu, 5 Dec 2024 13:46:59 +0000 Subject: [PATCH 7/8] Tests for SQLClasses --- tests/test_sql_classes.py | 166 +++++++++++++++++++++++++++++++++++++- 1 file changed, 163 insertions(+), 3 deletions(-) diff --git a/tests/test_sql_classes.py b/tests/test_sql_classes.py index 6e8d0db..8e94202 100644 --- a/tests/test_sql_classes.py +++ b/tests/test_sql_classes.py @@ -92,7 +92,7 @@ def example_class(self): class ExampleClass(SQLClass): uid: int = SQLAttribute(default=None, primary_key=True) name: str = SQLAttribute(unique=True) - age: int = SQLAttribute(internal=True) + age: int = SQLAttribute(default=20, internal=True) height_m: float height_feet: float = SQLAttribute(default=None, computed="height_m * 3.28084") friends: list[str] = SQLAttribute(default_factory=list) @@ -105,7 +105,7 @@ def field_dict(self): return { "uid": SQLAttribute(default=None, primary_key=True, type=int), "name": SQLAttribute(unique=True, type=str), - "age": SQLAttribute(internal=True, type=int), + "age": SQLAttribute(default=20, internal=True, type=int), "height_m": SQLAttribute(type=float), "height_feet": SQLAttribute(default=None, computed="height_m * 3.28084", type=float), "friends": SQLAttribute(default_factory=list, type=list[str]), @@ -230,6 +230,12 @@ def test_select_rows_filters(self): mock_con.cursor.assert_called_once() mock_cursor.close.assert_called_once() + def test_select_row_invalid_filter(self): + mock_con = mock.MagicMock() + + with pytest.raises(KeyError): + self.example_class.select_rows(mock_con, {"NotAField": 42}) + def test_select_rows_like(self): mock_con = mock.MagicMock() mock_cursor = mock.MagicMock() @@ -251,6 +257,33 @@ def test_select_rows_like(self): mock_con.cursor.assert_called_once() mock_cursor.close.assert_called_once() + def test_select_rows_like_empty(self): + mock_con = mock.MagicMock() + mock_cursor = mock.MagicMock() + mock_rows = mock.MagicMock() + mock_fetchall = mock.MagicMock() + + mock_con.cursor.return_value = mock_cursor + mock_cursor.execute.return_value = mock_rows + mock_rows.fetchall.return_value = mock_fetchall + + row_out = self.example_class.select_like(mock_con, {}) + assert row_out is mock_fetchall + + mock_rows.fetchall.assert_called_once() + mock_cursor.execute.assert_called_once_with( + "SELECT * FROM example_class", + {} + ) + mock_con.cursor.assert_called_once() + mock_cursor.close.assert_called_once() + + def test_select_like_invalid_filter(self): + mock_con = mock.MagicMock() + + with pytest.raises(KeyError): + self.example_class.select_like(mock_con, {"NotAField": "*John"}) + def test_max_pk(self): mock_con = mock.MagicMock() mock_result = mock.MagicMock() @@ -345,6 +378,21 @@ def test_update_row(self): } ) + def test_update_row_invalid(self): + ExampleClass = self.example_class + ex = ExampleClass( + uid=1, + name="John", + age=42, + height_m=1.0, + some_bool=True, + ) + + mock_con = mock.MagicMock() + + with pytest.raises(ValueError): + ex.update_row(mock_con, ["NotAField"]) + def test_update_row_fail_no_pk(self): ExampleClass = self.example_class ex = ExampleClass( @@ -379,12 +427,124 @@ def test_delete_row(self): {"uid": 1}, ) + def test_delete_row_before_set(self): + ExampleClass = self.example_class + ex = ExampleClass( + uid=None, + name="John", + age=42, + height_m=1.0, + some_bool=True, + ) + + mock_con = mock.MagicMock() + + with pytest.raises(AttributeError): + ex.delete_row(mock_con) -class TestSQLExecution: +class TestSQLExecution(SharedExample): """ Test that the generated SQL actually does what we expect. """ + def test_table_create_drop(self): + ExampleClass = self.example_class + context = SQLContext(":memory:") + with context as con: + # Table doesn't exist + cursor = con.cursor() + try: + result = con.execute( + "SELECT name FROM sqlite_schema WHERE type = 'table' AND name = :name", + {"name": ExampleClass.TABLE_NAME}, + ) + row = result.fetchone() + finally: + cursor.close() + + assert row is None + + # Create the Table + ExampleClass.create_table(con) + + # Now it should be in the schema + cursor = con.cursor() + try: + result = con.execute( + "SELECT name FROM sqlite_schema WHERE type = 'table' AND name = :name", + {"name": ExampleClass.TABLE_NAME}, + ) + row = result.fetchone() + finally: + cursor.close() + + assert row[0] == "example_class" + + # Drop the table + ExampleClass.drop_table(con) + + cursor = con.cursor() + try: + result = con.execute( + "SELECT name FROM sqlite_schema WHERE type = 'table' AND name = :name", + {"name": ExampleClass.TABLE_NAME}, + ) + row = result.fetchone() + finally: + cursor.close() + + assert row is None + + def test_create_table_row_retrieve(self): + ExampleClass = self.example_class + context = SQLContext(":memory:") + with context as con: + ExampleClass.create_table(con) + + ex = ExampleClass( + name="John", + height_m=1.0, + some_bool=True, + ) + + ex.insert_row(con) + + ex_retrieved = ExampleClass.row_from_pk(con, ex.primary_key) + + assert ex == ex_retrieved + + ex.delete_row(con) + + ex_retrieved = ExampleClass.row_from_pk(con, ex.primary_key) + assert ex_retrieved is None + + def test_select_row_rows(self): + ExampleClass = self.example_class + context = SQLContext(":memory:") + with context as con: + ExampleClass.create_table(con) + + ex = ExampleClass( + name="John", + height_m=1.0, + some_bool=True, + ) + + ex.insert_row(con) + + ex_retrieved = ExampleClass.select_row(con, {"name": "John"}) + + assert ex_retrieved == ex + + ex_retrieved = ExampleClass.select_rows(con, {"name": "John"})[0] + + assert ex_retrieved == ex + + def test_select_missing_row_rows(self): + ExampleClass = self.example_class + context = SQLContext(":memory:") + with context as con: + ExampleClass.create_table(con) class TestIncorrectConstruction: From 5dcb521919eafebcceb12f473c3312537ac50474 Mon Sep 17 00:00:00 2001 From: David C Ellis Date: Thu, 5 Dec 2024 14:00:29 +0000 Subject: [PATCH 8/8] Construct the union in stages for Python 3.10 --- tests/test_sql_classes.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/test_sql_classes.py b/tests/test_sql_classes.py index 8e94202..090ee5d 100644 --- a/tests/test_sql_classes.py +++ b/tests/test_sql_classes.py @@ -44,8 +44,11 @@ def test_type_map(): # Check that the MAPPED_TYPES matches the union of types in TYPE_MAP - mapped_type_construct = typing.Union[*TYPE_MAP.keys()] - assert MAPPED_TYPES == mapped_type_construct + union = None + for t in TYPE_MAP.keys(): + union = typing.Union[union, t] + + assert MAPPED_TYPES == union class TestListFlattenSeparate: