diff --git a/go.mod b/go.mod index 38fd724..a234d34 100644 --- a/go.mod +++ b/go.mod @@ -12,6 +12,7 @@ require ( require ( github.com/beorn7/perks v1.0.1 // indirect + github.com/canonical/sqlair v0.0.0-20241004123011-77313b5382fd github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/klauspost/compress v1.17.9 // indirect github.com/kr/text v0.2.0 // indirect diff --git a/go.sum b/go.sum index 56c00d9..4b3de95 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,7 @@ github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= +github.com/canonical/sqlair v0.0.0-20241004123011-77313b5382fd h1:lVn7391CX7QQ5WBQriNUtCB4fvfurZg6XJwH9aVsRII= +github.com/canonical/sqlair v0.0.0-20241004123011-77313b5382fd/go.mod h1:T+40I2sXshY3KRxx0QQpqqn6hCibSKJ2KHzjBvJj8T4= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= diff --git a/internal/db/db.go b/internal/db/db.go index 868a6b2..6ba5b22 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -2,273 +2,444 @@ package db import ( + "context" "database/sql" "errors" "fmt" + "github.com/canonical/sqlair" _ "github.com/mattn/go-sqlite3" "golang.org/x/crypto/bcrypt" ) +// Database is the object used to communicate with the established repository. +type Database struct { + certificateTable string + usersTable string + conn *sqlair.DB +} + +type CertificateRequest struct { + ID int `db:"id"` + + CSR string `db:"csr"` + CertificateChain string `db:"certificate_chain"` + RequestStatus string `db:"request_status"` +} + +type User struct { + ID int `db:"id"` + + Username string `db:"username"` + HashedPassword string `db:"hashed_password"` + Permissions int `db:"permissions"` +} + const ( - certificateRequestsTableName = "CertificateRequests" + certificateRequestsTableName = "certificate_requests" usersTableName = "users" ) -const queryCreateCSRsTable = `CREATE TABLE IF NOT EXISTS %s ( - csr TEXT PRIMARY KEY UNIQUE NOT NULL, - certificate TEXT DEFAULT '' +const queryCreateCSRsTable = ` + CREATE TABLE IF NOT EXISTS %s ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + + csr TEXT NOT NULL UNIQUE, + certificate_chain TEXT DEFAULT '', + request_status TEXT DEFAULT 'Outstanding', + + CHECK (request_status IN ('Outstanding', 'Rejected', 'Revoked', 'Active')), + CHECK (NOT (certificate_chain == '' AND request_status == 'Active' )), + CHECK (NOT (certificate_chain != '' AND request_status == 'Outstanding')) + CHECK (NOT (certificate_chain != '' AND request_status == 'Rejected')) + CHECK (NOT (certificate_chain != '' AND request_status == 'Revoked')) )` -const ( - queryGetAllCSRs = "SELECT rowid, * FROM %s" - queryGetCSR = "SELECT rowid, * FROM %s WHERE rowid=?" - queryCreateCSR = "INSERT INTO %s (csr) VALUES (?)" - queryUpdateCSR = "UPDATE %s SET certificate=? WHERE rowid=?" - queryDeleteCSR = "DELETE FROM %s WHERE rowid=?" -) +const queryCreateUsersTable = ` + CREATE TABLE IF NOT EXISTS %s ( + id INTEGER PRIMARY KEY AUTOINCREMENT, -const queryCreateUsersTable = `CREATE TABLE IF NOT EXISTS %s ( - user_id INTEGER PRIMARY KEY AUTOINCREMENT, - username TEXT NOT NULL UNIQUE, - hashed_password TEXT NOT NULL, - permissions INTEGER + username TEXT NOT NULL UNIQUE, + hashed_password TEXT NOT NULL, + permissions INTEGER )` const ( - queryGetAllUsers = "SELECT * FROM %s" - queryGetUser = "SELECT * FROM %s WHERE user_id=?" - queryGetUserByUsername = "SELECT * FROM %s WHERE username=?" - queryCreateUser = "INSERT INTO %s (username, hashed_password, permissions) VALUES (?, ?, ?)" - queryUpdateUser = "UPDATE %s SET hashed_password=? WHERE user_id=?" - queryDeleteUser = "DELETE FROM %s WHERE user_id=?" - queryGetNumUsers = "SELECT COUNT(*) FROM %s" + getAllCSRsStmt = "SELECT &CertificateRequest.* FROM %s" + getCSRsStmt = "SELECT &CertificateRequest.* FROM %s WHERE id==$CertificateRequest.id or csr==$CertificateRequest.csr" + createCSRStmt = "INSERT INTO %s (csr) VALUES ($CertificateRequest.csr)" + updateCSRStmt = "UPDATE %s SET certificate_chain=$CertificateRequest.certificate_chain, request_status=$CertificateRequest.request_status WHERE id==$CertificateRequest.id or csr==$CertificateRequest.csr" + deleteCSRStmt = "DELETE FROM %s WHERE id=$CertificateRequest.id or csr=$CertificateRequest.csr" ) -// CertificateRequestRepository is the object used to communicate with the established repository. -type Database struct { - certificateTable string - usersTable string - conn *sql.DB +const ( + getAllUsersStmt = "SELECT &User.* from %s" + getUserStmt = "SELECT &User.* from %s WHERE id==$User.id or username==$User.username" + createUserStmt = "INSERT INTO %s (username, hashed_password, permissions) VALUES ($User.username, $User.hashed_password, $User.permissions)" + updateUserStmt = "UPDATE %s SET hashed_password=$User.hashed_password WHERE id==$User.id or username==$User.username" + deleteUserStmt = "DELETE FROM %s WHERE id==$User.id" + getNumUsersStmt = "SELECT COUNT(*) AS &NumUsers.count FROM %s" +) + +// RetrieveAllCSRs gets every CertificateRequest entry in the table. +func (db *Database) RetrieveAllCSRs() ([]CertificateRequest, error) { + stmt, err := sqlair.Prepare(fmt.Sprintf(getAllCSRsStmt, db.certificateTable), CertificateRequest{}) + if err != nil { + return nil, err + } + var csrs []CertificateRequest + err = db.conn.Query(context.Background(), stmt).GetAll(&csrs) + if err != nil { + if errors.Is(err, sqlair.ErrNoRows) { + return csrs, nil + } + return nil, err + } + return csrs, nil } -// A CertificateRequest struct represents an entry in the database. -// The object contains a Certificate Request, its matching Certificate if any, and the row ID. -type CertificateRequest struct { - ID int - CSR string - Certificate string +// RetrieveCSRbyID gets a CSR row from the repository from a given ID. +func (db *Database) RetrieveCSRbyID(id int) (*CertificateRequest, error) { + csr := CertificateRequest{ + ID: id, + } + stmt, err := sqlair.Prepare(fmt.Sprintf(getCSRsStmt, db.certificateTable), CertificateRequest{}) + if err != nil { + return nil, err + } + err = db.conn.Query(context.Background(), stmt, csr).Get(&csr) + if err != nil { + return nil, err + } + return &csr, nil } -type User struct { - ID int - Username string - Password string - Permissions int + +// RetrieveCSRbyCSR gets a given CSR row from the repository using the CSR text. +func (db *Database) RetrieveCSRbyCSR(csr string) (*CertificateRequest, error) { + row := CertificateRequest{ + CSR: csr, + } + stmt, err := sqlair.Prepare(fmt.Sprintf(getCSRsStmt, db.certificateTable), CertificateRequest{}) + if err != nil { + return nil, err + } + err = db.conn.Query(context.Background(), stmt, row).Get(&row) + if err != nil { + return nil, err + } + return &row, nil } -var ErrIdNotFound = errors.New("id not found") +// CreateCSR creates a new CSR entry in the repository. The string must be a valid CSR and unique. +func (db *Database) CreateCSR(csr string) error { + if err := ValidateCertificateRequest(csr); err != nil { + return errors.New("csr validation failed: " + err.Error()) + } + stmt, err := sqlair.Prepare(fmt.Sprintf(createCSRStmt, db.certificateTable), CertificateRequest{}) + if err != nil { + return err + } + row := CertificateRequest{ + CSR: csr, + } + err = db.conn.Query(context.Background(), stmt, row).Run() + if err != nil { + return err + } + return nil +} -// RetrieveAllCSRs gets every CertificateRequest entry in the table. -func (db *Database) RetrieveAllCSRs() ([]CertificateRequest, error) { - rows, err := db.conn.Query(fmt.Sprintf(queryGetAllCSRs, db.certificateTable)) +// AddCertificateChainToCSRbyCSR adds a new certificate chain to a row for a given CSR string. +func (db *Database) AddCertificateChainToCSRbyCSR(csr string, cert string) error { + err := ValidateCertificate(cert) if err != nil { - return nil, err + return errors.New("cert validation failed: " + err.Error()) + } + err = CertificateMatchesCSR(cert, csr) + if err != nil { + return errors.New("cert validation failed: " + err.Error()) + } + certBundle := sanitizeCertificateBundle(cert) + stmt, err := sqlair.Prepare(fmt.Sprintf(updateCSRStmt, db.certificateTable), CertificateRequest{}) + if err != nil { + return err + } + newRow := CertificateRequest{ + CSR: csr, + CertificateChain: certBundle, + RequestStatus: "Active", + } + err = db.conn.Query(context.Background(), stmt, newRow).Run() + if err != nil { + return err } + return nil +} - var allCsrs []CertificateRequest - defer rows.Close() - for rows.Next() { - var csr CertificateRequest - if err := rows.Scan(&csr.ID, &csr.CSR, &csr.Certificate); err != nil { - return nil, err - } - allCsrs = append(allCsrs, csr) +// AddCertificateChainToCSRbyID adds a new certificate chain to a row for a given row ID. +func (db *Database) AddCertificateToCSRbyID(id int, cert string) error { + csr, err := db.RetrieveCSRbyID(id) + if err != nil { + return err + } + err = ValidateCertificate(cert) + if err != nil { + return errors.New("cert validation failed: " + err.Error()) } - return allCsrs, nil + err = CertificateMatchesCSR(cert, csr.CSR) + if err != nil { + return errors.New("cert validation failed: " + err.Error()) + } + certBundle := sanitizeCertificateBundle(cert) + stmt, err := sqlair.Prepare(fmt.Sprintf(updateCSRStmt, db.certificateTable), CertificateRequest{}) + if err != nil { + return err + } + newRow := CertificateRequest{ + ID: id, + CertificateChain: certBundle, + RequestStatus: "Active", + } + err = db.conn.Query(context.Background(), stmt, newRow).Run() + if err != nil { + return err + } + return nil } -// RetrieveCSR gets a given CSR from the repository. -// It returns the row id and matching certificate alongside the CSR in a CertificateRequest object. -func (db *Database) RetrieveCSR(id string) (CertificateRequest, error) { - var newCSR CertificateRequest - row := db.conn.QueryRow(fmt.Sprintf(queryGetCSR, db.certificateTable), id) - if err := row.Scan(&newCSR.ID, &newCSR.CSR, &newCSR.Certificate); err != nil { - if err.Error() == "sql: no rows in result set" { - return newCSR, ErrIdNotFound - } - return newCSR, err +// RejectCSRbyCSR updates input CSR's row by setting the certificate bundle to "" and moving the row status to "Rejected". +func (db *Database) RejectCSRbyCSR(csr string) error { + oldRow, err := db.RetrieveCSRbyCSR(csr) + if err != nil { + return err } - return newCSR, nil + stmt, err := sqlair.Prepare(fmt.Sprintf(updateCSRStmt, db.certificateTable), CertificateRequest{}) + if err != nil { + return err + } + newRow := CertificateRequest{ + ID: oldRow.ID, + CSR: oldRow.CSR, + CertificateChain: "", + RequestStatus: "Rejected", + } + err = db.conn.Query(context.Background(), stmt, newRow).Run() + if err != nil { + return err + } + return nil } -// CreateCSR creates a new entry in the repository. -// The given CSR must be valid and unique -func (db *Database) CreateCSR(csr string) (int64, error) { - if err := ValidateCertificateRequest(csr); err != nil { - return 0, errors.New("csr validation failed: " + err.Error()) +// RejectCSRbyCSR updates input ID's row by setting the certificate bundle to "" and sets the row status to "Rejected". +func (db *Database) RejectCSRbyID(id int) error { + oldRow, err := db.RetrieveCSRbyID(id) + if err != nil { + return err } - result, err := db.conn.Exec(fmt.Sprintf(queryCreateCSR, db.certificateTable), csr) + stmt, err := sqlair.Prepare(fmt.Sprintf(updateCSRStmt, db.certificateTable), CertificateRequest{}) if err != nil { - return 0, err + return err } - id, err := result.LastInsertId() + newRow := CertificateRequest{ + ID: oldRow.ID, + CSR: oldRow.CSR, + CertificateChain: "", + RequestStatus: "Rejected", + } + err = db.conn.Query(context.Background(), stmt, newRow).Run() if err != nil { - return 0, err + return err } - return id, nil + return nil } -// UpdateCSR adds a new cert to the given CSR in the repository. -// The given certificate must share the public key of the CSR and must be valid. -func (db *Database) UpdateCSR(id string, cert string) (int64, error) { - csr, err := db.RetrieveCSR(id) +// RevokeCSR updates the input CSR's row by setting the certificate bundle to "" and sets the row status to "Revoked". +func (db *Database) RevokeCSR(csr string) error { + oldRow, err := db.RetrieveCSRbyCSR(csr) if err != nil { - return 0, err + return err } - if cert != "rejected" && cert != "" { - err = ValidateCertificate(cert) - if err != nil { - return 0, errors.New("cert validation failed: " + err.Error()) - } - err = CertificateMatchesCSR(cert, csr.CSR) - if err != nil { - return 0, errors.New("cert validation failed: " + err.Error()) - } - cert = sanitizeCertificateBundle(cert) + stmt, err := sqlair.Prepare(fmt.Sprintf(updateCSRStmt, db.certificateTable), CertificateRequest{}) + if err != nil { + return err + } + newRow := CertificateRequest{ + ID: oldRow.ID, + CSR: oldRow.CSR, + CertificateChain: "", + RequestStatus: "Revoked", } - _, err = db.conn.Exec(fmt.Sprintf(queryUpdateCSR, db.certificateTable), cert, csr.ID) + err = db.conn.Query(context.Background(), stmt, newRow).Run() if err != nil { - return 0, err + return err } - return int64(csr.ID), nil + return nil } -// DeleteCSR removes a CSR from the database alongside the certificate that may have been generated for it. -func (db *Database) DeleteCSR(id string) (int64, error) { - result, err := db.conn.Exec(fmt.Sprintf(queryDeleteCSR, db.certificateTable), id) +// DeleteCSRbyCSR removes a CSR from the database alongside the certificate that may have been generated for it. +func (db *Database) DeleteCSRbyCSR(csr string) error { + stmt, err := sqlair.Prepare(fmt.Sprintf(deleteCSRStmt, db.certificateTable), CertificateRequest{}) if err != nil { - return 0, err + return err + } + row := CertificateRequest{ + CSR: csr, } - deleteId, err := result.RowsAffected() + err = db.conn.Query(context.Background(), stmt, row).Run() if err != nil { - return 0, err + return err } - if deleteId == 0 { - return 0, ErrIdNotFound + return nil +} + +// DeleteCSRByID removes a CSR from the database alongside the certificate that may have been generated for it. +func (db *Database) DeleteCSRbyID(id int) error { + stmt, err := sqlair.Prepare(fmt.Sprintf(deleteCSRStmt, db.certificateTable), CertificateRequest{}) + if err != nil { + return err + } + row := CertificateRequest{ + ID: id, + } + err = db.conn.Query(context.Background(), stmt, row).Run() + if err != nil { + return err } - return deleteId, nil + return nil } // RetrieveAllUsers returns all of the users and their fields available in the database. func (db *Database) RetrieveAllUsers() ([]User, error) { - rows, err := db.conn.Query(fmt.Sprintf(queryGetAllUsers, db.usersTable)) + stmt, err := sqlair.Prepare(fmt.Sprintf(getAllUsersStmt, db.usersTable), User{}) if err != nil { return nil, err } - - var allUsers []User - defer rows.Close() - for rows.Next() { - var user User - if err := rows.Scan(&user.ID, &user.Username, &user.Password, &user.Permissions); err != nil { - return nil, err - } - allUsers = append(allUsers, user) + var users []User + err = db.conn.Query(context.Background(), stmt).GetAll(&users) + if err != nil { + return nil, err } - return allUsers, nil + return users, nil } // RetrieveUser retrieves the name, password and the permission level of a user. -func (db *Database) RetrieveUser(id string) (User, error) { - var newUser User - row := db.conn.QueryRow(fmt.Sprintf(queryGetUser, db.usersTable), id) - if err := row.Scan(&newUser.ID, &newUser.Username, &newUser.Password, &newUser.Permissions); err != nil { - if err.Error() == "sql: no rows in result set" { - return newUser, ErrIdNotFound - } - return newUser, err +func (db *Database) RetrieveUserByID(id int) (*User, error) { + row := User{ + ID: id, } - return newUser, nil + stmt, err := sqlair.Prepare(fmt.Sprintf(getUserStmt, db.usersTable), User{}) + if err != nil { + return nil, err + } + err = db.conn.Query(context.Background(), stmt, row).Get(&row) + if err != nil { + return nil, err + } + return &row, nil } // RetrieveUser retrieves the id, password and the permission level of a user. -func (db *Database) RetrieveUserByUsername(name string) (User, error) { - var newUser User - row := db.conn.QueryRow(fmt.Sprintf(queryGetUserByUsername, db.usersTable), name) - if err := row.Scan(&newUser.ID, &newUser.Username, &newUser.Password, &newUser.Permissions); err != nil { - if err.Error() == "sql: no rows in result set" { - return newUser, ErrIdNotFound - } - return newUser, err +func (db *Database) RetrieveUserByUsername(name string) (*User, error) { + row := User{ + Username: name, + } + stmt, err := sqlair.Prepare(fmt.Sprintf(getUserStmt, db.usersTable), User{}) + if err != nil { + return nil, err } - return newUser, nil + err = db.conn.Query(context.Background(), stmt, row).Get(&row) + if err != nil { + return nil, err + } + return &row, nil } // CreateUser creates a new user from a given username, password and permission level. // The permission level 1 represents an admin, and a 0 represents a regular user. // The password passed in should be in plaintext. This function handles hashing and salting the password before storing it in the database. -func (db *Database) CreateUser(username string, password string, permission int) (int64, error) { +func (db *Database) CreateUser(username string, password string, permission int) error { pw, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) if err != nil { - return 0, err + return err } - result, err := db.conn.Exec(fmt.Sprintf(queryCreateUser, db.usersTable), username, pw, permission) + stmt, err := sqlair.Prepare(fmt.Sprintf(createUserStmt, db.usersTable), User{}) if err != nil { - return 0, err + return err + } + row := User{ + Username: username, + HashedPassword: string(pw), + Permissions: permission, } - id, err := result.LastInsertId() + err = db.conn.Query(context.Background(), stmt, row).Run() if err != nil { - return 0, err + return err } - return id, nil + return nil } // UpdateUser updates the password of the given user. // Just like with CreateUser, this function handles hashing and salting the password before storage. -func (db *Database) UpdateUser(id, password string) (int64, error) { - user, err := db.RetrieveUser(id) +func (db *Database) UpdateUserPassword(id int, password string) error { + _, err := db.RetrieveUserByID(id) if err != nil { - return 0, err + return err } pw, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) if err != nil { - return 0, err + return err } - result, err := db.conn.Exec(fmt.Sprintf(queryUpdateUser, db.usersTable), pw, user.ID) + stmt, err := sqlair.Prepare(fmt.Sprintf(updateUserStmt, db.usersTable), User{}) if err != nil { - return 0, err + return err + } + row := User{ + ID: id, + HashedPassword: string(pw), } - affectedRows, err := result.RowsAffected() + err = db.conn.Query(context.Background(), stmt, row).Run() if err != nil { - return 0, err + return err } - return affectedRows, nil + return nil } -// DeleteUser removes a user from the table. -func (db *Database) DeleteUser(id string) (int64, error) { - result, err := db.conn.Exec(fmt.Sprintf(queryDeleteUser, db.usersTable), id) +// DeleteUserByID removes a user from the table. +func (db *Database) DeleteUserByID(id int) error { + _, err := db.RetrieveUserByID(id) if err != nil { - return 0, err + return err } - deleteId, err := result.RowsAffected() + stmt, err := sqlair.Prepare(fmt.Sprintf(deleteUserStmt, db.usersTable), User{}) if err != nil { - return 0, err + return err + } + row := User{ + ID: id, } - if deleteId == 0 { - return 0, ErrIdNotFound + err = db.conn.Query(context.Background(), stmt, row).Run() + if err != nil { + return err } - return deleteId, nil + return nil +} + +type NumUsers struct { + Count int `db:"count"` } // NumUsers returns the number of users in the database. func (db *Database) NumUsers() (int, error) { - var numUsers int - row := db.conn.QueryRow(fmt.Sprintf(queryGetNumUsers, db.usersTable)) - if err := row.Scan(&numUsers); err != nil { + stmt, err := sqlair.Prepare(fmt.Sprintf(getNumUsersStmt, db.usersTable), NumUsers{}) + if err != nil { + return 0, err + } + result := NumUsers{} + err = db.conn.Query(context.Background(), stmt).Get(&result) + if err != nil { return 0, err } - return numUsers, nil + return result.Count, nil } // Close closes the connection to the repository cleanly. @@ -276,7 +447,7 @@ func (db *Database) Close() error { if db.conn == nil { return nil } - if err := db.conn.Close(); err != nil { + if err := db.conn.PlainDB().Close(); err != nil { return err } return nil @@ -287,18 +458,18 @@ func (db *Database) Close() error { // The database path must be a valid file path or ":memory:". // The table will be created if it doesn't exist in the format expected by the package. func NewDatabase(databasePath string) (*Database, error) { - conn, err := sql.Open("sqlite3", databasePath) + sqlConnection, err := sql.Open("sqlite3", databasePath) if err != nil { return nil, err } - if _, err := conn.Exec(fmt.Sprintf(queryCreateCSRsTable, certificateRequestsTableName)); err != nil { + if _, err := sqlConnection.Exec(fmt.Sprintf(queryCreateCSRsTable, certificateRequestsTableName)); err != nil { return nil, err } - if _, err := conn.Exec(fmt.Sprintf(queryCreateUsersTable, usersTableName)); err != nil { + if _, err := sqlConnection.Exec(fmt.Sprintf(queryCreateUsersTable, usersTableName)); err != nil { return nil, err } db := new(Database) - db.conn = conn + db.conn = sqlair.NewDB(sqlConnection) db.certificateTable = certificateRequestsTableName db.usersTable = usersTableName return db, nil diff --git a/internal/db/db_test.go b/internal/db/db_test.go index 5d0f768..62989ce 100644 --- a/internal/db/db_test.go +++ b/internal/db/db_test.go @@ -3,7 +3,7 @@ package db_test import ( "fmt" "log" - "strconv" + "path/filepath" "strings" "testing" @@ -12,7 +12,8 @@ import ( ) func TestConnect(t *testing.T) { - db, err := db.NewDatabase(":memory:") + tempDir := t.TempDir() + db, err := db.NewDatabase(filepath.Join(tempDir, "db.sqlite3")) if err != nil { t.Fatalf("Can't connect to SQLite: %s", err) } @@ -20,21 +21,22 @@ func TestConnect(t *testing.T) { } func TestCSRsEndToEnd(t *testing.T) { - db, err := db.NewDatabase(":memory:") + tempDir := t.TempDir() + db, err := db.NewDatabase(filepath.Join(tempDir, "db.sqlite3")) if err != nil { t.Fatalf("Couldn't complete NewDatabase: %s", err) } defer db.Close() - id1, err := db.CreateCSR(AppleCSR) + err = db.CreateCSR(AppleCSR) if err != nil { t.Fatalf("Couldn't complete Create: %s", err) } - id2, err := db.CreateCSR(BananaCSR) + err = db.CreateCSR(BananaCSR) if err != nil { t.Fatalf("Couldn't complete Create: %s", err) } - id3, err := db.CreateCSR(StrawberryCSR) + err = db.CreateCSR(StrawberryCSR) if err != nil { t.Fatalf("Couldn't complete Create: %s", err) } @@ -46,7 +48,7 @@ func TestCSRsEndToEnd(t *testing.T) { if len(res) != 3 { t.Fatalf("One or more CSRs weren't found in DB") } - retrievedCSR, err := db.RetrieveCSR(strconv.FormatInt(id1, 10)) + retrievedCSR, err := db.RetrieveCSRbyCSR(AppleCSR) if err != nil { t.Fatalf("Couldn't complete Retrieve: %s", err) } @@ -54,7 +56,7 @@ func TestCSRsEndToEnd(t *testing.T) { t.Fatalf("The CSR from the database doesn't match the CSR that was given") } - if _, err = db.DeleteCSR(strconv.FormatInt(id1, 10)); err != nil { + if err = db.DeleteCSRbyCSR(AppleCSR); err != nil { t.Fatalf("Couldn't complete Delete: %s", err) } res, _ = db.RetrieveAllCSRs() @@ -62,24 +64,24 @@ func TestCSRsEndToEnd(t *testing.T) { t.Fatalf("CSR's weren't deleted from the DB properly") } BananaCertBundle := strings.TrimSpace(fmt.Sprintf("%s%s", BananaCert, IssuerCert)) - _, err = db.UpdateCSR(strconv.FormatInt(id2, 10), BananaCertBundle) + err = db.AddCertificateChainToCSRbyCSR(BananaCSR, BananaCertBundle) if err != nil { t.Fatalf("Couldn't complete Update: %s", err) } - retrievedCSR, _ = db.RetrieveCSR(strconv.FormatInt(id2, 10)) - if retrievedCSR.Certificate != BananaCertBundle { - t.Fatalf("The certificate that was uploaded does not match the certificate that was given.\n Retrieved: %s\nGiven: %s", retrievedCSR.Certificate, BananaCertBundle) + retrievedCSR, _ = db.RetrieveCSRbyCSR(BananaCSR) + if retrievedCSR.CertificateChain != BananaCertBundle { + t.Fatalf("The certificate that was uploaded does not match the certificate that was given.\n Retrieved: %s\nGiven: %s", retrievedCSR.CertificateChain, BananaCertBundle) } - _, err = db.UpdateCSR(strconv.FormatInt(id2, 10), "") + err = db.RevokeCSR(BananaCSR) if err != nil { - t.Fatalf("Couldn't complete Update to delete certificate: %s", err) + t.Fatalf("Couldn't complete Update to revoke certificate: %s", err) } - _, err = db.UpdateCSR(strconv.FormatInt(id3, 10), "rejected") + err = db.RejectCSRbyCSR(StrawberryCSR) if err != nil { t.Fatalf("Couldn't complete Update to reject CSR: %s", err) } - retrievedCSR, _ = db.RetrieveCSR(strconv.FormatInt(id2, 10)) - if retrievedCSR.Certificate != "" { + retrievedCSR, _ = db.RetrieveCSRbyCSR(BananaCSR) + if retrievedCSR.RequestStatus != "Revoked" { t.Fatalf("Couldn't delete certificate") } } @@ -89,12 +91,12 @@ func TestCreateFails(t *testing.T) { defer db.Close() InvalidCSR := strings.ReplaceAll(AppleCSR, "M", "i") - if _, err := db.CreateCSR(InvalidCSR); err == nil { + if err := db.CreateCSR(InvalidCSR); err == nil { t.Fatalf("Expected error due to invalid CSR") } db.CreateCSR(AppleCSR) //nolint:errcheck - if _, err := db.CreateCSR(AppleCSR); err == nil { + if err := db.CreateCSR(AppleCSR); err == nil { t.Fatalf("Expected error due to duplicate CSR") } } @@ -103,13 +105,13 @@ func TestUpdateFails(t *testing.T) { db, _ := db.NewDatabase(":memory:") defer db.Close() - id1, _ := db.CreateCSR(AppleCSR) //nolint:errcheck - id2, _ := db.CreateCSR(BananaCSR) //nolint:errcheck + db.CreateCSR(AppleCSR) //nolint:errcheck + db.CreateCSR(BananaCSR) //nolint:errcheck InvalidCert := strings.ReplaceAll(BananaCert, "/", "+") - if _, err := db.UpdateCSR(strconv.FormatInt(id2, 10), InvalidCert); err == nil { + if err := db.AddCertificateChainToCSRbyCSR(BananaCSR, InvalidCert); err == nil { t.Fatalf("Expected updating with invalid cert to fail") } - if _, err := db.UpdateCSR(strconv.FormatInt(id1, 10), BananaCert); err == nil { + if err := db.AddCertificateChainToCSRbyCSR(AppleCSR, BananaCert); err == nil { t.Fatalf("Expected updating with mismatched cert to fail") } } @@ -119,7 +121,10 @@ func TestRetrieve(t *testing.T) { defer db.Close() db.CreateCSR(AppleCSR) //nolint:errcheck - if _, err := db.RetrieveCSR("this is definitely not an id"); err == nil { + if _, err := db.RetrieveCSRbyCSR("this is definitely not an id"); err == nil { + t.Fatalf("Expected failure looking for nonexistent CSR") + } + if _, err := db.RetrieveCSRbyID(-1); err == nil { t.Fatalf("Expected failure looking for nonexistent CSR") } } @@ -131,11 +136,11 @@ func TestUsersEndToEnd(t *testing.T) { } defer db.Close() - id1, err := db.CreateUser("admin", "pw123", 1) + err = db.CreateUser("admin", "pw123", 1) if err != nil { t.Fatalf("Couldn't complete Create: %s", err) } - id2, err := db.CreateUser("norman", "pw456", 0) + err = db.CreateUser("norman", "pw456", 0) if err != nil { t.Fatalf("Couldn't complete Create: %s", err) } @@ -154,31 +159,36 @@ func TestUsersEndToEnd(t *testing.T) { if num != 2 { t.Fatalf("NumUsers didn't return the correct number of users") } - retrievedUser, err := db.RetrieveUser(strconv.FormatInt(id1, 10)) + retrievedUser, err := db.RetrieveUserByUsername("admin") + if err != nil { + t.Fatalf("Couldn't complete Retrieve: %s", err) + } + if retrievedUser.Username != "admin" { + t.Fatalf("The user from the database doesn't match the user that was given") + } + retrievedUser, err = db.RetrieveUserByID(1) if err != nil { t.Fatalf("Couldn't complete Retrieve: %s", err) } if retrievedUser.Username != "admin" { t.Fatalf("The user from the database doesn't match the user that was given") } - if err := bcrypt.CompareHashAndPassword([]byte(retrievedUser.Password), []byte("pw123")); err != nil { + if err := bcrypt.CompareHashAndPassword([]byte(retrievedUser.HashedPassword), []byte("pw123")); err != nil { t.Fatalf("The user's password doesn't match the one stored in the database") } - - if _, err = db.DeleteUser(strconv.FormatInt(id1, 10)); err != nil { + if err = db.DeleteUserByID(1); err != nil { t.Fatalf("Couldn't complete Delete: %s", err) } res, _ = db.RetrieveAllUsers() if len(res) != 1 { t.Fatalf("users weren't deleted from the DB properly") } - - _, err = db.UpdateUser(strconv.FormatInt(id2, 10), "thebestpassword") + err = db.UpdateUserPassword(2, "thebestpassword") if err != nil { t.Fatalf("Couldn't complete Update: %s", err) } - retrievedUser, _ = db.RetrieveUser(strconv.FormatInt(id2, 10)) - if err := bcrypt.CompareHashAndPassword([]byte(retrievedUser.Password), []byte("thebestpassword")); err != nil { + retrievedUser, _ = db.RetrieveUserByUsername("norman") + if err := bcrypt.CompareHashAndPassword([]byte(retrievedUser.HashedPassword), []byte("thebestpassword")); err != nil { t.Fatalf("The new password that was given does not match the password that was stored.") } } @@ -188,19 +198,19 @@ func Example() { if err != nil { log.Fatalln(err) } - _, err = db.CreateCSR(BananaCSR) + err = db.CreateCSR(BananaCSR) if err != nil { log.Fatalln(err) } - _, err = db.UpdateCSR(BananaCSR, BananaCert) + err = db.AddCertificateChainToCSRbyCSR(BananaCSR, BananaCert) if err != nil { log.Fatalln(err) } - entry, err := db.RetrieveCSR(BananaCSR) + entry, err := db.RetrieveCSRbyCSR(BananaCSR) if err != nil { log.Fatalln(err) } - if entry.Certificate != BananaCert { + if entry.CertificateChain != BananaCert { log.Fatalln("Retrieved Certificate doesn't match Stored Certificate") } err = db.Close() diff --git a/internal/metrics/metrics.go b/internal/metrics/metrics.go index 1076d61..686ce0d 100644 --- a/internal/metrics/metrics.go +++ b/internal/metrics/metrics.go @@ -95,15 +95,15 @@ func (pm *PrometheusMetrics) GenerateMetrics(csrs []db.CertificateRequest) { var expiringIn30DaysCertCount float64 var expiringIn90DaysCertCount float64 for _, entry := range csrs { - if entry.Certificate == "" { + if entry.CertificateChain == "" { outstandingCSRCount += 1 continue } - if entry.Certificate == "rejected" { + if entry.RequestStatus == "Rejected" { continue } certCount += 1 - expiryDate := certificateExpiryDate(entry.Certificate) + expiryDate := certificateExpiryDate(entry.CertificateChain) daysRemaining := time.Until(expiryDate).Hours() / 24 if daysRemaining < 0 { expiredCertCount += 1 @@ -218,6 +218,7 @@ func requestDurationMetric() prometheus.HistogramVec { } func certificateExpiryDate(certString string) time.Time { + // TODO: Does this return the expiry date of the issuer? certBlock, _ := pem.Decode([]byte(certString)) cert, _ := x509.ParseCertificate(certBlock.Bytes) // TODO: cert.NotAfter can exist in a wrong cert. We should catch that at the db level validation diff --git a/internal/metrics/metrics_test.go b/internal/metrics/metrics_test.go index ecdd854..2742276 100644 --- a/internal/metrics/metrics_test.go +++ b/internal/metrics/metrics_test.go @@ -21,7 +21,8 @@ import ( // TestPrometheusHandler tests that the Prometheus metrics handler responds correctly to an HTTP request. func TestPrometheusHandler(t *testing.T) { - db, err := db.NewDatabase(":memory:") + tempDir := t.TempDir() + db, err := db.NewDatabase(filepath.Join(tempDir, "db.sqlite3")) if err != nil { t.Fatal(err) } @@ -95,13 +96,13 @@ func generateCertPair(daysRemaining int) (string, string, string) { } func initializeTestDB(t *testing.T, db *db.Database) { - for i, v := range []int{5, 10, 32} { + for _, v := range []int{5, 10, 32} { csr, cert, ca := generateCertPair(v) - _, err := db.CreateCSR(csr) + err := db.CreateCSR(csr) if err != nil { t.Fatalf("couldn't create test csr: %s", err) } - _, err = db.UpdateCSR(fmt.Sprint(i+1), fmt.Sprintf("%s%s", cert, ca)) + err = db.AddCertificateChainToCSRbyCSR(csr, fmt.Sprintf("%s%s", cert, ca)) if err != nil { t.Fatalf("couldn't create test cert: %s", err) } diff --git a/internal/server/authorization_test.go b/internal/server/authorization_test.go index 164faa7..4aaf8f9 100644 --- a/internal/server/authorization_test.go +++ b/internal/server/authorization_test.go @@ -2,12 +2,15 @@ package server_test import ( "net/http" + "path/filepath" "strings" "testing" ) func TestAuthorizationNoAuth(t *testing.T) { - ts, _, err := setupServer() + tempDir := t.TempDir() + db_path := filepath.Join(tempDir, "db.sqlite3") + ts, _, err := setupServer(db_path) if err != nil { t.Fatalf("couldn't create test server: %s", err) } @@ -47,7 +50,9 @@ func TestAuthorizationNoAuth(t *testing.T) { } func TestAuthorizationNonAdminAuthorized(t *testing.T) { - ts, _, err := setupServer() + tempDir := t.TempDir() + db_path := filepath.Join(tempDir, "db.sqlite3") + ts, _, err := setupServer(db_path) if err != nil { t.Fatalf("couldn't create test server: %s", err) } @@ -98,7 +103,9 @@ func TestAuthorizationNonAdminAuthorized(t *testing.T) { } func TestAuthorizationNonAdminUnauthorized(t *testing.T) { - ts, _, err := setupServer() + tempDir := t.TempDir() + db_path := filepath.Join(tempDir, "db.sqlite3") + ts, _, err := setupServer(db_path) if err != nil { t.Fatalf("couldn't create test server: %s", err) } @@ -156,7 +163,9 @@ func TestAuthorizationNonAdminUnauthorized(t *testing.T) { } func TestAuthorizationAdminAuthorized(t *testing.T) { - ts, _, err := setupServer() + tempDir := t.TempDir() + db_path := filepath.Join(tempDir, "db.sqlite3") + ts, _, err := setupServer(db_path) if err != nil { t.Fatalf("couldn't create test server: %s", err) } @@ -205,7 +214,9 @@ func TestAuthorizationAdminAuthorized(t *testing.T) { } func TestAuthorizationAdminUnAuthorized(t *testing.T) { - ts, _, err := setupServer() + tempDir := t.TempDir() + db_path := filepath.Join(tempDir, "db.sqlite3") + ts, _, err := setupServer(db_path) if err != nil { t.Fatalf("couldn't create test server: %s", err) } diff --git a/internal/server/handlers_accounts.go b/internal/server/handlers_accounts.go index 7b72e82..82ad61e 100644 --- a/internal/server/handlers_accounts.go +++ b/internal/server/handlers_accounts.go @@ -10,6 +10,7 @@ import ( "strings" "github.com/canonical/notary/internal/db" + "github.com/canonical/sqlair" ) type CreateAccountParams struct { @@ -27,18 +28,6 @@ type GetAccountResponse struct { Permissions int `json:"permissions"` } -type CreateAccountResponse struct { - ID int `json:"id"` -} - -type ChangeAccountResponse struct { - ID int `json:"id"` -} - -type DeleteAccountResponse struct { - ID int `json:"id"` -} - func validatePassword(password string) bool { if len(password) < 8 { return false @@ -87,8 +76,12 @@ func ListAccounts(env *HandlerConfig) http.HandlerFunc { func GetAccount(env *HandlerConfig) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { id := r.PathValue("id") - var account db.User - var err error + idNum, err := strconv.Atoi(id) + if err != nil { + writeError(w, http.StatusInternalServerError, "Internal Error") + return + } + var account *db.User if id == "me" { claims, headerErr := getClaimsFromAuthorizationHeader(r.Header.Get("Authorization"), env.JWTSecret) if headerErr != nil { @@ -96,11 +89,11 @@ func GetAccount(env *HandlerConfig) http.HandlerFunc { } account, err = env.DB.RetrieveUserByUsername(claims.Username) } else { - account, err = env.DB.RetrieveUser(id) + account, err = env.DB.RetrieveUserByID(idNum) } if err != nil { log.Println(err) - if errors.Is(err, db.ErrIdNotFound) { + if errors.Is(err, sqlair.ErrNoRows) { writeError(w, http.StatusNotFound, "Not Found") return } @@ -150,12 +143,11 @@ func CreateAccount(env *HandlerConfig) http.HandlerFunc { writeError(w, http.StatusInternalServerError, "Failed to retrieve accounts: "+err.Error()) return } - permission := UserPermission if numUsers == 0 { permission = AdminPermission } - id, err := env.DB.CreateUser(createAccountParams.Username, createAccountParams.Password, permission) + err = env.DB.CreateUser(createAccountParams.Username, createAccountParams.Password, permission) if err != nil { if strings.Contains(err.Error(), "UNIQUE constraint failed") { writeError(w, http.StatusBadRequest, "account with given username already exists") @@ -165,11 +157,9 @@ func CreateAccount(env *HandlerConfig) http.HandlerFunc { writeError(w, http.StatusInternalServerError, "Internal Error") return } - accountResponse := CreateAccountResponse{ - ID: int(id), - } + successResponse := SuccessResponse{Message: "success"} w.WriteHeader(http.StatusCreated) - err = writeJSON(w, accountResponse) + err = writeJSON(w, successResponse) if err != nil { writeError(w, http.StatusInternalServerError, "internal error") return @@ -182,39 +172,39 @@ func CreateAccount(env *HandlerConfig) http.HandlerFunc { func DeleteAccount(env *HandlerConfig) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { id := r.PathValue("id") - idInt, err := strconv.ParseInt(id, 10, 64) + idInt, err := strconv.Atoi(id) if err != nil { log.Println(err) writeError(w, http.StatusInternalServerError, "Internal Error") return } - account, err := env.DB.RetrieveUser(id) + account, err := env.DB.RetrieveUserByID(idInt) if err != nil { - if !errors.Is(err, db.ErrIdNotFound) { - log.Println(err) - writeError(w, http.StatusInternalServerError, "Internal Error") + log.Println(err) + if errors.Is(err, sqlair.ErrNoRows) { + writeError(w, http.StatusNotFound, "Not Found") return } + writeError(w, http.StatusInternalServerError, "Internal Error") + return } if account.Permissions == 1 { writeError(w, http.StatusBadRequest, "deleting an Admin account is not allowed.") return } - _, err = env.DB.DeleteUser(id) + err = env.DB.DeleteUserByID(idInt) if err != nil { log.Println(err) - if errors.Is(err, db.ErrIdNotFound) { + if errors.Is(err, sqlair.ErrNoRows) { writeError(w, http.StatusNotFound, "Not Found") return } writeError(w, http.StatusInternalServerError, "Internal Error") return } - deleteAccountResponse := DeleteAccountResponse{ - ID: int(idInt), - } + successResponse := SuccessResponse{Message: "success"} w.WriteHeader(http.StatusAccepted) - err = writeJSON(w, deleteAccountResponse) + err = writeJSON(w, successResponse) if err != nil { writeError(w, http.StatusInternalServerError, "internal error") return @@ -225,6 +215,7 @@ func DeleteAccount(env *HandlerConfig) http.HandlerFunc { func ChangeAccountPassword(env *HandlerConfig) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { id := r.PathValue("id") + var idNum int if id == "me" { claims, err := getClaimsFromAuthorizationHeader(r.Header.Get("Authorization"), env.JWTSecret) if err != nil { @@ -238,7 +229,15 @@ func ChangeAccountPassword(env *HandlerConfig) http.HandlerFunc { writeError(w, http.StatusUnauthorized, "Unauthorized") return } - id = strconv.Itoa(account.ID) + idNum = account.ID + } else { + idInt, err := strconv.Atoi(id) + if err != nil { + log.Println(err) + writeError(w, http.StatusInternalServerError, "Internal Error") + return + } + idNum = idInt } var changeAccountParams ChangeAccountParams if err := json.NewDecoder(r.Body).Decode(&changeAccountParams); err != nil { @@ -257,21 +256,19 @@ func ChangeAccountPassword(env *HandlerConfig) http.HandlerFunc { ) return } - ret, err := env.DB.UpdateUser(id, changeAccountParams.Password) + err := env.DB.UpdateUserPassword(idNum, changeAccountParams.Password) if err != nil { log.Println(err) - if errors.Is(err, db.ErrIdNotFound) { + if errors.Is(err, sqlair.ErrNoRows) { writeError(w, http.StatusNotFound, "Not Found") return } writeError(w, http.StatusInternalServerError, "Internal Error") return } - changeAccountResponse := ChangeAccountResponse{ - ID: int(ret), - } + successResponse := SuccessResponse{Message: "success"} w.WriteHeader(http.StatusCreated) - err = writeJSON(w, changeAccountResponse) + err = writeJSON(w, successResponse) if err != nil { writeError(w, http.StatusInternalServerError, "internal error") return diff --git a/internal/server/handlers_accounts_test.go b/internal/server/handlers_accounts_test.go index 42e2e7a..38d40a1 100644 --- a/internal/server/handlers_accounts_test.go +++ b/internal/server/handlers_accounts_test.go @@ -3,11 +3,16 @@ package server_test import ( "encoding/json" "net/http" + "path/filepath" "strconv" "strings" "testing" ) +type SuccessResponse struct { + Message string `json:"message"` +} + type GetAccountResponseResult struct { ID int `json:"id"` Username string `json:"username"` @@ -29,8 +34,8 @@ type CreateAccountResponseResult struct { } type CreateAccountResponse struct { - Result CreateAccountResponseResult `json:"result"` - Error string `json:"error,omitempty"` + Result SuccessResponse `json:"result"` + Error string `json:"error,omitempty"` } type ChangeAccountPasswordParams struct { @@ -139,7 +144,9 @@ func deleteAccount(url string, client *http.Client, adminToken string, id int) ( // The order of the tests is important, as some tests depend on // the state of the server after previous tests. func TestAccountsEndToEnd(t *testing.T) { - ts, _, err := setupServer() + tempDir := t.TempDir() + db_path := filepath.Join(tempDir, "db.sqlite3") + ts, _, err := setupServer(db_path) if err != nil { t.Fatalf("couldn't create test server: %s", err) } @@ -197,10 +204,7 @@ func TestAccountsEndToEnd(t *testing.T) { t.Fatalf("expected status %d, got %d", http.StatusCreated, statusCode) } if response.Error != "" { - t.Fatalf("expected error %q, got %q", "", response.Error) - } - if response.Result.ID != 3 { - t.Fatalf("expected ID 3, got %d", response.Result.ID) + t.Fatalf("unexpected error :%q", response.Error) } }) @@ -285,10 +289,7 @@ func TestAccountsEndToEnd(t *testing.T) { t.Fatalf("expected status %d, got %d", http.StatusCreated, statusCode) } if response.Error != "" { - t.Fatalf("expected error %q, got %q", "", response.Error) - } - if response.Result.ID != 1 { - t.Fatalf("expected ID 1, got %d", response.Result.ID) + t.Fatalf("unexpected error :%q", response.Error) } }) @@ -351,9 +352,6 @@ func TestAccountsEndToEnd(t *testing.T) { if response.Error != "" { t.Fatalf("expected error %q, got %q", "", response.Error) } - if response.Result.ID != 2 { - t.Fatalf("expected ID 2, got %d", response.Result.ID) - } }) t.Run("13. Delete account - no user", func(t *testing.T) { diff --git a/internal/server/handlers_certificate_requests.go b/internal/server/handlers_certificate_requests.go index a1a2d0f..bcd440d 100644 --- a/internal/server/handlers_certificate_requests.go +++ b/internal/server/handlers_certificate_requests.go @@ -8,7 +8,7 @@ import ( "strconv" "strings" - "github.com/canonical/notary/internal/db" + "github.com/canonical/sqlair" ) type CreateCertificateRequestParams struct { @@ -16,54 +16,32 @@ type CreateCertificateRequestParams struct { } type CreateCertificateParams struct { - Certificate string `json:"certificate"` + CertificateChain string `json:"certificate"` } -type GetCertificateRequestResponse struct { - ID int `json:"id"` - CSR string `json:"csr"` - Certificate string `json:"certificate"` -} - -type CreateCertificateRequestResponse struct { - ID int `json:"id"` -} - -type DeleteCertificateRequestResponse struct { - ID int `json:"id"` -} - -type RejectCertificateRequestResponse struct { - ID int `json:"id"` -} - -type CreateCertificateResponse struct { - ID int `json:"id"` -} - -type DeleteCertificateResponse struct { - ID int `json:"id"` -} - -type RejectCertificateResponse struct { - ID int `json:"id"` +type CertificateRequest struct { + ID int `json:"id"` + CSR string `json:"csr"` + CertificateChain string `json:"certificate"` + CSRStatus string `json:"csr_status"` } // ListCertificateRequests returns all of the Certificate Requests func ListCertificateRequests(env *HandlerConfig) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - certs, err := env.DB.RetrieveAllCSRs() + csrs, err := env.DB.RetrieveAllCSRs() if err != nil { log.Println(err) writeError(w, http.StatusInternalServerError, "Internal Error") return } - certificateRequestsResponse := make([]GetCertificateRequestResponse, len(certs)) - for i, cert := range certs { - certificateRequestsResponse[i] = GetCertificateRequestResponse{ - ID: cert.ID, - CSR: cert.CSR, - Certificate: cert.Certificate, + certificateRequestsResponse := make([]CertificateRequest, len(csrs)) + for i, csr := range csrs { + certificateRequestsResponse[i] = CertificateRequest{ + ID: csr.ID, + CSR: csr.CSR, + CertificateChain: csr.CertificateChain, + CSRStatus: csr.RequestStatus, } } w.WriteHeader(http.StatusOK) @@ -87,7 +65,7 @@ func CreateCertificateRequest(env *HandlerConfig) http.HandlerFunc { writeError(w, http.StatusBadRequest, "csr is missing") return } - id, err := env.DB.CreateCSR(createCertificateRequestParams.CSR) + err := env.DB.CreateCSR(createCertificateRequestParams.CSR) if err != nil { if strings.Contains(err.Error(), "UNIQUE constraint failed") { writeError(w, http.StatusBadRequest, "given csr already recorded") @@ -102,11 +80,9 @@ func CreateCertificateRequest(env *HandlerConfig) http.HandlerFunc { writeError(w, http.StatusInternalServerError, "Internal Error") return } - certificateRequestResponse := CreateCertificateRequestResponse{ - ID: int(id), - } + successResponse := SuccessResponse{Message: "success"} w.WriteHeader(http.StatusCreated) - err = writeJSON(w, certificateRequestResponse) + err = writeJSON(w, successResponse) if err != nil { writeError(w, http.StatusInternalServerError, "internal error") return @@ -119,20 +95,26 @@ func CreateCertificateRequest(env *HandlerConfig) http.HandlerFunc { func GetCertificateRequest(env *HandlerConfig) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { id := r.PathValue("id") - cert, err := env.DB.RetrieveCSR(id) + idNum, err := strconv.Atoi(id) + if err != nil { + writeError(w, http.StatusInternalServerError, "Internal Error") + return + } + csr, err := env.DB.RetrieveCSRbyID(idNum) if err != nil { log.Println(err) - if errors.Is(err, db.ErrIdNotFound) { + if errors.Is(err, sqlair.ErrNoRows) { writeError(w, http.StatusNotFound, "Not Found") return } writeError(w, http.StatusInternalServerError, "Internal Error") return } - certificateRequestResponse := GetCertificateRequestResponse{ - ID: cert.ID, - CSR: cert.CSR, - Certificate: cert.Certificate, + certificateRequestResponse := CertificateRequest{ + ID: csr.ID, + CSR: csr.CSR, + CertificateChain: csr.CertificateChain, + CSRStatus: csr.RequestStatus, } w.WriteHeader(http.StatusOK) err = writeJSON(w, certificateRequestResponse) @@ -148,21 +130,24 @@ func GetCertificateRequest(env *HandlerConfig) http.HandlerFunc { func DeleteCertificateRequest(env *HandlerConfig) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { id := r.PathValue("id") - insertId, err := env.DB.DeleteCSR(id) + idNum, err := strconv.Atoi(id) + if err != nil { + writeError(w, http.StatusInternalServerError, "Internal Error") + return + } + err = env.DB.DeleteCSRbyID(idNum) if err != nil { log.Println(err) - if errors.Is(err, db.ErrIdNotFound) { + if errors.Is(err, sqlair.ErrNoRows) { writeError(w, http.StatusNotFound, "Not Found") return } writeError(w, http.StatusInternalServerError, "Internal Error") return } - certificateRequestResponse := DeleteCertificateRequestResponse{ - ID: int(insertId), - } + successResponse := SuccessResponse{Message: "success"} w.WriteHeader(http.StatusAccepted) - err = writeJSON(w, certificateRequestResponse) + err = writeJSON(w, successResponse) if err != nil { writeError(w, http.StatusInternalServerError, "internal error") return @@ -179,15 +164,20 @@ func CreateCertificate(env *HandlerConfig) http.HandlerFunc { writeError(w, http.StatusBadRequest, "Invalid JSON format") return } - if createCertificateParams.Certificate == "" { + if createCertificateParams.CertificateChain == "" { writeError(w, http.StatusBadRequest, "certificate is missing") return } id := r.PathValue("id") - insertId, err := env.DB.UpdateCSR(id, createCertificateParams.Certificate) + idNum, err := strconv.Atoi(id) + if err != nil { + writeError(w, http.StatusInternalServerError, "Internal Error") + return + } + err = env.DB.AddCertificateToCSRbyID(idNum, createCertificateParams.CertificateChain) if err != nil { log.Println(err) - if errors.Is(err, db.ErrIdNotFound) || + if errors.Is(err, sqlair.ErrNoRows) || err.Error() == "certificate does not match CSR" || strings.Contains(err.Error(), "cert validation failed") { writeError(w, http.StatusBadRequest, "Bad Request") @@ -196,18 +186,15 @@ func CreateCertificate(env *HandlerConfig) http.HandlerFunc { writeError(w, http.StatusInternalServerError, "Internal Error") return } - insertIdStr := strconv.FormatInt(insertId, 10) if env.SendPebbleNotifications { - err := SendPebbleNotification("canonical.com/notary/certificate/update", insertIdStr) + err := SendPebbleNotification("canonical.com/notary/certificate/update", id) if err != nil { log.Printf("pebble notify failed: %s. continuing silently.", err.Error()) } } - certificateResponse := CreateCertificateResponse{ - ID: int(insertId), - } + successResponse := SuccessResponse{Message: "success"} w.WriteHeader(http.StatusCreated) - err = writeJSON(w, certificateResponse) + err = writeJSON(w, successResponse) if err != nil { writeError(w, http.StatusInternalServerError, "internal error") return @@ -218,28 +205,30 @@ func CreateCertificate(env *HandlerConfig) http.HandlerFunc { func RejectCertificate(env *HandlerConfig) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { id := r.PathValue("id") - insertId, err := env.DB.UpdateCSR(id, "rejected") + idNum, err := strconv.Atoi(id) + if err != nil { + writeError(w, http.StatusInternalServerError, "Internal Error") + return + } + err = env.DB.RejectCSRbyID(idNum) if err != nil { log.Println(err) - if errors.Is(err, db.ErrIdNotFound) { + if errors.Is(err, sqlair.ErrNoRows) { writeError(w, http.StatusNotFound, "Not Found") return } writeError(w, http.StatusInternalServerError, "Internal Error") return } - insertIdStr := strconv.FormatInt(insertId, 10) if env.SendPebbleNotifications { - err := SendPebbleNotification("canonical.com/notary/certificate/update", insertIdStr) + err := SendPebbleNotification("canonical.com/notary/certificate/update", id) if err != nil { log.Printf("pebble notify failed: %s. continuing silently.", err.Error()) } } - certificateResponse := RejectCertificateResponse{ - ID: int(insertId), - } + successResponse := SuccessResponse{Message: "success"} w.WriteHeader(http.StatusAccepted) - err = writeJSON(w, certificateResponse) + err = writeJSON(w, successResponse) if err != nil { writeError(w, http.StatusInternalServerError, "internal error") return @@ -252,28 +241,30 @@ func RejectCertificate(env *HandlerConfig) http.HandlerFunc { func DeleteCertificate(env *HandlerConfig) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { id := r.PathValue("id") - insertId, err := env.DB.UpdateCSR(id, "") + idNum, err := strconv.Atoi(id) + if err != nil { + writeError(w, http.StatusInternalServerError, "Internal Error") + return + } + err = env.DB.DeleteCSRbyID(idNum) if err != nil { log.Println(err) - if errors.Is(err, db.ErrIdNotFound) { + if errors.Is(err, sqlair.ErrNoRows) { writeError(w, http.StatusBadRequest, "Bad Request") return } writeError(w, http.StatusInternalServerError, "Internal Error") return } - insertIdStr := strconv.FormatInt(insertId, 10) if env.SendPebbleNotifications { - err := SendPebbleNotification("canonical.com/notary/certificate/update", insertIdStr) + err := SendPebbleNotification("canonical.com/notary/certificate/update", id) if err != nil { log.Printf("pebble notify failed: %s. continuing silently.", err.Error()) } } - certificateResponse := DeleteCertificateResponse{ - ID: int(insertId), - } + successResponse := SuccessResponse{Message: "success"} w.WriteHeader(http.StatusOK) - err = writeJSON(w, certificateResponse) + err = writeJSON(w, successResponse) if err != nil { writeError(w, http.StatusInternalServerError, "internal error") return diff --git a/internal/server/handlers_certificate_requests_test.go b/internal/server/handlers_certificate_requests_test.go index 21b8311..ac06dd6 100644 --- a/internal/server/handlers_certificate_requests_test.go +++ b/internal/server/handlers_certificate_requests_test.go @@ -9,22 +9,24 @@ import ( "path/filepath" "strconv" "testing" + + "github.com/canonical/notary/internal/server" ) -type CertificateRequest struct { - ID int `json:"id"` - CSR string `json:"csr"` - Certificate string `json:"certificate"` -} +// type CertificateRequest struct { +// ID int `json:"id"` +// CSR string `json:"csr"` +// Certificate string `json:"certificate"` +// } type GetCertificateRequestResponse struct { - Result CertificateRequest `json:"result"` - Error string `json:"error,omitempty"` + Result server.CertificateRequest `json:"result"` + Error string `json:"error,omitempty"` } type ListCertificateRequestsResponse struct { - Error string `json:"error,omitempty"` - Result []CertificateRequest `json:"result"` + Error string `json:"error,omitempty"` + Result []server.CertificateRequest `json:"result"` } type CreateCertificateRequestResponse struct { @@ -162,7 +164,10 @@ func rejectCertificate(url string, client *http.Client, adminToken string, id in // The order of the tests is important, as some tests depend on the // state of the server after previous tests. func TestCertificateRequestsEndToEnd(t *testing.T) { - ts, _, err := setupServer() + + tempDir := t.TempDir() + db_path := filepath.Join(tempDir, "db.sqlite3") + ts, _, err := setupServer(db_path) if err != nil { t.Fatalf("couldn't create test server: %s", err) } @@ -257,8 +262,8 @@ func TestCertificateRequestsEndToEnd(t *testing.T) { if getCertRequestResponse.Result.CSR == "" { t.Fatalf("expected CSR, got empty string") } - if getCertRequestResponse.Result.Certificate != "" { - t.Fatalf("expected no certificate, got %s", getCertRequestResponse.Result.Certificate) + if getCertRequestResponse.Result.CertificateChain != "" { + t.Fatalf("expected no certificate, got %s", getCertRequestResponse.Result.CertificateChain) } }) @@ -353,8 +358,8 @@ func TestCertificateRequestsEndToEnd(t *testing.T) { if getCertRequestResponse.Result.CSR == "" { t.Fatalf("expected CSR, got empty string") } - if getCertRequestResponse.Result.Certificate != "" { - t.Fatalf("expected no certificate, got %s", getCertRequestResponse.Result.Certificate) + if getCertRequestResponse.Result.CertificateChain != "" { + t.Fatalf("expected no certificate, got %s", getCertRequestResponse.Result.CertificateChain) } }) @@ -399,7 +404,9 @@ func TestCertificateRequestsEndToEnd(t *testing.T) { // The order of the tests is important, as some tests depend on the // state of the server after previous tests. func TestCertificatesEndToEnd(t *testing.T) { - ts, _, err := setupServer() + tempDir := t.TempDir() + db_path := filepath.Join(tempDir, "db.sqlite3") + ts, _, err := setupServer(db_path) if err != nil { t.Fatalf("couldn't create test server: %s", err) } @@ -481,7 +488,7 @@ func TestCertificatesEndToEnd(t *testing.T) { if getCertResponse.Error != "" { t.Fatalf("expected no error, got %s", getCertResponse.Error) } - if getCertResponse.Result.Certificate == "" { + if getCertResponse.Result.CertificateChain == "" { t.Fatalf("expected certificate, got empty string") } }) @@ -507,8 +514,8 @@ func TestCertificatesEndToEnd(t *testing.T) { if getCertResponse.Error != "" { t.Fatalf("expected no error, got %s", getCertResponse.Error) } - if getCertResponse.Result.Certificate != "rejected" { - t.Fatalf("expected `rejected` certificate, got %s", getCertResponse.Result.Certificate) + if getCertResponse.Result.CSRStatus != "Rejected" { + t.Fatalf("expected `Rejected` status, got %s", getCertResponse.Result.CertificateChain) } }) diff --git a/internal/server/handlers_helpers_test.go b/internal/server/handlers_helpers_test.go index cb309c1..1552e1e 100644 --- a/internal/server/handlers_helpers_test.go +++ b/internal/server/handlers_helpers_test.go @@ -10,8 +10,8 @@ import ( "github.com/canonical/notary/internal/server" ) -func setupServer() (*httptest.Server, *server.HandlerConfig, error) { - testdb, err := db.NewDatabase(":memory:") +func setupServer(filepath string) (*httptest.Server, *server.HandlerConfig, error) { + testdb, err := db.NewDatabase(filepath) if err != nil { return nil, nil, err } diff --git a/internal/server/handlers_login.go b/internal/server/handlers_login.go index 21b31f8..a0f702a 100644 --- a/internal/server/handlers_login.go +++ b/internal/server/handlers_login.go @@ -7,7 +7,7 @@ import ( "net/http" "time" - "github.com/canonical/notary/internal/db" + "github.com/canonical/sqlair" "github.com/golang-jwt/jwt" "golang.org/x/crypto/bcrypt" ) @@ -68,14 +68,14 @@ func Login(env *HandlerConfig) http.HandlerFunc { userAccount, err := env.DB.RetrieveUserByUsername(loginParams.Username) if err != nil { log.Println(err) - if errors.Is(err, db.ErrIdNotFound) { + if errors.Is(err, sqlair.ErrNoRows) { writeError(w, http.StatusUnauthorized, "The username or password is incorrect. Try again.") return } writeError(w, http.StatusInternalServerError, "Internal Error") return } - if err := bcrypt.CompareHashAndPassword([]byte(userAccount.Password), []byte(loginParams.Password)); err != nil { + if err := bcrypt.CompareHashAndPassword([]byte(userAccount.HashedPassword), []byte(loginParams.Password)); err != nil { writeError(w, http.StatusUnauthorized, "The username or password is incorrect. Try again.") return } diff --git a/internal/server/handlers_login_test.go b/internal/server/handlers_login_test.go index 408cb2c..ff3c0db 100644 --- a/internal/server/handlers_login_test.go +++ b/internal/server/handlers_login_test.go @@ -4,6 +4,7 @@ import ( "encoding/json" "fmt" "net/http" + "path/filepath" "strings" "testing" @@ -46,7 +47,9 @@ func login(url string, client *http.Client, data *LoginParams) (int, *LoginRespo } func TestLoginEndToEnd(t *testing.T) { - ts, config, err := setupServer() + tempDir := t.TempDir() + db_path := filepath.Join(tempDir, "db.sqlite3") + ts, config, err := setupServer(db_path) if err != nil { t.Fatalf("couldn't create test server: %s", err) } diff --git a/internal/server/handlers_status_test.go b/internal/server/handlers_status_test.go index c725574..d819fd5 100644 --- a/internal/server/handlers_status_test.go +++ b/internal/server/handlers_status_test.go @@ -3,6 +3,7 @@ package server_test import ( "encoding/json" "net/http" + "path/filepath" "testing" ) @@ -35,7 +36,9 @@ func getStatus(url string, client *http.Client, adminToken string) (int, *GetSta } func TestStatus(t *testing.T) { - ts, _, err := setupServer() + tempDir := t.TempDir() + db_path := filepath.Join(tempDir, "db.sqlite3") + ts, _, err := setupServer(db_path) if err != nil { t.Fatalf("couldn't create test server: %s", err) } diff --git a/internal/server/response.go b/internal/server/response.go index 0b9678c..7e7daf6 100644 --- a/internal/server/response.go +++ b/internal/server/response.go @@ -6,6 +6,10 @@ import ( "net/http" ) +type SuccessResponse struct { + Message string `json:"message"` +} + // writeJSON is a helper function that writes a JSON response to the http.ResponseWriter func writeJSON(w http.ResponseWriter, v any) error { type response struct {