Skip to content

Commit

Permalink
table alterations mysql
Browse files Browse the repository at this point in the history
  • Loading branch information
CrispenGari committed Feb 23, 2024
1 parent 4f58e56 commit a76a981
Show file tree
Hide file tree
Showing 11 changed files with 291 additions and 246 deletions.
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,13 @@ loom = Loom(
logs_filename="logs.sql",
port=5432,
)

# OR with connection_uri
loom = Loom(
dialect="mysql",
connection_uri = "mysql://root:root@localhost:3306/hi",
# ...
)
```

The `Loom` class takes in the following options:
Expand Down
2 changes: 1 addition & 1 deletion dataloom/keys.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Configuration file for unit testing.


push = True
push = False


class PgConfig:
Expand Down
64 changes: 48 additions & 16 deletions dataloom/loom/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,15 @@ class Loom(ILoom):
"""

def __get_database_name(self, uri: str) -> str | None:
if self.dialect == "postgres" or self.dialect == "mysql":
from urllib.parse import urlparse

components = urlparse(uri)
db = components.path.lstrip("/")
return db
return None

def __init__(
self,
dialect: DIALECT_LITERAL,
Expand All @@ -81,12 +90,16 @@ def __init__(
sql_logger: Optional[SQL_LOGGER_LITERAL] = None,
logs_filename: Optional[str] = "dataloom.sql",
) -> None:
self.database = database
self.conn = None
self.sql_logger = sql_logger
self.dialect = dialect
self.logs_filename = logs_filename
self.connection_uri = connection_uri
self.database = (
database
if self.connection_uri is None
else self.__get_database_name(self.connection_uri)
)

try:
config = instances[dialect]
Expand Down Expand Up @@ -1128,7 +1141,7 @@ def tables(self) -> list[str]:
"""
sql = GetStatement(self.dialect)._get_tables_command
res = self._execute_sql(sql, fetchall=True)
res = self._execute_sql(sql, fetchall=True, _verbose=0)
if self.dialect == "sqlite":
return [t[0] for t in res if not str(t[0]).lower().startswith("sqlite_")]
return [t[0] for t in res]
Expand Down Expand Up @@ -1331,19 +1344,8 @@ def connect_and_sync(
sql_logger=self.sql_logger,
logs_filename=self.logs_filename,
)
for model in models:
if drop or force:
self._execute_sql(model._drop_sql(dialect=self.dialect))
for sql in model._create_sql(dialect=self.dialect):
if sql is not None:
self._execute_sql(sql)
elif alter:
pass
else:
for sql in model._create_sql(dialect=self.dialect):
if sql is not None:
self._execute_sql(sql)
return self.conn, self.tables
tables = self.sync(models=models, drop=drop, force=force, alter=alter)
return self.conn, tables
except Exception as e:
raise Exception(e)

Expand Down Expand Up @@ -1407,7 +1409,37 @@ def sync(
if sql is not None:
self._execute_sql(sql)
elif alter:
pass
# 1. we only alter the table if it does exists
# 2. if not we just have to create a new table
if model._get_table_name() in self.tables:
sql1 = model._get_describe_stm(
dialect=self.dialect, fields=["column_name"]
)
args = None
if self.dialect == "mysql":
args = (self.database, model._get_table_name())
elif self.dialect == "postgres":
args = ("public", model._get_table_name())
elif self.dialect == "sqlite":
args = ()
cols = self._execute_sql(
sql1, _verbose=0, args=args, fetchall=True
)
if cols is not None:
if self.dialect == "mysql":
old_columns = [col for (col,) in cols]
elif self.dialect == "postgres":
old_columns = [col for (col,) in cols]
else:
old_columns = [col[1] for col in cols]
sql = model._alter_sql(
dialect=self.dialect, old_columns=old_columns
)
self._execute_sql(sql)
else:
for sql in model._create_sql(dialect=self.dialect):
if sql is not None:
self._execute_sql(sql)
else:
for sql in model._create_sql(dialect=self.dialect):
if sql is not None:
Expand Down
8 changes: 7 additions & 1 deletion dataloom/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,18 @@ class Model:
"""

@classmethod
def _create_sql(cls, dialect: DIALECT_LITERAL, ignore_exists=True):
def _create_sql(cls, dialect: DIALECT_LITERAL):
sqls = GetStatement(
dialect=dialect, model=cls, table_name=cls._get_table_name()
)._get_create_table_command
return sqls

@classmethod
def _alter_sql(cls, dialect: DIALECT_LITERAL, old_columns: list[str]):
return GetStatement(
dialect=dialect, model=cls, table_name=cls._get_table_name()
)._get_alter_table_command(old_columns=old_columns)

@classmethod
def _get_table_name(self):
__tablename__ = None
Expand Down
98 changes: 57 additions & 41 deletions dataloom/statements/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
get_relationships,
get_create_table_params,
get_table_fields,
get_alter_table_params,
)


Expand Down Expand Up @@ -147,28 +148,25 @@ def _get_tables_command(self) -> Optional[str]:
return sql

@property
def _get_create_table_command(self) -> Optional[str]:
def _get_create_table_command(self) -> Optional[list[str]]:
# is the primary key defined in this table?
_, pk_name, _, _ = get_table_fields(model=self.model, dialect=self.dialect)
pks, user_fields, predefined_fields, sql2 = get_create_table_params(
pks, user_fields, predefined_fields = get_create_table_params(
dialect=self.dialect,
model=self.model,
child_alias_name=self.model.__name__.lower(),
child_pk_name=pk_name,
child_name=self.model._get_table_name(),
)

if len(pks) == 0:
raise PkNotDefinedException(
"Your table does not have a primary key column."
)
if len(pks) > 1:
raise TooManyPkException(
f"You have defined many field as primary keys which is not allowed. Fields ({', '.join(pks)}) are primary keys."
)
fields = [*user_fields, *predefined_fields]
fields_name = ", ".join(f for f in [" ".join(field) for field in fields])
if self.dialect == "postgres":
# do we have a single primary key or not?
if len(pks) == 0:
raise PkNotDefinedException(
"Your table does not have a primary key column."
)
if len(pks) > 1:
raise TooManyPkException(
f"You have defined many field as primary keys which is not allowed. Fields ({', '.join(pks)}) are primary keys."
)
fields = [*user_fields, *predefined_fields]
fields_name = ", ".join(f for f in [" ".join(field) for field in fields])
sql = (
PgStatements.CREATE_NEW_TABLE.format(
table_name=f'"{self.table_name}"', fields_name=fields_name
Expand All @@ -180,17 +178,6 @@ def _get_create_table_command(self) -> Optional[str]:
)

elif self.dialect == "mysql":
# do we have a single primary key or not?
if len(pks) == 0:
raise PkNotDefinedException(
"Your table does not have a primary key column."
)
if len(pks) > 1:
raise TooManyPkException(
f"You have defined many field as primary keys which is not allowed. Fields ({', '.join(pks)}) are primary keys."
)
fields = [*user_fields, *predefined_fields]
fields_name = ", ".join(f for f in [" ".join(field) for field in fields])
sql = (
MySqlStatements.CREATE_NEW_TABLE.format(
table_name=f"`{self.table_name}`", fields_name=fields_name
Expand All @@ -202,23 +189,12 @@ def _get_create_table_command(self) -> Optional[str]:
)

elif self.dialect == "sqlite":
# do we have a single primary key or not?
if len(pks) == 0:
raise PkNotDefinedException(
"Your table does not have a primary key column."
)
if len(pks) > 1:
raise TooManyPkException(
f"You have defined many field as primary keys which is not allowed. Fields ({', '.join(pks)}) are primary keys."
)
fields = [*user_fields, *predefined_fields]
fields_name = ", ".join(f for f in [" ".join(field) for field in fields])
sql = (
MySqlStatements.CREATE_NEW_TABLE.format(
Sqlite3Statements.CREATE_NEW_TABLE.format(
table_name=f"`{self.table_name}`", fields_name=fields_name
)
if not self.ignore_exists
else MySqlStatements.CREATE_NEW_TABLE_IF_NOT_EXITS.format(
else Sqlite3Statements.CREATE_NEW_TABLE_IF_NOT_EXITS.format(
table_name=f"`{self.table_name}`", fields_name=fields_name
)
)
Expand All @@ -227,7 +203,7 @@ def _get_create_table_command(self) -> Optional[str]:
raise UnsupportedDialectException(
"The dialect passed is not supported the supported dialects are: {'postgres', 'mysql', 'sqlite'}"
)
return [sql, sql2]
return [sql]

def _get_select_where_command(
self,
Expand Down Expand Up @@ -851,3 +827,43 @@ def _get_pk_command(
"The dialect passed is not supported the supported dialects are: {'postgres', 'mysql', 'sqlite'}"
)
return sql

def _get_alter_table_command(self, old_columns: list[str]) -> str:
"""
1. get table columns
2. check if is new column
3. check if the column has been removed
"""
_, pk_name, _, _ = get_table_fields(model=self.model, dialect=self.dialect)
pks, alterations = get_alter_table_params(
dialect=self.dialect, model=self.model, old_columns=old_columns
)

alterations = ", ".join(alterations)
# do we have a single primary key or not?
if len(pks) == 0:
raise PkNotDefinedException(
"Your table does not have a primary key column."
)
if len(pks) > 1:
raise TooManyPkException(
f"You have defined many field as primary keys which is not allowed. Fields ({', '.join(pks)}) are primary keys."
)
if self.dialect == "postgres":
sql = PgStatements.ALTER_TABLE_COMMAND.format(
table_name=f'"{self.table_name}"', alterations=alterations
)
elif self.dialect == "mysql":
sql = MySqlStatements.ALTER_TABLE_COMMAND.format(
table_name=f"`{self.table_name}`", alterations=alterations
)

elif self.dialect == "sqlite":
sql = Sqlite3Statements.ALTER_TABLE_COMMAND.format(
table_name=f"`{self.table_name}`", alterations=alterations
)
else:
raise UnsupportedDialectException(
"The dialect passed is not supported the supported dialects are: {'postgres', 'mysql', 'sqlite'}"
)
return sql
16 changes: 16 additions & 0 deletions dataloom/statements/statements.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
class MySqlStatements:
# Altering tables

ALTER_TABLE_COMMAND = """
ALTER TABLE {table_name} {alterations};
"""

# describing tables

DESCRIBE_TABLE_COMMAND = """
Expand Down Expand Up @@ -137,6 +143,11 @@ class MySqlStatements:


class Sqlite3Statements:
# Altering tables

ALTER_TABLE_COMMAND = """
ALTER TABLE {table_name} {alterations};
"""
# describing table

DESCRIBE_TABLE_COMMAND = """PRAGMA table_info({table_name});"""
Expand Down Expand Up @@ -254,6 +265,11 @@ class Sqlite3Statements:


class PgStatements:
# Altering tables

ALTER_TABLE_COMMAND = """
ALTER TABLE {table_name} {alterations};
"""
# describing table
DESCRIBE_TABLE_COMMAND = """
SELECT {fields}
Expand Down
2 changes: 2 additions & 0 deletions dataloom/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from dataloom.utils.logger import console_logger, file_logger
from dataloom.utils.create_table import get_create_table_params
from dataloom.utils.alter_table import get_alter_table_params
from dataloom.utils.aggregations import get_groups
from dataloom.utils.helpers import is_collection
from dataloom.utils.tables import (
Expand Down Expand Up @@ -162,4 +163,5 @@ def get_formatted_query(
print_pretty_table,
is_collection,
get_groups,
get_alter_table_params,
]
Loading

0 comments on commit a76a981

Please sign in to comment.