diff --git a/delete_handler.go b/delete_handler.go index 2912bd3..feec823 100644 --- a/delete_handler.go +++ b/delete_handler.go @@ -8,7 +8,7 @@ import ( "path" ) -func getDeleteStuff(urlPath string) (string, string, string, error) { +func parseDeleteEntryPath(urlPath string) (string, string, string, error) { pathDir, delKey := path.Split(urlPath) if len(pathDir) < 1 { return "", "", "", fmt.Errorf("Invalid URL %q", urlPath) @@ -28,7 +28,7 @@ func getDeleteStuff(urlPath string) (string, string, string, error) { } func handleDeleteEntry(w http.ResponseWriter, r *http.Request) { - UUID, _, deleteKey, err := getDeleteStuff(r.URL.Path) + UUID, _, deleteKey, err := parseDeleteEntryPath(r.URL.Path) if err != nil { log.Println(err) diff --git a/postgresql_storage_construct.go b/postgresql_storage_construct.go index 5ff9fe3..142d034 100644 --- a/postgresql_storage_construct.go +++ b/postgresql_storage_construct.go @@ -1,23 +1,13 @@ package main import ( + "context" "database/sql" "log" ) type dbExec func(*sql.DB) error -func addExtension(db *sql.DB) error { - q, err := db.Prepare("CREATE EXTENSION IF NOT EXISTS \"uuid-ossp\";") - - if err != nil { - return err - } - _, err = q.Exec() - - return err -} - func createTable(db *sql.DB) error { q, err := db.Prepare("CREATE TABLE IF NOT EXISTS entries (uuid uuid PRIMARY KEY, data BYTEA, remaining_reads SMALLINT DEFAULT 1, delete_key CHAR(256) NOT NULL, created TIMESTAMPTZ, accessed TIMESTAMPTZ, expire TIMESTAMPTZ);") @@ -42,15 +32,48 @@ func addRemainingRead(db *sql.DB) error { } func addDeleteKey(db *sql.DB) error { - alterTable, err := db.Prepare("ALTER TABLE entries ADD COLUMN IF NOT EXISTS delete_key CHAR(256) NOT NULL;") + ctx := context.Background() + tx, err := db.BeginTx(ctx, nil) + alterTable, err := db.PrepareContext(ctx, "ALTER TABLE entries ADD COLUMN IF NOT EXISTS delete_key CHAR(256);") if err != nil { + tx.Rollback() return err } - _, err = alterTable.Exec() + _, err = alterTable.ExecContext(ctx) - return err + rows, err := db.QueryContext(ctx, "SELECT uuid FROM entries WHERE delete_key IS NULL;") + if err != nil { + tx.Rollback() + return err + } + + for rows.Next() { + var UUID string + if err := rows.Scan(&UUID); err != nil { + tx.Rollback() + return err + } + + _, deleteKey, err := createKey() + if err != nil { + tx.Rollback() + return err + } + + _, err = db.ExecContext(ctx, "UPDATE entries SET delete_key=$2 WHERE uuid=$1", UUID, deleteKey) + if err != nil { + tx.Rollback() + return err + } + } + _, err = db.ExecContext(ctx, "ALTER TABLE entries ALTER COLUMN delete_key SET NOT NULL;") + if err != nil { + tx.Rollback() + return err + } + return tx.Commit() } func newPostgresqlStorage(psqlconn string) *postgresqlStorage { @@ -67,7 +90,7 @@ func newPostgresqlStorage(psqlconn string) *postgresqlStorage { log.Fatal("DB ping failed", err) } - for _, f := range []dbExec{addExtension, createTable, addRemainingRead, addDeleteKey} { + for _, f := range []dbExec{createTable, addRemainingRead, addDeleteKey} { err = f(db) if err != nil { defer db.Close()