-
-
Notifications
You must be signed in to change notification settings - Fork 0
/
array.go
257 lines (228 loc) · 5.3 KB
/
array.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
package greddis
import (
"bufio"
"database/sql/driver"
"fmt"
"io"
"strconv"
"unsafe"
)
type StrInt int
type arrayWriter interface {
Add(item ...interface{}) (err error)
AddString(item ...string) arrayWriter
Flush() (err error)
Init(length int) (newAW ArrayWriter)
Len() (length int)
Reset(w io.Writer)
}
func NewArrayWriter(bufw *bufio.Writer) *ArrayWriter {
return &ArrayWriter{
bufw: bufw,
convBuf: make([]byte, 0, 20), // max length of an signed int64 is 20
}
}
type ArrayWriter struct {
length int
added int
convBuf []byte
bufw *bufio.Writer
}
func (w *ArrayWriter) Init(length int) *ArrayWriter {
w.length = length
w.bufw.WriteRune('*')
w.bufw.Write(w.convInt(w.length))
w.bufw.Write(sep)
return w
}
func (w *ArrayWriter) Len() int {
return w.length
}
func (w *ArrayWriter) Add(items ...interface{}) error {
var err error
for _, item := range items {
err = w.addItem(item)
if err != nil {
return err
}
w.added++
}
return err
}
func (w *ArrayWriter) AddString(items ...string) *ArrayWriter {
for _, item := range items {
w.writeString(item)
w.added++
}
return w
}
func (w *ArrayWriter) Flush() error {
if w.length != w.added {
return fmt.Errorf("Expected %d items, but %d items were added", w.length, w.added)
}
//w.convBuf = w.convBuf[:0]
w.length = 0
w.added = 0
return w.bufw.Flush()
}
func (w *ArrayWriter) Reset(wr io.Writer) {
w.bufw.Reset(wr)
w.convBuf = w.convBuf[:0]
w.length = 0
w.added = 0
}
func (w *ArrayWriter) convInt(item int) []byte {
w.convBuf = w.convBuf[:0]
w.convBuf = strconv.AppendInt(w.convBuf, int64(item), 10)
return w.convBuf
}
func (w *ArrayWriter) writeInt(item int) {
w.bufw.WriteRune(':')
w.bufw.Write(w.convInt(item))
w.bufw.Write(sep)
}
func (w *ArrayWriter) writeIntStr(item int) {
// get string length of int first
length := len(w.convInt(item))
w.bufw.WriteRune('$')
w.bufw.Write(w.convInt(length))
w.bufw.Write(sep)
w.bufw.Write(w.convInt(item))
w.bufw.Write(sep)
}
func (w *ArrayWriter) writeBytes(item []byte) {
w.bufw.WriteRune('$')
w.bufw.Write(w.convInt(len(item)))
w.bufw.Write(sep)
w.bufw.Write(item)
w.bufw.Write(sep)
}
func (c *ArrayWriter) writeString(item string) {
c.writeBytes(*(*[]byte)(unsafe.Pointer(&item)))
}
func (c *ArrayWriter) addString(item interface{}) {
// this case avoids weird race conditions per https://github.com/mikn/greddis/issues/9
switch d := item.(type) {
case string:
c.writeString(d)
case *string:
c.writeString(*d)
}
}
func (w *ArrayWriter) addItem(item interface{}) error {
switch d := item.(type) {
case string: // sigh, need this for fallthrough
w.addString(d)
case *string:
w.addString(d)
case []byte:
w.writeBytes(d)
case StrInt:
w.writeIntStr(int(d))
case int:
w.writeInt(d)
case *[]byte:
w.writeBytes(*d)
case *int:
w.writeInt(*d)
case *StrInt:
w.writeIntStr(int(*d))
case driver.Valuer:
val, err := d.Value()
if err != nil {
return err
}
switch v := val.(type) {
case []byte:
w.writeBytes(v)
return nil
default:
return ErrWrongType(v, "a driver.Valuer that supports []byte")
}
default:
return ErrWrongType(d, "driver.Valuer, *string, string, *[]byte, []byte, *int or int, *StrInt, StrInt")
}
return nil
}
type arrayReader interface {
Init(r *Reader) (self *ArrayReader)
Len() (length int)
Scan(value ...interface{})
}
func NewArrayReader(r *Reader) *ArrayReader {
return &ArrayReader{
r: r,
}
}
type ArrayReader struct {
r *Reader
length int
pos int
scanFunc ScanFunc
err error
}
// Len returns the length of the ArrayReader
func (a *ArrayReader) Len() int {
return a.length
}
func (a *ArrayReader) Err() error {
return a.err
}
func (a *ArrayReader) Init(defaultScanFunc ScanFunc) error {
err := a.r.Next(ScanArray)
a.length = a.r.Len()
a.pos = 0
a.r.tokenLen = 0
a.scanFunc = defaultScanFunc
return err
}
func (r *ArrayReader) NextIs(scanFunc ScanFunc) *ArrayReader {
prev := r.scanFunc
r.scanFunc = scanFunc
r.Next()
r.scanFunc = prev
return r
}
// Next prepares the next row to be used by `Scan()`, it returns either a "no more rows" error or
// a connection/read error will be wrapped.
func (r *ArrayReader) Next() *ArrayReader {
if r.pos >= r.length {
r.err = ErrNoMoreRows
return r
}
r.pos++
r.err = r.r.Next(r.scanFunc)
return r
}
// Scan operates the same as `Scan` on a single result
func (r *ArrayReader) Scan(dst interface{}) error {
if r.err != nil {
return r.err
}
if err := scan(r.r, dst); err != nil {
return err
}
return nil
}
// SwitchOnNext returns a string value of the next value in the ArrayReader which is a pointer to the underlying
// byte slice - as the name implies, it is mostly implemented for switch cases where there's a guarantee
// that the next Scan/SwitchOnNext call will happen after the last use of this value. If you want to not
// only switch on the value or do a one-off comparison, please use Scan() instead.
func (r *ArrayReader) SwitchOnNext() string {
if r.err != nil {
return ""
}
return r.r.String()
}
// Expect does an Any byte comparison with the values passed in against the next value in the array
func (r *ArrayReader) Expect(vars ...string) error {
for i, v := range vars {
if r.r.String() == v {
return nil
}
if i < len(vars)-1 {
r.Next()
}
}
return fmt.Errorf("%s was not equal to any of %s", r.r.String(), vars)
}