From 2065a791120e4563d019ee450c940074012611e4 Mon Sep 17 00:00:00 2001 From: SiuHui <127535046+siuhui@users.noreply.github.com> Date: Sat, 6 Apr 2024 10:43:30 +0800 Subject: [PATCH] feat: support BatchAdapter and UpdateAdapter interfaces (#6) --- casbin_async_redis_adapter/adapter.py | 119 ++++++++++++++++- requirements.txt | 2 +- tests/test_adapter.py | 177 ++++++++++++++++++++++++++ 3 files changed, 296 insertions(+), 2 deletions(-) diff --git a/casbin_async_redis_adapter/adapter.py b/casbin_async_redis_adapter/adapter.py index 57797e5..b6aa69f 100644 --- a/casbin_async_redis_adapter/adapter.py +++ b/casbin_async_redis_adapter/adapter.py @@ -101,7 +101,7 @@ async def _delete_policy_lines(self, ptype, rule): await self.client.lrem(self.key, 0, json.dumps(line.dict())) async def save_policy(self, model) -> bool: - """Implement add Interface for casbin. Save the policy in mongodb + """Implement add Interface for casbin. Save the policy in redis Args: model (Class Model): Casbin Model which loads from .conf file usually. @@ -131,6 +131,21 @@ async def add_policy(self, sec, ptype, rule): await self._save_policy_line(ptype, rule) return True + async def add_policies(self, sec, ptype, rules): + """AddPolicies adds policy rules to the storage. + + Args: + sec (str): Section name, 'g' or 'p' + ptype (str): Policy type, 'g', 'g2', 'p', etc. + rules: Casbin rules will be added + + Returns: + bool: True if succeed else False + """ + for rule in rules: + await self.add_policy(sec, ptype, rule) + return True + async def remove_policy(self, sec, ptype, rule): """Remove policy rules in redis(rules duplicate will all be removed) @@ -145,6 +160,21 @@ async def remove_policy(self, sec, ptype, rule): await self._delete_policy_lines(ptype, rule) return True + async def remove_policies(self, sec, ptype, rules): + """RemovePolicies removes policy rules from the storage. + + Args: + sec (str): Section name, 'g' or 'p' + ptype (str): Policy type, 'g', 'g2', 'p', etc. + rules: Casbin rules will be removed + + Returns: + bool: True if succeed else False + """ + for rule in rules: + await self.remove_policy(sec, ptype, rule) + return True + async def remove_filtered_policy(self, sec, ptype, field_index, *field_values): """Remove policy rules that match the filter from the storage. This is part of the Auto-Save feature. @@ -183,3 +213,90 @@ async def remove_filtered_policy(self, sec, ptype, field_index, *field_values): await self.client.lrem(self.key, 0, "__CASBIN_DELETED__") return True + + async def update_policy(self, sec, ptype, old_rule, new_rule): + """ + update_policy updates a policy rule from storage. + This is part of the Auto-Save feature. + + Args: + sec (str): Section name, 'g' or 'p' + ptype (str): Policy type, 'g', 'g2', 'p', etc. + old_rule: Casbin rule if it is exactly same as will be removed. + new_rule: Casbin rule if it is exactly same as will be added. + + Returns: + bool: True if succeed else False + """ + old_rule_obj = CasbinRule(ptype=ptype) + new_rule_obj = CasbinRule(ptype=ptype) + for index, value in enumerate(old_rule): + setattr(old_rule_obj, f"v{index}", value) + for index, value in enumerate(new_rule): + setattr(new_rule_obj, f"v{index}", value) + + # Convert old_rule_obj and new_rule_obj to json + old_rule_json = json.dumps(old_rule_obj.dict()) + new_rule_json = json.dumps(new_rule_obj.dict()) + + lua_script = """ + local old_rule_json = ARGV[1] + local new_rule_json = ARGV[2] + local rules = redis.call('lrange', KEYS[1], 0, -1) + for i, rule_json in ipairs(rules) do + local rule = cjson.decode(rule_json) + if rule.ptype == ARGV[3] and rule_json == old_rule_json then + redis.call('lset', KEYS[1], i-1, new_rule_json) + return 1 + end + end + return 0 + """ + + result = await self.client.eval( + lua_script, 1, self.key, old_rule_json, new_rule_json, ptype + ) + + return result == 1 + + async def update_policies(self, sec, ptype, old_rules, new_rules): + """ + UpdatePolicies updates some policy rules to storage, like db, redis. + + Args: + sec (str): Section name, 'g' or 'p' + ptype (str): Policy type, 'g', 'g2', 'p', etc. + old_rules: Casbin rule if it is exactly same as will be removed. + new_rules: Casbin rule if it is exactly same as will be added. + + Returns: + bool: True if succeed else False + """ + for i in range(len(old_rules)): + await self.update_policy(sec, ptype, old_rules[i], new_rules[i]) + return True + + async def update_filtered_policies( + self, sec, ptype, new_rules, field_index, *field_values + ): + """ + update_filtered_policies deletes old rules and adds new rules. + + Args: + sec (str): Section name, 'g' or 'p' + ptype (str): Policy type, 'g', 'g2', 'p', etc. + new_rules: Casbin rule if it is exactly same as will be added. + field_index (int): The policy index at which the filed_values begins filtering. Its range is [0, 5] + field_values(List[str]): A list of rules to filter policy which starts from + + Returns: + bool: True if succeed else False + """ + if not (0 <= field_index <= 5): + return False + if not (1 <= field_index + len(field_values) <= 6): + return False + + await self.remove_filtered_policy(sec, ptype, field_index, *field_values) + await self.add_policies(sec, ptype, new_rules) + return True diff --git a/requirements.txt b/requirements.txt index d4f3367..5ad68f4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,2 @@ casbin>=1.34.0 -redis>=5.0.0 +redis>=5.0.0 \ No newline at end of file diff --git a/tests/test_adapter.py b/tests/test_adapter.py index 974c383..417b98b 100644 --- a/tests/test_adapter.py +++ b/tests/test_adapter.py @@ -100,6 +100,36 @@ async def test_add_policy(self): self.assertTrue(e.enforce("alice", "data2", "read")) self.assertTrue(e.enforce("alice", "data2", "write")) + async def test_add_policies(self): + """ + test add_policies + """ + e = await get_enforcer() + adapter = e.get_adapter() + self.assertTrue(e.enforce("alice", "data1", "read")) + self.assertFalse(e.enforce("alice", "data1", "write")) + self.assertFalse(e.enforce("bob", "data2", "read")) + self.assertTrue(e.enforce("bob", "data2", "write")) + self.assertTrue(e.enforce("alice", "data2", "read")) + self.assertTrue(e.enforce("alice", "data2", "write")) + + # test add_policies after insert 2 rules + await adapter.add_policies( + sec="p", + ptype="p", + rules=(("alice", "data1", "write"), ("bob", "data2", "read")), + ) + + # reload policies from database + await e.load_policy() + + self.assertTrue(e.enforce("alice", "data1", "read")) + self.assertTrue(e.enforce("alice", "data1", "write")) + self.assertTrue(e.enforce("bob", "data2", "read")) + self.assertTrue(e.enforce("bob", "data2", "write")) + self.assertTrue(e.enforce("alice", "data2", "read")) + self.assertTrue(e.enforce("alice", "data2", "write")) + async def test_remove_policy(self): """ test remove_policy @@ -129,6 +159,38 @@ async def test_remove_policy(self): self.assertFalse(e.enforce("alice", "data2", "write")) self.assertTrue(result) + async def test_remove_policies(self): + """ + test remove_policies + """ + e = await get_enforcer() + adapter = e.get_adapter() + + self.assertFalse(e.enforce("alice", "data3", "write")) + self.assertFalse(e.enforce("alice", "data3", "read")) + + await adapter.add_policies( + sec="p", + ptype="p", + rules=(("alice", "data3", "write"), ("alice", "data3", "read")), + ) + + await e.load_policy() + self.assertTrue(e.enforce("alice", "data3", "write")) + self.assertTrue(e.enforce("alice", "data3", "read")) + + # test remove_policies after delete delete 2 rules + result = await adapter.remove_policies( + sec="p", + ptype="p", + rules=(("alice", "data3", "read"), ("alice", "data3", "write")), + ) + + await e.load_policy() + self.assertFalse(e.enforce("alice", "data3", "write")) + self.assertFalse(e.enforce("alice", "data3", "read")) + self.assertTrue(result) + async def test_remove_policy_no_remove_when_rule_is_incomplete(self): adapter = Adapter("localhost", 6379) e = casbin.AsyncEnforcer(get_fixture("rbac_with_resources_roles.conf"), adapter) @@ -213,6 +275,121 @@ async def test_remove_filtered_policy(self): self.assertFalse(e.enforce("alice", "data2", "read")) self.assertFalse(e.enforce("alice", "data2", "write")) + async def test_update_policy(self): + """ + test update_policy + """ + e = await get_enforcer() + adapter = e.get_adapter() + self.assertTrue(e.enforce("alice", "data1", "read")) + self.assertFalse(e.enforce("alice", "data1", "write")) + self.assertFalse(e.enforce("bob", "data2", "read")) + self.assertTrue(e.enforce("bob", "data2", "write")) + self.assertTrue(e.enforce("alice", "data2", "read")) + self.assertTrue(e.enforce("alice", "data2", "write")) + + # test update_policy after update a rule + result = await adapter.update_policy( + sec="p", + ptype="p", + old_rule=("bob", "data2", "write"), + new_rule=("bob", "data1", "write"), + ) + + # reload policies from database + await e.load_policy() + + self.assertTrue(e.enforce("alice", "data1", "read")) + self.assertFalse(e.enforce("alice", "data1", "write")) + self.assertFalse(e.enforce("bob", "data2", "read")) + self.assertFalse(e.enforce("bob", "data2", "write")) + self.assertTrue(e.enforce("bob", "data1", "write")) + self.assertTrue(e.enforce("alice", "data2", "read")) + self.assertTrue(e.enforce("alice", "data2", "write")) + self.assertTrue(result) + + async def test_update_policies(self): + """ + test update_policies + """ + e = await get_enforcer() + adapter = e.get_adapter() + self.assertFalse(e.enforce("alice", "data3", "write")) + self.assertFalse(e.enforce("alice", "data3", "read")) + + await adapter.add_policies( + sec="p", + ptype="p", + rules=(("alice", "data3", "write"), ("alice", "data3", "read")), + ) + + await e.load_policy() + self.assertTrue(e.enforce("alice", "data3", "write")) + self.assertTrue(e.enforce("alice", "data3", "read")) + + # test update_policies after update 2 rules + result = await adapter.update_policies( + sec="p", + ptype="p", + old_rules=(("alice", "data3", "write"), ("alice", "data3", "read")), + new_rules=(("alice", "data4", "write"), ("alice", "data4", "read")), + ) + + await e.load_policy() + self.assertFalse(e.enforce("alice", "data3", "write")) + self.assertFalse(e.enforce("alice", "data3", "read")) + self.assertTrue(e.enforce("alice", "data4", "write")) + self.assertTrue(e.enforce("alice", "data4", "read")) + self.assertTrue(result) + + async def test_update_filtered_policies(self): + """ + test update_filtered_policies + """ + e = await get_enforcer() + adapter = e.get_adapter() + self.assertFalse(e.enforce("alice", "data3", "write")) + self.assertFalse(e.enforce("alice", "data3", "read")) + + await adapter.add_policies( + sec="p", + ptype="p", + rules=(("alice", "data3", "write"), ("alice", "data3", "read")), + ) + + await e.load_policy() + self.assertTrue(e.enforce("alice", "data3", "write")) + self.assertTrue(e.enforce("alice", "data3", "read")) + + # test update_filtered_policies + result = await adapter.remove_filtered_policy( + "g", "g", 6, "alice", "data2_admin" + ) + await e.load_policy() + self.assertFalse(result) + + result = await adapter.remove_filtered_policy( + "g", "g", 0, *[f"v{i}" for i in range(7)] + ) + await e.load_policy() + self.assertFalse(result) + + result = await adapter.update_filtered_policies( + "p", + "p", + (("alice", "data4", "write"), ("alice", "data4", "read")), + 0, + "alice", + "data3", + ) + + await e.load_policy() + self.assertFalse(e.enforce("alice", "data3", "write")) + self.assertFalse(e.enforce("alice", "data3", "read")) + self.assertTrue(e.enforce("alice", "data4", "write")) + self.assertTrue(e.enforce("alice", "data4", "read")) + self.assertTrue(result) + def test_str(self): """ test __str__ function