Skip to content

Commit

Permalink
Quote db names, just in case of unusual chars
Browse files Browse the repository at this point in the history
  • Loading branch information
flimzy committed Feb 11, 2024
1 parent b3c84ac commit 88d9ac2
Showing 1 changed file with 20 additions and 19 deletions.
39 changes: 20 additions & 19 deletions x/sqlite/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"database/sql"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strings"
Expand Down Expand Up @@ -74,11 +75,11 @@ func (d *db) Put(ctx context.Context, docID string, doc interface{}, options dri
return "", err
}
var newRev string
err = d.db.QueryRowContext(ctx, `
INSERT INTO `+d.name+` (id, rev_id, rev, doc)
err = d.db.QueryRowContext(ctx, fmt.Sprintf(`
INSERT INTO %q (id, rev_id, rev, doc)
VALUES ($1, $2, $3, $4)
RETURNING rev_id || '-' || rev
`, docID, rev.id, rev.rev, jsonDoc).Scan(&newRev)
`, d.name), docID, rev.id, rev.rev, jsonDoc).Scan(&newRev)
var sqliteErr *sqlite.Error
if errors.As(err, &sqliteErr) && sqliteErr.Code() == sqlite3.SQLITE_CONSTRAINT_UNIQUE {
// In the case of a conflict for new_edits=false, we assume that the
Expand All @@ -96,25 +97,25 @@ func (d *db) Put(ctx context.Context, docID string, doc interface{}, options dri
defer tx.Rollback()

var curRev string
err = tx.QueryRowContext(ctx, `
err = tx.QueryRowContext(ctx, fmt.Sprintf(`
SELECT COALESCE(MAX(rev_id || '-' || rev),'')
FROM `+d.name+`
FROM %q
WHERE id = $1
`, docID).Scan(&curRev)
`, d.name), docID).Scan(&curRev)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return "", err
}
if curRev != docRev {
return "", &internal.Error{Status: http.StatusConflict, Message: "conflict"}
}
var newRev string
err = tx.QueryRowContext(ctx, `
INSERT INTO `+d.name+` (id, rev_id, rev, doc)
err = tx.QueryRowContext(ctx, fmt.Sprintf(`
INSERT INTO %[1]q (id, rev_id, rev, doc)
SELECT $1, COALESCE(MAX(rev_id),0) + 1, $2, $3
FROM `+d.name+`
FROM %[1]q
WHERE id = $1
RETURNING rev_id || '-' || rev
`, docID, rev, jsonDoc).Scan(&newRev)
`, d.name), docID, rev, jsonDoc).Scan(&newRev)
if err != nil {
return "", err
}
Expand All @@ -134,23 +135,23 @@ func (d *db) Get(ctx context.Context, id string, options driver.Options) (*drive
if err != nil {
return nil, err

Check warning on line 136 in x/sqlite/db.go

View check run for this annotation

Codecov / codecov/patch

x/sqlite/db.go#L136

Added line #L136 was not covered by tests
}
err = d.db.QueryRowContext(ctx, `
err = d.db.QueryRowContext(ctx, fmt.Sprintf(`
SELECT doc
FROM `+d.name+`
FROM %q
WHERE id = $1
AND rev_id = $2
AND rev = $3
`, id, r.id, r.rev).Scan(&body)
`, d.name), id, r.id, r.rev).Scan(&body)
rev = optsRev
} else {
err = d.db.QueryRowContext(ctx, `
err = d.db.QueryRowContext(ctx, fmt.Sprintf(`
SELECT rev_id || '-' || rev, doc
FROM `+d.name+`
FROM %q
WHERE id = $1
AND deleted = FALSE
ORDER BY rev_id DESC, rev DESC
LIMIT 1
`, id).Scan(&rev, &body)
`, d.name), id).Scan(&rev, &body)
}

if errors.Is(err, sql.ErrNoRows) {
Expand All @@ -162,13 +163,13 @@ func (d *db) Get(ctx context.Context, id string, options driver.Options) (*drive

if conflicts, _ := opts["conflicts"].(bool); conflicts {
var revs []string
rows, err := d.db.QueryContext(ctx, `
rows, err := d.db.QueryContext(ctx, fmt.Sprintf(`
SELECT rev_id || '-' || rev
FROM `+d.name+`
FROM %q
WHERE id = $1
AND rev_id || '-' || rev != $2
AND DELETED = FALSE
`, id, rev)
`, d.name), id, rev)
if err != nil {
return nil, err

Check warning on line 174 in x/sqlite/db.go

View check run for this annotation

Codecov / codecov/patch

x/sqlite/db.go#L174

Added line #L174 was not covered by tests
}
Expand Down

0 comments on commit 88d9ac2

Please sign in to comment.