Skip to content

Commit

Permalink
fix(ingest/datahub): Use server side cursor instead of local one (dat…
Browse files Browse the repository at this point in the history
  • Loading branch information
treff7es authored and sleeperdeep committed Dec 17, 2024
1 parent 76ac477 commit 149b840
Showing 1 changed file with 41 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,47 @@ def query(self) -> str:
version
"""

def execute_server_cursor(
self, query: str, params: Dict[str, Any]
) -> Iterable[Dict[str, Any]]:
with self.engine.connect() as conn:
if self.engine.dialect.name == "postgresql":
with conn.begin(): # Transaction required for PostgreSQL server-side cursor
conn = conn.execution_options(
stream_results=True,
yield_per=self.config.database_query_batch_size,
)
result = conn.execute(query, params)
for row in result:
yield dict(row)
elif self.engine.dialect.name == "mysql": # MySQL
import MySQLdb

with contextlib.closing(
conn.connection.cursor(MySQLdb.cursors.SSCursor)
) as cursor:
logger.debug(f"Using Cursor type: {cursor.__class__.__name__}")
cursor.execute(query, params)

columns = [desc[0] for desc in cursor.description]
while True:
rows = cursor.fetchmany(self.config.database_query_batch_size)
if not rows:
break # Use break instead of return in generator
for row in rows:
yield dict(zip(columns, row))
else:
raise ValueError(f"Unsupported dialect: {self.engine.dialect.name}")

def _get_rows(
self, from_createdon: datetime, stop_time: datetime
) -> Iterable[Dict[str, Any]]:
params = {
"exclude_aspects": list(self.config.exclude_aspects),
"since_createdon": from_createdon.strftime(DATETIME_FORMAT),
}
yield from self.execute_server_cursor(self.query, params)

def get_aspects(
self, from_createdon: datetime, stop_time: datetime
) -> Iterable[Tuple[MetadataChangeProposalWrapper, datetime]]:
Expand All @@ -159,27 +200,6 @@ def get_aspects(
if mcp:
yield mcp, row["createdon"]

def _get_rows(
self, from_createdon: datetime, stop_time: datetime
) -> Iterable[Dict[str, Any]]:
with self.engine.connect() as conn:
with contextlib.closing(conn.connection.cursor()) as cursor:
cursor.execute(
self.query,
{
"exclude_aspects": list(self.config.exclude_aspects),
"since_createdon": from_createdon.strftime(DATETIME_FORMAT),
},
)

columns = [desc[0] for desc in cursor.description]
while True:
rows = cursor.fetchmany(self.config.database_query_batch_size)
if not rows:
return
for row in rows:
yield dict(zip(columns, row))

def get_soft_deleted_rows(self) -> Iterable[Dict[str, Any]]:
"""
Fetches all soft-deleted entities from the database.
Expand Down

0 comments on commit 149b840

Please sign in to comment.