diff --git a/casbin_async_sqlalchemy_adapter/adapter.py b/casbin_async_sqlalchemy_adapter/adapter.py index d7a76d3..e7f76de 100644 --- a/casbin_async_sqlalchemy_adapter/adapter.py +++ b/casbin_async_sqlalchemy_adapter/adapter.py @@ -16,6 +16,7 @@ from typing import List from casbin import persist +from casbin.persist.adapters.asyncio import AsyncAdapter from sqlalchemy import Column, Integer, String, delete from sqlalchemy import or_ from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession @@ -59,7 +60,7 @@ class Filter: v5 = [] -class Adapter(persist.Adapter): +class Adapter(AsyncAdapter): """the interface for Casbin adapters.""" def __init__(self, engine, db_class=None, filtered=False, warning=True): @@ -72,12 +73,21 @@ def __init__(self, engine, db_class=None, filtered=False, warning=True): db_class = CasbinRule if warning: warnings.warn( - 'Using default CasbinRule table, please note the use of the "Adapter().create_table()" method to ' - 'create the table, and ignore this warning if you are using a custom CasbinRule table.', + "Using default CasbinRule table, please note the use of the 'Adapter().create_table()' method" + " to create the table, and ignore this warning if you are using a custom CasbinRule table.", RuntimeWarning, ) else: - for attr in ("id", "ptype", "v0", "v1", "v2", "v3", "v4", "v5"): # id attr was used by filter + for attr in ( + "id", + "ptype", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + ): # id attr was used by filter if not hasattr(db_class, attr): raise Exception(f"{attr} not found in custom DatabaseClass.") Base.metadata = db_class.metadata @@ -124,7 +134,7 @@ async def load_filtered_policy(self, model, filter) -> None: for line in result.scalars(): persist.load_policy_line(str(line), model) self._filtered = True - + def filter_query(self, stmt, filter): for attr in ("ptype", "v0", "v1", "v2", "v3", "v4", "v5"): if len(getattr(filter, attr)) > 0: @@ -204,7 +214,9 @@ async def remove_filtered_policy(self, sec, ptype, field_index, *field_values): return True if r.rowcount > 0 else False - async def update_policy(self, sec: str, ptype: str, old_rule: List[str], new_rule: List[str]) -> None: + async def update_policy( + self, sec: str, ptype: str, old_rule: List[str], new_rule: List[str] + ) -> None: """ Update the old_rule with the new_rule in the database (storage). @@ -236,7 +248,13 @@ async def update_policy(self, sec: str, ptype: str, old_rule: List[str], new_rul else: setattr(old_rule_line, "v{}".format(index), None) - async def update_policies(self, sec: str, ptype: str, old_rules: List[List[str]], new_rules: List[List[str]]) -> None: + async def update_policies( + self, + sec: str, + ptype: str, + old_rules: List[List[str]], + new_rules: List[List[str]], + ) -> None: """ Update the old_rules with the new_rules in the database (storage). @@ -250,7 +268,9 @@ async def update_policies(self, sec: str, ptype: str, old_rules: List[List[str]] for i in range(len(old_rules)): await self.update_policy(sec, ptype, old_rules[i], new_rules[i]) - async def update_filtered_policies(self, sec, ptype, new_rules: List[List[str]], field_index, *field_values) -> List[List[str]]: + async def update_filtered_policies( + self, sec, ptype, new_rules: List[List[str]], field_index, *field_values + ) -> List[List[str]]: """update_filtered_policies updates all the policies on the basis of the filter.""" filter = Filter() @@ -271,9 +291,7 @@ async def _update_filtered_policies(self, new_rules, filter) -> List[List[str]]: async with self._session_scope() as session: # Load old policies - stmt = select(self._db_class).where( - self._db_class.ptype == filter.ptype - ) + stmt = select(self._db_class).where(self._db_class.ptype == filter.ptype) filtered_stmt = self.filter_query(stmt, filter) result = await session.execute(filtered_stmt) old_rules = result.scalars().all() diff --git a/requirements.txt b/requirements.txt index e033a93..720388e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,2 @@ SQLAlchemy>=1.4.0 -casbin>=1.23.0 +casbin>=1.34.0 diff --git a/tests/test_adapter.py b/tests/test_adapter.py index 31d9f98..08d84d9 100644 --- a/tests/test_adapter.py +++ b/tests/test_adapter.py @@ -18,8 +18,7 @@ import casbin from sqlalchemy import Column, Integer, String, select -from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession -from sqlalchemy.orm import sessionmaker +from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker from casbin_async_sqlalchemy_adapter import Adapter from casbin_async_sqlalchemy_adapter import Base @@ -39,7 +38,7 @@ async def get_enforcer(): adapter = Adapter(engine) await adapter.create_table() - async_session = sessionmaker(engine, expire_on_commit=False, class_=AsyncSession) + async_session = async_sessionmaker(engine, expire_on_commit=False, class_=AsyncSession) async with async_session() as s: s.add(CasbinRule(ptype="p", v0="alice", v1="data1", v2="read")) s.add(CasbinRule(ptype="p", v0="bob", v1="data2", v2="write")) @@ -72,7 +71,7 @@ class CustomRule(Base): async with engine.begin() as conn: await conn.run_sync(Base.metadata.create_all) - session = sessionmaker(engine, expire_on_commit=False, class_=AsyncSession) + session = async_sessionmaker(engine, expire_on_commit=False, class_=AsyncSession) async with session() as s: s.add(CustomRule(not_exist="NotNone")) await s.commit() @@ -122,7 +121,7 @@ async def test_save_policy(self): async def test_remove_policy(self): e = await get_enforcer() - + self.assertFalse(e.enforce("alice", "data5", "read")) await e.add_permission_for_user("alice", "data5", "read") self.assertTrue(e.enforce("alice", "data5", "read")) @@ -137,7 +136,9 @@ async def test_remove_policies(self): await e.add_policies((("alice", "data5", "read"), ("alice", "data6", "read"))) self.assertTrue(e.enforce("alice", "data5", "read")) self.assertTrue(e.enforce("alice", "data6", "read")) - await e.remove_policies((("alice", "data5", "read"), ("alice", "data6", "read"))) + await e.remove_policies( + (("alice", "data5", "read"), ("alice", "data6", "read")) + ) self.assertFalse(e.enforce("alice", "data5", "read")) self.assertFalse(e.enforce("alice", "data6", "read")) @@ -180,7 +181,7 @@ async def test_repr(self): self.assertEqual(repr(rule), '') engine = create_async_engine("sqlite+aiosqlite://", future=True) - session = sessionmaker(engine, expire_on_commit=False, class_=AsyncSession) + session = async_sessionmaker(engine, expire_on_commit=False, class_=AsyncSession) async with engine.begin() as conn: await conn.run_sync(Base.metadata.create_all) s = session() @@ -370,7 +371,7 @@ async def test_update_policies(self): self.assertFalse(e.enforce("data2_admin", "data2", "write")) self.assertTrue(e.enforce("data2_admin", "data_test", "write")) - + async def test_update_filtered_policies(self): e = await get_enforcer() @@ -392,5 +393,5 @@ async def test_update_filtered_policies(self): self.assertTrue(e.enforce("bob", "data2", "read")) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main()