Skip to content

Commit

Permalink
Merge pull request #879 from go-kivik/sqliteGet
Browse files Browse the repository at this point in the history
Beginning Get support in sqlite driver
  • Loading branch information
flimzy authored Feb 11, 2024
2 parents 87963b0 + 88d9ac2 commit 72d0daf
Show file tree
Hide file tree
Showing 6 changed files with 295 additions and 51 deletions.
44 changes: 44 additions & 0 deletions x/sqlite/common_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
// Licensed under the Apache License, Version 2.0 (the "License"); you may not
// use this file except in compliance with the License. You may obtain a copy of
// the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations under
// the License.

//go:build !js
// +build !js

package sqlite

import (
"context"
"testing"

"github.com/go-kivik/kivik/v4/driver"
)

// newDB creates a new driver.DB instance backed by an in-memory SQLite database,
// and registers a cleanup function to close the database when the test is done.
func newDB(t *testing.T) driver.DB {
d := drv{}
client, err := d.NewClient(":memory:", nil)
if err != nil {
t.Fatal(err)
}
if err := client.CreateDB(context.Background(), "test", nil); err != nil {
t.Fatal(err)
}
db, err := client.DB("test", nil)
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() {
_ = db.Close()
})
return db
}
103 changes: 91 additions & 12 deletions x/sqlite/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,12 @@ package sqlite
import (
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strings"

"modernc.org/sqlite"
sqlite3 "modernc.org/sqlite/lib"
Expand Down Expand Up @@ -71,11 +75,11 @@ func (d *db) Put(ctx context.Context, docID string, doc interface{}, options dri
return "", err
}
var newRev string
err = d.db.QueryRowContext(ctx, `
INSERT INTO `+d.name+` (id, rev_id, rev, doc)
err = d.db.QueryRowContext(ctx, fmt.Sprintf(`
INSERT INTO %q (id, rev_id, rev, doc)
VALUES ($1, $2, $3, $4)
RETURNING rev_id || '-' || rev
`, docID, rev.id, rev.rev, jsonDoc).Scan(&newRev)
`, d.name), docID, rev.id, rev.rev, jsonDoc).Scan(&newRev)
var sqliteErr *sqlite.Error
if errors.As(err, &sqliteErr) && sqliteErr.Code() == sqlite3.SQLITE_CONSTRAINT_UNIQUE {
// In the case of a conflict for new_edits=false, we assume that the
Expand All @@ -93,33 +97,108 @@ func (d *db) Put(ctx context.Context, docID string, doc interface{}, options dri
defer tx.Rollback()

var curRev string
err = tx.QueryRowContext(ctx, `
err = tx.QueryRowContext(ctx, fmt.Sprintf(`
SELECT COALESCE(MAX(rev_id || '-' || rev),'')
FROM `+d.name+`
FROM %q
WHERE id = $1
`, docID).Scan(&curRev)
`, d.name), docID).Scan(&curRev)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return "", err
}
if curRev != docRev {
return "", &internal.Error{Status: http.StatusConflict, Message: "conflict"}
}
var newRev string
err = tx.QueryRowContext(ctx, `
INSERT INTO `+d.name+` (id, rev_id, rev, doc)
err = tx.QueryRowContext(ctx, fmt.Sprintf(`
INSERT INTO %[1]q (id, rev_id, rev, doc)
SELECT $1, COALESCE(MAX(rev_id),0) + 1, $2, $3
FROM `+d.name+`
FROM %[1]q
WHERE id = $1
RETURNING rev_id || '-' || rev
`, docID, rev, jsonDoc).Scan(&newRev)
`, d.name), docID, rev, jsonDoc).Scan(&newRev)
if err != nil {
return "", err
}
return newRev, tx.Commit()
}

func (db) Get(context.Context, string, driver.Options) (*driver.Document, error) {
return nil, nil
func (d *db) Get(ctx context.Context, id string, options driver.Options) (*driver.Document, error) {
opts := map[string]interface{}{}
options.Apply(opts)

var rev, body string
var err error

if optsRev, _ := opts["rev"].(string); optsRev != "" {
var r revision
r, err = parseRev(optsRev)
if err != nil {
return nil, err
}
err = d.db.QueryRowContext(ctx, fmt.Sprintf(`
SELECT doc
FROM %q
WHERE id = $1
AND rev_id = $2
AND rev = $3
`, d.name), id, r.id, r.rev).Scan(&body)
rev = optsRev
} else {
err = d.db.QueryRowContext(ctx, fmt.Sprintf(`
SELECT rev_id || '-' || rev, doc
FROM %q
WHERE id = $1
AND deleted = FALSE
ORDER BY rev_id DESC, rev DESC
LIMIT 1
`, d.name), id).Scan(&rev, &body)
}

if errors.Is(err, sql.ErrNoRows) {
return nil, &internal.Error{Status: http.StatusNotFound, Message: "not found"}
}
if err != nil {
return nil, err
}

if conflicts, _ := opts["conflicts"].(bool); conflicts {
var revs []string
rows, err := d.db.QueryContext(ctx, fmt.Sprintf(`
SELECT rev_id || '-' || rev
FROM %q
WHERE id = $1
AND rev_id || '-' || rev != $2
AND DELETED = FALSE
`, d.name), id, rev)
if err != nil {
return nil, err
}
defer rows.Close()
for rows.Next() {
var r string
if err := rows.Scan(&r); err != nil {
return nil, err
}
revs = append(revs, r)
}
if err := rows.Err(); err != nil {
return nil, err
}
var doc map[string]interface{}
if err := json.Unmarshal([]byte(body), &doc); err != nil {
return nil, err
}
doc["_conflicts"] = revs
jonDoc, err := json.Marshal(doc)
if err != nil {
return nil, err
}
body = string(jonDoc)
}
return &driver.Document{
Rev: rev,
Body: io.NopCloser(strings.NewReader(body)),
}, nil
}

func (db) Delete(context.Context, string, driver.Options) (string, error) {
Expand Down
142 changes: 127 additions & 15 deletions x/sqlite/db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package sqlite

import (
"context"
"encoding/json"
"net/http"
"testing"

Expand Down Expand Up @@ -196,21 +197,7 @@ func TestDBPut(t *testing.T) {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
d := drv{}
client, err := d.NewClient(":memory:", nil)
if err != nil {
t.Fatal(err)
}
if err := client.CreateDB(context.Background(), "test", nil); err != nil {
t.Fatal(err)
}
db, err := client.DB("test", nil)
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() {
_ = db.Close()
})
db := newDB(t)
if tt.setup != nil {
tt.setup(t, db)
}
Expand All @@ -234,3 +221,128 @@ func TestDBPut(t *testing.T) {
})
}
}

