diff --git a/casbin_async_redis_adapter/adapter.py b/casbin_async_redis_adapter/adapter.py index 57797e5..de8f400 100644 --- a/casbin_async_redis_adapter/adapter.py +++ b/casbin_async_redis_adapter/adapter.py @@ -100,6 +100,10 @@ async def _delete_policy_lines(self, ptype, rule): else: await self.client.lrem(self.key, 0, json.dumps(line.dict())) + async def _update_policy_line(self, ptype, old_rule, new_rule): + # TODO + ... + async def save_policy(self, model) -> bool: """Implement add Interface for casbin. Save the policy in mongodb @@ -131,6 +135,21 @@ async def add_policy(self, sec, ptype, rule): await self._save_policy_line(ptype, rule) return True + async def add_policies(self, sec, ptype, rules): + """Adds policy rules to redis + + Args: + sec (str): Section name, 'g' or 'p' + ptype (str): Policy type, 'g', 'g2', 'p', etc. + rules (Iterable[CasbinRule]): Casbin rules will be added + + Returns: + bool: True if succeed else False + """ + for rule in rules: + await self._save_policy_line(ptype, rule) + return True + async def remove_policy(self, sec, ptype, rule): """Remove policy rules in redis(rules duplicate will all be removed) @@ -145,6 +164,21 @@ async def remove_policy(self, sec, ptype, rule): await self._delete_policy_lines(ptype, rule) return True + async def remove_policies(self, sec, ptype, rules): + """Remove policy rules in redis(rules duplicate will all be removed) + + Args: + sec (str): Section name, 'g' or 'p' + ptype (str): Policy type, 'g', 'g2', 'p', etc. + rules (Iterable[CasbinRule]): Casbin rules if it is exactly same as will be removed. + + Returns: + bool: True if succeed else False + """ + for rule in rules: + await self._delete_policy_lines(ptype, rule) + return True + async def remove_filtered_policy(self, sec, ptype, field_index, *field_values): """Remove policy rules that match the filter from the storage. This is part of the Auto-Save feature. @@ -183,3 +217,34 @@ async def remove_filtered_policy(self, sec, ptype, field_index, *field_values): await self.client.lrem(self.key, 0, "__CASBIN_DELETED__") return True + + async def update_policy(self, sec, ptype, old_rule, new_rule): + """Update policy rule in redis. + + Args: + sec (str): section type + ptype (str): policy type + old_rule (List[str]): old: the old rule that needs to be modified + new_rule (List[str]): new: the new rule to replace the old rule + + Returns: + bool: True if succeed else False + """ + await self._update_policy_line(ptype, old_rule, new_rule) + return True + + async def update_policies(self, sec, ptype, old_rules, new_rules): + """Update policy rule in redis. + + Args: + sec (str): section type + ptype (str): policy type + old_rules (List[List[str]]): old: the old rule that needs to be modified + new_rules (List[List[str]]): new: the new rule to replace the old rule + + Returns: + bool: True if succeed else False + """ + for i in range(len(old_rules)): + await self.update_policy(sec, ptype, old_rules[i], new_rules[i]) + return True diff --git a/tests/test_adapter.py b/tests/test_adapter.py index 974c383..c0bec8d 100644 --- a/tests/test_adapter.py +++ b/tests/test_adapter.py @@ -100,6 +100,33 @@ async def test_add_policy(self): self.assertTrue(e.enforce("alice", "data2", "read")) self.assertTrue(e.enforce("alice", "data2", "write")) + async def test_add_policies(self): + e = await get_enforcer() + adapter = e.get_adapter() + self.assertTrue(e.enforce("alice", "data1", "read")) + self.assertFalse(e.enforce("alice", "data1", "write")) + self.assertFalse(e.enforce("bob", "data2", "read")) + self.assertTrue(e.enforce("bob", "data2", "write")) + self.assertTrue(e.enforce("alice", "data2", "read")) + self.assertTrue(e.enforce("alice", "data2", "write")) + + # test add_policy after insert rules + await adapter.add_policies( + sec="p", + ptype="p", + rules=(("alice", "data1", "write"), ("bob", "data2", "read")), + ) + + # reload policies from database + await e.load_policy() + + self.assertTrue(e.enforce("alice", "data1", "read")) + self.assertTrue(e.enforce("alice", "data1", "write")) + self.assertTrue(e.enforce("bob", "data2", "read")) + self.assertTrue(e.enforce("bob", "data2", "write")) + self.assertTrue(e.enforce("alice", "data2", "read")) + self.assertTrue(e.enforce("alice", "data2", "write")) + async def test_remove_policy(self): """ test remove_policy @@ -129,6 +156,37 @@ async def test_remove_policy(self): self.assertFalse(e.enforce("alice", "data2", "write")) self.assertTrue(result) + async def test_remove_policies(self): + """ + test remove_policy + """ + e = await get_enforcer() + adapter = e.get_adapter() + self.assertTrue(e.enforce("alice", "data1", "read")) + self.assertFalse(e.enforce("alice", "data1", "write")) + self.assertFalse(e.enforce("bob", "data2", "read")) + self.assertTrue(e.enforce("bob", "data2", "write")) + self.assertTrue(e.enforce("alice", "data2", "read")) + self.assertTrue(e.enforce("alice", "data2", "write")) + + # test remove_policy after delete a role definition + result = await adapter.remove_policies( + sec="p", + ptype="p", + rules=(("data2_admin", "data2", "read"), ("data2_admin", "data2", "write")), + ) + + # reload policies from database + await e.load_policy() + + self.assertTrue(e.enforce("alice", "data1", "read")) + self.assertFalse(e.enforce("alice", "data1", "write")) + self.assertFalse(e.enforce("bob", "data2", "read")) + self.assertTrue(e.enforce("bob", "data2", "write")) + self.assertFalse(e.enforce("alice", "data2", "read")) + self.assertFalse(e.enforce("alice", "data2", "write")) + self.assertTrue(result) + async def test_remove_policy_no_remove_when_rule_is_incomplete(self): adapter = Adapter("localhost", 6379) e = casbin.AsyncEnforcer(get_fixture("rbac_with_resources_roles.conf"), adapter)