diff --git a/internal/datasets/subjectsetbyresourceid.go b/internal/datasets/subjectsetbyresourceid.go index e1f41b7bfd..0cdc7ae6d1 100644 --- a/internal/datasets/subjectsetbyresourceid.go +++ b/internal/datasets/subjectsetbyresourceid.go @@ -1,6 +1,8 @@ package datasets import ( + "fmt" + core "github.com/authzed/spicedb/pkg/proto/core/v1" v1 "github.com/authzed/spicedb/pkg/proto/dispatch/v1" ) @@ -18,30 +20,43 @@ type SubjectSetByResourceID struct { subjectSetByResourceID map[string]SubjectSet } -func (ssr SubjectSetByResourceID) add(resourceID string, subject *v1.FoundSubject) { +func (ssr SubjectSetByResourceID) add(resourceID string, subject *v1.FoundSubject) error { + if subject == nil { + return fmt.Errorf("cannot add a nil subject to SubjectSetByResourceID") + } + _, ok := ssr.subjectSetByResourceID[resourceID] if !ok { ssr.subjectSetByResourceID[resourceID] = NewSubjectSet() } ssr.subjectSetByResourceID[resourceID].Add(subject) + return nil } // AddFromRelationship adds the subject found in the given relationship to this map, indexed at // the resource ID specified in the relationship. -func (ssr SubjectSetByResourceID) AddFromRelationship(relationship *core.RelationTuple) { - ssr.add(relationship.ResourceAndRelation.ObjectId, &v1.FoundSubject{ +func (ssr SubjectSetByResourceID) AddFromRelationship(relationship *core.RelationTuple) error { + return ssr.add(relationship.ResourceAndRelation.ObjectId, &v1.FoundSubject{ SubjectId: relationship.Subject.ObjectId, CaveatExpression: wrapCaveat(relationship.Caveat), }) } // UnionWith unions the map's sets with the other map of sets provided. -func (ssr SubjectSetByResourceID) UnionWith(other map[string]*v1.FoundSubjects) { +func (ssr SubjectSetByResourceID) UnionWith(other map[string]*v1.FoundSubjects) error { for resourceID, subjects := range other { + if subjects == nil { + return fmt.Errorf("received nil FoundSubjects in other map of SubjectSetByResourceID's UnionWith for key %s", resourceID) + } + for _, subject := range subjects.FoundSubjects { - ssr.add(resourceID, subject) + if err := ssr.add(resourceID, subject); err != nil { + return err + } } } + + return nil } // IntersectionDifference performs an in-place intersection between the two maps' sets. diff --git a/internal/datasets/subjectsetbyresourceid_test.go b/internal/datasets/subjectsetbyresourceid_test.go index 59321b01b4..897236b928 100644 --- a/internal/datasets/subjectsetbyresourceid_test.go +++ b/internal/datasets/subjectsetbyresourceid_test.go @@ -15,11 +15,11 @@ func TestSubjectSetByResourceIDBasicOperations(t *testing.T) { ssr := NewSubjectSetByResourceID() require.True(t, ssr.IsEmpty()) - ssr.AddFromRelationship(tuple.MustParse("document:firstdoc#viewer@user:tom#...")) + require.NoError(t, ssr.AddFromRelationship(tuple.MustParse("document:firstdoc#viewer@user:tom#..."))) require.False(t, ssr.IsEmpty()) - ssr.AddFromRelationship(tuple.MustParse("document:firstdoc#viewer@user:sarah#...")) - ssr.AddFromRelationship(tuple.MustParse("document:seconddoc#viewer@user:fred#...")) + require.NoError(t, ssr.AddFromRelationship(tuple.MustParse("document:firstdoc#viewer@user:sarah#..."))) + require.NoError(t, ssr.AddFromRelationship(tuple.MustParse("document:seconddoc#viewer@user:fred#..."))) expected := map[string]*v1.FoundSubjects{ "firstdoc": { @@ -47,11 +47,11 @@ func TestSubjectSetByResourceIDBasicOperations(t *testing.T) { func TestSubjectSetByResourceIDUnionWith(t *testing.T) { ssr := NewSubjectSetByResourceID() - ssr.AddFromRelationship(tuple.MustParse("document:firstdoc#viewer@user:tom#...")) - ssr.AddFromRelationship(tuple.MustParse("document:firstdoc#viewer@user:sarah#...")) - ssr.AddFromRelationship(tuple.MustParse("document:seconddoc#viewer@user:fred#...")) + require.NoError(t, ssr.AddFromRelationship(tuple.MustParse("document:firstdoc#viewer@user:tom#..."))) + require.NoError(t, ssr.AddFromRelationship(tuple.MustParse("document:firstdoc#viewer@user:sarah#..."))) + require.NoError(t, ssr.AddFromRelationship(tuple.MustParse("document:seconddoc#viewer@user:fred#..."))) - ssr.UnionWith(map[string]*v1.FoundSubjects{ + err := ssr.UnionWith(map[string]*v1.FoundSubjects{ "firstdoc": { FoundSubjects: []*v1.FoundSubject{ {SubjectId: "tom"}, @@ -64,6 +64,7 @@ func TestSubjectSetByResourceIDUnionWith(t *testing.T) { }, }, }) + require.NoError(t, err) found := ssr.AsMap() sort.Sort(sortFoundSubjects(found["firstdoc"].FoundSubjects)) @@ -99,14 +100,14 @@ func (a sortFoundSubjects) Less(i, j int) bool { func TestSubjectSetByResourceIDIntersectionDifference(t *testing.T) { first := NewSubjectSetByResourceID() - first.AddFromRelationship(tuple.MustParse("document:firstdoc#viewer@user:tom#...")) - first.AddFromRelationship(tuple.MustParse("document:firstdoc#viewer@user:sarah#...")) - first.AddFromRelationship(tuple.MustParse("document:seconddoc#viewer@user:fred#...")) + require.NoError(t, first.AddFromRelationship(tuple.MustParse("document:firstdoc#viewer@user:tom#..."))) + require.NoError(t, first.AddFromRelationship(tuple.MustParse("document:firstdoc#viewer@user:sarah#..."))) + require.NoError(t, first.AddFromRelationship(tuple.MustParse("document:seconddoc#viewer@user:fred#..."))) second := NewSubjectSetByResourceID() - second.AddFromRelationship(tuple.MustParse("document:firstdoc#viewer@user:tom#...")) - second.AddFromRelationship(tuple.MustParse("document:firstdoc#viewer@user:micah#...")) - second.AddFromRelationship(tuple.MustParse("document:seconddoc#viewer@user:george#...")) + require.NoError(t, second.AddFromRelationship(tuple.MustParse("document:firstdoc#viewer@user:tom#..."))) + require.NoError(t, second.AddFromRelationship(tuple.MustParse("document:firstdoc#viewer@user:micah#..."))) + require.NoError(t, second.AddFromRelationship(tuple.MustParse("document:seconddoc#viewer@user:george#..."))) first.IntersectionDifference(second) @@ -121,13 +122,13 @@ func TestSubjectSetByResourceIDIntersectionDifference(t *testing.T) { func TestSubjectSetByResourceIDIntersectionDifferenceMissingKey(t *testing.T) { first := NewSubjectSetByResourceID() - first.AddFromRelationship(tuple.MustParse("document:firstdoc#viewer@user:tom#...")) - first.AddFromRelationship(tuple.MustParse("document:firstdoc#viewer@user:sarah#...")) - first.AddFromRelationship(tuple.MustParse("document:seconddoc#viewer@user:fred#...")) + require.NoError(t, first.AddFromRelationship(tuple.MustParse("document:firstdoc#viewer@user:tom#..."))) + require.NoError(t, first.AddFromRelationship(tuple.MustParse("document:firstdoc#viewer@user:sarah#..."))) + require.NoError(t, first.AddFromRelationship(tuple.MustParse("document:seconddoc#viewer@user:fred#..."))) second := NewSubjectSetByResourceID() - second.AddFromRelationship(tuple.MustParse("document:firstdoc#viewer@user:tom#...")) - second.AddFromRelationship(tuple.MustParse("document:firstdoc#viewer@user:micah#...")) + require.NoError(t, second.AddFromRelationship(tuple.MustParse("document:firstdoc#viewer@user:tom#..."))) + require.NoError(t, second.AddFromRelationship(tuple.MustParse("document:firstdoc#viewer@user:micah#..."))) first.IntersectionDifference(second) @@ -142,13 +143,13 @@ func TestSubjectSetByResourceIDIntersectionDifferenceMissingKey(t *testing.T) { func TestSubjectSetByResourceIDIntersectionDifferenceItemInSecondSet(t *testing.T) { first := NewSubjectSetByResourceID() - first.AddFromRelationship(tuple.MustParse("document:firstdoc#viewer@user:tom#...")) - first.AddFromRelationship(tuple.MustParse("document:firstdoc#viewer@user:sarah#...")) + require.NoError(t, first.AddFromRelationship(tuple.MustParse("document:firstdoc#viewer@user:tom#..."))) + require.NoError(t, first.AddFromRelationship(tuple.MustParse("document:firstdoc#viewer@user:sarah#..."))) second := NewSubjectSetByResourceID() - second.AddFromRelationship(tuple.MustParse("document:firstdoc#viewer@user:tom#...")) - second.AddFromRelationship(tuple.MustParse("document:firstdoc#viewer@user:micah#...")) - second.AddFromRelationship(tuple.MustParse("document:seconddoc#viewer@user:fred#...")) + require.NoError(t, second.AddFromRelationship(tuple.MustParse("document:firstdoc#viewer@user:tom#..."))) + require.NoError(t, second.AddFromRelationship(tuple.MustParse("document:firstdoc#viewer@user:micah#..."))) + require.NoError(t, second.AddFromRelationship(tuple.MustParse("document:seconddoc#viewer@user:fred#..."))) first.IntersectionDifference(second) @@ -163,14 +164,14 @@ func TestSubjectSetByResourceIDIntersectionDifferenceItemInSecondSet(t *testing. func TestSubjectSetByResourceIDSubtractAll(t *testing.T) { first := NewSubjectSetByResourceID() - first.AddFromRelationship(tuple.MustParse("document:firstdoc#viewer@user:tom#...")) - first.AddFromRelationship(tuple.MustParse("document:firstdoc#viewer@user:sarah#...")) - first.AddFromRelationship(tuple.MustParse("document:seconddoc#viewer@user:fred#...")) + require.NoError(t, first.AddFromRelationship(tuple.MustParse("document:firstdoc#viewer@user:tom#..."))) + require.NoError(t, first.AddFromRelationship(tuple.MustParse("document:firstdoc#viewer@user:sarah#..."))) + require.NoError(t, first.AddFromRelationship(tuple.MustParse("document:seconddoc#viewer@user:fred#..."))) second := NewSubjectSetByResourceID() - second.AddFromRelationship(tuple.MustParse("document:firstdoc#viewer@user:tom#...")) - second.AddFromRelationship(tuple.MustParse("document:firstdoc#viewer@user:micah#...")) - second.AddFromRelationship(tuple.MustParse("document:seconddoc#viewer@user:george#...")) + require.NoError(t, second.AddFromRelationship(tuple.MustParse("document:firstdoc#viewer@user:tom#..."))) + require.NoError(t, second.AddFromRelationship(tuple.MustParse("document:firstdoc#viewer@user:micah#..."))) + require.NoError(t, second.AddFromRelationship(tuple.MustParse("document:seconddoc#viewer@user:george#..."))) first.SubtractAll(second) @@ -190,11 +191,11 @@ func TestSubjectSetByResourceIDSubtractAll(t *testing.T) { func TestSubjectSetByResourceIDSubtractAllEmpty(t *testing.T) { first := NewSubjectSetByResourceID() - first.AddFromRelationship(tuple.MustParse("document:firstdoc#viewer@user:tom#...")) - first.AddFromRelationship(tuple.MustParse("document:firstdoc#viewer@user:mi#...")) + require.NoError(t, first.AddFromRelationship(tuple.MustParse("document:firstdoc#viewer@user:tom#..."))) + require.NoError(t, first.AddFromRelationship(tuple.MustParse("document:firstdoc#viewer@user:mi#..."))) second := NewSubjectSetByResourceID() - second.AddFromRelationship(tuple.MustParse("document:firstdoc#viewer@user:*#...")) + require.NoError(t, second.AddFromRelationship(tuple.MustParse("document:firstdoc#viewer@user:*#..."))) first.SubtractAll(second) @@ -205,8 +206,8 @@ func TestSubjectSetByResourceIDBasicCaveatedOperations(t *testing.T) { ssr := NewSubjectSetByResourceID() require.True(t, ssr.IsEmpty()) - ssr.AddFromRelationship(tuple.WithCaveat(tuple.MustParse("document:firstdoc#viewer@user:tom#..."), "first")) - ssr.AddFromRelationship(tuple.WithCaveat(tuple.MustParse("document:firstdoc#viewer@user:tom#..."), "second")) + require.NoError(t, ssr.AddFromRelationship(tuple.WithCaveat(tuple.MustParse("document:firstdoc#viewer@user:tom#..."), "first"))) + require.NoError(t, ssr.AddFromRelationship(tuple.WithCaveat(tuple.MustParse("document:firstdoc#viewer@user:tom#..."), "second"))) expected := map[string]*v1.FoundSubjects{ "firstdoc": { @@ -225,3 +226,13 @@ func TestSubjectSetByResourceIDBasicCaveatedOperations(t *testing.T) { require.Equal(t, expected, asMap) } + +func TestSubjectSetByResoureIDUnionWithBadMap(t *testing.T) { + ssr := NewSubjectSetByResourceID() + require.True(t, ssr.IsEmpty()) + + err := ssr.UnionWith(map[string]*v1.FoundSubjects{ + "isnil": nil, + }) + require.NotNil(t, err) +} diff --git a/internal/graph/lookupsubjects.go b/internal/graph/lookupsubjects.go index a1f32ef081..13078b78bf 100644 --- a/internal/graph/lookupsubjects.go +++ b/internal/graph/lookupsubjects.go @@ -49,11 +49,10 @@ func (cl *ConcurrentLookupSubjects) LookupSubjects( // If the resource type matches the subject type, yield directly. if req.SubjectRelation.Namespace == req.ResourceRelation.Namespace && req.SubjectRelation.Relation == req.ResourceRelation.Relation { - err := stream.Publish(&v1.DispatchLookupSubjectsResponse{ + if err := stream.Publish(&v1.DispatchLookupSubjectsResponse{ FoundSubjectsByResourceId: subjectsForConcreteIds(req.ResourceIds), Metadata: emptyMetadata, - }) - if err != nil { + }); err != nil { return err } } @@ -120,7 +119,9 @@ func (cl *ConcurrentLookupSubjects) lookupDirectSubjects( if tpl.Subject.Namespace == req.SubjectRelation.Namespace && tpl.Subject.Relation == req.SubjectRelation.Relation { - foundSubjectsByResourceID.AddFromRelationship(tpl) + if err := foundSubjectsByResourceID.AddFromRelationship(tpl); err != nil { + return fmt.Errorf("failed to call AddFromRelationship in lookupDirectSubjects: %w", err) + } } if tpl.Subject.Relation != tuple.Ellipsis { @@ -130,11 +131,10 @@ func (cl *ConcurrentLookupSubjects) lookupDirectSubjects( } if !foundSubjectsByResourceID.IsEmpty() { - err := stream.Publish(&v1.DispatchLookupSubjectsResponse{ + if err := stream.Publish(&v1.DispatchLookupSubjectsResponse{ FoundSubjectsByResourceId: foundSubjectsByResourceID.AsMap(), Metadata: emptyMetadata, - }) - if err != nil { + }); err != nil { return err } } @@ -149,8 +149,7 @@ func (cl *ConcurrentLookupSubjects) lookupViaComputed( cu *core.ComputedUserset, ) error { ds := datastoremw.MustFromContext(ctx).SnapshotReader(parentRequest.Revision) - err := namespace.CheckNamespaceAndRelation(ctx, parentRequest.ResourceRelation.Namespace, cu.Relation, true, ds) - if err != nil { + if err := namespace.CheckNamespaceAndRelation(ctx, parentRequest.ResourceRelation.Namespace, cu.Relation, true, ds); err != nil { if errors.As(err, &namespace.ErrRelationNotFound{}) { return nil } @@ -221,8 +220,7 @@ func (cl *ConcurrentLookupSubjects) lookupViaTupleToUserset( // Map the found subject types by the computed userset relation, so that we dispatch to it. toDispatchByComputedRelationType, err := toDispatchByTuplesetType.Map(func(resourceType *core.RelationReference) (*core.RelationReference, error) { - err := namespace.CheckNamespaceAndRelation(ctx, resourceType.Namespace, ttu.ComputedUserset.Relation, false, ds) - if err != nil { + if err := namespace.CheckNamespaceAndRelation(ctx, resourceType.Namespace, ttu.ComputedUserset.Relation, false, ds); err != nil { if errors.As(err, &namespace.ErrRelationNotFound{}) { return nil, nil } @@ -374,17 +372,27 @@ func (cl *ConcurrentLookupSubjects) dispatchTo( // If the relationship has no caveat, simply map the resource ID. if relationship.GetCaveat() == nil { - mappedFoundSubjects[relationship.ResourceAndRelation.ObjectId] = combineFoundSubjects(existing, foundSubjects) + combined, err := combineFoundSubjects(existing, foundSubjects) + if err != nil { + return nil, false, fmt.Errorf("could not combine caveat-less subjects: %w", err) + } + mappedFoundSubjects[relationship.ResourceAndRelation.ObjectId] = combined continue } // Otherwise, apply the caveat to all found subjects for that resource and map to the resource ID. foundSubjectSet := datasets.NewSubjectSet() foundSubjectSet.UnionWith(foundSubjects.FoundSubjects) - mappedFoundSubjects[relationship.ResourceAndRelation.ObjectId] = combineFoundSubjects( + + combined, err := combineFoundSubjects( existing, foundSubjectSet.WithParentCaveatExpression(wrapCaveat(relationship.Caveat)).AsFoundSubjects(), ) + if err != nil { + return nil, false, fmt.Errorf("could not combine caveated subjects: %w", err) + } + + mappedFoundSubjects[relationship.ResourceAndRelation.ObjectId] = combined } } @@ -414,14 +422,18 @@ func (cl *ConcurrentLookupSubjects) dispatchTo( return g.Wait() } -func combineFoundSubjects(existing *v1.FoundSubjects, toAdd *v1.FoundSubjects) *v1.FoundSubjects { +func combineFoundSubjects(existing *v1.FoundSubjects, toAdd *v1.FoundSubjects) (*v1.FoundSubjects, error) { if existing == nil { - return toAdd + return toAdd, nil + } + + if toAdd == nil { + return nil, fmt.Errorf("toAdd FoundSubject cannot be nil") } return &v1.FoundSubjects{ FoundSubjects: append(existing.FoundSubjects, toAdd.FoundSubjects...), - } + }, nil } type lookupSubjectsReducer interface { @@ -460,7 +472,9 @@ func (lsu *lookupSubjectsUnion) CompletedChildOperations() error { for _, result := range collector.Results() { metadata = combineResponseMetadata(metadata, result.Metadata) - foundSubjects.UnionWith(result.FoundSubjectsByResourceId) + if err := foundSubjects.UnionWith(result.FoundSubjectsByResourceId); err != nil { + return fmt.Errorf("failed to UnionWith under lookupSubjectsUnion: %w", err) + } } } @@ -506,7 +520,9 @@ func (lsi *lookupSubjectsIntersection) CompletedChildOperations() error { results := datasets.NewSubjectSetByResourceID() for _, result := range collector.Results() { metadata = combineResponseMetadata(metadata, result.Metadata) - results.UnionWith(result.FoundSubjectsByResourceId) + if err := results.UnionWith(result.FoundSubjectsByResourceId); err != nil { + return fmt.Errorf("failed to UnionWith under lookupSubjectsIntersection: %w", err) + } } if index == 0 { @@ -553,7 +569,9 @@ func (lse *lookupSubjectsExclusion) CompletedChildOperations() error { results := datasets.NewSubjectSetByResourceID() for _, result := range collector.Results() { metadata = combineResponseMetadata(metadata, result.Metadata) - results.UnionWith(result.FoundSubjectsByResourceId) + if err := results.UnionWith(result.FoundSubjectsByResourceId); err != nil { + return fmt.Errorf("failed to UnionWith under lookupSubjectsExclusion: %w", err) + } } if index == 0 {