diff --git a/.golangci.toml b/.golangci.toml index 7cb5dd1e1..1d9fe05cf 100644 --- a/.golangci.toml +++ b/.golangci.toml @@ -24,13 +24,17 @@ enable = [ "gocritic", "staticcheck", "goheader", + + # SQL-related linters + "rowserrcheck", + "sqlclosecheck", ] [issues] exclude-use-default = false [[issues.exclude-rules]] -source = "defer .*\\.Close\\(\\)$" +source = "defer .*\\.(Close|Rollback)\\(\\)$" linters = ["errcheck"] [linters-settings.gci] diff --git a/x/sqlite/db.go b/x/sqlite/db.go new file mode 100644 index 000000000..ab011c38f --- /dev/null +++ b/x/sqlite/db.go @@ -0,0 +1,167 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); you may not +// use this file except in compliance with the License. You may obtain a copy of +// the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations under +// the License. + +package sqlite + +import ( + "context" + "database/sql" + "errors" + "net/http" + + "modernc.org/sqlite" + sqlite3 "modernc.org/sqlite/lib" + + "github.com/go-kivik/kivik/v4/driver" + "github.com/go-kivik/kivik/v4/internal" +) + +type db struct { + db *sql.DB + name string +} + +var _ driver.DB = (*db)(nil) + +func (db) AllDocs(context.Context, driver.Options) (driver.Rows, error) { + return nil, nil +} + +func (db) CreateDoc(context.Context, interface{}, driver.Options) (string, string, error) { + return "", "", nil +} + +func (d *db) Put(ctx context.Context, docID string, doc interface{}, options driver.Options) (string, error) { + docRev, err := extractRev(doc) + if err != nil { + return "", err + } + opts := map[string]interface{}{ + "new_edits": true, + } + options.Apply(opts) + optsRev, _ := opts["rev"].(string) + if optsRev != "" && docRev != "" && optsRev != docRev { + return "", &internal.Error{Status: http.StatusBadRequest, Message: "Document rev and option have different values"} + } + if docRev == "" && optsRev != "" { + docRev = optsRev + } + + docID, rev, jsonDoc, err := prepareDoc(docID, doc) + if err != nil { + return "", err + } + + if newEdits, _ := opts["new_edits"].(bool); !newEdits { + if docRev == "" { + return "", &internal.Error{Status: http.StatusBadRequest, Message: "When `new_edits: false`, the document needs `_rev` or `_revisions` specified"} + } + rev, err := parseRev(docRev) + if err != nil { + return "", err + } + var newRev string + err = d.db.QueryRowContext(ctx, ` + INSERT INTO `+d.name+` (id, rev_id, rev, doc) + VALUES ($1, $2, $3, $4) + RETURNING rev_id || '-' || rev + `, docID, rev.id, rev.rev, jsonDoc).Scan(&newRev) + var sqliteErr *sqlite.Error + if errors.As(err, &sqliteErr) && sqliteErr.Code() == sqlite3.SQLITE_CONSTRAINT_UNIQUE { + // In the case of a conflict for new_edits=false, we assume that the + // documents are identical, for the sake of idempotency, and return + // the current rev, to match CouchDB behavior. + return docRev, nil + } + return newRev, err + } + + tx, err := d.db.BeginTx(ctx, nil) + if err != nil { + return "", err + } + defer tx.Rollback() + + var curRev string + err = tx.QueryRowContext(ctx, ` + SELECT COALESCE(MAX(rev_id || '-' || rev),'') + FROM `+d.name+` + WHERE id = $1 + `, docID).Scan(&curRev) + if err != nil && !errors.Is(err, sql.ErrNoRows) { + return "", err + } + if curRev != docRev { + return "", &internal.Error{Status: http.StatusConflict, Message: "conflict"} + } + var newRev string + err = tx.QueryRowContext(ctx, ` + INSERT INTO `+d.name+` (id, rev_id, rev, doc) + SELECT $1, COALESCE(MAX(rev_id),0) + 1, $2, $3 + FROM `+d.name+` + WHERE id = $1 + RETURNING rev_id || '-' || rev + `, docID, rev, jsonDoc).Scan(&newRev) + if err != nil { + return "", err + } + return newRev, tx.Commit() +} + +func (db) Get(context.Context, string, driver.Options) (*driver.Document, error) { + return nil, nil +} + +func (db) Delete(context.Context, string, driver.Options) (string, error) { + return "", nil +} + +func (db) Stats(context.Context) (*driver.DBStats, error) { + return nil, nil +} + +func (db) Compact(context.Context) error { + return nil +} + +func (db) CompactView(context.Context, string) error { + return nil +} + +func (db) ViewCleanup(context.Context) error { + return nil +} + +func (db) Changes(context.Context, driver.Options) (driver.Changes, error) { + return nil, nil +} + +func (db) PutAttachment(context.Context, string, *driver.Attachment, driver.Options) (string, error) { + return "", nil +} + +func (db) GetAttachment(context.Context, string, string, driver.Options) (*driver.Attachment, error) { + return nil, nil +} + +func (db) DeleteAttachment(context.Context, string, string, driver.Options) (string, error) { + return "", nil +} + +func (db) Query(context.Context, string, string, driver.Options) (driver.Rows, error) { + return nil, nil +} + +func (db) Close() error { + return nil +} diff --git a/x/sqlite/db_test.go b/x/sqlite/db_test.go new file mode 100644 index 000000000..48fce1515 --- /dev/null +++ b/x/sqlite/db_test.go @@ -0,0 +1,236 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); you may not +// use this file except in compliance with the License. You may obtain a copy of +// the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations under +// the License. + +//go:build !js +// +build !js + +package sqlite + +import ( + "context" + "net/http" + "testing" + + "gitlab.com/flimzy/testy" + + "github.com/go-kivik/kivik/v4" + "github.com/go-kivik/kivik/v4/driver" + "github.com/go-kivik/kivik/v4/internal/mock" +) + +func TestDBPut(t *testing.T) { + t.Parallel() + tests := []struct { + name string + setup func(*testing.T, driver.DB) + docID string + doc interface{} + options driver.Options + check func(*testing.T, driver.DB) + wantRev string + wantStatus int + wantErr string + }{ + { + name: "create new document", + docID: "foo", + doc: map[string]string{ + "foo": "bar", + }, + wantRev: "1-9bb58f26192e4ba00f01e2e7b136bbd8", + }, + { + name: "doc rev & option rev mismatch", + docID: "foo", + doc: map[string]interface{}{ + "_rev": "1-1234567890abcdef1234567890abcdef", + "foo": "bar", + }, + options: driver.Options(kivik.Rev("2-1234567890abcdef1234567890abcdef")), + wantStatus: http.StatusBadRequest, + wantErr: "Document rev and option have different values", + }, + { + name: "attempt to create doc with rev should conflict", + docID: "foo", + doc: map[string]interface{}{ + "_rev": "1-1234567890abcdef1234567890abcdef", + "foo": "bar", + }, + wantStatus: http.StatusConflict, + wantErr: "conflict", + }, + { + name: "attempt to update doc without rev should conflict", + setup: func(t *testing.T, d driver.DB) { + if _, err := d.Put(context.Background(), "foo", map[string]string{"foo": "bar"}, mock.NilOption); err != nil { + t.Fatal(err) + } + }, + docID: "foo", + doc: map[string]interface{}{ + "foo": "bar", + }, + wantStatus: http.StatusConflict, + wantErr: "conflict", + }, + { + name: "attempt to update doc with wrong rev should conflict", + setup: func(t *testing.T, d driver.DB) { + if _, err := d.Put(context.Background(), "foo", map[string]string{"foo": "bar"}, mock.NilOption); err != nil { + t.Fatal(err) + } + }, + docID: "foo", + doc: map[string]interface{}{ + "_rev": "2-1234567890abcdef1234567890abcdef", + "foo": "bar", + }, + wantStatus: http.StatusConflict, + wantErr: "conflict", + }, + { + name: "update doc with correct rev", + setup: func(t *testing.T, d driver.DB) { + _, err := d.Put(context.Background(), "foo", map[string]string{"foo": "bar"}, mock.NilOption) + if err != nil { + t.Fatal(err) + } + }, + docID: "foo", + doc: map[string]interface{}{ + "_rev": "1-9bb58f26192e4ba00f01e2e7b136bbd8", + "foo": "baz", + }, + wantRev: "2-afa7ae8a1906f4bb061be63525974f92", + }, + { + name: "update doc with new_edits=false, no existing doc", + docID: "foo", + doc: map[string]interface{}{ + "_rev": "1-6fe51f74859f3579abaccc426dd5104f", + "foo": "baz", + }, + options: kivik.Param("new_edits", false), + wantRev: "1-6fe51f74859f3579abaccc426dd5104f", + }, + { + name: "update doc with new_edits=false, no rev", + docID: "foo", + doc: map[string]interface{}{ + "foo": "baz", + }, + options: kivik.Param("new_edits", false), + wantStatus: http.StatusBadRequest, + wantErr: "When `new_edits: false`, the document needs `_rev` or `_revisions` specified", + }, + { + name: "update doc with new_edits=false, existing doc", + setup: func(t *testing.T, d driver.DB) { + _, err := d.Put(context.Background(), "foo", map[string]string{"foo": "bar"}, mock.NilOption) + if err != nil { + t.Fatal(err) + } + }, + docID: "foo", + doc: map[string]interface{}{ + "_rev": "1-asdf", + "foo": "baz", + }, + options: kivik.Param("new_edits", false), + wantRev: "1-asdf", + }, + { + name: "update doc with new_edits=false, existing doc and rev", + setup: func(t *testing.T, d driver.DB) { + _, err := d.Put(context.Background(), "foo", map[string]string{"foo": "bar"}, mock.NilOption) + if err != nil { + t.Fatal(err) + } + }, + docID: "foo", + doc: map[string]interface{}{ + "_rev": "1-9bb58f26192e4ba00f01e2e7b136bbd8", + "foo": "baz", + }, + options: kivik.Param("new_edits", false), + wantRev: "1-9bb58f26192e4ba00f01e2e7b136bbd8", + check: func(t *testing.T, d driver.DB) { + var doc string + err := d.(*db).db.QueryRow(` + SELECT doc + FROM test + WHERE id='foo' + AND rev_id=1 + AND rev='9bb58f26192e4ba00f01e2e7b136bbd8'`).Scan(&doc) + if err != nil { + t.Fatal(err) + } + if doc != `{"foo":"bar"}` { + t.Errorf("Unexpected doc: %s", doc) + } + }, + }, + { + name: "doc id in url and doc differ", + docID: "foo", + doc: map[string]interface{}{ + "_id": "bar", + "foo": "baz", + }, + wantStatus: http.StatusBadRequest, + wantErr: "Document ID must match _id in document", + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + d := drv{} + client, err := d.NewClient(":memory:", nil) + if err != nil { + t.Fatal(err) + } + if err := client.CreateDB(context.Background(), "test", nil); err != nil { + t.Fatal(err) + } + db, err := client.DB("test", nil) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { + _ = db.Close() + }) + if tt.setup != nil { + tt.setup(t, db) + } + opts := tt.options + if opts == nil { + opts = mock.NilOption + } + rev, err := db.Put(context.Background(), tt.docID, tt.doc, opts) + if !testy.ErrorMatches(tt.wantErr, err) { + t.Errorf("Unexpected error: %s", err) + } + if tt.check != nil { + tt.check(t, db) + } + if err != nil { + return + } + if rev != tt.wantRev { + t.Errorf("Unexpected rev: %s, want %s", rev, tt.wantRev) + } + }) + } +} diff --git a/x/sqlite/json.go b/x/sqlite/json.go new file mode 100644 index 000000000..43b4c3946 --- /dev/null +++ b/x/sqlite/json.go @@ -0,0 +1,102 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); you may not +// use this file except in compliance with the License. You may obtain a copy of +// the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations under +// the License. + +package sqlite + +import ( + "bytes" + "crypto/md5" + "encoding/hex" + "encoding/json" + "io" + "net/http" + "strconv" + "strings" + + "github.com/go-kivik/kivik/v4/internal" +) + +type rev struct { + id int + rev string +} + +func (r rev) String() string { + return strconv.Itoa(r.id) + "-" + r.rev +} + +func parseRev(s string) (rev, error) { + if s == "" { + return rev{}, &internal.Error{Status: http.StatusBadRequest, Message: "missing _rev"} + } + const revElements = 2 + parts := strings.SplitN(s, "-", revElements) + id, err := strconv.ParseInt(parts[0], 10, 64) + if err != nil { + return rev{}, &internal.Error{Status: http.StatusBadRequest, Err: err} + } + if len(parts) == 1 { + // A rev that contains only a number is technically valid. + return rev{id: int(id)}, nil + } + return rev{id: int(id), rev: parts[1]}, nil +} + +// prepareDoc prepares the doc for insertion. It returns the new docID, rev, and +// marshaled doc with rev and id removed. +func prepareDoc(docID string, doc interface{}) (string, string, []byte, error) { + tmpJSON, err := json.Marshal(doc) + if err != nil { + return "", "", nil, err + } + var tmp map[string]interface{} + if err := json.Unmarshal(tmpJSON, &tmp); err != nil { + return "", "", nil, err + } + delete(tmp, "_rev") + if id, ok := tmp["_id"].(string); ok { + if docID != "" && id != docID { + return "", "", nil, &internal.Error{Status: http.StatusBadRequest, Message: "Document ID must match _id in document"} + } + docID = id + delete(tmp, "_id") + } + h := md5.New() + b, _ := json.Marshal(tmp) + if _, err := io.Copy(h, bytes.NewReader(b)); err != nil { + return "", "", nil, err + } + return docID, hex.EncodeToString(h.Sum(nil)), b, nil +} + +// extractRev extracts the rev from the document. +func extractRev(doc interface{}) (string, error) { + switch t := doc.(type) { + case map[string]interface{}: + r, _ := t["_rev"].(string) + return r, nil + case map[string]string: + return t["_rev"], nil + default: + tmpJSON, err := json.Marshal(doc) + if err != nil { + return "", &internal.Error{Status: http.StatusBadRequest, Err: err} + } + var revDoc struct { + Rev string `json:"_rev"` + } + if err := json.Unmarshal(tmpJSON, &revDoc); err != nil { + return "", &internal.Error{Status: http.StatusBadRequest, Err: err} + } + return revDoc.Rev, nil + } +} diff --git a/x/sqlite/json_test.go b/x/sqlite/json_test.go new file mode 100644 index 000000000..019b7902a --- /dev/null +++ b/x/sqlite/json_test.go @@ -0,0 +1,138 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); you may not +// use this file except in compliance with the License. You may obtain a copy of +// the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations under +// the License. + +//go:build !js +// +build !js + +package sqlite + +import ( + "testing" + + "gitlab.com/flimzy/testy" +) + +func Test_calculateRev(t *testing.T) { + tests := []struct { + name string + docID string + doc interface{} + want string + wantErr string + }{ + { + name: "no rev in document", + doc: map[string]string{"foo": "bar"}, + want: "9bb58f26192e4ba00f01e2e7b136bbd8", + }, + { + name: "rev in document", + doc: map[string]interface{}{ + "_rev": "1-1234567890abcdef1234567890abcdef", + "foo": "bar", + }, + want: "9bb58f26192e4ba00f01e2e7b136bbd8", + }, + { + name: "add docID", + docID: "foo", + doc: map[string]string{"foo": "bar"}, + want: "9bb58f26192e4ba00f01e2e7b136bbd8", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, got, _, err := prepareDoc(tt.docID, tt.doc) + if !testy.ErrorMatches(tt.wantErr, err) { + t.Errorf("unexpected error = %v, wantErr %v", err, tt.wantErr) + } + if got != tt.want { + t.Errorf("unexpected rev= %v, want %v", got, tt.want) + } + }) + } +} + +func Test_extractRev(t *testing.T) { + tests := []struct { + name string + doc interface{} + wantRev string + wantErr string + }{ + { + name: "nil", + doc: nil, + wantRev: "", + }, + { + name: "empty", + doc: map[string]string{}, + wantRev: "", + }, + { + name: "no rev", + doc: map[string]string{"foo": "bar"}, + wantRev: "", + }, + { + name: "rev in string", + doc: map[string]string{"_rev": "1-1234567890abcdef1234567890abcdef"}, + wantRev: "1-1234567890abcdef1234567890abcdef", + }, + { + name: "rev in interface", + doc: map[string]interface{}{"_rev": "1-1234567890abcdef1234567890abcdef"}, + wantRev: "1-1234567890abcdef1234567890abcdef", + }, + { + name: "rev in struct", + doc: struct { + Rev string `json:"_rev"` + }{Rev: "1-1234567890abcdef1234567890abcdef"}, + wantRev: "1-1234567890abcdef1234567890abcdef", + }, + { + name: "rev id only", + doc: map[string]string{"_rev": "1"}, + wantRev: "1", + }, + { + name: "invalid rev struct", + doc: struct{ Rev func() }{}, + wantErr: "json: unsupported type: func()", + }, + { + name: "invalid rev type", + doc: struct { + Rev int `json:"_rev"` + }{Rev: 1}, + wantErr: "json: cannot unmarshal number into Go struct field ._rev of type string", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rev, err := extractRev(tt.doc) + if !testy.ErrorMatches(tt.wantErr, err) { + t.Errorf("unexpected error = %v, wantErr %v", err, tt.wantErr) + } + if err != nil { + return + } + if rev != tt.wantRev { + t.Errorf("unexpected rev= %v, want %v", rev, tt.wantRev) + } + }) + } +} diff --git a/x/sqlite/schema.go b/x/sqlite/schema.go new file mode 100644 index 000000000..21d4dfc11 --- /dev/null +++ b/x/sqlite/schema.go @@ -0,0 +1,24 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); you may not +// use this file except in compliance with the License. You may obtain a copy of +// the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations under +// the License. + +package sqlite + +const ( + schema = `CREATE TABLE %q ( + seq INTEGER PRIMARY KEY, + id TEXT, + rev_id INTEGER, + rev TEXT, + doc BLOB, + UNIQUE(id, rev_id, rev) + )` +) diff --git a/x/sqlite/sqlite.go b/x/sqlite/sqlite.go index 1726524f8..d7e9aa9eb 100644 --- a/x/sqlite/sqlite.go +++ b/x/sqlite/sqlite.go @@ -16,6 +16,7 @@ import ( "context" "database/sql" "errors" + "fmt" "net/http" "regexp" "strings" @@ -33,8 +34,8 @@ var _ driver.Driver = (*drv)(nil) // NewClient returns a new SQLite client. dsn should be the full path to your // SQLite database file. -func (drv) NewClient(dns string, _ driver.Options) (driver.Client, error) { - db, err := sql.Open("sqlite", dns) +func (drv) NewClient(dsn string, _ driver.Options) (driver.Client, error) { + db, err := sql.Open("sqlite", dsn) if err != nil { return nil, err } @@ -114,7 +115,7 @@ func (c *client) CreateDB(ctx context.Context, name string, _ driver.Options) er if !validDBNameRE.MatchString(name) { return &internal.Error{Status: http.StatusBadRequest, Message: "invalid database name"} } - _, err := c.db.ExecContext(ctx, `CREATE TABLE "`+name+`" (id INTEGER)`) + _, err := c.db.ExecContext(ctx, fmt.Sprintf(schema, name)) if err == nil { return nil } @@ -144,6 +145,12 @@ func (c *client) DestroyDB(ctx context.Context, name string, _ driver.Options) e return err } -func (client) DB(string, driver.Options) (driver.DB, error) { - return nil, nil +func (c *client) DB(name string, _ driver.Options) (driver.DB, error) { + if !validDBNameRE.MatchString(name) { + return nil, &internal.Error{Status: http.StatusBadRequest, Message: "invalid database name"} + } + return &db{ + db: c.db, + name: name, + }, nil }