-
Notifications
You must be signed in to change notification settings - Fork 1
/
mask_merge.go
133 lines (113 loc) · 3.75 KB
/
mask_merge.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
/*
Package maskmerge deals with merging protocol buffer messages using field masks.
See google.golang.org/protobuf/types/known/fieldmaskpb for more information on
how field masks are interpreted.
*/
package maskmerge
import (
"errors"
"fmt"
"strings"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/types/known/fieldmaskpb"
)
// MaskMerge merges the fields in mask in src into dst.
// Following the specification field mask specification, lists and messages are
// merged, not replaced.
func MaskMerge(dst, src proto.Message, mask *fieldmaskpb.FieldMask) error {
if !dst.ProtoReflect().IsValid() {
return errors.New("invalid destination message")
}
dstr := dst.ProtoReflect()
srcr := src.ProtoReflect()
// Is this the best way to check if this is the same message?
if stn, dtn := srcr.Descriptor().FullName(), dstr.Descriptor().FullName(); stn != dtn {
return fmt.Errorf("incompatible types %s and %s", stn, dtn)
}
for _, fullPath := range mask.Paths {
path, childPath, hasChild := cut(fullPath, ".")
if len(path) == 0 {
return fmt.Errorf("invalid field name %s", fullPath)
}
// get the field descriptor from this update mask path
fd := dstr.Descriptor().Fields().ByName(protoreflect.Name(path))
if fd == nil {
return fmt.Errorf("invalid field name %s for %s", fullPath, dstr.Descriptor().Name())
}
_, _ = childPath, hasChild
if hasChild {
sm := &fieldmaskpb.FieldMask{Paths: []string{childPath}}
// check if the target field is valid before we try to mutate it
if !dstr.Has(fd) {
// create it
dstr.Set(fd, dstr.NewField(fd))
}
if fd.Kind() != protoreflect.MessageKind {
return errors.New("nested field is not a message")
} else if !sm.IsValid(dstr.Get(fd).Message().Interface()) {
return fmt.Errorf("nested field %s is not valid for %s (%s)", childPath, path, fullPath)
}
if err := MaskMerge(dstr.Get(fd).Message().Interface(), srcr.Mutable(fd).Message().Interface(), sm); err != nil {
return fmt.Errorf("applying to nested field %s: %w", fd.FullName(), err)
}
continue
}
switch {
case fd.IsMap() || fd.IsList() || fd.Kind() == protoreflect.MessageKind:
dstr.Set(fd, mergeMessages(fd, dstr.Get(fd), srcr.Get(fd)))
default:
dstr.Set(fd, srcr.Get(fd))
}
}
return nil
}
func mergeMessages(fd protoreflect.FieldDescriptor, a protoreflect.Value, b protoreflect.Value) protoreflect.Value {
switch {
case fd.IsList():
return mergeList(a, b)
case fd.IsMap():
return mergeMap(a, b)
case fd.Kind() == protoreflect.MessageKind:
return mergeMessage(a, b)
}
return a
}
func mergeMap(a protoreflect.Value, b protoreflect.Value) protoreflect.Value {
dst := a.Map()
src := b.Map()
src.Range(func(key protoreflect.MapKey, value protoreflect.Value) bool {
dst.Set(key, value)
return true
})
return protoreflect.ValueOfMap(dst)
}
func mergeMessage(a protoreflect.Value, b protoreflect.Value) protoreflect.Value {
dst := a.Message()
src := b.Message()
src.Range(func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool {
dst.Set(fd, v)
return true
})
return protoreflect.ValueOfMessage(dst)
}
func mergeList(a protoreflect.Value, b protoreflect.Value) protoreflect.Value {
dst := a.List()
src := b.List()
for i := 0; i < src.Len(); i++ {
dst.Append(src.Get(i))
}
return protoreflect.ValueOfList(dst)
}
// cut cuts s around the first instance of sep,
// returning the text before and after sep.
// The found result reports whether sep appears in s.
// If sep does not appear in s, cut returns s, "", false.
//
// todo: replace with strings.Cut in go1.18
func cut(s, sep string) (before, after string, found bool) {
if i := strings.Index(s, sep); i >= 0 {
return s[:i], s[i+len(sep):], true
}
return s, "", false
}