func TestGet(t *testing.T) {
t.Parallel()
tests := []struct {
name string
setup func(*testing.T, driver.DB)
id string
options driver.Options
wantDoc interface{}
wantStatus int
wantErr string
}{
{
name: "not found",
id: "foo",
wantStatus: http.StatusNotFound,
wantErr: "not found",
},
{
name: "success",
setup: func(t *testing.T, d driver.DB) {
_, err := d.Put(context.Background(), "foo", map[string]string{"foo": "bar"}, mock.NilOption)
if err != nil {
t.Fatal(err)
}
},
id: "foo",
wantDoc: map[string]string{"foo": "bar"},
},
{
name: "get specific rev",
setup: func(t *testing.T, d driver.DB) {
rev, err := d.Put(context.Background(), "foo", map[string]string{"foo": "bar"}, mock.NilOption)
if err != nil {
t.Fatal(err)
}
_, err = d.Put(context.Background(), "foo", map[string]string{"foo": "baz"}, kivik.Rev(rev))
if err != nil {
t.Fatal(err)
}
},
id: "foo",
options: kivik.Rev("1-9bb58f26192e4ba00f01e2e7b136bbd8"),
wantDoc: map[string]string{"foo": "bar"},
},
{
name: "specific rev not found",
id: "foo",
options: kivik.Rev("1-9bb58f26192e4ba00f01e2e7b136bbd8"),
wantStatus: http.StatusNotFound,
wantErr: "not found",
},
{
name: "include conflicts",
setup: func(t *testing.T, d driver.DB) {
_, err := d.Put(context.Background(), "foo", map[string]string{"foo": "bar"}, kivik.Params(map[string]interface{}{
"new_edits": false,
"rev": "1-abc",
}))
if err != nil {
t.Fatal(err)
}
_, err = d.Put(context.Background(), "foo", map[string]string{"foo": "baz"}, kivik.Params(map[string]interface{}{
"new_edits": false,
"rev": "1-xyz",
}))
if err != nil {
t.Fatal(err)
}
},
id: "foo",
options: kivik.Param("conflicts", true),
wantDoc: map[string]interface{}{
"foo": "baz",
"_conflicts": []string{"1-abc"},
},
},
/*
TODO:
attachments = true
att_encoding_info = true
atts_since = [revs]
conflicts = true
deleted_conflicts = true
latest = true
local_seq = true
meta = true
open_revs = []
revs = true
revs_info = true
*/
}

for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
db := newDB(t)
if tt.setup != nil {
tt.setup(t, db)
}
opts := tt.options
if opts == nil {
opts = mock.NilOption
}
doc, err := db.Get(context.Background(), tt.id, opts)
if !testy.ErrorMatches(tt.wantErr, err) {
t.Errorf("Unexpected error: %s", err)
}
if status := kivik.HTTPStatus(err); status != tt.wantStatus {
t.Errorf("Unexpected status: %d", status)
}
if err != nil {
return
}
var gotDoc interface{}
if err := json.NewDecoder(doc.Body).Decode(&gotDoc); err != nil {
t.Fatal(err)
}
if d := testy.DiffAsJSON(tt.wantDoc, gotDoc); d != nil {
t.Errorf("Unexpected doc: %s", d)
}
})
}
}
14 changes: 7 additions & 7 deletions x/sqlite/json.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,30 +25,30 @@ import (
"github.com/go-kivik/kivik/v4/internal"
)

type rev struct {
type revision struct {
id int
rev string
}

func (r rev) String() string {
func (r revision) String() string {
return strconv.Itoa(r.id) + "-" + r.rev
}

func parseRev(s string) (rev, error) {
func parseRev(s string) (revision, error) {
if s == "" {
return rev{}, &internal.Error{Status: http.StatusBadRequest, Message: "missing _rev"}
return revision{}, &internal.Error{Status: http.StatusBadRequest, Message: "missing _rev"}
}
const revElements = 2
parts := strings.SplitN(s, "-", revElements)
id, err := strconv.ParseInt(parts[0], 10, 64)
if err != nil {
return rev{}, &internal.Error{Status: http.StatusBadRequest, Err: err}
return revision{}, &internal.Error{Status: http.StatusBadRequest, Err: err}
}
if len(parts) == 1 {
// A rev that contains only a number is technically valid.
return rev{id: int(id)}, nil
return revision{id: int(id)}, nil
}
return rev{id: int(id), rev: parts[1]}, nil
return revision{id: int(id), rev: parts[1]}, nil
}

// prepareDoc prepares the doc for insertion. It returns the new docID, rev, and
Expand Down
Loading

0 comments on commit 72d0daf

Please sign in to comment.