diff --git a/primed/cdsa/audit/signed_agreement_audit.py b/primed/cdsa/audit/signed_agreement_audit.py index c264271f..03ec421d 100644 --- a/primed/cdsa/audit/signed_agreement_audit.py +++ b/primed/cdsa/audit/signed_agreement_audit.py @@ -4,6 +4,7 @@ import django_tables2 as tables from anvil_consortium_manager.models import ManagedGroup from django.conf import settings +from django.db.models import QuerySet from django.urls import reverse from django.utils.safestring import mark_safe @@ -143,13 +144,24 @@ class SignedAgreementAccessAudit: results_table_class = SignedAgreementAccessAuditTable - def __init__(self): + def __init__(self, signed_agreement_queryset=None): # Store the CDSA group for auditing membership. self.completed = False # Set up lists to hold audit results. self.verified = [] self.needs_action = [] self.errors = [] + # Store the queryset to run the audit on. + if signed_agreement_queryset is None: + signed_agreement_queryset = models.SignedAgreement.objects.all() + if not ( + isinstance(signed_agreement_queryset, QuerySet) + and signed_agreement_queryset.model is models.SignedAgreement + ): + raise ValueError( + "signed_agreement_queryset must be a queryset of SignedAgreement objects." + ) + self.signed_agreement_queryset = signed_agreement_queryset def _audit_primary_agreement(self, signed_agreement): """Audit a single component signed agreement. @@ -338,7 +350,7 @@ def _audit_signed_agreement(self, signed_agreement): def run_audit(self): """Run an audit on all SignedAgreements.""" - for signed_agreement in models.SignedAgreement.objects.all(): + for signed_agreement in self.signed_agreement_queryset: self._audit_signed_agreement(signed_agreement) self.completed = True diff --git a/primed/cdsa/audit/workspace_audit.py b/primed/cdsa/audit/workspace_audit.py index e6ac2fc4..5add31f8 100644 --- a/primed/cdsa/audit/workspace_audit.py +++ b/primed/cdsa/audit/workspace_audit.py @@ -3,6 +3,7 @@ import django_tables2 as tables from anvil_consortium_manager.models import GroupGroupMembership, ManagedGroup from django.conf import settings +from django.db.models import QuerySet from django.urls import reverse from django.utils.safestring import mark_safe @@ -139,7 +140,7 @@ class WorkspaceAccessAudit: results_table_class = WorkspaceAccessAuditTable - def __init__(self): + def __init__(self, cdsa_workspace_queryset=None): # Store the CDSA group for auditing membership. self.anvil_cdsa_group = ManagedGroup.objects.get( name=settings.ANVIL_CDSA_GROUP_NAME @@ -149,6 +150,17 @@ def __init__(self): self.verified = [] self.needs_action = [] self.errors = [] + # Store the queryset to run the audit on. + if cdsa_workspace_queryset is None: + cdsa_workspace_queryset = models.CDSAWorkspace.objects.all() + if not ( + isinstance(cdsa_workspace_queryset, QuerySet) + and cdsa_workspace_queryset.model is models.CDSAWorkspace + ): + raise ValueError( + "cdsa_workspace_queryset must be a queryset of CDSAWorkspace objects." + ) + self.cdsa_workspace_queryset = cdsa_workspace_queryset def _audit_workspace(self, workspace): # Check if the access group is in the overall CDSA group. @@ -233,7 +245,7 @@ def _audit_workspace(self, workspace): def run_audit(self): """Run an audit on all SignedAgreements.""" - for workspace in models.CDSAWorkspace.objects.all(): + for workspace in self.cdsa_workspace_queryset: self._audit_workspace(workspace) self.completed = True diff --git a/primed/cdsa/tests/test_audit.py b/primed/cdsa/tests/test_audit.py index 9a85e2e5..208be9f5 100644 --- a/primed/cdsa/tests/test_audit.py +++ b/primed/cdsa/tests/test_audit.py @@ -93,8 +93,8 @@ def test_anvil_group_name_setting(self): self.assertEqual(instance.anvil_cdsa_group, group) -class SignedAgreementAccessAuditResultTest(TestCase): - """Tests for the SignedAgreementAccessAuditResult class.""" +class SignedAgreementAccessAuditTest(TestCase): + """Tests for the SignedAgreementAccessAudit class.""" def setUp(self): super().setUp() @@ -118,14 +118,68 @@ def test_no_signed_agreements(self): self.assertEqual(len(cdsa_audit.needs_action), 0) self.assertEqual(len(cdsa_audit.errors), 0) - def test_loops_over_signed_agreements(self): - """run_audit loops over all signed agreements.""" + def test_one_signed_agreement(self): + """Audit works when there is one signed agreement.""" + this_agreement = factories.MemberAgreementFactory.create() + cdsa_audit = signed_agreement_audit.SignedAgreementAccessAudit() + cdsa_audit.run_audit() + self.assertEqual(len(cdsa_audit.verified), 0) + self.assertEqual(len(cdsa_audit.needs_action), 1) + self.assertEqual(len(cdsa_audit.errors), 0) + record = cdsa_audit.needs_action[0] + self.assertIsInstance(record, signed_agreement_audit.GrantAccess) + self.assertEqual(record.signed_agreement, this_agreement.signed_agreement) + self.assertEqual(record.note, cdsa_audit.ACTIVE_PRIMARY_AGREEMENT) + + def test_two_signed_agreements(self): + """Audit runs on all signed agreements by default.""" # Create two signed agreements that need to be added to the SAG group. factories.MemberAgreementFactory.create_batch(2) cdsa_audit = signed_agreement_audit.SignedAgreementAccessAudit() cdsa_audit.run_audit() self.assertEqual(len(cdsa_audit.needs_action), 2) + def test_signed_agreement_queryset(self): + """Audit only runs on SignedAgreements in the signed_agreement_queryset.""" + this_agreement = factories.MemberAgreementFactory.create() + factories.MemberAgreementFactory.create() + cdsa_audit = signed_agreement_audit.SignedAgreementAccessAudit( + signed_agreement_queryset=models.SignedAgreement.objects.filter( + pk=this_agreement.signed_agreement.pk + ) + ) + cdsa_audit.run_audit() + self.assertEqual(len(cdsa_audit.verified), 0) + self.assertEqual(len(cdsa_audit.needs_action), 1) + self.assertEqual(len(cdsa_audit.errors), 0) + record = cdsa_audit.needs_action[0] + self.assertIsInstance(record, signed_agreement_audit.GrantAccess) + self.assertEqual(record.signed_agreement, this_agreement.signed_agreement) + self.assertEqual(record.note, cdsa_audit.ACTIVE_PRIMARY_AGREEMENT) + + def test_dbgap_application_queryset_wrong_class(self): + """dbGaPAccessAudit raises error if dbgap_application_queryset has the wrong model class.""" + with self.assertRaises(ValueError) as e: + signed_agreement_audit.SignedAgreementAccessAudit( + signed_agreement_queryset=models.MemberAgreement.objects.all() + ) + self.assertEqual( + str(e.exception), + "signed_agreement_queryset must be a queryset of SignedAgreement objects.", + ) + + def test_dbgap_application_queryset_not_queryset(self): + """dbGaPAccessAudit raises error if dbgap_application_queryset is not a queryset.""" + member_agreement = factories.MemberAgreementFactory.create() + with self.assertRaises(ValueError) as e: + signed_agreement_audit.SignedAgreementAccessAudit( + signed_agreement_queryset=member_agreement.signed_agreement + ) + self.assertEqual( + str(e.exception), + "signed_agreement_queryset must be a queryset of SignedAgreement objects.", + ) + def test_member_primary_in_group(self): """Member primary agreement with valid version in CDSA group.""" this_agreement = factories.MemberAgreementFactory.create() @@ -1478,6 +1532,46 @@ def test_completed(self): cdsa_audit.run_audit() self.assertTrue(cdsa_audit.completed) + def test_cdsa_workspace_queryset(self): + """Audit only runs on CDSAWorkspaces in the cdsa_workspace_queryset.""" + cdsa_workspace = factories.CDSAWorkspaceFactory.create() + factories.CDSAWorkspaceFactory.create() + cdsa_audit = workspace_audit.WorkspaceAccessAudit( + cdsa_workspace_queryset=models.CDSAWorkspace.objects.filter( + pk=cdsa_workspace.workspace.pk + ) + ) + cdsa_audit.run_audit() + self.assertEqual(len(cdsa_audit.verified), 1) + self.assertEqual(len(cdsa_audit.needs_action), 0) + self.assertEqual(len(cdsa_audit.errors), 0) + record = cdsa_audit.verified[0] + self.assertIsInstance(record, workspace_audit.VerifiedNoAccess) + self.assertEqual(record.workspace, cdsa_workspace) + self.assertIsNone(record.data_affiliate_agreement) + self.assertEqual(record.note, cdsa_audit.NO_PRIMARY_AGREEMENT) + + def test_cdsa_workspace_queryset_wrong_class(self): + """Audit raises error if dbgap_application_queryset has the wrong model class.""" + with self.assertRaises(ValueError) as e: + workspace_audit.WorkspaceAccessAudit( + cdsa_workspace_queryset=models.SignedAgreement.objects.all() + ) + self.assertEqual( + str(e.exception), + "cdsa_workspace_queryset must be a queryset of CDSAWorkspace objects.", + ) + + def test_cdsa_workspace_queryset_not_queryset(self): + """Audit raises error if dbgap_application_queryset is not a queryset.""" + workspace = factories.CDSAWorkspaceFactory.create() + with self.assertRaises(ValueError) as e: + workspace_audit.WorkspaceAccessAudit(cdsa_workspace_queryset=workspace) + self.assertEqual( + str(e.exception), + "cdsa_workspace_queryset must be a queryset of CDSAWorkspace objects.", + ) + def test_primary_in_auth_domain(self): study = StudyFactory.create() workspace = factories.CDSAWorkspaceFactory.create(study=study)