From cc16cd23de8e7155928c85ca05b21e4a133375ed Mon Sep 17 00:00:00 2001 From: Jonas Carson Date: Fri, 28 Apr 2023 08:46:48 -0700 Subject: [PATCH] Fix bug where when we went from any partner groups to none the group was not removed. Add tests --- gregor_django/users/adapters.py | 34 ++++++++++++---------- gregor_django/users/tests/test_adapters.py | 6 ++++ 2 files changed, 24 insertions(+), 16 deletions(-) diff --git a/gregor_django/users/adapters.py b/gregor_django/users/adapters.py index 183e2f02..22184279 100644 --- a/gregor_django/users/adapters.py +++ b/gregor_django/users/adapters.py @@ -37,12 +37,13 @@ def update_user_name(self, user, extra_data: Dict): def update_user_partner_groups(self, user, extra_data: Dict): partner_groups = extra_data.get("partner_group", []) logger.debug(f"partner groups: {partner_groups} for user {user}") + partner_group_object_list = [] if partner_groups: if not isinstance(partner_groups, list): raise ImproperlyConfigured( "sociallogin.extra_data.partner_groups should be None or a list" ) - partner_group_object_list = [] + for pg_name in partner_groups: try: pg = PartnerGroup.objects.get(short_name=pg_name) @@ -71,24 +72,25 @@ def update_user_partner_groups(self, user, extra_data: Dict): f"partner_groups user: {user} rc: {pg}" ) - for existing_pg in user.partner_groups.all(): - if existing_pg not in partner_group_object_list: - user.partner_groups.remove(existing_pg) - logger.info( - "[SocialAccountAdapter:update_user_partner_groups] " - f"removing pg {existing_pg} for user {user}" - ) + for existing_pg in user.partner_groups.all(): + if existing_pg not in partner_group_object_list: + user.partner_groups.remove(existing_pg) + logger.info( + "[SocialAccountAdapter:update_user_partner_groups] " + f"removing pg {existing_pg} for user {user}" + ) def update_user_research_centers(self, user, extra_data: Dict): # Get list of research centers in domain table research_center_or_site = extra_data.get("research_center_or_site", []) + research_center_object_list = [] if research_center_or_site: if not isinstance(research_center_or_site, list): raise ImproperlyConfigured( "sociallogin.extra_data.research_center_or_site should be a list" ) - research_center_object_list = [] + for rc_name in research_center_or_site: try: # For transition from passed full name to short name support both @@ -118,13 +120,13 @@ def update_user_research_centers(self, user, extra_data: Dict): f"research_centers user: {user} rc: {rc}" ) - for existing_rc in user.research_centers.all(): - if existing_rc not in research_center_object_list: - user.research_centers.remove(existing_rc) - logger.info( - "[SocialAccountAdatpter:update_user_research_centers] " - f"removing rc {existing_rc} for user {user}" - ) + for existing_rc in user.research_centers.all(): + if existing_rc not in research_center_object_list: + user.research_centers.remove(existing_rc) + logger.info( + "[SocialAccountAdatpter:update_user_research_centers] " + f"removing rc {existing_rc} for user {user}" + ) def update_user_groups(self, user, extra_data: Dict): managed_scope_status = extra_data.get("managed_scope_status") diff --git a/gregor_django/users/tests/test_adapters.py b/gregor_django/users/tests/test_adapters.py index 8a7fb238..7305435e 100644 --- a/gregor_django/users/tests/test_adapters.py +++ b/gregor_django/users/tests/test_adapters.py @@ -104,6 +104,9 @@ def test_update_user_research_centers_remove(self): assert user.research_centers.filter(pk=rc1.pk).exists() assert user.research_centers.all().count() == 1 + adapter.update_user_research_centers(user, dict(research_center_or_site=None)) + assert user.research_centers.all().count() == 0 + def test_update_research_centers_malformed(self): adapter = SocialAccountAdapter() user = UserFactory() @@ -178,6 +181,9 @@ def test_update_user_partner_groups_remove(self): assert user.partner_groups.filter(pk=pg1.pk).exists() assert user.partner_groups.all().count() == 1 + adapter.update_user_partner_groups(user, dict(partner_group=None)) + assert user.partner_groups.all().count() == 0 + def test_update_partner_groups_malformed(self): adapter = SocialAccountAdapter() user = UserFactory()