Skip to content

Commit

Permalink
db.In, db.NotIn iterating slices (#21)
Browse files Browse the repository at this point in the history
  • Loading branch information
david-littlefarmer authored Aug 22, 2024
1 parent 7c55126 commit 7177fd7
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 8 deletions.
5 changes: 5 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,13 @@ test: test-clean
test-all: test-clean
GOGC=off go test $(TEST_FLAGS) $(MOD_VENDOR) -run=$(TEST) ./...

test-all-tparse: test-clean
GOGC=off go test $(TEST_FLAGS) $(MOD_VENDOR) -run=$(TEST) ./... -json | tparse --follow

test-with-reset: db-reset test-all

test-with-reset-tparse: db-reset test-all-tparse

test-clean:
GOGC=off go clean -testcache

Expand Down
55 changes: 49 additions & 6 deletions db/cond.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package db

import (
"fmt"
"reflect"
"strings"

"github.com/Masterminds/squirrel"
Expand Down Expand Up @@ -80,7 +81,6 @@ func (n *binaryExprNode) ToSql() (string, []interface{}, error) {
func compileNodes(nodes []squirrel.Sqlizer) (q string, args []interface{}, err error) {
for i, node := range nodes {
qn, argsn, err := node.ToSql()

if err != nil {
return "", nil, fmt.Errorf("error compiling node %d: %w", i, err)
}
Expand Down Expand Up @@ -203,12 +203,12 @@ func NotILike(v interface{}) squirrel.Sqlizer {

// In represents an IN operator. The value must be variadic.
func In[T interface{}](v ...T) squirrel.Sqlizer {
return Func[T]("IN", v...)
return Func("IN", v...)
}

// NotIn represents a NOT IN operator. The value must be variadic.
func NotIn[T interface{}](v ...T) squirrel.Sqlizer {
return Func[T]("NOT IN", v...)
return Func("NOT IN", v...)
}

// Raw represents a raw SQL expression.
Expand All @@ -226,15 +226,58 @@ func Func[T interface{}](name string, params ...T) squirrel.Sqlizer {
}

places := make([]string, len(params))
args := make([]interface{}, 0, len(params))

// iterating through slices
if reflect.TypeOf(params[0]).Kind() == reflect.Slice {
elements := 0
for _, subSlice := range params {
v := reflect.ValueOf(subSlice)
elements += v.Len()
}

args := make([]interface{}, 0, elements)

for i, subSlice := range params {
subSliceVal := reflect.ValueOf(subSlice)
subPlaces := make([]string, subSliceVal.Len())

for j := 0; j < subSliceVal.Len(); j++ {
val := subSliceVal.Index(j).Interface()
if sqlizer, ok := interface{}(val).(squirrel.Sqlizer); ok {
paramSQL, paramArgs, err := sqlizer.ToSql()
if err != nil {
return "", nil, fmt.Errorf("%s: error compiling argument %d: %w", name, i, err)
}

subPlaces[j] = paramSQL
args = append(args, paramArgs...)
} else if reflect.TypeOf(val).Kind() == reflect.Slice {
v := reflect.ValueOf(val)
for k := 0; k < v.Len(); k++ {
subPlaces[j] = paramPlaceholder
args = append(args, v.Index(k).Interface())
}
} else {
subPlaces[j] = paramPlaceholder
args = append(args, val)
}
}

places[i] = "(" + strings.Join(subPlaces, ",") + ")"
}

return name + " (" + strings.Join(places, ",") + ")", args, nil
}

args := make([]interface{}, 0, len(params))
for i, param := range params {
if sqlizer, ok := interface{}(param).(squirrel.Sqlizer); ok {
paramSql, paramArgs, err := sqlizer.ToSql()
paramSQL, paramArgs, err := sqlizer.ToSql()
if err != nil {
return "", nil, fmt.Errorf("%s: error compiling argument %d: %w", name, i, err)
}
places[i] = paramSql

places[i] = paramSQL
args = append(args, paramArgs...)
} else {
places[i] = paramPlaceholder
Expand Down
64 changes: 62 additions & 2 deletions db/cond_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@ import (
"testing"

sq "github.com/Masterminds/squirrel"
"github.com/goware/pgkit/v2/db"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/goware/pgkit/v2/db"
)

func TestCond(t *testing.T) {

t.Run("equal to", func(t *testing.T) {
cond := db.Cond{"one": 1}
s, args, err := cond.ToSql()
Expand All @@ -19,6 +19,14 @@ func TestCond(t *testing.T) {
assert.Equal(t, "one = ?", s)
})

t.Run("equal to with multiple parameters", func(t *testing.T) {
cond := db.And{db.Cond{"one": 1}, db.Cond{"two": 2}}
s, args, err := cond.ToSql()
require.NoError(t, err)
assert.Equal(t, []interface{}{1, 2}, args)
assert.Equal(t, "(one = ? AND two = ?)", s)
})

t.Run("equal to (inverted)", func(t *testing.T) {
cond := db.Cond{1: "one"}
s, args, err := cond.ToSql()
Expand Down Expand Up @@ -64,6 +72,16 @@ func TestCond(t *testing.T) {
})

t.Run("IN with slice", func(t *testing.T) {
sl1 := []int{1, 2, 3}
cond := db.Cond{"list": db.In(sl1...)}
s, args, err := cond.ToSql()
require.NoError(t, err)

assert.Equal(t, []interface{}{1, 2, 3}, args)
assert.Equal(t, "list IN (?, ?, ?)", s)
})

t.Run("IN with slice variadic", func(t *testing.T) {
cond := db.Cond{"list": db.In(1, 2, 3)}
s, args, err := cond.ToSql()
require.NoError(t, err)
Expand All @@ -72,6 +90,48 @@ func TestCond(t *testing.T) {
assert.Equal(t, "list IN (?, ?, ?)", s)
})

t.Run("multiple IN with slice", func(t *testing.T) {
sl1 := []int{1, 2, 3}
sl2 := []int{4, 5, 6}
cond := db.Cond{"list": db.In([]interface{}{sl1, sl2}...)}
s, args, err := cond.ToSql()
require.NoError(t, err)

assert.Equal(t, []interface{}{1, 2, 3, 4, 5, 6}, args)
assert.Equal(t, "list IN ((?,?,?),(?,?,?))", s)
})

t.Run("multiple IN with slice AND where ID", func(t *testing.T) {
cond := db.And{db.Cond{"list": db.In([][]string{{"1", "2", "3"}, {"3", "4", "5"}}...)}, db.Cond{"id": 1}}
s, args, err := cond.ToSql()
require.NoError(t, err)

assert.Equal(t, []interface{}{"1", "2", "3", "3", "4", "5", 1}, args)
assert.Equal(t, "(list IN ((?,?,?),(?,?,?)) AND id = ?)", s)
})

t.Run("multiple IN with struct", func(t *testing.T) {
randomStruct := []struct {
Id uint64
Name string
}{
{Id: 1, Name: "Lukas"},
{Id: 2, Name: "David"},
}

data := [][]interface{}{}
for _, s := range randomStruct {
data = append(data, []interface{}{s.Id, s.Name})
}

cond := db.Cond{"list": db.In(data...)}
s, args, err := cond.ToSql()
require.NoError(t, err)

assert.Equal(t, []interface{}{uint64(1), "Lukas", uint64(2), "David"}, args)
assert.Equal(t, "list IN ((?,?),(?,?))", s)
})

t.Run("NOT IN", func(t *testing.T) {
cond := db.Cond{"list": db.NotIn("Czech Republic", "Slovakia")}
s, args, err := cond.ToSql()
Expand Down

0 comments on commit 7177fd7

Please sign in to comment.