Skip to content

Commit

Permalink
Merge pull request #1008 from josephschorr/ensure-no-nil
Browse files Browse the repository at this point in the history
Catch nil values for FoundSubjectsByResourceID map and return as errors
  • Loading branch information
josephschorr authored Nov 21, 2022
2 parents 485d475 + 5518ee1 commit 7f3f207
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 58 deletions.
25 changes: 20 additions & 5 deletions internal/datasets/subjectsetbyresourceid.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package datasets

import (
"fmt"

core "github.com/authzed/spicedb/pkg/proto/core/v1"
v1 "github.com/authzed/spicedb/pkg/proto/dispatch/v1"
)
Expand All @@ -18,30 +20,43 @@ type SubjectSetByResourceID struct {
subjectSetByResourceID map[string]SubjectSet
}

func (ssr SubjectSetByResourceID) add(resourceID string, subject *v1.FoundSubject) {
func (ssr SubjectSetByResourceID) add(resourceID string, subject *v1.FoundSubject) error {
if subject == nil {
return fmt.Errorf("cannot add a nil subject to SubjectSetByResourceID")
}

_, ok := ssr.subjectSetByResourceID[resourceID]
if !ok {
ssr.subjectSetByResourceID[resourceID] = NewSubjectSet()
}
ssr.subjectSetByResourceID[resourceID].Add(subject)
return nil
}

// AddFromRelationship adds the subject found in the given relationship to this map, indexed at
// the resource ID specified in the relationship.
func (ssr SubjectSetByResourceID) AddFromRelationship(relationship *core.RelationTuple) {
ssr.add(relationship.ResourceAndRelation.ObjectId, &v1.FoundSubject{
func (ssr SubjectSetByResourceID) AddFromRelationship(relationship *core.RelationTuple) error {
return ssr.add(relationship.ResourceAndRelation.ObjectId, &v1.FoundSubject{
SubjectId: relationship.Subject.ObjectId,
CaveatExpression: wrapCaveat(relationship.Caveat),
})
}

// UnionWith unions the map's sets with the other map of sets provided.
func (ssr SubjectSetByResourceID) UnionWith(other map[string]*v1.FoundSubjects) {
func (ssr SubjectSetByResourceID) UnionWith(other map[string]*v1.FoundSubjects) error {
for resourceID, subjects := range other {
if subjects == nil {
return fmt.Errorf("received nil FoundSubjects in other map of SubjectSetByResourceID's UnionWith for key %s", resourceID)
}

for _, subject := range subjects.FoundSubjects {
ssr.add(resourceID, subject)
if err := ssr.add(resourceID, subject); err != nil {
return err
}
}
}

return nil
}

// IntersectionDifference performs an in-place intersection between the two maps' sets.
Expand Down
79 changes: 45 additions & 34 deletions internal/datasets/subjectsetbyresourceid_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@ func TestSubjectSetByResourceIDBasicOperations(t *testing.T) {
ssr := NewSubjectSetByResourceID()
require.True(t, ssr.IsEmpty())

ssr.AddFromRelationship(tuple.MustParse("document:firstdoc#viewer@user:tom#..."))
require.NoError(t, ssr.AddFromRelationship(tuple.MustParse("document:firstdoc#viewer@user:tom#...")))
require.False(t, ssr.IsEmpty())

ssr.AddFromRelationship(tuple.MustParse("document:firstdoc#viewer@user:sarah#..."))
ssr.AddFromRelationship(tuple.MustParse("document:seconddoc#viewer@user:fred#..."))
require.NoError(t, ssr.AddFromRelationship(tuple.MustParse("document:firstdoc#viewer@user:sarah#...")))
require.NoError(t, ssr.AddFromRelationship(tuple.MustParse("document:seconddoc#viewer@user:fred#...")))

expected := map[string]*v1.FoundSubjects{
"firstdoc": {
Expand Down Expand Up @@ -47,11 +47,11 @@ func TestSubjectSetByResourceIDBasicOperations(t *testing.T) {

func TestSubjectSetByResourceIDUnionWith(t *testing.T) {
ssr := NewSubjectSetByResourceID()
ssr.AddFromRelationship(tuple.MustParse("document:firstdoc#viewer@user:tom#..."))
ssr.AddFromRelationship(tuple.MustParse("document:firstdoc#viewer@user:sarah#..."))
ssr.AddFromRelationship(tuple.MustParse("document:seconddoc#viewer@user:fred#..."))
require.NoError(t, ssr.AddFromRelationship(tuple.MustParse("document:firstdoc#viewer@user:tom#...")))
require.NoError(t, ssr.AddFromRelationship(tuple.MustParse("document:firstdoc#viewer@user:sarah#...")))
require.NoError(t, ssr.AddFromRelationship(tuple.MustParse("document:seconddoc#viewer@user:fred#...")))

ssr.UnionWith(map[string]*v1.FoundSubjects{
err := ssr.UnionWith(map[string]*v1.FoundSubjects{
"firstdoc": {
FoundSubjects: []*v1.FoundSubject{
{SubjectId: "tom"},
Expand All @@ -64,6 +64,7 @@ func TestSubjectSetByResourceIDUnionWith(t *testing.T) {
},
},
})
require.NoError(t, err)

found := ssr.AsMap()
sort.Sort(sortFoundSubjects(found["firstdoc"].FoundSubjects))
Expand Down Expand Up @@ -99,14 +100,14 @@ func (a sortFoundSubjects) Less(i, j int) bool {

func TestSubjectSetByResourceIDIntersectionDifference(t *testing.T) {
first := NewSubjectSetByResourceID()
first.AddFromRelationship(tuple.MustParse("document:firstdoc#viewer@user:tom#..."))
first.AddFromRelationship(tuple.MustParse("document:firstdoc#viewer@user:sarah#..."))
first.AddFromRelationship(tuple.MustParse("document:seconddoc#viewer@user:fred#..."))
require.NoError(t, first.AddFromRelationship(tuple.MustParse("document:firstdoc#viewer@user:tom#...")))
require.NoError(t, first.AddFromRelationship(tuple.MustParse("document:firstdoc#viewer@user:sarah#...")))
require.NoError(t, first.AddFromRelationship(tuple.MustParse("document:seconddoc#viewer@user:fred#...")))

second := NewSubjectSetByResourceID()
second.AddFromRelationship(tuple.MustParse("document:firstdoc#viewer@user:tom#..."))
second.AddFromRelationship(tuple.MustParse("document:firstdoc#viewer@user:micah#..."))
second.AddFromRelationship(tuple.MustParse("document:seconddoc#viewer@user:george#..."))
require.NoError(t, second.AddFromRelationship(tuple.MustParse("document:firstdoc#viewer@user:tom#...")))
require.NoError(t, second.AddFromRelationship(tuple.MustParse("document:firstdoc#viewer@user:micah#...")))
require.NoError(t, second.AddFromRelationship(tuple.MustParse("document:seconddoc#viewer@user:george#...")))

first.IntersectionDifference(second)

Expand All @@ -121,13 +122,13 @@ func TestSubjectSetByResourceIDIntersectionDifference(t *testing.T) {

func TestSubjectSetByResourceIDIntersectionDifferenceMissingKey(t *testing.T) {
first := NewSubjectSetByResourceID()
first.AddFromRelationship(tuple.MustParse("document:firstdoc#viewer@user:tom#..."))
first.AddFromRelationship(tuple.MustParse("document:firstdoc#viewer@user:sarah#..."))
first.AddFromRelationship(tuple.MustParse("document:seconddoc#viewer@user:fred#..."))
require.NoError(t, first.AddFromRelationship(tuple.MustParse("document:firstdoc#viewer@user:tom#...")))
require.NoError(t, first.AddFromRelationship(tuple.MustParse("document:firstdoc#viewer@user:sarah#...")))
require.NoError(t, first.AddFromRelationship(tuple.MustParse("document:seconddoc#viewer@user:fred#...")))

second := NewSubjectSetByResourceID()
second.AddFromRelationship(tuple.MustParse("document:firstdoc#viewer@user:tom#..."))
second.AddFromRelationship(tuple.MustParse("document:firstdoc#viewer@user:micah#..."))
require.NoError(t, second.AddFromRelationship(tuple.MustParse("document:firstdoc#viewer@user:tom#...")))
require.NoError(t, second.AddFromRelationship(tuple.MustParse("document:firstdoc#viewer@user:micah#...")))

first.IntersectionDifference(second)

Expand All @@ -142,13 +143,13 @@ func TestSubjectSetByResourceIDIntersectionDifferenceMissingKey(t *testing.T) {

func TestSubjectSetByResourceIDIntersectionDifferenceItemInSecondSet(t *testing.T) {
first := NewSubjectSetByResourceID()
first.AddFromRelationship(tuple.MustParse("document:firstdoc#viewer@user:tom#..."))
first.AddFromRelationship(tuple.MustParse("document:firstdoc#viewer@user:sarah#..."))
require.NoError(t, first.AddFromRelationship(tuple.MustParse("document:firstdoc#viewer@user:tom#...")))
require.NoError(t, first.AddFromRelationship(tuple.MustParse("document:firstdoc#viewer@user:sarah#...")))

second := NewSubjectSetByResourceID()
second.AddFromRelationship(tuple.MustParse("document:firstdoc#viewer@user:tom#..."))
second.AddFromRelationship(tuple.MustParse("document:firstdoc#viewer@user:micah#..."))
second.AddFromRelationship(tuple.MustParse("document:seconddoc#viewer@user:fred#..."))
require.NoError(t, second.AddFromRelationship(tuple.MustParse("document:firstdoc#viewer@user:tom#...")))
require.NoError(t, second.AddFromRelationship(tuple.MustParse("document:firstdoc#viewer@user:micah#...")))
require.NoError(t, second.AddFromRelationship(tuple.MustParse("document:seconddoc#viewer@user:fred#...")))

first.IntersectionDifference(second)

Expand All @@ -163,14 +164,14 @@ func TestSubjectSetByResourceIDIntersectionDifferenceItemInSecondSet(t *testing.

func TestSubjectSetByResourceIDSubtractAll(t *testing.T) {
first := NewSubjectSetByResourceID()
first.AddFromRelationship(tuple.MustParse("document:firstdoc#viewer@user:tom#..."))
first.AddFromRelationship(tuple.MustParse("document:firstdoc#viewer@user:sarah#..."))
first.AddFromRelationship(tuple.MustParse("document:seconddoc#viewer@user:fred#..."))
require.NoError(t, first.AddFromRelationship(tuple.MustParse("document:firstdoc#viewer@user:tom#...")))
require.NoError(t, first.AddFromRelationship(tuple.MustParse("document:firstdoc#viewer@user:sarah#...")))
require.NoError(t, first.AddFromRelationship(tuple.MustParse("document:seconddoc#viewer@user:fred#...")))

second := NewSubjectSetByResourceID()
second.AddFromRelationship(tuple.MustParse("document:firstdoc#viewer@user:tom#..."))
second.AddFromRelationship(tuple.MustParse("document:firstdoc#viewer@user:micah#..."))
second.AddFromRelationship(tuple.MustParse("document:seconddoc#viewer@user:george#..."))
require.NoError(t, second.AddFromRelationship(tuple.MustParse("document:firstdoc#viewer@user:tom#...")))
require.NoError(t, second.AddFromRelationship(tuple.MustParse("document:firstdoc#viewer@user:micah#...")))
require.NoError(t, second.AddFromRelationship(tuple.MustParse("document:seconddoc#viewer@user:george#...")))

first.SubtractAll(second)

Expand All @@ -190,11 +191,11 @@ func TestSubjectSetByResourceIDSubtractAll(t *testing.T) {

func TestSubjectSetByResourceIDSubtractAllEmpty(t *testing.T) {
first := NewSubjectSetByResourceID()
first.AddFromRelationship(tuple.MustParse("document:firstdoc#viewer@user:tom#..."))
first.AddFromRelationship(tuple.MustParse("document:firstdoc#viewer@user:mi#..."))
require.NoError(t, first.AddFromRelationship(tuple.MustParse("document:firstdoc#viewer@user:tom#...")))
require.NoError(t, first.AddFromRelationship(tuple.MustParse("document:firstdoc#viewer@user:mi#...")))

second := NewSubjectSetByResourceID()
second.AddFromRelationship(tuple.MustParse("document:firstdoc#viewer@user:*#..."))
require.NoError(t, second.AddFromRelationship(tuple.MustParse("document:firstdoc#viewer@user:*#...")))

first.SubtractAll(second)

Expand All @@ -205,8 +206,8 @@ func TestSubjectSetByResourceIDBasicCaveatedOperations(t *testing.T) {
ssr := NewSubjectSetByResourceID()
require.True(t, ssr.IsEmpty())

ssr.AddFromRelationship(tuple.WithCaveat(tuple.MustParse("document:firstdoc#viewer@user:tom#..."), "first"))
ssr.AddFromRelationship(tuple.WithCaveat(tuple.MustParse("document:firstdoc#viewer@user:tom#..."), "second"))
require.NoError(t, ssr.AddFromRelationship(tuple.WithCaveat(tuple.MustParse("document:firstdoc#viewer@user:tom#..."), "first")))
require.NoError(t, ssr.AddFromRelationship(tuple.WithCaveat(tuple.MustParse("document:firstdoc#viewer@user:tom#..."), "second")))

expected := map[string]*v1.FoundSubjects{
"firstdoc": {
Expand All @@ -225,3 +226,13 @@ func TestSubjectSetByResourceIDBasicCaveatedOperations(t *testing.T) {

require.Equal(t, expected, asMap)
}

func TestSubjectSetByResoureIDUnionWithBadMap(t *testing.T) {
ssr := NewSubjectSetByResourceID()
require.True(t, ssr.IsEmpty())

err := ssr.UnionWith(map[string]*v1.FoundSubjects{
"isnil": nil,
})
require.NotNil(t, err)
}
56 changes: 37 additions & 19 deletions internal/graph/lookupsubjects.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,10 @@ func (cl *ConcurrentLookupSubjects) LookupSubjects(
// If the resource type matches the subject type, yield directly.
if req.SubjectRelation.Namespace == req.ResourceRelation.Namespace &&
req.SubjectRelation.Relation == req.ResourceRelation.Relation {
err := stream.Publish(&v1.DispatchLookupSubjectsResponse{
if err := stream.Publish(&v1.DispatchLookupSubjectsResponse{
FoundSubjectsByResourceId: subjectsForConcreteIds(req.ResourceIds),
Metadata: emptyMetadata,
})
if err != nil {
}); err != nil {
return err
}
}
Expand Down Expand Up @@ -120,7 +119,9 @@ func (cl *ConcurrentLookupSubjects) lookupDirectSubjects(

if tpl.Subject.Namespace == req.SubjectRelation.Namespace &&
tpl.Subject.Relation == req.SubjectRelation.Relation {
foundSubjectsByResourceID.AddFromRelationship(tpl)
if err := foundSubjectsByResourceID.AddFromRelationship(tpl); err != nil {
return fmt.Errorf("failed to call AddFromRelationship in lookupDirectSubjects: %w", err)
}
}

if tpl.Subject.Relation != tuple.Ellipsis {
Expand All @@ -130,11 +131,10 @@ func (cl *ConcurrentLookupSubjects) lookupDirectSubjects(
}

if !foundSubjectsByResourceID.IsEmpty() {
err := stream.Publish(&v1.DispatchLookupSubjectsResponse{
if err := stream.Publish(&v1.DispatchLookupSubjectsResponse{
FoundSubjectsByResourceId: foundSubjectsByResourceID.AsMap(),
Metadata: emptyMetadata,
})
if err != nil {
}); err != nil {
return err
}
}
Expand All @@ -149,8 +149,7 @@ func (cl *ConcurrentLookupSubjects) lookupViaComputed(
cu *core.ComputedUserset,
) error {
ds := datastoremw.MustFromContext(ctx).SnapshotReader(parentRequest.Revision)
err := namespace.CheckNamespaceAndRelation(ctx, parentRequest.ResourceRelation.Namespace, cu.Relation, true, ds)
if err != nil {
if err := namespace.CheckNamespaceAndRelation(ctx, parentRequest.ResourceRelation.Namespace, cu.Relation, true, ds); err != nil {
if errors.As(err, &namespace.ErrRelationNotFound{}) {
return nil
}
Expand Down Expand Up @@ -221,8 +220,7 @@ func (cl *ConcurrentLookupSubjects) lookupViaTupleToUserset(

// Map the found subject types by the computed userset relation, so that we dispatch to it.
toDispatchByComputedRelationType, err := toDispatchByTuplesetType.Map(func(resourceType *core.RelationReference) (*core.RelationReference, error) {
err := namespace.CheckNamespaceAndRelation(ctx, resourceType.Namespace, ttu.ComputedUserset.Relation, false, ds)
if err != nil {
if err := namespace.CheckNamespaceAndRelation(ctx, resourceType.Namespace, ttu.ComputedUserset.Relation, false, ds); err != nil {
if errors.As(err, &namespace.ErrRelationNotFound{}) {
return nil, nil
}
Expand Down Expand Up @@ -374,17 +372,27 @@ func (cl *ConcurrentLookupSubjects) dispatchTo(

// If the relationship has no caveat, simply map the resource ID.
if relationship.GetCaveat() == nil {
mappedFoundSubjects[relationship.ResourceAndRelation.ObjectId] = combineFoundSubjects(existing, foundSubjects)
combined, err := combineFoundSubjects(existing, foundSubjects)
if err != nil {
return nil, false, fmt.Errorf("could not combine caveat-less subjects: %w", err)
}
mappedFoundSubjects[relationship.ResourceAndRelation.ObjectId] = combined
continue
}

// Otherwise, apply the caveat to all found subjects for that resource and map to the resource ID.
foundSubjectSet := datasets.NewSubjectSet()
foundSubjectSet.UnionWith(foundSubjects.FoundSubjects)
mappedFoundSubjects[relationship.ResourceAndRelation.ObjectId] = combineFoundSubjects(

combined, err := combineFoundSubjects(
existing,
foundSubjectSet.WithParentCaveatExpression(wrapCaveat(relationship.Caveat)).AsFoundSubjects(),
)
if err != nil {
return nil, false, fmt.Errorf("could not combine caveated subjects: %w", err)
}

mappedFoundSubjects[relationship.ResourceAndRelation.ObjectId] = combined
}
}

Expand Down Expand Up @@ -414,14 +422,18 @@ func (cl *ConcurrentLookupSubjects) dispatchTo(
return g.Wait()
}

func combineFoundSubjects(existing *v1.FoundSubjects, toAdd *v1.FoundSubjects) *v1.FoundSubjects {
func combineFoundSubjects(existing *v1.FoundSubjects, toAdd *v1.FoundSubjects) (*v1.FoundSubjects, error) {
if existing == nil {
return toAdd
return toAdd, nil
}

if toAdd == nil {
return nil, fmt.Errorf("toAdd FoundSubject cannot be nil")
}

return &v1.FoundSubjects{
FoundSubjects: append(existing.FoundSubjects, toAdd.FoundSubjects...),
}
}, nil
}

type lookupSubjectsReducer interface {
Expand Down Expand Up @@ -460,7 +472,9 @@ func (lsu *lookupSubjectsUnion) CompletedChildOperations() error {

for _, result := range collector.Results() {
metadata = combineResponseMetadata(metadata, result.Metadata)
foundSubjects.UnionWith(result.FoundSubjectsByResourceId)
if err := foundSubjects.UnionWith(result.FoundSubjectsByResourceId); err != nil {
return fmt.Errorf("failed to UnionWith under lookupSubjectsUnion: %w", err)
}
}
}

Expand Down Expand Up @@ -506,7 +520,9 @@ func (lsi *lookupSubjectsIntersection) CompletedChildOperations() error {
results := datasets.NewSubjectSetByResourceID()
for _, result := range collector.Results() {
metadata = combineResponseMetadata(metadata, result.Metadata)
results.UnionWith(result.FoundSubjectsByResourceId)
if err := results.UnionWith(result.FoundSubjectsByResourceId); err != nil {
return fmt.Errorf("failed to UnionWith under lookupSubjectsIntersection: %w", err)
}
}

if index == 0 {
Expand Down Expand Up @@ -553,7 +569,9 @@ func (lse *lookupSubjectsExclusion) CompletedChildOperations() error {
results := datasets.NewSubjectSetByResourceID()
for _, result := range collector.Results() {
metadata = combineResponseMetadata(metadata, result.Metadata)
results.UnionWith(result.FoundSubjectsByResourceId)
if err := results.UnionWith(result.FoundSubjectsByResourceId); err != nil {
return fmt.Errorf("failed to UnionWith under lookupSubjectsExclusion: %w", err)
}
}

if index == 0 {
Expand Down

0 comments on commit 7f3f207

Please sign in to comment.