Skip to content

Commit

Permalink
Merge pull request #34 from DavidCEllis/cleanup
Browse files Browse the repository at this point in the history
Fix an internal bug in _sqlclasses, some cleanup
  • Loading branch information
DavidCEllis authored Dec 5, 2024
2 parents 67616a2 + 5dcb521 commit fba9870
Show file tree
Hide file tree
Showing 3 changed files with 613 additions and 21 deletions.
4 changes: 1 addition & 3 deletions src/ducktools/env/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,9 +434,7 @@ def delete_env_command(manager, args):
def main_command() -> int:
executable_name = os.path.splitext(os.path.basename(sys.executable))[0]

if zipapp_path := globals().get("zipapp_path"):
command = f"{executable_name} {zipapp_path}"
elif __name__ == "__main__":
if __name__ == "__main__":
command = f"{executable_name} -m ducktools.env"
else:
command = os.path.basename(sys.argv[0])
Expand Down
55 changes: 37 additions & 18 deletions src/ducktools/env/_sqlclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@

# This is a minimal object/database wrapper for ducktools.classbuilder
# Execute the class to see examples of the methods that will be generated
# There are a lot of features that would be needed for a *general* version of this
# This only implements the required features for ducktools-env's use case

import itertools

Expand Down Expand Up @@ -132,7 +134,7 @@ class SQLMeta(SlotMakerMeta):
TABLE_NAME: str
VALID_FIELDS: dict[str, SQLAttribute]
COMPUTED_FIELDS: set[str]
PRIMARY_KEY: str
PK_NAME: str
STR_LIST_COLUMNS: set[str]
BOOL_COLUMNS: set[str]

Expand Down Expand Up @@ -181,20 +183,25 @@ def __init_subclass__(
primary_key = None
for name, field in fields.items():
if field.primary_key:
if primary_key is not None:
raise AttributeError("sqlclass *must* have **only** one primary key")
primary_key = name
break

if primary_key is None:
raise AttributeError("sqlclass *must* have one primary key")

if sum(1 for f in fields.values() if f.primary_key) > 1:
raise AttributeError("sqlclass *must* have **only** one primary key")

cls.PRIMARY_KEY = primary_key
cls.PK_NAME = primary_key
cls.TABLE_NAME = caps_to_snake(cls.__name__)

super().__init_subclass__(**kwargs)

@property
def primary_key(self):
"""
Get the actual value of the primary key on an instance.
"""
return getattr(self, self.PK_NAME)

@classmethod
def create_table(cls, con):
sql_field_list = []
Expand Down Expand Up @@ -256,7 +263,7 @@ def _select_query(cls, cursor, filters: dict[str, MAPPED_TYPES] | None = None):
search_condition = ""

cursor.row_factory = cls.row_factory
result = cursor.execute(f"SELECT * FROM {cls.TABLE_NAME} {search_condition}", filters)
result = cursor.execute(f"SELECT * FROM {cls.TABLE_NAME}{search_condition}", filters)
return result

@classmethod
Expand Down Expand Up @@ -302,7 +309,7 @@ def select_like(cls, con, filters: dict[str, MAPPED_TYPES] | None = None):
try:
cursor.row_factory = cls.row_factory
result = cursor.execute(
f"SELECT * FROM {cls.TABLE_NAME} {search_condition}",
f"SELECT * FROM {cls.TABLE_NAME}{search_condition}",
filters
)
rows = result.fetchall()
Expand All @@ -313,13 +320,13 @@ def select_like(cls, con, filters: dict[str, MAPPED_TYPES] | None = None):

@classmethod
def max_pk(cls, con):
statement = f"SELECT MAX({cls.PRIMARY_KEY}) from {cls.TABLE_NAME}"
statement = f"SELECT MAX({cls.PK_NAME}) FROM {cls.TABLE_NAME}"
result = con.execute(statement)
return result.fetchone()[0]

@classmethod
def row_from_pk(cls, con, pk_value):
return cls.select_row(con, filters={cls.PRIMARY_KEY: pk_value})
return cls.select_row(con, filters={cls.PK_NAME: pk_value})

def insert_row(self, con):
columns = ", ".join(
Expand All @@ -338,16 +345,22 @@ def insert_row(self, con):
with con:
result = con.execute(sql_statement, processed_values)

if getattr(self, self.PRIMARY_KEY) is None:
setattr(self, self.PRIMARY_KEY, result.lastrowid)
if getattr(self, self.PK_NAME) is None:
setattr(self, self.PK_NAME, result.lastrowid)

if self.COMPUTED_FIELDS:
row = self.row_from_pk(con, result.lastrowid)
for field in self.COMPUTED_FIELDS:
setattr(self, field, getattr(row, field))

def update_row(self, con, columns: list[str]):
if self.PRIMARY_KEY is None:
"""
Update the values in the database for this 'row'
:param con: SQLContext
:param columns: list of the columns to update from this class.
"""
if self.primary_key is None:
raise AttributeError("Primary key has not yet been set")

if invalid_columns := (set(columns) - self.VALID_FIELDS.keys()):
Expand All @@ -360,22 +373,28 @@ def update_row(self, con, columns: list[str]):
}

set_columns = ", ".join(f"{name} = :{name}" for name in columns)
search_condition = f"{self.PRIMARY_KEY} = :{self.PRIMARY_KEY}"
search_condition = f"{self.PK_NAME} = :{self.PK_NAME}"

with con:
con.execute(
result = con.execute(
f"UPDATE {self.TABLE_NAME} SET {set_columns} WHERE {search_condition}",
processed_values,
)

# Computed rows may need to be updated
if self.COMPUTED_FIELDS:
row = self.row_from_pk(con, self.primary_key)
for field in self.COMPUTED_FIELDS:
setattr(self, field, getattr(row, field))

def delete_row(self, con):
if self.PRIMARY_KEY is None:
if self.primary_key is None:
raise AttributeError("Primary key has not yet been set")

pk_filter = {self.PRIMARY_KEY: getattr(self, self.PRIMARY_KEY)}
pk_filter = {self.PK_NAME: self.primary_key}

with con:
con.execute(
f"DELETE FROM {self.TABLE_NAME} WHERE {self.PRIMARY_KEY} = :{self.PRIMARY_KEY}",
f"DELETE FROM {self.TABLE_NAME} WHERE {self.PK_NAME} = :{self.PK_NAME}",
pk_filter,
)
Loading

0 comments on commit fba9870

Please sign in to comment.