diff --git a/peewee_async.py b/peewee_async.py index 009eac1..5a0c53a 100644 --- a/peewee_async.py +++ b/peewee_async.py @@ -78,10 +78,8 @@ class ConnectionContext: def __init__(self, connection): self.connection = connection - self.transactions = [] - - def has_transactions(self): - return len(self.transactions) > 0 + # needs for to know whether begin a transaction or create a savepoint + self.transaction_is_opened = False connection_context: ContextVar[Optional[ConnectionContext]] = ContextVar("connection_context", default=None) @@ -185,10 +183,8 @@ def __init__(self, connection, is_savepoint=False): self.savepoint = None @property - def name(self): - if self.savepoint: - return self.savepoint - return "begin" + def is_savepoint(self): + return self.savepoint is not None async def execute(self, sql): async with self.connection.cursor() as cursor: @@ -248,16 +244,22 @@ async def __aexit__(self, *args): class TransactionContextManager(ConnectionContextManager): async def __aenter__(self): connection = await super().__aenter__() - self.transaction = Transaction( - connection, is_savepoint=self.connection_context.has_transactions() - ) + begin_transaction = self.connection_context.transaction_is_opened is False + + self.transaction = Transaction(connection, is_savepoint=begin_transaction is False) await self.transaction.__aenter__() - self.connection_context.transactions.append(self.transaction.name) + + if begin_transaction is True: + self.connection_context.transaction_is_opened = True return connection async def __aexit__(self, exc_type, exc_value, exc_tb): await self.transaction.__aexit__(exc_type, exc_value, exc_tb) - self.connection_context.transactions.pop() + + end_transaction = self.transaction.is_savepoint is False + if end_transaction is True: + self.connection_context.transaction_is_opened = False + await super().__aexit__()