diff --git a/go.mod b/go.mod index 38fd724..a234d34 100644 --- a/go.mod +++ b/go.mod @@ -12,6 +12,7 @@ require ( require ( github.com/beorn7/perks v1.0.1 // indirect + github.com/canonical/sqlair v0.0.0-20241004123011-77313b5382fd github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/klauspost/compress v1.17.9 // indirect github.com/kr/text v0.2.0 // indirect diff --git a/go.sum b/go.sum index 56c00d9..4b3de95 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,7 @@ github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= +github.com/canonical/sqlair v0.0.0-20241004123011-77313b5382fd h1:lVn7391CX7QQ5WBQriNUtCB4fvfurZg6XJwH9aVsRII= +github.com/canonical/sqlair v0.0.0-20241004123011-77313b5382fd/go.mod h1:T+40I2sXshY3KRxx0QQpqqn6hCibSKJ2KHzjBvJj8T4= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= diff --git a/internal/db/db.go b/internal/db/db.go index 868a6b2..67ba0db 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -2,273 +2,411 @@ package db import ( + "context" "database/sql" "errors" "fmt" + "github.com/canonical/sqlair" _ "github.com/mattn/go-sqlite3" "golang.org/x/crypto/bcrypt" ) +// Database is the object used to communicate with the established repository. +type Database struct { + certificateTable string + usersTable string + conn *sqlair.DB +} + +type CertificateRequest struct { + ID int `db:"id"` + + CSR string `db:"csr"` + CertificateChain string `db:"certificate_chain"` + Status string `db:"status"` +} + +type User struct { + ID int `db:"id"` + + Username string `db:"username"` + HashedPassword string `db:"hashed_password"` + Permissions int `db:"permissions"` +} + const ( - certificateRequestsTableName = "CertificateRequests" + certificateRequestsTableName = "certificate_requests" usersTableName = "users" ) -const queryCreateCSRsTable = `CREATE TABLE IF NOT EXISTS %s ( - csr TEXT PRIMARY KEY UNIQUE NOT NULL, - certificate TEXT DEFAULT '' +const queryCreateCertificateRequestsTable = ` + CREATE TABLE IF NOT EXISTS %s ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + + csr TEXT NOT NULL UNIQUE, + certificate_chain TEXT DEFAULT '', + status TEXT DEFAULT 'Outstanding', + + CHECK (status IN ('Outstanding', 'Rejected', 'Revoked', 'Active')), + CHECK (NOT (certificate_chain == '' AND status == 'Active' )), + CHECK (NOT (certificate_chain != '' AND status == 'Outstanding')) + CHECK (NOT (certificate_chain != '' AND status == 'Rejected')) + CHECK (NOT (certificate_chain != '' AND status == 'Revoked')) )` -const ( - queryGetAllCSRs = "SELECT rowid, * FROM %s" - queryGetCSR = "SELECT rowid, * FROM %s WHERE rowid=?" - queryCreateCSR = "INSERT INTO %s (csr) VALUES (?)" - queryUpdateCSR = "UPDATE %s SET certificate=? WHERE rowid=?" - queryDeleteCSR = "DELETE FROM %s WHERE rowid=?" -) +const queryCreateUsersTable = ` + CREATE TABLE IF NOT EXISTS %s ( + id INTEGER PRIMARY KEY AUTOINCREMENT, -const queryCreateUsersTable = `CREATE TABLE IF NOT EXISTS %s ( - user_id INTEGER PRIMARY KEY AUTOINCREMENT, - username TEXT NOT NULL UNIQUE, - hashed_password TEXT NOT NULL, - permissions INTEGER + username TEXT NOT NULL UNIQUE, + hashed_password TEXT NOT NULL, + permissions INTEGER )` const ( - queryGetAllUsers = "SELECT * FROM %s" - queryGetUser = "SELECT * FROM %s WHERE user_id=?" - queryGetUserByUsername = "SELECT * FROM %s WHERE username=?" - queryCreateUser = "INSERT INTO %s (username, hashed_password, permissions) VALUES (?, ?, ?)" - queryUpdateUser = "UPDATE %s SET hashed_password=? WHERE user_id=?" - queryDeleteUser = "DELETE FROM %s WHERE user_id=?" - queryGetNumUsers = "SELECT COUNT(*) FROM %s" + listCertificateRequestsStmt = "SELECT &CertificateRequest.* FROM %s" + getCertificateRequestStmt = "SELECT &CertificateRequest.* FROM %s WHERE id==$CertificateRequest.id or csr==$CertificateRequest.csr" + createCertificateRequestStmt = "INSERT INTO %s (csr) VALUES ($CertificateRequest.csr)" + updateCertificateRequestStmt = "UPDATE %s SET certificate_chain=$CertificateRequest.certificate_chain, status=$CertificateRequest.status WHERE id==$CertificateRequest.id or csr==$CertificateRequest.csr" + deleteCertificateRequestStmt = "DELETE FROM %s WHERE id=$CertificateRequest.id or csr=$CertificateRequest.csr" ) -// CertificateRequestRepository is the object used to communicate with the established repository. -type Database struct { - certificateTable string - usersTable string - conn *sql.DB -} +const ( + listUsersStmt = "SELECT &User.* from %s" + getUserStmt = "SELECT &User.* from %s WHERE id==$User.id or username==$User.username" + createUserStmt = "INSERT INTO %s (username, hashed_password, permissions) VALUES ($User.username, $User.hashed_password, $User.permissions)" + updateUserStmt = "UPDATE %s SET hashed_password=$User.hashed_password WHERE id==$User.id or username==$User.username" + deleteUserStmt = "DELETE FROM %s WHERE id==$User.id" + getNumUsersStmt = "SELECT COUNT(*) AS &NumUsers.count FROM %s" +) -// A CertificateRequest struct represents an entry in the database. -// The object contains a Certificate Request, its matching Certificate if any, and the row ID. -type CertificateRequest struct { - ID int - CSR string - Certificate string -} -type User struct { - ID int - Username string - Password string - Permissions int +// ListCertificateRequests gets every CertificateRequest entry in the table. +func (db *Database) ListCertificateRequests() ([]CertificateRequest, error) { + stmt, err := sqlair.Prepare(fmt.Sprintf(listCertificateRequestsStmt, db.certificateTable), CertificateRequest{}) + if err != nil { + return nil, err + } + var csrs []CertificateRequest + err = db.conn.Query(context.Background(), stmt).GetAll(&csrs) + if err != nil { + if errors.Is(err, sqlair.ErrNoRows) { + return csrs, nil + } + return nil, err + } + return csrs, nil } -var ErrIdNotFound = errors.New("id not found") +// GetCertificateRequestByID gets a CSR row from the repository from a given ID. +func (db *Database) GetCertificateRequestByID(id int) (*CertificateRequest, error) { + csr := CertificateRequest{ + ID: id, + } + stmt, err := sqlair.Prepare(fmt.Sprintf(getCertificateRequestStmt, db.certificateTable), CertificateRequest{}) + if err != nil { + return nil, err + } + err = db.conn.Query(context.Background(), stmt, csr).Get(&csr) + if err != nil { + return nil, err + } + return &csr, nil +} -// RetrieveAllCSRs gets every CertificateRequest entry in the table. -func (db *Database) RetrieveAllCSRs() ([]CertificateRequest, error) { - rows, err := db.conn.Query(fmt.Sprintf(queryGetAllCSRs, db.certificateTable)) +// GetCertificateRequestByCSR gets a given CSR row from the repository using the CSR text. +func (db *Database) GetCertificateRequestByCSR(csr string) (*CertificateRequest, error) { + row := CertificateRequest{ + CSR: csr, + } + stmt, err := sqlair.Prepare(fmt.Sprintf(getCertificateRequestStmt, db.certificateTable), CertificateRequest{}) if err != nil { return nil, err } + err = db.conn.Query(context.Background(), stmt, row).Get(&row) + if err != nil { + return nil, err + } + return &row, nil +} - var allCsrs []CertificateRequest - defer rows.Close() - for rows.Next() { - var csr CertificateRequest - if err := rows.Scan(&csr.ID, &csr.CSR, &csr.Certificate); err != nil { - return nil, err - } - allCsrs = append(allCsrs, csr) +// CreateCertificateRequest creates a new CSR entry in the repository. The string must be a valid CSR and unique. +func (db *Database) CreateCertificateRequest(csr string) error { + if err := ValidateCertificateRequest(csr); err != nil { + return errors.New("csr validation failed: " + err.Error()) + } + stmt, err := sqlair.Prepare(fmt.Sprintf(createCertificateRequestStmt, db.certificateTable), CertificateRequest{}) + if err != nil { + return err } - return allCsrs, nil + row := CertificateRequest{ + CSR: csr, + } + err = db.conn.Query(context.Background(), stmt, row).Run() + return err } -// RetrieveCSR gets a given CSR from the repository. -// It returns the row id and matching certificate alongside the CSR in a CertificateRequest object. -func (db *Database) RetrieveCSR(id string) (CertificateRequest, error) { - var newCSR CertificateRequest - row := db.conn.QueryRow(fmt.Sprintf(queryGetCSR, db.certificateTable), id) - if err := row.Scan(&newCSR.ID, &newCSR.CSR, &newCSR.Certificate); err != nil { - if err.Error() == "sql: no rows in result set" { - return newCSR, ErrIdNotFound - } - return newCSR, err +// AddCertificateChainToCertificateRequestByCSR adds a new certificate chain to a row for a given CSR string. +func (db *Database) AddCertificateChainToCertificateRequestByCSR(csr string, cert string) error { + err := ValidateCertificate(cert) + if err != nil { + return errors.New("cert validation failed: " + err.Error()) } - return newCSR, nil + err = CertificateMatchesCSR(cert, csr) + if err != nil { + return errors.New("cert validation failed: " + err.Error()) + } + certBundle := sanitizeCertificateBundle(cert) + stmt, err := sqlair.Prepare(fmt.Sprintf(updateCertificateRequestStmt, db.certificateTable), CertificateRequest{}) + if err != nil { + return err + } + newRow := CertificateRequest{ + CSR: csr, + CertificateChain: certBundle, + Status: "Active", + } + err = db.conn.Query(context.Background(), stmt, newRow).Run() + return err } -// CreateCSR creates a new entry in the repository. -// The given CSR must be valid and unique -func (db *Database) CreateCSR(csr string) (int64, error) { - if err := ValidateCertificateRequest(csr); err != nil { - return 0, errors.New("csr validation failed: " + err.Error()) +// AddCertificateChainToCSRbyID adds a new certificate chain to a row for a given row ID. +func (db *Database) AddCertificateChainToCertificateRequestByID(id int, cert string) error { + csr, err := db.GetCertificateRequestByID(id) + if err != nil { + return err } - result, err := db.conn.Exec(fmt.Sprintf(queryCreateCSR, db.certificateTable), csr) + err = ValidateCertificate(cert) if err != nil { - return 0, err + return errors.New("cert validation failed: " + err.Error()) } - id, err := result.LastInsertId() + err = CertificateMatchesCSR(cert, csr.CSR) if err != nil { - return 0, err + return errors.New("cert validation failed: " + err.Error()) + } + certBundle := sanitizeCertificateBundle(cert) + stmt, err := sqlair.Prepare(fmt.Sprintf(updateCertificateRequestStmt, db.certificateTable), CertificateRequest{}) + if err != nil { + return err } - return id, nil + newRow := CertificateRequest{ + ID: id, + CertificateChain: certBundle, + Status: "Active", + } + err = db.conn.Query(context.Background(), stmt, newRow).Run() + return err } -// UpdateCSR adds a new cert to the given CSR in the repository. -// The given certificate must share the public key of the CSR and must be valid. -func (db *Database) UpdateCSR(id string, cert string) (int64, error) { - csr, err := db.RetrieveCSR(id) +// RejectCertificateRequestByCSR updates input CSR's row by setting the certificate bundle to "" and moving the row status to "Rejected". +func (db *Database) RejectCertificateRequestByCSR(csr string) error { + oldRow, err := db.GetCertificateRequestByCSR(csr) if err != nil { - return 0, err + return err } - if cert != "rejected" && cert != "" { - err = ValidateCertificate(cert) - if err != nil { - return 0, errors.New("cert validation failed: " + err.Error()) - } - err = CertificateMatchesCSR(cert, csr.CSR) - if err != nil { - return 0, errors.New("cert validation failed: " + err.Error()) - } - cert = sanitizeCertificateBundle(cert) + stmt, err := sqlair.Prepare(fmt.Sprintf(updateCertificateRequestStmt, db.certificateTable), CertificateRequest{}) + if err != nil { + return err } - _, err = db.conn.Exec(fmt.Sprintf(queryUpdateCSR, db.certificateTable), cert, csr.ID) + newRow := CertificateRequest{ + ID: oldRow.ID, + CSR: oldRow.CSR, + CertificateChain: "", + Status: "Rejected", + } + err = db.conn.Query(context.Background(), stmt, newRow).Run() + return err +} + +// RejectCSRbyCSR updates input ID's row by setting the certificate bundle to "" and sets the row status to "Rejected". +func (db *Database) RejectCertificateRequestByID(id int) error { + oldRow, err := db.GetCertificateRequestByID(id) if err != nil { - return 0, err + return err + } + stmt, err := sqlair.Prepare(fmt.Sprintf(updateCertificateRequestStmt, db.certificateTable), CertificateRequest{}) + if err != nil { + return err + } + newRow := CertificateRequest{ + ID: oldRow.ID, + CSR: oldRow.CSR, + CertificateChain: "", + Status: "Rejected", } - return int64(csr.ID), nil + err = db.conn.Query(context.Background(), stmt, newRow).Run() + return err } -// DeleteCSR removes a CSR from the database alongside the certificate that may have been generated for it. -func (db *Database) DeleteCSR(id string) (int64, error) { - result, err := db.conn.Exec(fmt.Sprintf(queryDeleteCSR, db.certificateTable), id) +// RevokeCertificateByCSR updates the input CSR's row by setting the certificate bundle to "" and sets the row status to "Revoked". +func (db *Database) RevokeCertificateByCSR(csr string) error { + oldRow, err := db.GetCertificateRequestByCSR(csr) if err != nil { - return 0, err + return err } - deleteId, err := result.RowsAffected() + stmt, err := sqlair.Prepare(fmt.Sprintf(updateCertificateRequestStmt, db.certificateTable), CertificateRequest{}) if err != nil { - return 0, err + return err } - if deleteId == 0 { - return 0, ErrIdNotFound + newRow := CertificateRequest{ + ID: oldRow.ID, + CSR: oldRow.CSR, + CertificateChain: "", + Status: "Revoked", } - return deleteId, nil + err = db.conn.Query(context.Background(), stmt, newRow).Run() + return err } -// RetrieveAllUsers returns all of the users and their fields available in the database. -func (db *Database) RetrieveAllUsers() ([]User, error) { - rows, err := db.conn.Query(fmt.Sprintf(queryGetAllUsers, db.usersTable)) +// DeleteCertificateRequestByCSR removes a CSR from the database alongside the certificate that may have been generated for it. +func (db *Database) DeleteCertificateRequestByCSR(csr string) error { + stmt, err := sqlair.Prepare(fmt.Sprintf(deleteCertificateRequestStmt, db.certificateTable), CertificateRequest{}) if err != nil { - return nil, err + return err } + row := CertificateRequest{ + CSR: csr, + } + err = db.conn.Query(context.Background(), stmt, row).Run() + return err +} - var allUsers []User - defer rows.Close() - for rows.Next() { - var user User - if err := rows.Scan(&user.ID, &user.Username, &user.Password, &user.Permissions); err != nil { - return nil, err - } - allUsers = append(allUsers, user) +// DeleteCSRByID removes a CSR from the database alongside the certificate that may have been generated for it. +func (db *Database) DeleteCertificateRequestByID(id int) error { + stmt, err := sqlair.Prepare(fmt.Sprintf(deleteCertificateRequestStmt, db.certificateTable), CertificateRequest{}) + if err != nil { + return err } - return allUsers, nil + row := CertificateRequest{ + ID: id, + } + err = db.conn.Query(context.Background(), stmt, row).Run() + return err } -// RetrieveUser retrieves the name, password and the permission level of a user. -func (db *Database) RetrieveUser(id string) (User, error) { - var newUser User - row := db.conn.QueryRow(fmt.Sprintf(queryGetUser, db.usersTable), id) - if err := row.Scan(&newUser.ID, &newUser.Username, &newUser.Password, &newUser.Permissions); err != nil { - if err.Error() == "sql: no rows in result set" { - return newUser, ErrIdNotFound - } - return newUser, err +// ListUsers returns all of the users and their fields available in the database. +func (db *Database) ListUsers() ([]User, error) { + stmt, err := sqlair.Prepare(fmt.Sprintf(listUsersStmt, db.usersTable), User{}) + if err != nil { + return nil, err } - return newUser, nil + var users []User + err = db.conn.Query(context.Background(), stmt).GetAll(&users) + if err != nil { + return nil, err + } + return users, nil } -// RetrieveUser retrieves the id, password and the permission level of a user. -func (db *Database) RetrieveUserByUsername(name string) (User, error) { - var newUser User - row := db.conn.QueryRow(fmt.Sprintf(queryGetUserByUsername, db.usersTable), name) - if err := row.Scan(&newUser.ID, &newUser.Username, &newUser.Password, &newUser.Permissions); err != nil { - if err.Error() == "sql: no rows in result set" { - return newUser, ErrIdNotFound - } - return newUser, err +// GetUserByID retrieves the name, password and the permission level of a user. +func (db *Database) GetUserByID(id int) (*User, error) { + row := User{ + ID: id, + } + stmt, err := sqlair.Prepare(fmt.Sprintf(getUserStmt, db.usersTable), User{}) + if err != nil { + return nil, err + } + err = db.conn.Query(context.Background(), stmt, row).Get(&row) + if err != nil { + return nil, err + } + return &row, nil +} + +// GetUserByUsername retrieves the id, password and the permission level of a user. +func (db *Database) GetUserByUsername(name string) (*User, error) { + row := User{ + Username: name, + } + stmt, err := sqlair.Prepare(fmt.Sprintf(getUserStmt, db.usersTable), User{}) + if err != nil { + return nil, err + } + err = db.conn.Query(context.Background(), stmt, row).Get(&row) + if err != nil { + return nil, err } - return newUser, nil + return &row, nil } // CreateUser creates a new user from a given username, password and permission level. // The permission level 1 represents an admin, and a 0 represents a regular user. // The password passed in should be in plaintext. This function handles hashing and salting the password before storing it in the database. -func (db *Database) CreateUser(username string, password string, permission int) (int64, error) { +func (db *Database) CreateUser(username string, password string, permission int) error { pw, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) if err != nil { - return 0, err + return err } - result, err := db.conn.Exec(fmt.Sprintf(queryCreateUser, db.usersTable), username, pw, permission) + stmt, err := sqlair.Prepare(fmt.Sprintf(createUserStmt, db.usersTable), User{}) if err != nil { - return 0, err + return err } - id, err := result.LastInsertId() - if err != nil { - return 0, err + row := User{ + Username: username, + HashedPassword: string(pw), + Permissions: permission, } - return id, nil + err = db.conn.Query(context.Background(), stmt, row).Run() + return err } // UpdateUser updates the password of the given user. // Just like with CreateUser, this function handles hashing and salting the password before storage. -func (db *Database) UpdateUser(id, password string) (int64, error) { - user, err := db.RetrieveUser(id) +func (db *Database) UpdateUserPassword(id int, password string) error { + _, err := db.GetUserByID(id) if err != nil { - return 0, err + return err } pw, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) if err != nil { - return 0, err + return err } - result, err := db.conn.Exec(fmt.Sprintf(queryUpdateUser, db.usersTable), pw, user.ID) + stmt, err := sqlair.Prepare(fmt.Sprintf(updateUserStmt, db.usersTable), User{}) if err != nil { - return 0, err + return err } - affectedRows, err := result.RowsAffected() - if err != nil { - return 0, err + row := User{ + ID: id, + HashedPassword: string(pw), } - return affectedRows, nil + err = db.conn.Query(context.Background(), stmt, row).Run() + return err } -// DeleteUser removes a user from the table. -func (db *Database) DeleteUser(id string) (int64, error) { - result, err := db.conn.Exec(fmt.Sprintf(queryDeleteUser, db.usersTable), id) +// DeleteUserByID removes a user from the table. +func (db *Database) DeleteUserByID(id int) error { + _, err := db.GetUserByID(id) if err != nil { - return 0, err + return err } - deleteId, err := result.RowsAffected() + stmt, err := sqlair.Prepare(fmt.Sprintf(deleteUserStmt, db.usersTable), User{}) if err != nil { - return 0, err + return err } - if deleteId == 0 { - return 0, ErrIdNotFound + row := User{ + ID: id, } - return deleteId, nil + err = db.conn.Query(context.Background(), stmt, row).Run() + return err +} + +type NumUsers struct { + Count int `db:"count"` } // NumUsers returns the number of users in the database. func (db *Database) NumUsers() (int, error) { - var numUsers int - row := db.conn.QueryRow(fmt.Sprintf(queryGetNumUsers, db.usersTable)) - if err := row.Scan(&numUsers); err != nil { + stmt, err := sqlair.Prepare(fmt.Sprintf(getNumUsersStmt, db.usersTable), NumUsers{}) + if err != nil { + return 0, err + } + result := NumUsers{} + err = db.conn.Query(context.Background(), stmt).Get(&result) + if err != nil { return 0, err } - return numUsers, nil + return result.Count, nil } // Close closes the connection to the repository cleanly. @@ -276,7 +414,7 @@ func (db *Database) Close() error { if db.conn == nil { return nil } - if err := db.conn.Close(); err != nil { + if err := db.conn.PlainDB().Close(); err != nil { return err } return nil @@ -287,18 +425,18 @@ func (db *Database) Close() error { // The database path must be a valid file path or ":memory:". // The table will be created if it doesn't exist in the format expected by the package. func NewDatabase(databasePath string) (*Database, error) { - conn, err := sql.Open("sqlite3", databasePath) + sqlConnection, err := sql.Open("sqlite3", databasePath) if err != nil { return nil, err } - if _, err := conn.Exec(fmt.Sprintf(queryCreateCSRsTable, certificateRequestsTableName)); err != nil { + if _, err := sqlConnection.Exec(fmt.Sprintf(queryCreateCertificateRequestsTable, certificateRequestsTableName)); err != nil { return nil, err } - if _, err := conn.Exec(fmt.Sprintf(queryCreateUsersTable, usersTableName)); err != nil { + if _, err := sqlConnection.Exec(fmt.Sprintf(queryCreateUsersTable, usersTableName)); err != nil { return nil, err } db := new(Database) - db.conn = conn + db.conn = sqlair.NewDB(sqlConnection) db.certificateTable = certificateRequestsTableName db.usersTable = usersTableName return db, nil diff --git a/internal/db/db_test.go b/internal/db/db_test.go index 5d0f768..19cb647 100644 --- a/internal/db/db_test.go +++ b/internal/db/db_test.go @@ -3,7 +3,7 @@ package db_test import ( "fmt" "log" - "strconv" + "path/filepath" "strings" "testing" @@ -12,7 +12,8 @@ import ( ) func TestConnect(t *testing.T) { - db, err := db.NewDatabase(":memory:") + tempDir := t.TempDir() + db, err := db.NewDatabase(filepath.Join(tempDir, "db.sqlite3")) if err != nil { t.Fatalf("Can't connect to SQLite: %s", err) } @@ -20,33 +21,34 @@ func TestConnect(t *testing.T) { } func TestCSRsEndToEnd(t *testing.T) { - db, err := db.NewDatabase(":memory:") + tempDir := t.TempDir() + db, err := db.NewDatabase(filepath.Join(tempDir, "db.sqlite3")) if err != nil { t.Fatalf("Couldn't complete NewDatabase: %s", err) } defer db.Close() - id1, err := db.CreateCSR(AppleCSR) + err = db.CreateCertificateRequest(AppleCSR) if err != nil { t.Fatalf("Couldn't complete Create: %s", err) } - id2, err := db.CreateCSR(BananaCSR) + err = db.CreateCertificateRequest(BananaCSR) if err != nil { t.Fatalf("Couldn't complete Create: %s", err) } - id3, err := db.CreateCSR(StrawberryCSR) + err = db.CreateCertificateRequest(StrawberryCSR) if err != nil { t.Fatalf("Couldn't complete Create: %s", err) } - res, err := db.RetrieveAllCSRs() + res, err := db.ListCertificateRequests() if err != nil { t.Fatalf("Couldn't complete RetrieveAll: %s", err) } if len(res) != 3 { t.Fatalf("One or more CSRs weren't found in DB") } - retrievedCSR, err := db.RetrieveCSR(strconv.FormatInt(id1, 10)) + retrievedCSR, err := db.GetCertificateRequestByCSR(AppleCSR) if err != nil { t.Fatalf("Couldn't complete Retrieve: %s", err) } @@ -54,32 +56,32 @@ func TestCSRsEndToEnd(t *testing.T) { t.Fatalf("The CSR from the database doesn't match the CSR that was given") } - if _, err = db.DeleteCSR(strconv.FormatInt(id1, 10)); err != nil { + if err = db.DeleteCertificateRequestByCSR(AppleCSR); err != nil { t.Fatalf("Couldn't complete Delete: %s", err) } - res, _ = db.RetrieveAllCSRs() + res, _ = db.ListCertificateRequests() if len(res) != 2 { t.Fatalf("CSR's weren't deleted from the DB properly") } BananaCertBundle := strings.TrimSpace(fmt.Sprintf("%s%s", BananaCert, IssuerCert)) - _, err = db.UpdateCSR(strconv.FormatInt(id2, 10), BananaCertBundle) + err = db.AddCertificateChainToCertificateRequestByCSR(BananaCSR, BananaCertBundle) if err != nil { t.Fatalf("Couldn't complete Update: %s", err) } - retrievedCSR, _ = db.RetrieveCSR(strconv.FormatInt(id2, 10)) - if retrievedCSR.Certificate != BananaCertBundle { - t.Fatalf("The certificate that was uploaded does not match the certificate that was given.\n Retrieved: %s\nGiven: %s", retrievedCSR.Certificate, BananaCertBundle) + retrievedCSR, _ = db.GetCertificateRequestByCSR(BananaCSR) + if retrievedCSR.CertificateChain != BananaCertBundle { + t.Fatalf("The certificate that was uploaded does not match the certificate that was given.\n Retrieved: %s\nGiven: %s", retrievedCSR.CertificateChain, BananaCertBundle) } - _, err = db.UpdateCSR(strconv.FormatInt(id2, 10), "") + err = db.RevokeCertificateByCSR(BananaCSR) if err != nil { - t.Fatalf("Couldn't complete Update to delete certificate: %s", err) + t.Fatalf("Couldn't complete Update to revoke certificate: %s", err) } - _, err = db.UpdateCSR(strconv.FormatInt(id3, 10), "rejected") + err = db.RejectCertificateRequestByCSR(StrawberryCSR) if err != nil { t.Fatalf("Couldn't complete Update to reject CSR: %s", err) } - retrievedCSR, _ = db.RetrieveCSR(strconv.FormatInt(id2, 10)) - if retrievedCSR.Certificate != "" { + retrievedCSR, _ = db.GetCertificateRequestByCSR(BananaCSR) + if retrievedCSR.Status != "Revoked" { t.Fatalf("Couldn't delete certificate") } } @@ -89,12 +91,12 @@ func TestCreateFails(t *testing.T) { defer db.Close() InvalidCSR := strings.ReplaceAll(AppleCSR, "M", "i") - if _, err := db.CreateCSR(InvalidCSR); err == nil { + if err := db.CreateCertificateRequest(InvalidCSR); err == nil { t.Fatalf("Expected error due to invalid CSR") } - db.CreateCSR(AppleCSR) //nolint:errcheck - if _, err := db.CreateCSR(AppleCSR); err == nil { + db.CreateCertificateRequest(AppleCSR) //nolint:errcheck + if err := db.CreateCertificateRequest(AppleCSR); err == nil { t.Fatalf("Expected error due to duplicate CSR") } } @@ -103,13 +105,13 @@ func TestUpdateFails(t *testing.T) { db, _ := db.NewDatabase(":memory:") defer db.Close() - id1, _ := db.CreateCSR(AppleCSR) //nolint:errcheck - id2, _ := db.CreateCSR(BananaCSR) //nolint:errcheck + db.CreateCertificateRequest(AppleCSR) //nolint:errcheck + db.CreateCertificateRequest(BananaCSR) //nolint:errcheck InvalidCert := strings.ReplaceAll(BananaCert, "/", "+") - if _, err := db.UpdateCSR(strconv.FormatInt(id2, 10), InvalidCert); err == nil { + if err := db.AddCertificateChainToCertificateRequestByCSR(BananaCSR, InvalidCert); err == nil { t.Fatalf("Expected updating with invalid cert to fail") } - if _, err := db.UpdateCSR(strconv.FormatInt(id1, 10), BananaCert); err == nil { + if err := db.AddCertificateChainToCertificateRequestByCSR(AppleCSR, BananaCert); err == nil { t.Fatalf("Expected updating with mismatched cert to fail") } } @@ -118,8 +120,11 @@ func TestRetrieve(t *testing.T) { db, _ := db.NewDatabase(":memory:") //nolint:errcheck defer db.Close() - db.CreateCSR(AppleCSR) //nolint:errcheck - if _, err := db.RetrieveCSR("this is definitely not an id"); err == nil { + db.CreateCertificateRequest(AppleCSR) //nolint:errcheck + if _, err := db.GetCertificateRequestByCSR("this is definitely not an id"); err == nil { + t.Fatalf("Expected failure looking for nonexistent CSR") + } + if _, err := db.GetCertificateRequestByID(-1); err == nil { t.Fatalf("Expected failure looking for nonexistent CSR") } } @@ -131,16 +136,16 @@ func TestUsersEndToEnd(t *testing.T) { } defer db.Close() - id1, err := db.CreateUser("admin", "pw123", 1) + err = db.CreateUser("admin", "pw123", 1) if err != nil { t.Fatalf("Couldn't complete Create: %s", err) } - id2, err := db.CreateUser("norman", "pw456", 0) + err = db.CreateUser("norman", "pw456", 0) if err != nil { t.Fatalf("Couldn't complete Create: %s", err) } - res, err := db.RetrieveAllUsers() + res, err := db.ListUsers() if err != nil { t.Fatalf("Couldn't complete RetrieveAll: %s", err) } @@ -154,31 +159,36 @@ func TestUsersEndToEnd(t *testing.T) { if num != 2 { t.Fatalf("NumUsers didn't return the correct number of users") } - retrievedUser, err := db.RetrieveUser(strconv.FormatInt(id1, 10)) + retrievedUser, err := db.GetUserByUsername("admin") + if err != nil { + t.Fatalf("Couldn't complete Retrieve: %s", err) + } + if retrievedUser.Username != "admin" { + t.Fatalf("The user from the database doesn't match the user that was given") + } + retrievedUser, err = db.GetUserByID(1) if err != nil { t.Fatalf("Couldn't complete Retrieve: %s", err) } if retrievedUser.Username != "admin" { t.Fatalf("The user from the database doesn't match the user that was given") } - if err := bcrypt.CompareHashAndPassword([]byte(retrievedUser.Password), []byte("pw123")); err != nil { + if err := bcrypt.CompareHashAndPassword([]byte(retrievedUser.HashedPassword), []byte("pw123")); err != nil { t.Fatalf("The user's password doesn't match the one stored in the database") } - - if _, err = db.DeleteUser(strconv.FormatInt(id1, 10)); err != nil { + if err = db.DeleteUserByID(1); err != nil { t.Fatalf("Couldn't complete Delete: %s", err) } - res, _ = db.RetrieveAllUsers() + res, _ = db.ListUsers() if len(res) != 1 { t.Fatalf("users weren't deleted from the DB properly") } - - _, err = db.UpdateUser(strconv.FormatInt(id2, 10), "thebestpassword") + err = db.UpdateUserPassword(2, "thebestpassword") if err != nil { t.Fatalf("Couldn't complete Update: %s", err) } - retrievedUser, _ = db.RetrieveUser(strconv.FormatInt(id2, 10)) - if err := bcrypt.CompareHashAndPassword([]byte(retrievedUser.Password), []byte("thebestpassword")); err != nil { + retrievedUser, _ = db.GetUserByUsername("norman") + if err := bcrypt.CompareHashAndPassword([]byte(retrievedUser.HashedPassword), []byte("thebestpassword")); err != nil { t.Fatalf("The new password that was given does not match the password that was stored.") } } @@ -188,19 +198,19 @@ func Example() { if err != nil { log.Fatalln(err) } - _, err = db.CreateCSR(BananaCSR) + err = db.CreateCertificateRequest(BananaCSR) if err != nil { log.Fatalln(err) } - _, err = db.UpdateCSR(BananaCSR, BananaCert) + err = db.AddCertificateChainToCertificateRequestByCSR(BananaCSR, BananaCert) if err != nil { log.Fatalln(err) } - entry, err := db.RetrieveCSR(BananaCSR) + entry, err := db.GetCertificateRequestByCSR(BananaCSR) if err != nil { log.Fatalln(err) } - if entry.Certificate != BananaCert { + if entry.CertificateChain != BananaCert { log.Fatalln("Retrieved Certificate doesn't match Stored Certificate") } err = db.Close() diff --git a/internal/metrics/metrics.go b/internal/metrics/metrics.go index 1076d61..725891f 100644 --- a/internal/metrics/metrics.go +++ b/internal/metrics/metrics.go @@ -37,7 +37,7 @@ func NewMetricsSubsystem(db *db.Database) *PrometheusMetrics { ticker := time.NewTicker(120 * time.Second) go func() { for ; ; <-ticker.C { - csrs, err := db.RetrieveAllCSRs() + csrs, err := db.ListCertificateRequests() if err != nil { log.Println(errors.Join(errors.New("error generating metrics repository: "), err)) panic(1) @@ -95,15 +95,15 @@ func (pm *PrometheusMetrics) GenerateMetrics(csrs []db.CertificateRequest) { var expiringIn30DaysCertCount float64 var expiringIn90DaysCertCount float64 for _, entry := range csrs { - if entry.Certificate == "" { + if entry.CertificateChain == "" { outstandingCSRCount += 1 continue } - if entry.Certificate == "rejected" { + if entry.Status == "Rejected" { continue } certCount += 1 - expiryDate := certificateExpiryDate(entry.Certificate) + expiryDate := certificateExpiryDate(entry.CertificateChain) daysRemaining := time.Until(expiryDate).Hours() / 24 if daysRemaining < 0 { expiredCertCount += 1 diff --git a/internal/metrics/metrics_test.go b/internal/metrics/metrics_test.go index ecdd854..0775f28 100644 --- a/internal/metrics/metrics_test.go +++ b/internal/metrics/metrics_test.go @@ -21,7 +21,8 @@ import ( // TestPrometheusHandler tests that the Prometheus metrics handler responds correctly to an HTTP request. func TestPrometheusHandler(t *testing.T) { - db, err := db.NewDatabase(":memory:") + tempDir := t.TempDir() + db, err := db.NewDatabase(filepath.Join(tempDir, "db.sqlite3")) if err != nil { t.Fatal(err) } @@ -95,13 +96,13 @@ func generateCertPair(daysRemaining int) (string, string, string) { } func initializeTestDB(t *testing.T, db *db.Database) { - for i, v := range []int{5, 10, 32} { + for _, v := range []int{5, 10, 32} { csr, cert, ca := generateCertPair(v) - _, err := db.CreateCSR(csr) + err := db.CreateCertificateRequest(csr) if err != nil { t.Fatalf("couldn't create test csr: %s", err) } - _, err = db.UpdateCSR(fmt.Sprint(i+1), fmt.Sprintf("%s%s", cert, ca)) + err = db.AddCertificateChainToCertificateRequestByCSR(csr, fmt.Sprintf("%s%s", cert, ca)) if err != nil { t.Fatalf("couldn't create test cert: %s", err) } @@ -117,7 +118,7 @@ func TestMetrics(t *testing.T) { } initializeTestDB(t, db) m := metrics.NewMetricsSubsystem(db) - csrs, _ := db.RetrieveAllCSRs() + csrs, _ := db.ListCertificateRequests() m.GenerateMetrics(csrs) request, _ := http.NewRequest("GET", "/", nil) diff --git a/internal/server/authorization_test.go b/internal/server/authorization_test.go index 164faa7..4aaf8f9 100644 --- a/internal/server/authorization_test.go +++ b/internal/server/authorization_test.go @@ -2,12 +2,15 @@ package server_test import ( "net/http" + "path/filepath" "strings" "testing" ) func TestAuthorizationNoAuth(t *testing.T) { - ts, _, err := setupServer() + tempDir := t.TempDir() + db_path := filepath.Join(tempDir, "db.sqlite3") + ts, _, err := setupServer(db_path) if err != nil { t.Fatalf("couldn't create test server: %s", err) } @@ -47,7 +50,9 @@ func TestAuthorizationNoAuth(t *testing.T) { } func TestAuthorizationNonAdminAuthorized(t *testing.T) { - ts, _, err := setupServer() + tempDir := t.TempDir() + db_path := filepath.Join(tempDir, "db.sqlite3") + ts, _, err := setupServer(db_path) if err != nil { t.Fatalf("couldn't create test server: %s", err) } @@ -98,7 +103,9 @@ func TestAuthorizationNonAdminAuthorized(t *testing.T) { } func TestAuthorizationNonAdminUnauthorized(t *testing.T) { - ts, _, err := setupServer() + tempDir := t.TempDir() + db_path := filepath.Join(tempDir, "db.sqlite3") + ts, _, err := setupServer(db_path) if err != nil { t.Fatalf("couldn't create test server: %s", err) } @@ -156,7 +163,9 @@ func TestAuthorizationNonAdminUnauthorized(t *testing.T) { } func TestAuthorizationAdminAuthorized(t *testing.T) { - ts, _, err := setupServer() + tempDir := t.TempDir() + db_path := filepath.Join(tempDir, "db.sqlite3") + ts, _, err := setupServer(db_path) if err != nil { t.Fatalf("couldn't create test server: %s", err) } @@ -205,7 +214,9 @@ func TestAuthorizationAdminAuthorized(t *testing.T) { } func TestAuthorizationAdminUnAuthorized(t *testing.T) { - ts, _, err := setupServer() + tempDir := t.TempDir() + db_path := filepath.Join(tempDir, "db.sqlite3") + ts, _, err := setupServer(db_path) if err != nil { t.Fatalf("couldn't create test server: %s", err) } diff --git a/internal/server/handlers_accounts.go b/internal/server/handlers_accounts.go index 7b72e82..3e8beb3 100644 --- a/internal/server/handlers_accounts.go +++ b/internal/server/handlers_accounts.go @@ -10,6 +10,7 @@ import ( "strings" "github.com/canonical/notary/internal/db" + "github.com/canonical/sqlair" ) type CreateAccountParams struct { @@ -27,18 +28,6 @@ type GetAccountResponse struct { Permissions int `json:"permissions"` } -type CreateAccountResponse struct { - ID int `json:"id"` -} - -type ChangeAccountResponse struct { - ID int `json:"id"` -} - -type DeleteAccountResponse struct { - ID int `json:"id"` -} - func validatePassword(password string) bool { if len(password) < 8 { return false @@ -59,7 +48,7 @@ func validatePassword(password string) bool { // ListAccounts returns all accounts from the database func ListAccounts(env *HandlerConfig) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - accounts, err := env.DB.RetrieveAllUsers() + accounts, err := env.DB.ListUsers() if err != nil { log.Println(err) writeError(w, http.StatusInternalServerError, "Internal Error") @@ -87,20 +76,24 @@ func ListAccounts(env *HandlerConfig) http.HandlerFunc { func GetAccount(env *HandlerConfig) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { id := r.PathValue("id") - var account db.User - var err error + idNum, err := strconv.Atoi(id) + if err != nil { + writeError(w, http.StatusInternalServerError, "Internal Error") + return + } + var account *db.User if id == "me" { claims, headerErr := getClaimsFromAuthorizationHeader(r.Header.Get("Authorization"), env.JWTSecret) if headerErr != nil { writeError(w, http.StatusUnauthorized, "Unauthorized") } - account, err = env.DB.RetrieveUserByUsername(claims.Username) + account, err = env.DB.GetUserByUsername(claims.Username) } else { - account, err = env.DB.RetrieveUser(id) + account, err = env.DB.GetUserByID(idNum) } if err != nil { log.Println(err) - if errors.Is(err, db.ErrIdNotFound) { + if errors.Is(err, sqlair.ErrNoRows) { writeError(w, http.StatusNotFound, "Not Found") return } @@ -150,12 +143,11 @@ func CreateAccount(env *HandlerConfig) http.HandlerFunc { writeError(w, http.StatusInternalServerError, "Failed to retrieve accounts: "+err.Error()) return } - permission := UserPermission if numUsers == 0 { permission = AdminPermission } - id, err := env.DB.CreateUser(createAccountParams.Username, createAccountParams.Password, permission) + err = env.DB.CreateUser(createAccountParams.Username, createAccountParams.Password, permission) if err != nil { if strings.Contains(err.Error(), "UNIQUE constraint failed") { writeError(w, http.StatusBadRequest, "account with given username already exists") @@ -165,11 +157,9 @@ func CreateAccount(env *HandlerConfig) http.HandlerFunc { writeError(w, http.StatusInternalServerError, "Internal Error") return } - accountResponse := CreateAccountResponse{ - ID: int(id), - } + successResponse := SuccessResponse{Message: "success"} w.WriteHeader(http.StatusCreated) - err = writeJSON(w, accountResponse) + err = writeJSON(w, successResponse) if err != nil { writeError(w, http.StatusInternalServerError, "internal error") return @@ -182,39 +172,39 @@ func CreateAccount(env *HandlerConfig) http.HandlerFunc { func DeleteAccount(env *HandlerConfig) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { id := r.PathValue("id") - idInt, err := strconv.ParseInt(id, 10, 64) + idInt, err := strconv.Atoi(id) if err != nil { log.Println(err) writeError(w, http.StatusInternalServerError, "Internal Error") return } - account, err := env.DB.RetrieveUser(id) + account, err := env.DB.GetUserByID(idInt) if err != nil { - if !errors.Is(err, db.ErrIdNotFound) { - log.Println(err) - writeError(w, http.StatusInternalServerError, "Internal Error") + log.Println(err) + if errors.Is(err, sqlair.ErrNoRows) { + writeError(w, http.StatusNotFound, "Not Found") return } + writeError(w, http.StatusInternalServerError, "Internal Error") + return } if account.Permissions == 1 { writeError(w, http.StatusBadRequest, "deleting an Admin account is not allowed.") return } - _, err = env.DB.DeleteUser(id) + err = env.DB.DeleteUserByID(idInt) if err != nil { log.Println(err) - if errors.Is(err, db.ErrIdNotFound) { + if errors.Is(err, sqlair.ErrNoRows) { writeError(w, http.StatusNotFound, "Not Found") return } writeError(w, http.StatusInternalServerError, "Internal Error") return } - deleteAccountResponse := DeleteAccountResponse{ - ID: int(idInt), - } + successResponse := SuccessResponse{Message: "success"} w.WriteHeader(http.StatusAccepted) - err = writeJSON(w, deleteAccountResponse) + err = writeJSON(w, successResponse) if err != nil { writeError(w, http.StatusInternalServerError, "internal error") return @@ -225,6 +215,7 @@ func DeleteAccount(env *HandlerConfig) http.HandlerFunc { func ChangeAccountPassword(env *HandlerConfig) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { id := r.PathValue("id") + var idNum int if id == "me" { claims, err := getClaimsFromAuthorizationHeader(r.Header.Get("Authorization"), env.JWTSecret) if err != nil { @@ -232,13 +223,21 @@ func ChangeAccountPassword(env *HandlerConfig) http.HandlerFunc { writeError(w, http.StatusUnauthorized, "Unauthorized") return } - account, err := env.DB.RetrieveUserByUsername(claims.Username) + account, err := env.DB.GetUserByUsername(claims.Username) if err != nil { log.Println(err) writeError(w, http.StatusUnauthorized, "Unauthorized") return } - id = strconv.Itoa(account.ID) + idNum = account.ID + } else { + idInt, err := strconv.Atoi(id) + if err != nil { + log.Println(err) + writeError(w, http.StatusInternalServerError, "Internal Error") + return + } + idNum = idInt } var changeAccountParams ChangeAccountParams if err := json.NewDecoder(r.Body).Decode(&changeAccountParams); err != nil { @@ -257,21 +256,19 @@ func ChangeAccountPassword(env *HandlerConfig) http.HandlerFunc { ) return } - ret, err := env.DB.UpdateUser(id, changeAccountParams.Password) + err := env.DB.UpdateUserPassword(idNum, changeAccountParams.Password) if err != nil { log.Println(err) - if errors.Is(err, db.ErrIdNotFound) { + if errors.Is(err, sqlair.ErrNoRows) { writeError(w, http.StatusNotFound, "Not Found") return } writeError(w, http.StatusInternalServerError, "Internal Error") return } - changeAccountResponse := ChangeAccountResponse{ - ID: int(ret), - } + successResponse := SuccessResponse{Message: "success"} w.WriteHeader(http.StatusCreated) - err = writeJSON(w, changeAccountResponse) + err = writeJSON(w, successResponse) if err != nil { writeError(w, http.StatusInternalServerError, "internal error") return diff --git a/internal/server/handlers_accounts_test.go b/internal/server/handlers_accounts_test.go index 42e2e7a..38d40a1 100644 --- a/internal/server/handlers_accounts_test.go +++ b/internal/server/handlers_accounts_test.go @@ -3,11 +3,16 @@ package server_test import ( "encoding/json" "net/http" + "path/filepath" "strconv" "strings" "testing" ) +type SuccessResponse struct { + Message string `json:"message"` +} + type GetAccountResponseResult struct { ID int `json:"id"` Username string `json:"username"` @@ -29,8 +34,8 @@ type CreateAccountResponseResult struct { } type CreateAccountResponse struct { - Result CreateAccountResponseResult `json:"result"` - Error string `json:"error,omitempty"` + Result SuccessResponse `json:"result"` + Error string `json:"error,omitempty"` } type ChangeAccountPasswordParams struct { @@ -139,7 +144,9 @@ func deleteAccount(url string, client *http.Client, adminToken string, id int) ( // The order of the tests is important, as some tests depend on // the state of the server after previous tests. func TestAccountsEndToEnd(t *testing.T) { - ts, _, err := setupServer() + tempDir := t.TempDir() + db_path := filepath.Join(tempDir, "db.sqlite3") + ts, _, err := setupServer(db_path) if err != nil { t.Fatalf("couldn't create test server: %s", err) } @@ -197,10 +204,7 @@ func TestAccountsEndToEnd(t *testing.T) { t.Fatalf("expected status %d, got %d", http.StatusCreated, statusCode) } if response.Error != "" { - t.Fatalf("expected error %q, got %q", "", response.Error) - } - if response.Result.ID != 3 { - t.Fatalf("expected ID 3, got %d", response.Result.ID) + t.Fatalf("unexpected error :%q", response.Error) } }) @@ -285,10 +289,7 @@ func TestAccountsEndToEnd(t *testing.T) { t.Fatalf("expected status %d, got %d", http.StatusCreated, statusCode) } if response.Error != "" { - t.Fatalf("expected error %q, got %q", "", response.Error) - } - if response.Result.ID != 1 { - t.Fatalf("expected ID 1, got %d", response.Result.ID) + t.Fatalf("unexpected error :%q", response.Error) } }) @@ -351,9 +352,6 @@ func TestAccountsEndToEnd(t *testing.T) { if response.Error != "" { t.Fatalf("expected error %q, got %q", "", response.Error) } - if response.Result.ID != 2 { - t.Fatalf("expected ID 2, got %d", response.Result.ID) - } }) t.Run("13. Delete account - no user", func(t *testing.T) { diff --git a/internal/server/handlers_certificate_requests.go b/internal/server/handlers_certificate_requests.go index a1a2d0f..a914c90 100644 --- a/internal/server/handlers_certificate_requests.go +++ b/internal/server/handlers_certificate_requests.go @@ -8,7 +8,7 @@ import ( "strconv" "strings" - "github.com/canonical/notary/internal/db" + "github.com/canonical/sqlair" ) type CreateCertificateRequestParams struct { @@ -16,54 +16,32 @@ type CreateCertificateRequestParams struct { } type CreateCertificateParams struct { - Certificate string `json:"certificate"` + CertificateChain string `json:"certificate"` } -type GetCertificateRequestResponse struct { - ID int `json:"id"` - CSR string `json:"csr"` - Certificate string `json:"certificate"` -} - -type CreateCertificateRequestResponse struct { - ID int `json:"id"` -} - -type DeleteCertificateRequestResponse struct { - ID int `json:"id"` -} - -type RejectCertificateRequestResponse struct { - ID int `json:"id"` -} - -type CreateCertificateResponse struct { - ID int `json:"id"` -} - -type DeleteCertificateResponse struct { - ID int `json:"id"` -} - -type RejectCertificateResponse struct { - ID int `json:"id"` +type CertificateRequest struct { + ID int `json:"id"` + CSR string `json:"csr"` + CertificateChain string `json:"certificate_chain"` + Status string `json:"status"` } // ListCertificateRequests returns all of the Certificate Requests func ListCertificateRequests(env *HandlerConfig) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - certs, err := env.DB.RetrieveAllCSRs() + csrs, err := env.DB.ListCertificateRequests() if err != nil { log.Println(err) writeError(w, http.StatusInternalServerError, "Internal Error") return } - certificateRequestsResponse := make([]GetCertificateRequestResponse, len(certs)) - for i, cert := range certs { - certificateRequestsResponse[i] = GetCertificateRequestResponse{ - ID: cert.ID, - CSR: cert.CSR, - Certificate: cert.Certificate, + certificateRequestsResponse := make([]CertificateRequest, len(csrs)) + for i, csr := range csrs { + certificateRequestsResponse[i] = CertificateRequest{ + ID: csr.ID, + CSR: csr.CSR, + CertificateChain: csr.CertificateChain, + Status: csr.Status, } } w.WriteHeader(http.StatusOK) @@ -87,7 +65,7 @@ func CreateCertificateRequest(env *HandlerConfig) http.HandlerFunc { writeError(w, http.StatusBadRequest, "csr is missing") return } - id, err := env.DB.CreateCSR(createCertificateRequestParams.CSR) + err := env.DB.CreateCertificateRequest(createCertificateRequestParams.CSR) if err != nil { if strings.Contains(err.Error(), "UNIQUE constraint failed") { writeError(w, http.StatusBadRequest, "given csr already recorded") @@ -102,11 +80,9 @@ func CreateCertificateRequest(env *HandlerConfig) http.HandlerFunc { writeError(w, http.StatusInternalServerError, "Internal Error") return } - certificateRequestResponse := CreateCertificateRequestResponse{ - ID: int(id), - } + successResponse := SuccessResponse{Message: "success"} w.WriteHeader(http.StatusCreated) - err = writeJSON(w, certificateRequestResponse) + err = writeJSON(w, successResponse) if err != nil { writeError(w, http.StatusInternalServerError, "internal error") return @@ -119,20 +95,26 @@ func CreateCertificateRequest(env *HandlerConfig) http.HandlerFunc { func GetCertificateRequest(env *HandlerConfig) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { id := r.PathValue("id") - cert, err := env.DB.RetrieveCSR(id) + idNum, err := strconv.Atoi(id) + if err != nil { + writeError(w, http.StatusInternalServerError, "Internal Error") + return + } + csr, err := env.DB.GetCertificateRequestByID(idNum) if err != nil { log.Println(err) - if errors.Is(err, db.ErrIdNotFound) { + if errors.Is(err, sqlair.ErrNoRows) { writeError(w, http.StatusNotFound, "Not Found") return } writeError(w, http.StatusInternalServerError, "Internal Error") return } - certificateRequestResponse := GetCertificateRequestResponse{ - ID: cert.ID, - CSR: cert.CSR, - Certificate: cert.Certificate, + certificateRequestResponse := CertificateRequest{ + ID: csr.ID, + CSR: csr.CSR, + CertificateChain: csr.CertificateChain, + Status: csr.Status, } w.WriteHeader(http.StatusOK) err = writeJSON(w, certificateRequestResponse) @@ -148,21 +130,24 @@ func GetCertificateRequest(env *HandlerConfig) http.HandlerFunc { func DeleteCertificateRequest(env *HandlerConfig) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { id := r.PathValue("id") - insertId, err := env.DB.DeleteCSR(id) + idNum, err := strconv.Atoi(id) + if err != nil { + writeError(w, http.StatusInternalServerError, "Internal Error") + return + } + err = env.DB.DeleteCertificateRequestByID(idNum) if err != nil { log.Println(err) - if errors.Is(err, db.ErrIdNotFound) { + if errors.Is(err, sqlair.ErrNoRows) { writeError(w, http.StatusNotFound, "Not Found") return } writeError(w, http.StatusInternalServerError, "Internal Error") return } - certificateRequestResponse := DeleteCertificateRequestResponse{ - ID: int(insertId), - } + successResponse := SuccessResponse{Message: "success"} w.WriteHeader(http.StatusAccepted) - err = writeJSON(w, certificateRequestResponse) + err = writeJSON(w, successResponse) if err != nil { writeError(w, http.StatusInternalServerError, "internal error") return @@ -179,15 +164,20 @@ func CreateCertificate(env *HandlerConfig) http.HandlerFunc { writeError(w, http.StatusBadRequest, "Invalid JSON format") return } - if createCertificateParams.Certificate == "" { + if createCertificateParams.CertificateChain == "" { writeError(w, http.StatusBadRequest, "certificate is missing") return } id := r.PathValue("id") - insertId, err := env.DB.UpdateCSR(id, createCertificateParams.Certificate) + idNum, err := strconv.Atoi(id) + if err != nil { + writeError(w, http.StatusInternalServerError, "Internal Error") + return + } + err = env.DB.AddCertificateChainToCertificateRequestByID(idNum, createCertificateParams.CertificateChain) if err != nil { log.Println(err) - if errors.Is(err, db.ErrIdNotFound) || + if errors.Is(err, sqlair.ErrNoRows) || err.Error() == "certificate does not match CSR" || strings.Contains(err.Error(), "cert validation failed") { writeError(w, http.StatusBadRequest, "Bad Request") @@ -196,18 +186,15 @@ func CreateCertificate(env *HandlerConfig) http.HandlerFunc { writeError(w, http.StatusInternalServerError, "Internal Error") return } - insertIdStr := strconv.FormatInt(insertId, 10) if env.SendPebbleNotifications { - err := SendPebbleNotification("canonical.com/notary/certificate/update", insertIdStr) + err := SendPebbleNotification("canonical.com/notary/certificate/update", id) if err != nil { log.Printf("pebble notify failed: %s. continuing silently.", err.Error()) } } - certificateResponse := CreateCertificateResponse{ - ID: int(insertId), - } + successResponse := SuccessResponse{Message: "success"} w.WriteHeader(http.StatusCreated) - err = writeJSON(w, certificateResponse) + err = writeJSON(w, successResponse) if err != nil { writeError(w, http.StatusInternalServerError, "internal error") return @@ -218,28 +205,30 @@ func CreateCertificate(env *HandlerConfig) http.HandlerFunc { func RejectCertificate(env *HandlerConfig) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { id := r.PathValue("id") - insertId, err := env.DB.UpdateCSR(id, "rejected") + idNum, err := strconv.Atoi(id) + if err != nil { + writeError(w, http.StatusInternalServerError, "Internal Error") + return + } + err = env.DB.RejectCertificateRequestByID(idNum) if err != nil { log.Println(err) - if errors.Is(err, db.ErrIdNotFound) { + if errors.Is(err, sqlair.ErrNoRows) { writeError(w, http.StatusNotFound, "Not Found") return } writeError(w, http.StatusInternalServerError, "Internal Error") return } - insertIdStr := strconv.FormatInt(insertId, 10) if env.SendPebbleNotifications { - err := SendPebbleNotification("canonical.com/notary/certificate/update", insertIdStr) + err := SendPebbleNotification("canonical.com/notary/certificate/update", id) if err != nil { log.Printf("pebble notify failed: %s. continuing silently.", err.Error()) } } - certificateResponse := RejectCertificateResponse{ - ID: int(insertId), - } + successResponse := SuccessResponse{Message: "success"} w.WriteHeader(http.StatusAccepted) - err = writeJSON(w, certificateResponse) + err = writeJSON(w, successResponse) if err != nil { writeError(w, http.StatusInternalServerError, "internal error") return @@ -252,28 +241,30 @@ func RejectCertificate(env *HandlerConfig) http.HandlerFunc { func DeleteCertificate(env *HandlerConfig) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { id := r.PathValue("id") - insertId, err := env.DB.UpdateCSR(id, "") + idNum, err := strconv.Atoi(id) + if err != nil { + writeError(w, http.StatusInternalServerError, "Internal Error") + return + } + err = env.DB.DeleteCertificateRequestByID(idNum) if err != nil { log.Println(err) - if errors.Is(err, db.ErrIdNotFound) { + if errors.Is(err, sqlair.ErrNoRows) { writeError(w, http.StatusBadRequest, "Bad Request") return } writeError(w, http.StatusInternalServerError, "Internal Error") return } - insertIdStr := strconv.FormatInt(insertId, 10) if env.SendPebbleNotifications { - err := SendPebbleNotification("canonical.com/notary/certificate/update", insertIdStr) + err := SendPebbleNotification("canonical.com/notary/certificate/update", id) if err != nil { log.Printf("pebble notify failed: %s. continuing silently.", err.Error()) } } - certificateResponse := DeleteCertificateResponse{ - ID: int(insertId), - } + successResponse := SuccessResponse{Message: "success"} w.WriteHeader(http.StatusOK) - err = writeJSON(w, certificateResponse) + err = writeJSON(w, successResponse) if err != nil { writeError(w, http.StatusInternalServerError, "internal error") return diff --git a/internal/server/handlers_certificate_requests_test.go b/internal/server/handlers_certificate_requests_test.go index 21b8311..63df6d1 100644 --- a/internal/server/handlers_certificate_requests_test.go +++ b/internal/server/handlers_certificate_requests_test.go @@ -9,22 +9,18 @@ import ( "path/filepath" "strconv" "testing" -) -type CertificateRequest struct { - ID int `json:"id"` - CSR string `json:"csr"` - Certificate string `json:"certificate"` -} + "github.com/canonical/notary/internal/server" +) type GetCertificateRequestResponse struct { - Result CertificateRequest `json:"result"` - Error string `json:"error,omitempty"` + Result server.CertificateRequest `json:"result"` + Error string `json:"error,omitempty"` } type ListCertificateRequestsResponse struct { - Error string `json:"error,omitempty"` - Result []CertificateRequest `json:"result"` + Error string `json:"error,omitempty"` + Result []server.CertificateRequest `json:"result"` } type CreateCertificateRequestResponse struct { @@ -162,7 +158,10 @@ func rejectCertificate(url string, client *http.Client, adminToken string, id in // The order of the tests is important, as some tests depend on the // state of the server after previous tests. func TestCertificateRequestsEndToEnd(t *testing.T) { - ts, _, err := setupServer() + + tempDir := t.TempDir() + db_path := filepath.Join(tempDir, "db.sqlite3") + ts, _, err := setupServer(db_path) if err != nil { t.Fatalf("couldn't create test server: %s", err) } @@ -257,8 +256,8 @@ func TestCertificateRequestsEndToEnd(t *testing.T) { if getCertRequestResponse.Result.CSR == "" { t.Fatalf("expected CSR, got empty string") } - if getCertRequestResponse.Result.Certificate != "" { - t.Fatalf("expected no certificate, got %s", getCertRequestResponse.Result.Certificate) + if getCertRequestResponse.Result.CertificateChain != "" { + t.Fatalf("expected no certificate, got %s", getCertRequestResponse.Result.CertificateChain) } }) @@ -353,8 +352,8 @@ func TestCertificateRequestsEndToEnd(t *testing.T) { if getCertRequestResponse.Result.CSR == "" { t.Fatalf("expected CSR, got empty string") } - if getCertRequestResponse.Result.Certificate != "" { - t.Fatalf("expected no certificate, got %s", getCertRequestResponse.Result.Certificate) + if getCertRequestResponse.Result.CertificateChain != "" { + t.Fatalf("expected no certificate, got %s", getCertRequestResponse.Result.CertificateChain) } }) @@ -399,7 +398,9 @@ func TestCertificateRequestsEndToEnd(t *testing.T) { // The order of the tests is important, as some tests depend on the // state of the server after previous tests. func TestCertificatesEndToEnd(t *testing.T) { - ts, _, err := setupServer() + tempDir := t.TempDir() + db_path := filepath.Join(tempDir, "db.sqlite3") + ts, _, err := setupServer(db_path) if err != nil { t.Fatalf("couldn't create test server: %s", err) } @@ -481,7 +482,7 @@ func TestCertificatesEndToEnd(t *testing.T) { if getCertResponse.Error != "" { t.Fatalf("expected no error, got %s", getCertResponse.Error) } - if getCertResponse.Result.Certificate == "" { + if getCertResponse.Result.CertificateChain == "" { t.Fatalf("expected certificate, got empty string") } }) @@ -507,8 +508,8 @@ func TestCertificatesEndToEnd(t *testing.T) { if getCertResponse.Error != "" { t.Fatalf("expected no error, got %s", getCertResponse.Error) } - if getCertResponse.Result.Certificate != "rejected" { - t.Fatalf("expected `rejected` certificate, got %s", getCertResponse.Result.Certificate) + if getCertResponse.Result.Status != "Rejected" { + t.Fatalf("expected `Rejected` status, got %s", getCertResponse.Result.CertificateChain) } }) diff --git a/internal/server/handlers_helpers_test.go b/internal/server/handlers_helpers_test.go index cb309c1..1552e1e 100644 --- a/internal/server/handlers_helpers_test.go +++ b/internal/server/handlers_helpers_test.go @@ -10,8 +10,8 @@ import ( "github.com/canonical/notary/internal/server" ) -func setupServer() (*httptest.Server, *server.HandlerConfig, error) { - testdb, err := db.NewDatabase(":memory:") +func setupServer(filepath string) (*httptest.Server, *server.HandlerConfig, error) { + testdb, err := db.NewDatabase(filepath) if err != nil { return nil, nil, err } diff --git a/internal/server/handlers_login.go b/internal/server/handlers_login.go index 21b31f8..469f022 100644 --- a/internal/server/handlers_login.go +++ b/internal/server/handlers_login.go @@ -7,7 +7,7 @@ import ( "net/http" "time" - "github.com/canonical/notary/internal/db" + "github.com/canonical/sqlair" "github.com/golang-jwt/jwt" "golang.org/x/crypto/bcrypt" ) @@ -65,17 +65,17 @@ func Login(env *HandlerConfig) http.HandlerFunc { writeError(w, http.StatusBadRequest, "Password is required") return } - userAccount, err := env.DB.RetrieveUserByUsername(loginParams.Username) + userAccount, err := env.DB.GetUserByUsername(loginParams.Username) if err != nil { log.Println(err) - if errors.Is(err, db.ErrIdNotFound) { + if errors.Is(err, sqlair.ErrNoRows) { writeError(w, http.StatusUnauthorized, "The username or password is incorrect. Try again.") return } writeError(w, http.StatusInternalServerError, "Internal Error") return } - if err := bcrypt.CompareHashAndPassword([]byte(userAccount.Password), []byte(loginParams.Password)); err != nil { + if err := bcrypt.CompareHashAndPassword([]byte(userAccount.HashedPassword), []byte(loginParams.Password)); err != nil { writeError(w, http.StatusUnauthorized, "The username or password is incorrect. Try again.") return } diff --git a/internal/server/handlers_login_test.go b/internal/server/handlers_login_test.go index 408cb2c..ff3c0db 100644 --- a/internal/server/handlers_login_test.go +++ b/internal/server/handlers_login_test.go @@ -4,6 +4,7 @@ import ( "encoding/json" "fmt" "net/http" + "path/filepath" "strings" "testing" @@ -46,7 +47,9 @@ func login(url string, client *http.Client, data *LoginParams) (int, *LoginRespo } func TestLoginEndToEnd(t *testing.T) { - ts, config, err := setupServer() + tempDir := t.TempDir() + db_path := filepath.Join(tempDir, "db.sqlite3") + ts, config, err := setupServer(db_path) if err != nil { t.Fatalf("couldn't create test server: %s", err) } diff --git a/internal/server/handlers_status_test.go b/internal/server/handlers_status_test.go index c725574..d819fd5 100644 --- a/internal/server/handlers_status_test.go +++ b/internal/server/handlers_status_test.go @@ -3,6 +3,7 @@ package server_test import ( "encoding/json" "net/http" + "path/filepath" "testing" ) @@ -35,7 +36,9 @@ func getStatus(url string, client *http.Client, adminToken string) (int, *GetSta } func TestStatus(t *testing.T) { - ts, _, err := setupServer() + tempDir := t.TempDir() + db_path := filepath.Join(tempDir, "db.sqlite3") + ts, _, err := setupServer(db_path) if err != nil { t.Fatalf("couldn't create test server: %s", err) } diff --git a/internal/server/response.go b/internal/server/response.go index 0b9678c..7e7daf6 100644 --- a/internal/server/response.go +++ b/internal/server/response.go @@ -6,6 +6,10 @@ import ( "net/http" ) +type SuccessResponse struct { + Message string `json:"message"` +} + // writeJSON is a helper function that writes a JSON response to the http.ResponseWriter func writeJSON(w http.ResponseWriter, v any) error { type response struct { diff --git a/ui/package-lock.json b/ui/package-lock.json index 0e09b13..c41ddf4 100644 --- a/ui/package-lock.json +++ b/ui/package-lock.json @@ -7962,6 +7962,21 @@ "funding": { "url": "https://github.com/sponsors/sindresorhus" } + }, + "node_modules/@next/swc-win32-ia32-msvc": { + "version": "14.2.15", + "resolved": "https://registry.npmjs.org/@next/swc-win32-ia32-msvc/-/swc-win32-ia32-msvc-14.2.15.tgz", + "integrity": "sha512-fyTE8cklgkyR1p03kJa5zXEaZ9El+kDNM5A+66+8evQS5e/6v0Gk28LqA0Jet8gKSOyP+OTm/tJHzMlGdQerdQ==", + "cpu": [ + "ia32" + ], + "optional": true, + "os": [ + "win32" + ], + "engines": { + "node": ">= 10" + } } } } diff --git a/ui/src/app/(notary)/certificate_requests/table.test.tsx b/ui/src/app/(notary)/certificate_requests/table.test.tsx index 5b3c076..ee5a4c2 100644 --- a/ui/src/app/(notary)/certificate_requests/table.test.tsx +++ b/ui/src/app/(notary)/certificate_requests/table.test.tsx @@ -2,8 +2,9 @@ import { expect, test } from 'vitest' import { render, screen } from '@testing-library/react' import { CertificateRequestsTable } from './table' import { QueryClient, QueryClientProvider } from '@tanstack/react-query' +import { CSREntry } from '@/types' -const rows = [ +const rows: CSREntry[] = [ { 'id': 1, 'csr': `-----BEGIN CERTIFICATE REQUEST----- @@ -23,7 +24,8 @@ Y/uPl4g3jpGqLCKTASWJDGnZLroLICOzYTVs5P3oj+VueSUwYhGK5tBnS2x5FHID uMNMgwl0fxGMQZjrlXyCBhXBm1k6PmwcJGJF5LQ31c+5aTTMFU7SyZhlymctB8mS y+ErBQsRpcQho6Ok+HTXQQUcx7WNcwI= -----END CERTIFICATE REQUEST-----`, - 'certificate': "" + 'certificate_chain': "", + 'status': "Outstanding", }, { 'id': 2, @@ -46,7 +48,8 @@ tK9qb8EE92MoWboo4m4bcX74y+eUo3xBev6ZZwdScy8OHLhA/MMI8EElpeYt+Hc2 WsDOAOH6qKQKQg3BO/xmRoohC6GL4CuhP7HYGi7+wziNhNZQa4GtE/k9DyIXVtJy yuf2PnfXCKnaIWRJNoEqDCZRVMfA5BFSwTPITqyo -----END CERTIFICATE REQUEST-----`, - 'certificate': "rejected" + 'certificate_chain': "", + 'status': "Rejected" }, { 'id': 3, @@ -68,7 +71,7 @@ cAQXk3fvTWuikHiCHqqdSdjDYj/8cyiwCrQWpV245VSbOE0WesWoEnSdFXVUfE1+ RSKeTRuuJMcdGqBkDnDI22myj0bjt7q8eqBIjTiLQLnAFnQYpcCrhc8dKU9IJlv1 H9Hay4ZO9LRew3pEtlx2WrExw/gpUcWM8rTI -----END CERTIFICATE REQUEST-----`, - 'certificate': `-----BEGIN CERTIFICATE----- + 'certificate_chain': `-----BEGIN CERTIFICATE----- MIIDrDCCApSgAwIBAgIURKr+jf7hj60SyAryIeN++9wDdtkwDQYJKoZIhvcNAQEL BQAwOTELMAkGA1UEBhMCVVMxKjAoBgNVBAMMIXNlbGYtc2lnbmVkLWNlcnRpZmlj YXRlcy1vcGVyYXRvcjAeFw0yNDAzMjcxMjQ4MDRaFw0yNTAzMjcxMjQ4MDRaMEcx @@ -89,7 +92,8 @@ WyhXkzguv3dwH+n43GJFP6MQ+n9W/nPZCUQ0Iy7ueAvj0HFhGyZzAE2wxNFZdvCs gCX3nqYpp70oZIFDrhmYwE5ij5KXlHD4/1IOfNUKCDmQDgGPLI1tVtwQLjeRq7Hg XVelpl/LXTQawmJyvDaVT/Q9P+WqoDiMjrqF6Sy7DzNeeccWVqvqX5TVS6Ky56iS Mvo/+PAJHkBciR5Xn+Wg2a+7vrZvT6CBoRSOTozlLSM= ------END CERTIFICATE-----` +-----END CERTIFICATE-----`, + 'status': 'Active' }, ] diff --git a/ui/src/app/(notary)/certificate_requests/table.tsx b/ui/src/app/(notary)/certificate_requests/table.tsx index c481737..4b7ecbf 100644 --- a/ui/src/app/(notary)/certificate_requests/table.tsx +++ b/ui/src/app/(notary)/certificate_requests/table.tsx @@ -139,9 +139,9 @@ export function CertificateRequestsTable({ csrs: rows }: TableProps) { }; const csrrows = rows.map((csrEntry) => { - const { id, csr, certificate } = csrEntry; + const { id, csr, certificate_chain, status: csr_status } = csrEntry; const csrObj = extractCSR(csr); - const certs = splitBundle(certificate); + const certs = splitBundle(certificate_chain); const clientCertificate = certs?.at(0); const certObj = clientCertificate ? extractCert(clientCertificate) : null; @@ -152,13 +152,13 @@ export function CertificateRequestsTable({ csrs: rows }: TableProps) { sortData: { id, common_name: csrObj.commonName, - csr_status: certificate === "" ? "outstanding" : (certificate === "rejected" ? "rejected" : "fulfilled"), + csr_status: csr_status, cert_expiry_date: certObj?.notAfter || "", }, columns: [ { content: id.toString() }, { content: csrObj.commonName || "N/A" }, - { content: certificate === "" ? "outstanding" : (certificate === "rejected" ? "rejected" : "fulfilled") }, + { content: csr_status }, { content: certObj?.notAfter || "", style: { backgroundColor: getExpiryColor(certObj?.notAfter) }, @@ -185,13 +185,13 @@ export function CertificateRequestsTable({ csrs: rows }: TableProps) { @@ -321,7 +321,7 @@ export function CertificateRequestsTable({ csrs: rows }: TableProps) { )} diff --git a/ui/src/types.ts b/ui/src/types.ts index 7e208e9..3cce885 100644 --- a/ui/src/types.ts +++ b/ui/src/types.ts @@ -1,7 +1,8 @@ export type CSREntry = { id: number, csr: string, - certificate: string + certificate_chain: string + status: "Outstanding" | "Active" | "Rejected" | "Revoked" } export type User = {