Skip to content

Commit

Permalink
feat: update adapter interface stub (#12)
Browse files Browse the repository at this point in the history
  • Loading branch information
wu-clan authored Jan 6, 2024
1 parent ce8bd5e commit cf53210
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 21 deletions.
40 changes: 29 additions & 11 deletions casbin_async_sqlalchemy_adapter/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from typing import List

from casbin import persist
from casbin.persist.adapters.asyncio import AsyncAdapter
from sqlalchemy import Column, Integer, String, delete
from sqlalchemy import or_
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
Expand Down Expand Up @@ -59,7 +60,7 @@ class Filter:
v5 = []


class Adapter(persist.Adapter):
class Adapter(AsyncAdapter):
"""the interface for Casbin adapters."""

def __init__(self, engine, db_class=None, filtered=False, warning=True):
Expand All @@ -72,12 +73,21 @@ def __init__(self, engine, db_class=None, filtered=False, warning=True):
db_class = CasbinRule
if warning:
warnings.warn(
'Using default CasbinRule table, please note the use of the "Adapter().create_table()" method to '
'create the table, and ignore this warning if you are using a custom CasbinRule table.',
"Using default CasbinRule table, please note the use of the 'Adapter().create_table()' method"
" to create the table, and ignore this warning if you are using a custom CasbinRule table.",
RuntimeWarning,
)
else:
for attr in ("id", "ptype", "v0", "v1", "v2", "v3", "v4", "v5"): # id attr was used by filter
for attr in (
"id",
"ptype",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
): # id attr was used by filter
if not hasattr(db_class, attr):
raise Exception(f"{attr} not found in custom DatabaseClass.")
Base.metadata = db_class.metadata
Expand Down Expand Up @@ -124,7 +134,7 @@ async def load_filtered_policy(self, model, filter) -> None:
for line in result.scalars():
persist.load_policy_line(str(line), model)
self._filtered = True

def filter_query(self, stmt, filter):
for attr in ("ptype", "v0", "v1", "v2", "v3", "v4", "v5"):
if len(getattr(filter, attr)) > 0:
Expand Down Expand Up @@ -204,7 +214,9 @@ async def remove_filtered_policy(self, sec, ptype, field_index, *field_values):

return True if r.rowcount > 0 else False

async def update_policy(self, sec: str, ptype: str, old_rule: List[str], new_rule: List[str]) -> None:
async def update_policy(
self, sec: str, ptype: str, old_rule: List[str], new_rule: List[str]
) -> None:
"""
Update the old_rule with the new_rule in the database (storage).
Expand Down Expand Up @@ -236,7 +248,13 @@ async def update_policy(self, sec: str, ptype: str, old_rule: List[str], new_rul
else:
setattr(old_rule_line, "v{}".format(index), None)

async def update_policies(self, sec: str, ptype: str, old_rules: List[List[str]], new_rules: List[List[str]]) -> None:
async def update_policies(
self,
sec: str,
ptype: str,
old_rules: List[List[str]],
new_rules: List[List[str]],
) -> None:
"""
Update the old_rules with the new_rules in the database (storage).
Expand All @@ -250,7 +268,9 @@ async def update_policies(self, sec: str, ptype: str, old_rules: List[List[str]]
for i in range(len(old_rules)):
await self.update_policy(sec, ptype, old_rules[i], new_rules[i])

async def update_filtered_policies(self, sec, ptype, new_rules: List[List[str]], field_index, *field_values) -> List[List[str]]:
async def update_filtered_policies(
self, sec, ptype, new_rules: List[List[str]], field_index, *field_values
) -> List[List[str]]:
"""update_filtered_policies updates all the policies on the basis of the filter."""

filter = Filter()
Expand All @@ -271,9 +291,7 @@ async def _update_filtered_policies(self, new_rules, filter) -> List[List[str]]:
async with self._session_scope() as session:
# Load old policies

stmt = select(self._db_class).where(
self._db_class.ptype == filter.ptype
)
stmt = select(self._db_class).where(self._db_class.ptype == filter.ptype)
filtered_stmt = self.filter_query(stmt, filter)
result = await session.execute(filtered_stmt)
old_rules = result.scalars().all()
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
SQLAlchemy>=1.4.0
casbin>=1.23.0
casbin>=1.34.0
19 changes: 10 additions & 9 deletions tests/test_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@

import casbin
from sqlalchemy import Column, Integer, String, select
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
from sqlalchemy.orm import sessionmaker
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker

from casbin_async_sqlalchemy_adapter import Adapter
from casbin_async_sqlalchemy_adapter import Base
Expand All @@ -39,7 +38,7 @@ async def get_enforcer():
adapter = Adapter(engine)
await adapter.create_table()

async_session = sessionmaker(engine, expire_on_commit=False, class_=AsyncSession)
async_session = async_sessionmaker(engine, expire_on_commit=False, class_=AsyncSession)
async with async_session() as s:
s.add(CasbinRule(ptype="p", v0="alice", v1="data1", v2="read"))
s.add(CasbinRule(ptype="p", v0="bob", v1="data2", v2="write"))
Expand Down Expand Up @@ -72,7 +71,7 @@ class CustomRule(Base):
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)

session = sessionmaker(engine, expire_on_commit=False, class_=AsyncSession)
session = async_sessionmaker(engine, expire_on_commit=False, class_=AsyncSession)
async with session() as s:
s.add(CustomRule(not_exist="NotNone"))
await s.commit()
Expand Down Expand Up @@ -122,7 +121,7 @@ async def test_save_policy(self):

async def test_remove_policy(self):
e = await get_enforcer()

self.assertFalse(e.enforce("alice", "data5", "read"))
await e.add_permission_for_user("alice", "data5", "read")
self.assertTrue(e.enforce("alice", "data5", "read"))
Expand All @@ -137,7 +136,9 @@ async def test_remove_policies(self):
await e.add_policies((("alice", "data5", "read"), ("alice", "data6", "read")))
self.assertTrue(e.enforce("alice", "data5", "read"))
self.assertTrue(e.enforce("alice", "data6", "read"))
await e.remove_policies((("alice", "data5", "read"), ("alice", "data6", "read")))
await e.remove_policies(
(("alice", "data5", "read"), ("alice", "data6", "read"))
)
self.assertFalse(e.enforce("alice", "data5", "read"))
self.assertFalse(e.enforce("alice", "data6", "read"))

Expand Down Expand Up @@ -180,7 +181,7 @@ async def test_repr(self):
self.assertEqual(repr(rule), '<CasbinRule None: "p, alice, data1, read">')
engine = create_async_engine("sqlite+aiosqlite://", future=True)

session = sessionmaker(engine, expire_on_commit=False, class_=AsyncSession)
session = async_sessionmaker(engine, expire_on_commit=False, class_=AsyncSession)
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
s = session()
Expand Down Expand Up @@ -370,7 +371,7 @@ async def test_update_policies(self):

self.assertFalse(e.enforce("data2_admin", "data2", "write"))
self.assertTrue(e.enforce("data2_admin", "data_test", "write"))

async def test_update_filtered_policies(self):
e = await get_enforcer()

Expand All @@ -392,5 +393,5 @@ async def test_update_filtered_policies(self):
self.assertTrue(e.enforce("bob", "data2", "read"))


if __name__ == '__main__':
if __name__ == "__main__":
unittest.main()

0 comments on commit cf53210

Please sign in to comment.