Skip to content

Commit

Permalink
it works for a test!
Browse files Browse the repository at this point in the history
  • Loading branch information
kalombos committed May 7, 2024
1 parent c448560 commit 6889227
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 10 deletions.
36 changes: 26 additions & 10 deletions peewee_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,16 +570,28 @@ async def make_for_all_rows(cls, cursor, query):

class Transaction:

def __init__(self, _connection_context: ConnectionContext):
self.connection_context = _connection_context
self.savepoint = None
def __init__(self, connection, is_savepoint=False):
self.connection = connection
if is_savepoint:
self.savepoint = f"PWASYNC__{uuid.uuid4().hex}"
else:
self.savepoint = None

@property
def name(self):
if self.savepoint:
return self.savepoint
return "begin"

async def execute(self, sql):
async with self.connection.cursor() as cursor:
await cursor.execute(sql)

async def begin(self):
sql = "BEGIN"
if self.connection_context.has_transactions():
self.savepoint = f"AIODB__{uuid.uuid4().hex}"
if self.savepoint:
sql = f"SAVEPOINT {self.savepoint}"
return await self.connection_context.connection.execute(sql)
return await self.execute(sql)

async def __aenter__(self):
await self.begin()
Expand All @@ -596,13 +608,13 @@ async def commit(self):
sql = "COMMIT"
if self.savepoint:
sql = f"RELEASE SAVEPOINT {self.savepoint}"
return await self.connection_context.connection.execute(sql)
return await self.execute(sql)

async def rollback(self):
sql = "ROLLBACK"
if self.savepoint:
sql = f"ROLLBACK TO SAVEPOINT {self.savepoint}"
return await self.connection_context.connection.execute(sql)
return await self.execute(sql)


class ConnectionContextManager:
Expand All @@ -614,6 +626,7 @@ async def __aenter__(self):
if self.connection_context is not None:
if self.connection_context.has_transactions() is False:
raise Exception("Connection already open")
connection = self.connection_context.connection
else:
connection = await self.aio_pool.acquire()
self.connection_context = ConnectionContext(connection)
Expand All @@ -625,12 +638,15 @@ async def __aexit__(self, *args):
self.aio_pool.release(self.connection_context.connection)
connection_context.set(None)


class TransactionContextManager(ConnectionContextManager):
async def __aenter__(self):
connection = await super().__aenter__()
self.transaction = Transaction(connection)
self.transaction = Transaction(
connection, is_savepoint=self.connection_context.has_transactions()
)
await self.transaction.__aenter__()
self.connection_context.transactions.append(self)
self.connection_context.transactions.append(self.transaction.name)
return connection

async def __aexit__(self, exc_type, exc_value, exc_tb):
Expand Down
3 changes: 3 additions & 0 deletions tests/test_transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ async def test_atomic_success(manager):
async with manager.database.aio_atomic():
obj.text = 'BAR'
await manager.update(obj)
async with manager.database.aio_atomic():
obj.text = 'BAR'
await manager.update(obj)

res = await manager.get(TestModel, id=obj_id)
assert res.text == 'BAR'
Expand Down

0 comments on commit 6889227

Please sign in to comment.