Skip to content

Commit

Permalink
feat: support BatchAdapter and UpdateAdapter interfaces (#6)
Browse files Browse the repository at this point in the history
  • Loading branch information
siuhui authored Apr 6, 2024
1 parent 328d63b commit 2065a79
Show file tree
Hide file tree
Showing 3 changed files with 296 additions and 2 deletions.
119 changes: 118 additions & 1 deletion casbin_async_redis_adapter/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ async def _delete_policy_lines(self, ptype, rule):
await self.client.lrem(self.key, 0, json.dumps(line.dict()))

async def save_policy(self, model) -> bool:
"""Implement add Interface for casbin. Save the policy in mongodb
"""Implement add Interface for casbin. Save the policy in redis
Args:
model (Class Model): Casbin Model which loads from .conf file usually.
Expand Down Expand Up @@ -131,6 +131,21 @@ async def add_policy(self, sec, ptype, rule):
await self._save_policy_line(ptype, rule)
return True

async def add_policies(self, sec, ptype, rules):
"""AddPolicies adds policy rules to the storage.
Args:
sec (str): Section name, 'g' or 'p'
ptype (str): Policy type, 'g', 'g2', 'p', etc.
rules: Casbin rules will be added
Returns:
bool: True if succeed else False
"""
for rule in rules:
await self.add_policy(sec, ptype, rule)
return True

async def remove_policy(self, sec, ptype, rule):
"""Remove policy rules in redis(rules duplicate will all be removed)
Expand All @@ -145,6 +160,21 @@ async def remove_policy(self, sec, ptype, rule):
await self._delete_policy_lines(ptype, rule)
return True

async def remove_policies(self, sec, ptype, rules):
"""RemovePolicies removes policy rules from the storage.
Args:
sec (str): Section name, 'g' or 'p'
ptype (str): Policy type, 'g', 'g2', 'p', etc.
rules: Casbin rules will be removed
Returns:
bool: True if succeed else False
"""
for rule in rules:
await self.remove_policy(sec, ptype, rule)
return True

async def remove_filtered_policy(self, sec, ptype, field_index, *field_values):
"""Remove policy rules that match the filter from the storage.
This is part of the Auto-Save feature.
Expand Down Expand Up @@ -183,3 +213,90 @@ async def remove_filtered_policy(self, sec, ptype, field_index, *field_values):

await self.client.lrem(self.key, 0, "__CASBIN_DELETED__")
return True

async def update_policy(self, sec, ptype, old_rule, new_rule):
"""
update_policy updates a policy rule from storage.
This is part of the Auto-Save feature.
Args:
sec (str): Section name, 'g' or 'p'
ptype (str): Policy type, 'g', 'g2', 'p', etc.
old_rule: Casbin rule if it is exactly same as will be removed.
new_rule: Casbin rule if it is exactly same as will be added.
Returns:
bool: True if succeed else False
"""
old_rule_obj = CasbinRule(ptype=ptype)
new_rule_obj = CasbinRule(ptype=ptype)
for index, value in enumerate(old_rule):
setattr(old_rule_obj, f"v{index}", value)
for index, value in enumerate(new_rule):
setattr(new_rule_obj, f"v{index}", value)

# Convert old_rule_obj and new_rule_obj to json
old_rule_json = json.dumps(old_rule_obj.dict())
new_rule_json = json.dumps(new_rule_obj.dict())

lua_script = """
local old_rule_json = ARGV[1]
local new_rule_json = ARGV[2]
local rules = redis.call('lrange', KEYS[1], 0, -1)
for i, rule_json in ipairs(rules) do
local rule = cjson.decode(rule_json)
if rule.ptype == ARGV[3] and rule_json == old_rule_json then
redis.call('lset', KEYS[1], i-1, new_rule_json)
return 1
end
end
return 0
"""

result = await self.client.eval(
lua_script, 1, self.key, old_rule_json, new_rule_json, ptype
)

return result == 1

async def update_policies(self, sec, ptype, old_rules, new_rules):
"""
UpdatePolicies updates some policy rules to storage, like db, redis.
Args:
sec (str): Section name, 'g' or 'p'
ptype (str): Policy type, 'g', 'g2', 'p', etc.
old_rules: Casbin rule if it is exactly same as will be removed.
new_rules: Casbin rule if it is exactly same as will be added.
Returns:
bool: True if succeed else False
"""
for i in range(len(old_rules)):
await self.update_policy(sec, ptype, old_rules[i], new_rules[i])
return True

async def update_filtered_policies(
self, sec, ptype, new_rules, field_index, *field_values
):
"""
update_filtered_policies deletes old rules and adds new rules.
Args:
sec (str): Section name, 'g' or 'p'
ptype (str): Policy type, 'g', 'g2', 'p', etc.
new_rules: Casbin rule if it is exactly same as will be added.
field_index (int): The policy index at which the filed_values begins filtering. Its range is [0, 5]
field_values(List[str]): A list of rules to filter policy which starts from
Returns:
bool: True if succeed else False
"""
if not (0 <= field_index <= 5):
return False
if not (1 <= field_index + len(field_values) <= 6):
return False

await self.remove_filtered_policy(sec, ptype, field_index, *field_values)
await self.add_policies(sec, ptype, new_rules)
return True
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
casbin>=1.34.0
redis>=5.0.0
redis>=5.0.0
177 changes: 177 additions & 0 deletions tests/test_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,36 @@ async def test_add_policy(self):
self.assertTrue(e.enforce("alice", "data2", "read"))
self.assertTrue(e.enforce("alice", "data2", "write"))

async def test_add_policies(self):
"""
test add_policies
"""
e = await get_enforcer()
adapter = e.get_adapter()
self.assertTrue(e.enforce("alice", "data1", "read"))
self.assertFalse(e.enforce("alice", "data1", "write"))
self.assertFalse(e.enforce("bob", "data2", "read"))
self.assertTrue(e.enforce("bob", "data2", "write"))
self.assertTrue(e.enforce("alice", "data2", "read"))
self.assertTrue(e.enforce("alice", "data2", "write"))

# test add_policies after insert 2 rules
await adapter.add_policies(
sec="p",
ptype="p",
rules=(("alice", "data1", "write"), ("bob", "data2", "read")),
)

# reload policies from database
await e.load_policy()

self.assertTrue(e.enforce("alice", "data1", "read"))
self.assertTrue(e.enforce("alice", "data1", "write"))
self.assertTrue(e.enforce("bob", "data2", "read"))
self.assertTrue(e.enforce("bob", "data2", "write"))
self.assertTrue(e.enforce("alice", "data2", "read"))
self.assertTrue(e.enforce("alice", "data2", "write"))

async def test_remove_policy(self):
"""
test remove_policy
Expand Down Expand Up @@ -129,6 +159,38 @@ async def test_remove_policy(self):
self.assertFalse(e.enforce("alice", "data2", "write"))
self.assertTrue(result)

async def test_remove_policies(self):
"""
test remove_policies
"""
e = await get_enforcer()
adapter = e.get_adapter()

self.assertFalse(e.enforce("alice", "data3", "write"))
self.assertFalse(e.enforce("alice", "data3", "read"))

await adapter.add_policies(
sec="p",
ptype="p",
rules=(("alice", "data3", "write"), ("alice", "data3", "read")),
)

await e.load_policy()
self.assertTrue(e.enforce("alice", "data3", "write"))
self.assertTrue(e.enforce("alice", "data3", "read"))

# test remove_policies after delete delete 2 rules
result = await adapter.remove_policies(
sec="p",
ptype="p",
rules=(("alice", "data3", "read"), ("alice", "data3", "write")),
)

await e.load_policy()
self.assertFalse(e.enforce("alice", "data3", "write"))
self.assertFalse(e.enforce("alice", "data3", "read"))
self.assertTrue(result)

async def test_remove_policy_no_remove_when_rule_is_incomplete(self):
adapter = Adapter("localhost", 6379)
e = casbin.AsyncEnforcer(get_fixture("rbac_with_resources_roles.conf"), adapter)
Expand Down Expand Up @@ -213,6 +275,121 @@ async def test_remove_filtered_policy(self):
self.assertFalse(e.enforce("alice", "data2", "read"))
self.assertFalse(e.enforce("alice", "data2", "write"))

async def test_update_policy(self):
"""
test update_policy
"""
e = await get_enforcer()
adapter = e.get_adapter()
self.assertTrue(e.enforce("alice", "data1", "read"))
self.assertFalse(e.enforce("alice", "data1", "write"))
self.assertFalse(e.enforce("bob", "data2", "read"))
self.assertTrue(e.enforce("bob", "data2", "write"))
self.assertTrue(e.enforce("alice", "data2", "read"))
self.assertTrue(e.enforce("alice", "data2", "write"))

# test update_policy after update a rule
result = await adapter.update_policy(
sec="p",
ptype="p",
old_rule=("bob", "data2", "write"),
new_rule=("bob", "data1", "write"),
)

# reload policies from database
await e.load_policy()

self.assertTrue(e.enforce("alice", "data1", "read"))
self.assertFalse(e.enforce("alice", "data1", "write"))
self.assertFalse(e.enforce("bob", "data2", "read"))
self.assertFalse(e.enforce("bob", "data2", "write"))
self.assertTrue(e.enforce("bob", "data1", "write"))
self.assertTrue(e.enforce("alice", "data2", "read"))
self.assertTrue(e.enforce("alice", "data2", "write"))
self.assertTrue(result)

async def test_update_policies(self):
"""
test update_policies
"""
e = await get_enforcer()
adapter = e.get_adapter()
self.assertFalse(e.enforce("alice", "data3", "write"))
self.assertFalse(e.enforce("alice", "data3", "read"))

await adapter.add_policies(
sec="p",
ptype="p",
rules=(("alice", "data3", "write"), ("alice", "data3", "read")),
)

await e.load_policy()
self.assertTrue(e.enforce("alice", "data3", "write"))
self.assertTrue(e.enforce("alice", "data3", "read"))

# test update_policies after update 2 rules
result = await adapter.update_policies(
sec="p",
ptype="p",
old_rules=(("alice", "data3", "write"), ("alice", "data3", "read")),
new_rules=(("alice", "data4", "write"), ("alice", "data4", "read")),
)

await e.load_policy()
self.assertFalse(e.enforce("alice", "data3", "write"))
self.assertFalse(e.enforce("alice", "data3", "read"))
self.assertTrue(e.enforce("alice", "data4", "write"))
self.assertTrue(e.enforce("alice", "data4", "read"))
self.assertTrue(result)

async def test_update_filtered_policies(self):
"""
test update_filtered_policies
"""
e = await get_enforcer()
adapter = e.get_adapter()
self.assertFalse(e.enforce("alice", "data3", "write"))
self.assertFalse(e.enforce("alice", "data3", "read"))

await adapter.add_policies(
sec="p",
ptype="p",
rules=(("alice", "data3", "write"), ("alice", "data3", "read")),
)

await e.load_policy()
self.assertTrue(e.enforce("alice", "data3", "write"))
self.assertTrue(e.enforce("alice", "data3", "read"))

# test update_filtered_policies
result = await adapter.remove_filtered_policy(
"g", "g", 6, "alice", "data2_admin"
)
await e.load_policy()
self.assertFalse(result)

result = await adapter.remove_filtered_policy(
"g", "g", 0, *[f"v{i}" for i in range(7)]
)
await e.load_policy()
self.assertFalse(result)

result = await adapter.update_filtered_policies(
"p",
"p",
(("alice", "data4", "write"), ("alice", "data4", "read")),
0,
"alice",
"data3",
)

await e.load_policy()
self.assertFalse(e.enforce("alice", "data3", "write"))
self.assertFalse(e.enforce("alice", "data3", "read"))
self.assertTrue(e.enforce("alice", "data4", "write"))
self.assertTrue(e.enforce("alice", "data4", "read"))
self.assertTrue(result)

def test_str(self):
"""
test __str__ function
Expand Down

0 comments on commit 2065a79

Please sign in to comment.