Skip to content

Commit

Permalink
deterministic connectionerror
Browse files Browse the repository at this point in the history
  • Loading branch information
spicy-sauce committed Aug 7, 2024
1 parent 33e3aab commit a440f8f
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 80 deletions.
1 change: 1 addition & 0 deletions core/src/datayoga_core/blocks/relational/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def get_engine(connection_name: str, context: Context, autocommit: bool = True)
query=query_args),
echo=connection_details.get("debug", False),
connect_args=connect_args,
pool_pre_ping=True,
**extra)

return engine, db_type
Expand Down
140 changes: 60 additions & 80 deletions core/src/datayoga_core/blocks/relational/write/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,7 @@
from datayoga_core.context import Context
from datayoga_core.opcode import OpCode
from datayoga_core.result import BlockResult, Result, Status
from sqlalchemy import select, text
from sqlalchemy.exc import (DatabaseError, OperationalError,
PendingRollbackError)
from sqlalchemy import text
from sqlalchemy.sql.expression import ColumnCollection

logger = logging.getLogger("dy")
Expand All @@ -27,67 +25,52 @@ def init(self, context: Optional[Context] = None):

self.context = context
self.engine = None
self.connection = None
self.setup_engine()

def setup_engine(self):
"""Sets up the SQLAlchemy engine and configure it."""
if self.engine:
return

try:
self.engine, self.db_type = relational_utils.get_engine(self.properties["connection"], self.context)

logger.debug(f"Connecting to {self.db_type}")
self.connection = self.engine.connect()

# Disable the new MySQL 8.0.17+ default behavior of requiring an alias for ON DUPLICATE KEY UPDATE
# This behavior is not supported by pymysql driver
if self.engine.driver == "pymysql":
self.engine.dialect._requires_alias_for_on_duplicate_key = False

self.schema = self.properties.get("schema")
self.table = self.properties.get("table")
self.opcode_field = self.properties.get("opcode_field")
self.load_strategy = self.properties.get("load_strategy")
self.keys = self.properties.get("keys")
self.mapping = self.properties.get("mapping")
self.foreach = self.properties.get("foreach")
self.tbl = sa.Table(self.table, sa.MetaData(schema=self.schema), autoload_with=self.engine)

if self.opcode_field:
self.business_key_columns = [column["column"] for column in write_utils.get_column_mapping(self.keys)]
self.mapping_columns = [column["column"] for column in write_utils.get_column_mapping(self.mapping)]

self.columns = self.business_key_columns + [x for x in self.mapping_columns
if x not in self.business_key_columns]

for column in self.columns:
if not any(col.name.lower() == column.lower() for col in self.tbl.columns):
raise ValueError(f"{column} column does not exist in {self.tbl.fullname} table")

conditions = []
for business_key_column in self.business_key_columns:
for tbl_column in self.tbl.columns:
if tbl_column.name.lower() == business_key_column.lower():
conditions.append(tbl_column == sa.bindparam(business_key_column))
break

self.delete_stmt = self.tbl.delete().where(sa.and_(*conditions))
self.upsert_stmt = self.generate_upsert_stmt()

except (OperationalError, PendingRollbackError, DatabaseError) as e:
self._handle_connection_error(e)

def dispose_engine(self):
with suppress(Exception):
self.connection.close()
with suppress(Exception):
self.engine.dispose()
self.engine, self.db_type = relational_utils.get_engine(self.properties["connection"], self.context)

for attr in self._engine_fields:
setattr(self, attr, None)
# Disable the new MySQL 8.0.17+ default behavior of requiring an alias for ON DUPLICATE KEY UPDATE
# This behavior is not supported by pymysql driver
if self.engine.driver == "pymysql":
self.engine.dialect._requires_alias_for_on_duplicate_key = False

self.schema = self.properties.get("schema")
self.table = self.properties.get("table")
self.opcode_field = self.properties.get("opcode_field")
self.load_strategy = self.properties.get("load_strategy")
self.keys = self.properties.get("keys")
self.mapping = self.properties.get("mapping")
self.foreach = self.properties.get("foreach")
self.tbl = sa.Table(self.table, sa.MetaData(schema=self.schema), autoload_with=self.engine)

if self.opcode_field:
self.business_key_columns = [column["column"] for column in write_utils.get_column_mapping(self.keys)]
self.mapping_columns = [column["column"] for column in write_utils.get_column_mapping(self.mapping)]

self.columns = self.business_key_columns + [x for x in self.mapping_columns
if x not in self.business_key_columns]

for column in self.columns:
if not any(col.name.lower() == column.lower() for col in self.tbl.columns):
raise ValueError(f"{column} column does not exist in {self.tbl.fullname} table")

conditions = []
for business_key_column in self.business_key_columns:
for tbl_column in self.tbl.columns:
if tbl_column.name.lower() == business_key_column.lower():
conditions.append(tbl_column == sa.bindparam(business_key_column))
break

self.delete_stmt = self.tbl.delete().where(sa.and_(*conditions))
self.upsert_stmt = self.generate_upsert_stmt()

async def run(self, data: List[Dict[str, Any]]) -> BlockResult:
"""Runs the block with provided data and return the result."""
logger.debug(f"Running {self.get_block_name()}")
rejected_records: List[Result] = []

Expand Down Expand Up @@ -185,17 +168,24 @@ def generate_upsert_stmt(self) -> Any:
))

def execute(self, statement: Any, records: List[Dict[str, Any]]):
"""Executes a SQL statement with given records."""
if isinstance(statement, str):
statement = text(statement)

logger.debug(f"Executing {statement} on {records}")
try:
if isinstance(statement, str):
statement = text(statement)
logger.debug(f"Executing {statement} on {records}")
self.connection.execute(statement, records)
if not self.connection._is_autocommit_isolation():
self.connection.commit()
except (OperationalError, PendingRollbackError, DatabaseError) as e:
self._handle_connection_error(e)
with self.engine.connect() as connection:
try:
connection.execute(statement, records)
if not connection._is_autocommit_isolation():
connection.commit()
except Exception:
raise
except Exception as e:
raise ConnectionError(e) from e

def execute_upsert(self, records: List[Dict[str, Any]]):
"""Upserts records into the table."""
if records:
logger.debug(f"Upserting {len(records)} record(s) to {self.table} table")
records_to_upsert = []
Expand All @@ -206,6 +196,7 @@ def execute_upsert(self, records: List[Dict[str, Any]]):
self.execute(self.upsert_stmt, records_to_upsert)

def execute_delete(self, records: List[Dict[str, Any]]):
"""Deletes records from the table."""
if records:
logger.debug(f"Deleting {len(records)} record(s) from {self.table} table")
records_to_delete = []
Expand All @@ -216,21 +207,10 @@ def execute_delete(self, records: List[Dict[str, Any]]):
self.execute(self.delete_stmt, records_to_delete)

def stop(self):
self.dispose_engine()
"""Disposes of the engine and cleans up resources."""
with suppress(Exception):
if self.engine:
self.engine.dispose()

def _is_connection_valid(self) -> bool:
"""Checks if the current database connection is still valid."""
try:
# Execute a simple query to check if the connection is still valid
self.connection.scalar(select(1))
return True
except (OperationalError, PendingRollbackError, DatabaseError):
return False

def _handle_connection_error(self, error: Exception):
"""Handles connection errors by disposing the engine if necessary and raising ConnectionError."""
if self.connection is not None and not self._is_connection_valid():
self.dispose_engine()
raise ConnectionError(error)
else:
raise
for attr in self._engine_fields:
setattr(self, attr, None)

0 comments on commit a440f8f

Please sign in to comment.