diff --git a/go.mod b/go.mod
index 38fd724..a234d34 100644
--- a/go.mod
+++ b/go.mod
@@ -12,6 +12,7 @@ require (
require (
github.com/beorn7/perks v1.0.1 // indirect
+ github.com/canonical/sqlair v0.0.0-20241004123011-77313b5382fd
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/klauspost/compress v1.17.9 // indirect
github.com/kr/text v0.2.0 // indirect
diff --git a/go.sum b/go.sum
index 56c00d9..4b3de95 100644
--- a/go.sum
+++ b/go.sum
@@ -1,5 +1,7 @@
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
+github.com/canonical/sqlair v0.0.0-20241004123011-77313b5382fd h1:lVn7391CX7QQ5WBQriNUtCB4fvfurZg6XJwH9aVsRII=
+github.com/canonical/sqlair v0.0.0-20241004123011-77313b5382fd/go.mod h1:T+40I2sXshY3KRxx0QQpqqn6hCibSKJ2KHzjBvJj8T4=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
diff --git a/internal/db/db.go b/internal/db/db.go
index 868a6b2..67ba0db 100644
--- a/internal/db/db.go
+++ b/internal/db/db.go
@@ -2,273 +2,411 @@
package db
import (
+ "context"
"database/sql"
"errors"
"fmt"
+ "github.com/canonical/sqlair"
_ "github.com/mattn/go-sqlite3"
"golang.org/x/crypto/bcrypt"
)
+// Database is the object used to communicate with the established repository.
+type Database struct {
+ certificateTable string
+ usersTable string
+ conn *sqlair.DB
+}
+
+type CertificateRequest struct {
+ ID int `db:"id"`
+
+ CSR string `db:"csr"`
+ CertificateChain string `db:"certificate_chain"`
+ Status string `db:"status"`
+}
+
+type User struct {
+ ID int `db:"id"`
+
+ Username string `db:"username"`
+ HashedPassword string `db:"hashed_password"`
+ Permissions int `db:"permissions"`
+}
+
const (
- certificateRequestsTableName = "CertificateRequests"
+ certificateRequestsTableName = "certificate_requests"
usersTableName = "users"
)
-const queryCreateCSRsTable = `CREATE TABLE IF NOT EXISTS %s (
- csr TEXT PRIMARY KEY UNIQUE NOT NULL,
- certificate TEXT DEFAULT ''
+const queryCreateCertificateRequestsTable = `
+ CREATE TABLE IF NOT EXISTS %s (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+
+ csr TEXT NOT NULL UNIQUE,
+ certificate_chain TEXT DEFAULT '',
+ status TEXT DEFAULT 'Outstanding',
+
+ CHECK (status IN ('Outstanding', 'Rejected', 'Revoked', 'Active')),
+ CHECK (NOT (certificate_chain == '' AND status == 'Active' )),
+ CHECK (NOT (certificate_chain != '' AND status == 'Outstanding'))
+ CHECK (NOT (certificate_chain != '' AND status == 'Rejected'))
+ CHECK (NOT (certificate_chain != '' AND status == 'Revoked'))
)`
-const (
- queryGetAllCSRs = "SELECT rowid, * FROM %s"
- queryGetCSR = "SELECT rowid, * FROM %s WHERE rowid=?"
- queryCreateCSR = "INSERT INTO %s (csr) VALUES (?)"
- queryUpdateCSR = "UPDATE %s SET certificate=? WHERE rowid=?"
- queryDeleteCSR = "DELETE FROM %s WHERE rowid=?"
-)
+const queryCreateUsersTable = `
+ CREATE TABLE IF NOT EXISTS %s (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
-const queryCreateUsersTable = `CREATE TABLE IF NOT EXISTS %s (
- user_id INTEGER PRIMARY KEY AUTOINCREMENT,
- username TEXT NOT NULL UNIQUE,
- hashed_password TEXT NOT NULL,
- permissions INTEGER
+ username TEXT NOT NULL UNIQUE,
+ hashed_password TEXT NOT NULL,
+ permissions INTEGER
)`
const (
- queryGetAllUsers = "SELECT * FROM %s"
- queryGetUser = "SELECT * FROM %s WHERE user_id=?"
- queryGetUserByUsername = "SELECT * FROM %s WHERE username=?"
- queryCreateUser = "INSERT INTO %s (username, hashed_password, permissions) VALUES (?, ?, ?)"
- queryUpdateUser = "UPDATE %s SET hashed_password=? WHERE user_id=?"
- queryDeleteUser = "DELETE FROM %s WHERE user_id=?"
- queryGetNumUsers = "SELECT COUNT(*) FROM %s"
+ listCertificateRequestsStmt = "SELECT &CertificateRequest.* FROM %s"
+ getCertificateRequestStmt = "SELECT &CertificateRequest.* FROM %s WHERE id==$CertificateRequest.id or csr==$CertificateRequest.csr"
+ createCertificateRequestStmt = "INSERT INTO %s (csr) VALUES ($CertificateRequest.csr)"
+ updateCertificateRequestStmt = "UPDATE %s SET certificate_chain=$CertificateRequest.certificate_chain, status=$CertificateRequest.status WHERE id==$CertificateRequest.id or csr==$CertificateRequest.csr"
+ deleteCertificateRequestStmt = "DELETE FROM %s WHERE id=$CertificateRequest.id or csr=$CertificateRequest.csr"
)
-// CertificateRequestRepository is the object used to communicate with the established repository.
-type Database struct {
- certificateTable string
- usersTable string
- conn *sql.DB
-}
+const (
+ listUsersStmt = "SELECT &User.* from %s"
+ getUserStmt = "SELECT &User.* from %s WHERE id==$User.id or username==$User.username"
+ createUserStmt = "INSERT INTO %s (username, hashed_password, permissions) VALUES ($User.username, $User.hashed_password, $User.permissions)"
+ updateUserStmt = "UPDATE %s SET hashed_password=$User.hashed_password WHERE id==$User.id or username==$User.username"
+ deleteUserStmt = "DELETE FROM %s WHERE id==$User.id"
+ getNumUsersStmt = "SELECT COUNT(*) AS &NumUsers.count FROM %s"
+)
-// A CertificateRequest struct represents an entry in the database.
-// The object contains a Certificate Request, its matching Certificate if any, and the row ID.
-type CertificateRequest struct {
- ID int
- CSR string
- Certificate string
-}
-type User struct {
- ID int
- Username string
- Password string
- Permissions int
+// ListCertificateRequests gets every CertificateRequest entry in the table.
+func (db *Database) ListCertificateRequests() ([]CertificateRequest, error) {
+ stmt, err := sqlair.Prepare(fmt.Sprintf(listCertificateRequestsStmt, db.certificateTable), CertificateRequest{})
+ if err != nil {
+ return nil, err
+ }
+ var csrs []CertificateRequest
+ err = db.conn.Query(context.Background(), stmt).GetAll(&csrs)
+ if err != nil {
+ if errors.Is(err, sqlair.ErrNoRows) {
+ return csrs, nil
+ }
+ return nil, err
+ }
+ return csrs, nil
}
-var ErrIdNotFound = errors.New("id not found")
+// GetCertificateRequestByID gets a CSR row from the repository from a given ID.
+func (db *Database) GetCertificateRequestByID(id int) (*CertificateRequest, error) {
+ csr := CertificateRequest{
+ ID: id,
+ }
+ stmt, err := sqlair.Prepare(fmt.Sprintf(getCertificateRequestStmt, db.certificateTable), CertificateRequest{})
+ if err != nil {
+ return nil, err
+ }
+ err = db.conn.Query(context.Background(), stmt, csr).Get(&csr)
+ if err != nil {
+ return nil, err
+ }
+ return &csr, nil
+}
-// RetrieveAllCSRs gets every CertificateRequest entry in the table.
-func (db *Database) RetrieveAllCSRs() ([]CertificateRequest, error) {
- rows, err := db.conn.Query(fmt.Sprintf(queryGetAllCSRs, db.certificateTable))
+// GetCertificateRequestByCSR gets a given CSR row from the repository using the CSR text.
+func (db *Database) GetCertificateRequestByCSR(csr string) (*CertificateRequest, error) {
+ row := CertificateRequest{
+ CSR: csr,
+ }
+ stmt, err := sqlair.Prepare(fmt.Sprintf(getCertificateRequestStmt, db.certificateTable), CertificateRequest{})
if err != nil {
return nil, err
}
+ err = db.conn.Query(context.Background(), stmt, row).Get(&row)
+ if err != nil {
+ return nil, err
+ }
+ return &row, nil
+}
- var allCsrs []CertificateRequest
- defer rows.Close()
- for rows.Next() {
- var csr CertificateRequest
- if err := rows.Scan(&csr.ID, &csr.CSR, &csr.Certificate); err != nil {
- return nil, err
- }
- allCsrs = append(allCsrs, csr)
+// CreateCertificateRequest creates a new CSR entry in the repository. The string must be a valid CSR and unique.
+func (db *Database) CreateCertificateRequest(csr string) error {
+ if err := ValidateCertificateRequest(csr); err != nil {
+ return errors.New("csr validation failed: " + err.Error())
+ }
+ stmt, err := sqlair.Prepare(fmt.Sprintf(createCertificateRequestStmt, db.certificateTable), CertificateRequest{})
+ if err != nil {
+ return err
}
- return allCsrs, nil
+ row := CertificateRequest{
+ CSR: csr,
+ }
+ err = db.conn.Query(context.Background(), stmt, row).Run()
+ return err
}
-// RetrieveCSR gets a given CSR from the repository.
-// It returns the row id and matching certificate alongside the CSR in a CertificateRequest object.
-func (db *Database) RetrieveCSR(id string) (CertificateRequest, error) {
- var newCSR CertificateRequest
- row := db.conn.QueryRow(fmt.Sprintf(queryGetCSR, db.certificateTable), id)
- if err := row.Scan(&newCSR.ID, &newCSR.CSR, &newCSR.Certificate); err != nil {
- if err.Error() == "sql: no rows in result set" {
- return newCSR, ErrIdNotFound
- }
- return newCSR, err
+// AddCertificateChainToCertificateRequestByCSR adds a new certificate chain to a row for a given CSR string.
+func (db *Database) AddCertificateChainToCertificateRequestByCSR(csr string, cert string) error {
+ err := ValidateCertificate(cert)
+ if err != nil {
+ return errors.New("cert validation failed: " + err.Error())
}
- return newCSR, nil
+ err = CertificateMatchesCSR(cert, csr)
+ if err != nil {
+ return errors.New("cert validation failed: " + err.Error())
+ }
+ certBundle := sanitizeCertificateBundle(cert)
+ stmt, err := sqlair.Prepare(fmt.Sprintf(updateCertificateRequestStmt, db.certificateTable), CertificateRequest{})
+ if err != nil {
+ return err
+ }
+ newRow := CertificateRequest{
+ CSR: csr,
+ CertificateChain: certBundle,
+ Status: "Active",
+ }
+ err = db.conn.Query(context.Background(), stmt, newRow).Run()
+ return err
}
-// CreateCSR creates a new entry in the repository.
-// The given CSR must be valid and unique
-func (db *Database) CreateCSR(csr string) (int64, error) {
- if err := ValidateCertificateRequest(csr); err != nil {
- return 0, errors.New("csr validation failed: " + err.Error())
+// AddCertificateChainToCSRbyID adds a new certificate chain to a row for a given row ID.
+func (db *Database) AddCertificateChainToCertificateRequestByID(id int, cert string) error {
+ csr, err := db.GetCertificateRequestByID(id)
+ if err != nil {
+ return err
}
- result, err := db.conn.Exec(fmt.Sprintf(queryCreateCSR, db.certificateTable), csr)
+ err = ValidateCertificate(cert)
if err != nil {
- return 0, err
+ return errors.New("cert validation failed: " + err.Error())
}
- id, err := result.LastInsertId()
+ err = CertificateMatchesCSR(cert, csr.CSR)
if err != nil {
- return 0, err
+ return errors.New("cert validation failed: " + err.Error())
+ }
+ certBundle := sanitizeCertificateBundle(cert)
+ stmt, err := sqlair.Prepare(fmt.Sprintf(updateCertificateRequestStmt, db.certificateTable), CertificateRequest{})
+ if err != nil {
+ return err
}
- return id, nil
+ newRow := CertificateRequest{
+ ID: id,
+ CertificateChain: certBundle,
+ Status: "Active",
+ }
+ err = db.conn.Query(context.Background(), stmt, newRow).Run()
+ return err
}
-// UpdateCSR adds a new cert to the given CSR in the repository.
-// The given certificate must share the public key of the CSR and must be valid.
-func (db *Database) UpdateCSR(id string, cert string) (int64, error) {
- csr, err := db.RetrieveCSR(id)
+// RejectCertificateRequestByCSR updates input CSR's row by setting the certificate bundle to "" and moving the row status to "Rejected".
+func (db *Database) RejectCertificateRequestByCSR(csr string) error {
+ oldRow, err := db.GetCertificateRequestByCSR(csr)
if err != nil {
- return 0, err
+ return err
}
- if cert != "rejected" && cert != "" {
- err = ValidateCertificate(cert)
- if err != nil {
- return 0, errors.New("cert validation failed: " + err.Error())
- }
- err = CertificateMatchesCSR(cert, csr.CSR)
- if err != nil {
- return 0, errors.New("cert validation failed: " + err.Error())
- }
- cert = sanitizeCertificateBundle(cert)
+ stmt, err := sqlair.Prepare(fmt.Sprintf(updateCertificateRequestStmt, db.certificateTable), CertificateRequest{})
+ if err != nil {
+ return err
}
- _, err = db.conn.Exec(fmt.Sprintf(queryUpdateCSR, db.certificateTable), cert, csr.ID)
+ newRow := CertificateRequest{
+ ID: oldRow.ID,
+ CSR: oldRow.CSR,
+ CertificateChain: "",
+ Status: "Rejected",
+ }
+ err = db.conn.Query(context.Background(), stmt, newRow).Run()
+ return err
+}
+
+// RejectCSRbyCSR updates input ID's row by setting the certificate bundle to "" and sets the row status to "Rejected".
+func (db *Database) RejectCertificateRequestByID(id int) error {
+ oldRow, err := db.GetCertificateRequestByID(id)
if err != nil {
- return 0, err
+ return err
+ }
+ stmt, err := sqlair.Prepare(fmt.Sprintf(updateCertificateRequestStmt, db.certificateTable), CertificateRequest{})
+ if err != nil {
+ return err
+ }
+ newRow := CertificateRequest{
+ ID: oldRow.ID,
+ CSR: oldRow.CSR,
+ CertificateChain: "",
+ Status: "Rejected",
}
- return int64(csr.ID), nil
+ err = db.conn.Query(context.Background(), stmt, newRow).Run()
+ return err
}
-// DeleteCSR removes a CSR from the database alongside the certificate that may have been generated for it.
-func (db *Database) DeleteCSR(id string) (int64, error) {
- result, err := db.conn.Exec(fmt.Sprintf(queryDeleteCSR, db.certificateTable), id)
+// RevokeCertificateByCSR updates the input CSR's row by setting the certificate bundle to "" and sets the row status to "Revoked".
+func (db *Database) RevokeCertificateByCSR(csr string) error {
+ oldRow, err := db.GetCertificateRequestByCSR(csr)
if err != nil {
- return 0, err
+ return err
}
- deleteId, err := result.RowsAffected()
+ stmt, err := sqlair.Prepare(fmt.Sprintf(updateCertificateRequestStmt, db.certificateTable), CertificateRequest{})
if err != nil {
- return 0, err
+ return err
}
- if deleteId == 0 {
- return 0, ErrIdNotFound
+ newRow := CertificateRequest{
+ ID: oldRow.ID,
+ CSR: oldRow.CSR,
+ CertificateChain: "",
+ Status: "Revoked",
}
- return deleteId, nil
+ err = db.conn.Query(context.Background(), stmt, newRow).Run()
+ return err
}
-// RetrieveAllUsers returns all of the users and their fields available in the database.
-func (db *Database) RetrieveAllUsers() ([]User, error) {
- rows, err := db.conn.Query(fmt.Sprintf(queryGetAllUsers, db.usersTable))
+// DeleteCertificateRequestByCSR removes a CSR from the database alongside the certificate that may have been generated for it.
+func (db *Database) DeleteCertificateRequestByCSR(csr string) error {
+ stmt, err := sqlair.Prepare(fmt.Sprintf(deleteCertificateRequestStmt, db.certificateTable), CertificateRequest{})
if err != nil {
- return nil, err
+ return err
}
+ row := CertificateRequest{
+ CSR: csr,
+ }
+ err = db.conn.Query(context.Background(), stmt, row).Run()
+ return err
+}
- var allUsers []User
- defer rows.Close()
- for rows.Next() {
- var user User
- if err := rows.Scan(&user.ID, &user.Username, &user.Password, &user.Permissions); err != nil {
- return nil, err
- }
- allUsers = append(allUsers, user)
+// DeleteCSRByID removes a CSR from the database alongside the certificate that may have been generated for it.
+func (db *Database) DeleteCertificateRequestByID(id int) error {
+ stmt, err := sqlair.Prepare(fmt.Sprintf(deleteCertificateRequestStmt, db.certificateTable), CertificateRequest{})
+ if err != nil {
+ return err
}
- return allUsers, nil
+ row := CertificateRequest{
+ ID: id,
+ }
+ err = db.conn.Query(context.Background(), stmt, row).Run()
+ return err
}
-// RetrieveUser retrieves the name, password and the permission level of a user.
-func (db *Database) RetrieveUser(id string) (User, error) {
- var newUser User
- row := db.conn.QueryRow(fmt.Sprintf(queryGetUser, db.usersTable), id)
- if err := row.Scan(&newUser.ID, &newUser.Username, &newUser.Password, &newUser.Permissions); err != nil {
- if err.Error() == "sql: no rows in result set" {
- return newUser, ErrIdNotFound
- }
- return newUser, err
+// ListUsers returns all of the users and their fields available in the database.
+func (db *Database) ListUsers() ([]User, error) {
+ stmt, err := sqlair.Prepare(fmt.Sprintf(listUsersStmt, db.usersTable), User{})
+ if err != nil {
+ return nil, err
}
- return newUser, nil
+ var users []User
+ err = db.conn.Query(context.Background(), stmt).GetAll(&users)
+ if err != nil {
+ return nil, err
+ }
+ return users, nil
}
-// RetrieveUser retrieves the id, password and the permission level of a user.
-func (db *Database) RetrieveUserByUsername(name string) (User, error) {
- var newUser User
- row := db.conn.QueryRow(fmt.Sprintf(queryGetUserByUsername, db.usersTable), name)
- if err := row.Scan(&newUser.ID, &newUser.Username, &newUser.Password, &newUser.Permissions); err != nil {
- if err.Error() == "sql: no rows in result set" {
- return newUser, ErrIdNotFound
- }
- return newUser, err
+// GetUserByID retrieves the name, password and the permission level of a user.
+func (db *Database) GetUserByID(id int) (*User, error) {
+ row := User{
+ ID: id,
+ }
+ stmt, err := sqlair.Prepare(fmt.Sprintf(getUserStmt, db.usersTable), User{})
+ if err != nil {
+ return nil, err
+ }
+ err = db.conn.Query(context.Background(), stmt, row).Get(&row)
+ if err != nil {
+ return nil, err
+ }
+ return &row, nil
+}
+
+// GetUserByUsername retrieves the id, password and the permission level of a user.
+func (db *Database) GetUserByUsername(name string) (*User, error) {
+ row := User{
+ Username: name,
+ }
+ stmt, err := sqlair.Prepare(fmt.Sprintf(getUserStmt, db.usersTable), User{})
+ if err != nil {
+ return nil, err
+ }
+ err = db.conn.Query(context.Background(), stmt, row).Get(&row)
+ if err != nil {
+ return nil, err
}
- return newUser, nil
+ return &row, nil
}
// CreateUser creates a new user from a given username, password and permission level.
// The permission level 1 represents an admin, and a 0 represents a regular user.
// The password passed in should be in plaintext. This function handles hashing and salting the password before storing it in the database.
-func (db *Database) CreateUser(username string, password string, permission int) (int64, error) {
+func (db *Database) CreateUser(username string, password string, permission int) error {
pw, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
- return 0, err
+ return err
}
- result, err := db.conn.Exec(fmt.Sprintf(queryCreateUser, db.usersTable), username, pw, permission)
+ stmt, err := sqlair.Prepare(fmt.Sprintf(createUserStmt, db.usersTable), User{})
if err != nil {
- return 0, err
+ return err
}
- id, err := result.LastInsertId()
- if err != nil {
- return 0, err
+ row := User{
+ Username: username,
+ HashedPassword: string(pw),
+ Permissions: permission,
}
- return id, nil
+ err = db.conn.Query(context.Background(), stmt, row).Run()
+ return err
}
// UpdateUser updates the password of the given user.
// Just like with CreateUser, this function handles hashing and salting the password before storage.
-func (db *Database) UpdateUser(id, password string) (int64, error) {
- user, err := db.RetrieveUser(id)
+func (db *Database) UpdateUserPassword(id int, password string) error {
+ _, err := db.GetUserByID(id)
if err != nil {
- return 0, err
+ return err
}
pw, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
- return 0, err
+ return err
}
- result, err := db.conn.Exec(fmt.Sprintf(queryUpdateUser, db.usersTable), pw, user.ID)
+ stmt, err := sqlair.Prepare(fmt.Sprintf(updateUserStmt, db.usersTable), User{})
if err != nil {
- return 0, err
+ return err
}
- affectedRows, err := result.RowsAffected()
- if err != nil {
- return 0, err
+ row := User{
+ ID: id,
+ HashedPassword: string(pw),
}
- return affectedRows, nil
+ err = db.conn.Query(context.Background(), stmt, row).Run()
+ return err
}
-// DeleteUser removes a user from the table.
-func (db *Database) DeleteUser(id string) (int64, error) {
- result, err := db.conn.Exec(fmt.Sprintf(queryDeleteUser, db.usersTable), id)
+// DeleteUserByID removes a user from the table.
+func (db *Database) DeleteUserByID(id int) error {
+ _, err := db.GetUserByID(id)
if err != nil {
- return 0, err
+ return err
}
- deleteId, err := result.RowsAffected()
+ stmt, err := sqlair.Prepare(fmt.Sprintf(deleteUserStmt, db.usersTable), User{})
if err != nil {
- return 0, err
+ return err
}
- if deleteId == 0 {
- return 0, ErrIdNotFound
+ row := User{
+ ID: id,
}
- return deleteId, nil
+ err = db.conn.Query(context.Background(), stmt, row).Run()
+ return err
+}
+
+type NumUsers struct {
+ Count int `db:"count"`
}
// NumUsers returns the number of users in the database.
func (db *Database) NumUsers() (int, error) {
- var numUsers int
- row := db.conn.QueryRow(fmt.Sprintf(queryGetNumUsers, db.usersTable))
- if err := row.Scan(&numUsers); err != nil {
+ stmt, err := sqlair.Prepare(fmt.Sprintf(getNumUsersStmt, db.usersTable), NumUsers{})
+ if err != nil {
+ return 0, err
+ }
+ result := NumUsers{}
+ err = db.conn.Query(context.Background(), stmt).Get(&result)
+ if err != nil {
return 0, err
}
- return numUsers, nil
+ return result.Count, nil
}
// Close closes the connection to the repository cleanly.
@@ -276,7 +414,7 @@ func (db *Database) Close() error {
if db.conn == nil {
return nil
}
- if err := db.conn.Close(); err != nil {
+ if err := db.conn.PlainDB().Close(); err != nil {
return err
}
return nil
@@ -287,18 +425,18 @@ func (db *Database) Close() error {
// The database path must be a valid file path or ":memory:".
// The table will be created if it doesn't exist in the format expected by the package.
func NewDatabase(databasePath string) (*Database, error) {
- conn, err := sql.Open("sqlite3", databasePath)
+ sqlConnection, err := sql.Open("sqlite3", databasePath)
if err != nil {
return nil, err
}
- if _, err := conn.Exec(fmt.Sprintf(queryCreateCSRsTable, certificateRequestsTableName)); err != nil {
+ if _, err := sqlConnection.Exec(fmt.Sprintf(queryCreateCertificateRequestsTable, certificateRequestsTableName)); err != nil {
return nil, err
}
- if _, err := conn.Exec(fmt.Sprintf(queryCreateUsersTable, usersTableName)); err != nil {
+ if _, err := sqlConnection.Exec(fmt.Sprintf(queryCreateUsersTable, usersTableName)); err != nil {
return nil, err
}
db := new(Database)
- db.conn = conn
+ db.conn = sqlair.NewDB(sqlConnection)
db.certificateTable = certificateRequestsTableName
db.usersTable = usersTableName
return db, nil
diff --git a/internal/db/db_test.go b/internal/db/db_test.go
index 5d0f768..19cb647 100644
--- a/internal/db/db_test.go
+++ b/internal/db/db_test.go
@@ -3,7 +3,7 @@ package db_test
import (
"fmt"
"log"
- "strconv"
+ "path/filepath"
"strings"
"testing"
@@ -12,7 +12,8 @@ import (
)
func TestConnect(t *testing.T) {
- db, err := db.NewDatabase(":memory:")
+ tempDir := t.TempDir()
+ db, err := db.NewDatabase(filepath.Join(tempDir, "db.sqlite3"))
if err != nil {
t.Fatalf("Can't connect to SQLite: %s", err)
}
@@ -20,33 +21,34 @@ func TestConnect(t *testing.T) {
}
func TestCSRsEndToEnd(t *testing.T) {
- db, err := db.NewDatabase(":memory:")
+ tempDir := t.TempDir()
+ db, err := db.NewDatabase(filepath.Join(tempDir, "db.sqlite3"))
if err != nil {
t.Fatalf("Couldn't complete NewDatabase: %s", err)
}
defer db.Close()
- id1, err := db.CreateCSR(AppleCSR)
+ err = db.CreateCertificateRequest(AppleCSR)
if err != nil {
t.Fatalf("Couldn't complete Create: %s", err)
}
- id2, err := db.CreateCSR(BananaCSR)
+ err = db.CreateCertificateRequest(BananaCSR)
if err != nil {
t.Fatalf("Couldn't complete Create: %s", err)
}
- id3, err := db.CreateCSR(StrawberryCSR)
+ err = db.CreateCertificateRequest(StrawberryCSR)
if err != nil {
t.Fatalf("Couldn't complete Create: %s", err)
}
- res, err := db.RetrieveAllCSRs()
+ res, err := db.ListCertificateRequests()
if err != nil {
t.Fatalf("Couldn't complete RetrieveAll: %s", err)
}
if len(res) != 3 {
t.Fatalf("One or more CSRs weren't found in DB")
}
- retrievedCSR, err := db.RetrieveCSR(strconv.FormatInt(id1, 10))
+ retrievedCSR, err := db.GetCertificateRequestByCSR(AppleCSR)
if err != nil {
t.Fatalf("Couldn't complete Retrieve: %s", err)
}
@@ -54,32 +56,32 @@ func TestCSRsEndToEnd(t *testing.T) {
t.Fatalf("The CSR from the database doesn't match the CSR that was given")
}
- if _, err = db.DeleteCSR(strconv.FormatInt(id1, 10)); err != nil {
+ if err = db.DeleteCertificateRequestByCSR(AppleCSR); err != nil {
t.Fatalf("Couldn't complete Delete: %s", err)
}
- res, _ = db.RetrieveAllCSRs()
+ res, _ = db.ListCertificateRequests()
if len(res) != 2 {
t.Fatalf("CSR's weren't deleted from the DB properly")
}
BananaCertBundle := strings.TrimSpace(fmt.Sprintf("%s%s", BananaCert, IssuerCert))
- _, err = db.UpdateCSR(strconv.FormatInt(id2, 10), BananaCertBundle)
+ err = db.AddCertificateChainToCertificateRequestByCSR(BananaCSR, BananaCertBundle)
if err != nil {
t.Fatalf("Couldn't complete Update: %s", err)
}
- retrievedCSR, _ = db.RetrieveCSR(strconv.FormatInt(id2, 10))
- if retrievedCSR.Certificate != BananaCertBundle {
- t.Fatalf("The certificate that was uploaded does not match the certificate that was given.\n Retrieved: %s\nGiven: %s", retrievedCSR.Certificate, BananaCertBundle)
+ retrievedCSR, _ = db.GetCertificateRequestByCSR(BananaCSR)
+ if retrievedCSR.CertificateChain != BananaCertBundle {
+ t.Fatalf("The certificate that was uploaded does not match the certificate that was given.\n Retrieved: %s\nGiven: %s", retrievedCSR.CertificateChain, BananaCertBundle)
}
- _, err = db.UpdateCSR(strconv.FormatInt(id2, 10), "")
+ err = db.RevokeCertificateByCSR(BananaCSR)
if err != nil {
- t.Fatalf("Couldn't complete Update to delete certificate: %s", err)
+ t.Fatalf("Couldn't complete Update to revoke certificate: %s", err)
}
- _, err = db.UpdateCSR(strconv.FormatInt(id3, 10), "rejected")
+ err = db.RejectCertificateRequestByCSR(StrawberryCSR)
if err != nil {
t.Fatalf("Couldn't complete Update to reject CSR: %s", err)
}
- retrievedCSR, _ = db.RetrieveCSR(strconv.FormatInt(id2, 10))
- if retrievedCSR.Certificate != "" {
+ retrievedCSR, _ = db.GetCertificateRequestByCSR(BananaCSR)
+ if retrievedCSR.Status != "Revoked" {
t.Fatalf("Couldn't delete certificate")
}
}
@@ -89,12 +91,12 @@ func TestCreateFails(t *testing.T) {
defer db.Close()
InvalidCSR := strings.ReplaceAll(AppleCSR, "M", "i")
- if _, err := db.CreateCSR(InvalidCSR); err == nil {
+ if err := db.CreateCertificateRequest(InvalidCSR); err == nil {
t.Fatalf("Expected error due to invalid CSR")
}
- db.CreateCSR(AppleCSR) //nolint:errcheck
- if _, err := db.CreateCSR(AppleCSR); err == nil {
+ db.CreateCertificateRequest(AppleCSR) //nolint:errcheck
+ if err := db.CreateCertificateRequest(AppleCSR); err == nil {
t.Fatalf("Expected error due to duplicate CSR")
}
}
@@ -103,13 +105,13 @@ func TestUpdateFails(t *testing.T) {
db, _ := db.NewDatabase(":memory:")
defer db.Close()
- id1, _ := db.CreateCSR(AppleCSR) //nolint:errcheck
- id2, _ := db.CreateCSR(BananaCSR) //nolint:errcheck
+ db.CreateCertificateRequest(AppleCSR) //nolint:errcheck
+ db.CreateCertificateRequest(BananaCSR) //nolint:errcheck
InvalidCert := strings.ReplaceAll(BananaCert, "/", "+")
- if _, err := db.UpdateCSR(strconv.FormatInt(id2, 10), InvalidCert); err == nil {
+ if err := db.AddCertificateChainToCertificateRequestByCSR(BananaCSR, InvalidCert); err == nil {
t.Fatalf("Expected updating with invalid cert to fail")
}
- if _, err := db.UpdateCSR(strconv.FormatInt(id1, 10), BananaCert); err == nil {
+ if err := db.AddCertificateChainToCertificateRequestByCSR(AppleCSR, BananaCert); err == nil {
t.Fatalf("Expected updating with mismatched cert to fail")
}
}
@@ -118,8 +120,11 @@ func TestRetrieve(t *testing.T) {
db, _ := db.NewDatabase(":memory:") //nolint:errcheck
defer db.Close()
- db.CreateCSR(AppleCSR) //nolint:errcheck
- if _, err := db.RetrieveCSR("this is definitely not an id"); err == nil {
+ db.CreateCertificateRequest(AppleCSR) //nolint:errcheck
+ if _, err := db.GetCertificateRequestByCSR("this is definitely not an id"); err == nil {
+ t.Fatalf("Expected failure looking for nonexistent CSR")
+ }
+ if _, err := db.GetCertificateRequestByID(-1); err == nil {
t.Fatalf("Expected failure looking for nonexistent CSR")
}
}
@@ -131,16 +136,16 @@ func TestUsersEndToEnd(t *testing.T) {
}
defer db.Close()
- id1, err := db.CreateUser("admin", "pw123", 1)
+ err = db.CreateUser("admin", "pw123", 1)
if err != nil {
t.Fatalf("Couldn't complete Create: %s", err)
}
- id2, err := db.CreateUser("norman", "pw456", 0)
+ err = db.CreateUser("norman", "pw456", 0)
if err != nil {
t.Fatalf("Couldn't complete Create: %s", err)
}
- res, err := db.RetrieveAllUsers()
+ res, err := db.ListUsers()
if err != nil {
t.Fatalf("Couldn't complete RetrieveAll: %s", err)
}
@@ -154,31 +159,36 @@ func TestUsersEndToEnd(t *testing.T) {
if num != 2 {
t.Fatalf("NumUsers didn't return the correct number of users")
}
- retrievedUser, err := db.RetrieveUser(strconv.FormatInt(id1, 10))
+ retrievedUser, err := db.GetUserByUsername("admin")
+ if err != nil {
+ t.Fatalf("Couldn't complete Retrieve: %s", err)
+ }
+ if retrievedUser.Username != "admin" {
+ t.Fatalf("The user from the database doesn't match the user that was given")
+ }
+ retrievedUser, err = db.GetUserByID(1)
if err != nil {
t.Fatalf("Couldn't complete Retrieve: %s", err)
}
if retrievedUser.Username != "admin" {
t.Fatalf("The user from the database doesn't match the user that was given")
}
- if err := bcrypt.CompareHashAndPassword([]byte(retrievedUser.Password), []byte("pw123")); err != nil {
+ if err := bcrypt.CompareHashAndPassword([]byte(retrievedUser.HashedPassword), []byte("pw123")); err != nil {
t.Fatalf("The user's password doesn't match the one stored in the database")
}
-
- if _, err = db.DeleteUser(strconv.FormatInt(id1, 10)); err != nil {
+ if err = db.DeleteUserByID(1); err != nil {
t.Fatalf("Couldn't complete Delete: %s", err)
}
- res, _ = db.RetrieveAllUsers()
+ res, _ = db.ListUsers()
if len(res) != 1 {
t.Fatalf("users weren't deleted from the DB properly")
}
-
- _, err = db.UpdateUser(strconv.FormatInt(id2, 10), "thebestpassword")
+ err = db.UpdateUserPassword(2, "thebestpassword")
if err != nil {
t.Fatalf("Couldn't complete Update: %s", err)
}
- retrievedUser, _ = db.RetrieveUser(strconv.FormatInt(id2, 10))
- if err := bcrypt.CompareHashAndPassword([]byte(retrievedUser.Password), []byte("thebestpassword")); err != nil {
+ retrievedUser, _ = db.GetUserByUsername("norman")
+ if err := bcrypt.CompareHashAndPassword([]byte(retrievedUser.HashedPassword), []byte("thebestpassword")); err != nil {
t.Fatalf("The new password that was given does not match the password that was stored.")
}
}
@@ -188,19 +198,19 @@ func Example() {
if err != nil {
log.Fatalln(err)
}
- _, err = db.CreateCSR(BananaCSR)
+ err = db.CreateCertificateRequest(BananaCSR)
if err != nil {
log.Fatalln(err)
}
- _, err = db.UpdateCSR(BananaCSR, BananaCert)
+ err = db.AddCertificateChainToCertificateRequestByCSR(BananaCSR, BananaCert)
if err != nil {
log.Fatalln(err)
}
- entry, err := db.RetrieveCSR(BananaCSR)
+ entry, err := db.GetCertificateRequestByCSR(BananaCSR)
if err != nil {
log.Fatalln(err)
}
- if entry.Certificate != BananaCert {
+ if entry.CertificateChain != BananaCert {
log.Fatalln("Retrieved Certificate doesn't match Stored Certificate")
}
err = db.Close()
diff --git a/internal/metrics/metrics.go b/internal/metrics/metrics.go
index 1076d61..725891f 100644
--- a/internal/metrics/metrics.go
+++ b/internal/metrics/metrics.go
@@ -37,7 +37,7 @@ func NewMetricsSubsystem(db *db.Database) *PrometheusMetrics {
ticker := time.NewTicker(120 * time.Second)
go func() {
for ; ; <-ticker.C {
- csrs, err := db.RetrieveAllCSRs()
+ csrs, err := db.ListCertificateRequests()
if err != nil {
log.Println(errors.Join(errors.New("error generating metrics repository: "), err))
panic(1)
@@ -95,15 +95,15 @@ func (pm *PrometheusMetrics) GenerateMetrics(csrs []db.CertificateRequest) {
var expiringIn30DaysCertCount float64
var expiringIn90DaysCertCount float64
for _, entry := range csrs {
- if entry.Certificate == "" {
+ if entry.CertificateChain == "" {
outstandingCSRCount += 1
continue
}
- if entry.Certificate == "rejected" {
+ if entry.Status == "Rejected" {
continue
}
certCount += 1
- expiryDate := certificateExpiryDate(entry.Certificate)
+ expiryDate := certificateExpiryDate(entry.CertificateChain)
daysRemaining := time.Until(expiryDate).Hours() / 24
if daysRemaining < 0 {
expiredCertCount += 1
diff --git a/internal/metrics/metrics_test.go b/internal/metrics/metrics_test.go
index ecdd854..0775f28 100644
--- a/internal/metrics/metrics_test.go
+++ b/internal/metrics/metrics_test.go
@@ -21,7 +21,8 @@ import (
// TestPrometheusHandler tests that the Prometheus metrics handler responds correctly to an HTTP request.
func TestPrometheusHandler(t *testing.T) {
- db, err := db.NewDatabase(":memory:")
+ tempDir := t.TempDir()
+ db, err := db.NewDatabase(filepath.Join(tempDir, "db.sqlite3"))
if err != nil {
t.Fatal(err)
}
@@ -95,13 +96,13 @@ func generateCertPair(daysRemaining int) (string, string, string) {
}
func initializeTestDB(t *testing.T, db *db.Database) {
- for i, v := range []int{5, 10, 32} {
+ for _, v := range []int{5, 10, 32} {
csr, cert, ca := generateCertPair(v)
- _, err := db.CreateCSR(csr)
+ err := db.CreateCertificateRequest(csr)
if err != nil {
t.Fatalf("couldn't create test csr: %s", err)
}
- _, err = db.UpdateCSR(fmt.Sprint(i+1), fmt.Sprintf("%s%s", cert, ca))
+ err = db.AddCertificateChainToCertificateRequestByCSR(csr, fmt.Sprintf("%s%s", cert, ca))
if err != nil {
t.Fatalf("couldn't create test cert: %s", err)
}
@@ -117,7 +118,7 @@ func TestMetrics(t *testing.T) {
}
initializeTestDB(t, db)
m := metrics.NewMetricsSubsystem(db)
- csrs, _ := db.RetrieveAllCSRs()
+ csrs, _ := db.ListCertificateRequests()
m.GenerateMetrics(csrs)
request, _ := http.NewRequest("GET", "/", nil)
diff --git a/internal/server/authorization_test.go b/internal/server/authorization_test.go
index 164faa7..4aaf8f9 100644
--- a/internal/server/authorization_test.go
+++ b/internal/server/authorization_test.go
@@ -2,12 +2,15 @@ package server_test
import (
"net/http"
+ "path/filepath"
"strings"
"testing"
)
func TestAuthorizationNoAuth(t *testing.T) {
- ts, _, err := setupServer()
+ tempDir := t.TempDir()
+ db_path := filepath.Join(tempDir, "db.sqlite3")
+ ts, _, err := setupServer(db_path)
if err != nil {
t.Fatalf("couldn't create test server: %s", err)
}
@@ -47,7 +50,9 @@ func TestAuthorizationNoAuth(t *testing.T) {
}
func TestAuthorizationNonAdminAuthorized(t *testing.T) {
- ts, _, err := setupServer()
+ tempDir := t.TempDir()
+ db_path := filepath.Join(tempDir, "db.sqlite3")
+ ts, _, err := setupServer(db_path)
if err != nil {
t.Fatalf("couldn't create test server: %s", err)
}
@@ -98,7 +103,9 @@ func TestAuthorizationNonAdminAuthorized(t *testing.T) {
}
func TestAuthorizationNonAdminUnauthorized(t *testing.T) {
- ts, _, err := setupServer()
+ tempDir := t.TempDir()
+ db_path := filepath.Join(tempDir, "db.sqlite3")
+ ts, _, err := setupServer(db_path)
if err != nil {
t.Fatalf("couldn't create test server: %s", err)
}
@@ -156,7 +163,9 @@ func TestAuthorizationNonAdminUnauthorized(t *testing.T) {
}
func TestAuthorizationAdminAuthorized(t *testing.T) {
- ts, _, err := setupServer()
+ tempDir := t.TempDir()
+ db_path := filepath.Join(tempDir, "db.sqlite3")
+ ts, _, err := setupServer(db_path)
if err != nil {
t.Fatalf("couldn't create test server: %s", err)
}
@@ -205,7 +214,9 @@ func TestAuthorizationAdminAuthorized(t *testing.T) {
}
func TestAuthorizationAdminUnAuthorized(t *testing.T) {
- ts, _, err := setupServer()
+ tempDir := t.TempDir()
+ db_path := filepath.Join(tempDir, "db.sqlite3")
+ ts, _, err := setupServer(db_path)
if err != nil {
t.Fatalf("couldn't create test server: %s", err)
}
diff --git a/internal/server/handlers_accounts.go b/internal/server/handlers_accounts.go
index 7b72e82..3e8beb3 100644
--- a/internal/server/handlers_accounts.go
+++ b/internal/server/handlers_accounts.go
@@ -10,6 +10,7 @@ import (
"strings"
"github.com/canonical/notary/internal/db"
+ "github.com/canonical/sqlair"
)
type CreateAccountParams struct {
@@ -27,18 +28,6 @@ type GetAccountResponse struct {
Permissions int `json:"permissions"`
}
-type CreateAccountResponse struct {
- ID int `json:"id"`
-}
-
-type ChangeAccountResponse struct {
- ID int `json:"id"`
-}
-
-type DeleteAccountResponse struct {
- ID int `json:"id"`
-}
-
func validatePassword(password string) bool {
if len(password) < 8 {
return false
@@ -59,7 +48,7 @@ func validatePassword(password string) bool {
// ListAccounts returns all accounts from the database
func ListAccounts(env *HandlerConfig) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
- accounts, err := env.DB.RetrieveAllUsers()
+ accounts, err := env.DB.ListUsers()
if err != nil {
log.Println(err)
writeError(w, http.StatusInternalServerError, "Internal Error")
@@ -87,20 +76,24 @@ func ListAccounts(env *HandlerConfig) http.HandlerFunc {
func GetAccount(env *HandlerConfig) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
id := r.PathValue("id")
- var account db.User
- var err error
+ idNum, err := strconv.Atoi(id)
+ if err != nil {
+ writeError(w, http.StatusInternalServerError, "Internal Error")
+ return
+ }
+ var account *db.User
if id == "me" {
claims, headerErr := getClaimsFromAuthorizationHeader(r.Header.Get("Authorization"), env.JWTSecret)
if headerErr != nil {
writeError(w, http.StatusUnauthorized, "Unauthorized")
}
- account, err = env.DB.RetrieveUserByUsername(claims.Username)
+ account, err = env.DB.GetUserByUsername(claims.Username)
} else {
- account, err = env.DB.RetrieveUser(id)
+ account, err = env.DB.GetUserByID(idNum)
}
if err != nil {
log.Println(err)
- if errors.Is(err, db.ErrIdNotFound) {
+ if errors.Is(err, sqlair.ErrNoRows) {
writeError(w, http.StatusNotFound, "Not Found")
return
}
@@ -150,12 +143,11 @@ func CreateAccount(env *HandlerConfig) http.HandlerFunc {
writeError(w, http.StatusInternalServerError, "Failed to retrieve accounts: "+err.Error())
return
}
-
permission := UserPermission
if numUsers == 0 {
permission = AdminPermission
}
- id, err := env.DB.CreateUser(createAccountParams.Username, createAccountParams.Password, permission)
+ err = env.DB.CreateUser(createAccountParams.Username, createAccountParams.Password, permission)
if err != nil {
if strings.Contains(err.Error(), "UNIQUE constraint failed") {
writeError(w, http.StatusBadRequest, "account with given username already exists")
@@ -165,11 +157,9 @@ func CreateAccount(env *HandlerConfig) http.HandlerFunc {
writeError(w, http.StatusInternalServerError, "Internal Error")
return
}
- accountResponse := CreateAccountResponse{
- ID: int(id),
- }
+ successResponse := SuccessResponse{Message: "success"}
w.WriteHeader(http.StatusCreated)
- err = writeJSON(w, accountResponse)
+ err = writeJSON(w, successResponse)
if err != nil {
writeError(w, http.StatusInternalServerError, "internal error")
return
@@ -182,39 +172,39 @@ func CreateAccount(env *HandlerConfig) http.HandlerFunc {
func DeleteAccount(env *HandlerConfig) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
id := r.PathValue("id")
- idInt, err := strconv.ParseInt(id, 10, 64)
+ idInt, err := strconv.Atoi(id)
if err != nil {
log.Println(err)
writeError(w, http.StatusInternalServerError, "Internal Error")
return
}
- account, err := env.DB.RetrieveUser(id)
+ account, err := env.DB.GetUserByID(idInt)
if err != nil {
- if !errors.Is(err, db.ErrIdNotFound) {
- log.Println(err)
- writeError(w, http.StatusInternalServerError, "Internal Error")
+ log.Println(err)
+ if errors.Is(err, sqlair.ErrNoRows) {
+ writeError(w, http.StatusNotFound, "Not Found")
return
}
+ writeError(w, http.StatusInternalServerError, "Internal Error")
+ return
}
if account.Permissions == 1 {
writeError(w, http.StatusBadRequest, "deleting an Admin account is not allowed.")
return
}
- _, err = env.DB.DeleteUser(id)
+ err = env.DB.DeleteUserByID(idInt)
if err != nil {
log.Println(err)
- if errors.Is(err, db.ErrIdNotFound) {
+ if errors.Is(err, sqlair.ErrNoRows) {
writeError(w, http.StatusNotFound, "Not Found")
return
}
writeError(w, http.StatusInternalServerError, "Internal Error")
return
}
- deleteAccountResponse := DeleteAccountResponse{
- ID: int(idInt),
- }
+ successResponse := SuccessResponse{Message: "success"}
w.WriteHeader(http.StatusAccepted)
- err = writeJSON(w, deleteAccountResponse)
+ err = writeJSON(w, successResponse)
if err != nil {
writeError(w, http.StatusInternalServerError, "internal error")
return
@@ -225,6 +215,7 @@ func DeleteAccount(env *HandlerConfig) http.HandlerFunc {
func ChangeAccountPassword(env *HandlerConfig) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
id := r.PathValue("id")
+ var idNum int
if id == "me" {
claims, err := getClaimsFromAuthorizationHeader(r.Header.Get("Authorization"), env.JWTSecret)
if err != nil {
@@ -232,13 +223,21 @@ func ChangeAccountPassword(env *HandlerConfig) http.HandlerFunc {
writeError(w, http.StatusUnauthorized, "Unauthorized")
return
}
- account, err := env.DB.RetrieveUserByUsername(claims.Username)
+ account, err := env.DB.GetUserByUsername(claims.Username)
if err != nil {
log.Println(err)
writeError(w, http.StatusUnauthorized, "Unauthorized")
return
}
- id = strconv.Itoa(account.ID)
+ idNum = account.ID
+ } else {
+ idInt, err := strconv.Atoi(id)
+ if err != nil {
+ log.Println(err)
+ writeError(w, http.StatusInternalServerError, "Internal Error")
+ return
+ }
+ idNum = idInt
}
var changeAccountParams ChangeAccountParams
if err := json.NewDecoder(r.Body).Decode(&changeAccountParams); err != nil {
@@ -257,21 +256,19 @@ func ChangeAccountPassword(env *HandlerConfig) http.HandlerFunc {
)
return
}
- ret, err := env.DB.UpdateUser(id, changeAccountParams.Password)
+ err := env.DB.UpdateUserPassword(idNum, changeAccountParams.Password)
if err != nil {
log.Println(err)
- if errors.Is(err, db.ErrIdNotFound) {
+ if errors.Is(err, sqlair.ErrNoRows) {
writeError(w, http.StatusNotFound, "Not Found")
return
}
writeError(w, http.StatusInternalServerError, "Internal Error")
return
}
- changeAccountResponse := ChangeAccountResponse{
- ID: int(ret),
- }
+ successResponse := SuccessResponse{Message: "success"}
w.WriteHeader(http.StatusCreated)
- err = writeJSON(w, changeAccountResponse)
+ err = writeJSON(w, successResponse)
if err != nil {
writeError(w, http.StatusInternalServerError, "internal error")
return
diff --git a/internal/server/handlers_accounts_test.go b/internal/server/handlers_accounts_test.go
index 42e2e7a..38d40a1 100644
--- a/internal/server/handlers_accounts_test.go
+++ b/internal/server/handlers_accounts_test.go
@@ -3,11 +3,16 @@ package server_test
import (
"encoding/json"
"net/http"
+ "path/filepath"
"strconv"
"strings"
"testing"
)
+type SuccessResponse struct {
+ Message string `json:"message"`
+}
+
type GetAccountResponseResult struct {
ID int `json:"id"`
Username string `json:"username"`
@@ -29,8 +34,8 @@ type CreateAccountResponseResult struct {
}
type CreateAccountResponse struct {
- Result CreateAccountResponseResult `json:"result"`
- Error string `json:"error,omitempty"`
+ Result SuccessResponse `json:"result"`
+ Error string `json:"error,omitempty"`
}
type ChangeAccountPasswordParams struct {
@@ -139,7 +144,9 @@ func deleteAccount(url string, client *http.Client, adminToken string, id int) (
// The order of the tests is important, as some tests depend on
// the state of the server after previous tests.
func TestAccountsEndToEnd(t *testing.T) {
- ts, _, err := setupServer()
+ tempDir := t.TempDir()
+ db_path := filepath.Join(tempDir, "db.sqlite3")
+ ts, _, err := setupServer(db_path)
if err != nil {
t.Fatalf("couldn't create test server: %s", err)
}
@@ -197,10 +204,7 @@ func TestAccountsEndToEnd(t *testing.T) {
t.Fatalf("expected status %d, got %d", http.StatusCreated, statusCode)
}
if response.Error != "" {
- t.Fatalf("expected error %q, got %q", "", response.Error)
- }
- if response.Result.ID != 3 {
- t.Fatalf("expected ID 3, got %d", response.Result.ID)
+ t.Fatalf("unexpected error :%q", response.Error)
}
})
@@ -285,10 +289,7 @@ func TestAccountsEndToEnd(t *testing.T) {
t.Fatalf("expected status %d, got %d", http.StatusCreated, statusCode)
}
if response.Error != "" {
- t.Fatalf("expected error %q, got %q", "", response.Error)
- }
- if response.Result.ID != 1 {
- t.Fatalf("expected ID 1, got %d", response.Result.ID)
+ t.Fatalf("unexpected error :%q", response.Error)
}
})
@@ -351,9 +352,6 @@ func TestAccountsEndToEnd(t *testing.T) {
if response.Error != "" {
t.Fatalf("expected error %q, got %q", "", response.Error)
}
- if response.Result.ID != 2 {
- t.Fatalf("expected ID 2, got %d", response.Result.ID)
- }
})
t.Run("13. Delete account - no user", func(t *testing.T) {
diff --git a/internal/server/handlers_certificate_requests.go b/internal/server/handlers_certificate_requests.go
index a1a2d0f..a914c90 100644
--- a/internal/server/handlers_certificate_requests.go
+++ b/internal/server/handlers_certificate_requests.go
@@ -8,7 +8,7 @@ import (
"strconv"
"strings"
- "github.com/canonical/notary/internal/db"
+ "github.com/canonical/sqlair"
)
type CreateCertificateRequestParams struct {
@@ -16,54 +16,32 @@ type CreateCertificateRequestParams struct {
}
type CreateCertificateParams struct {
- Certificate string `json:"certificate"`
+ CertificateChain string `json:"certificate"`
}
-type GetCertificateRequestResponse struct {
- ID int `json:"id"`
- CSR string `json:"csr"`
- Certificate string `json:"certificate"`
-}
-
-type CreateCertificateRequestResponse struct {
- ID int `json:"id"`
-}
-
-type DeleteCertificateRequestResponse struct {
- ID int `json:"id"`
-}
-
-type RejectCertificateRequestResponse struct {
- ID int `json:"id"`
-}
-
-type CreateCertificateResponse struct {
- ID int `json:"id"`
-}
-
-type DeleteCertificateResponse struct {
- ID int `json:"id"`
-}
-
-type RejectCertificateResponse struct {
- ID int `json:"id"`
+type CertificateRequest struct {
+ ID int `json:"id"`
+ CSR string `json:"csr"`
+ CertificateChain string `json:"certificate_chain"`
+ Status string `json:"status"`
}
// ListCertificateRequests returns all of the Certificate Requests
func ListCertificateRequests(env *HandlerConfig) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
- certs, err := env.DB.RetrieveAllCSRs()
+ csrs, err := env.DB.ListCertificateRequests()
if err != nil {
log.Println(err)
writeError(w, http.StatusInternalServerError, "Internal Error")
return
}
- certificateRequestsResponse := make([]GetCertificateRequestResponse, len(certs))
- for i, cert := range certs {
- certificateRequestsResponse[i] = GetCertificateRequestResponse{
- ID: cert.ID,
- CSR: cert.CSR,
- Certificate: cert.Certificate,
+ certificateRequestsResponse := make([]CertificateRequest, len(csrs))
+ for i, csr := range csrs {
+ certificateRequestsResponse[i] = CertificateRequest{
+ ID: csr.ID,
+ CSR: csr.CSR,
+ CertificateChain: csr.CertificateChain,
+ Status: csr.Status,
}
}
w.WriteHeader(http.StatusOK)
@@ -87,7 +65,7 @@ func CreateCertificateRequest(env *HandlerConfig) http.HandlerFunc {
writeError(w, http.StatusBadRequest, "csr is missing")
return
}
- id, err := env.DB.CreateCSR(createCertificateRequestParams.CSR)
+ err := env.DB.CreateCertificateRequest(createCertificateRequestParams.CSR)
if err != nil {
if strings.Contains(err.Error(), "UNIQUE constraint failed") {
writeError(w, http.StatusBadRequest, "given csr already recorded")
@@ -102,11 +80,9 @@ func CreateCertificateRequest(env *HandlerConfig) http.HandlerFunc {
writeError(w, http.StatusInternalServerError, "Internal Error")
return
}
- certificateRequestResponse := CreateCertificateRequestResponse{
- ID: int(id),
- }
+ successResponse := SuccessResponse{Message: "success"}
w.WriteHeader(http.StatusCreated)
- err = writeJSON(w, certificateRequestResponse)
+ err = writeJSON(w, successResponse)
if err != nil {
writeError(w, http.StatusInternalServerError, "internal error")
return
@@ -119,20 +95,26 @@ func CreateCertificateRequest(env *HandlerConfig) http.HandlerFunc {
func GetCertificateRequest(env *HandlerConfig) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
id := r.PathValue("id")
- cert, err := env.DB.RetrieveCSR(id)
+ idNum, err := strconv.Atoi(id)
+ if err != nil {
+ writeError(w, http.StatusInternalServerError, "Internal Error")
+ return
+ }
+ csr, err := env.DB.GetCertificateRequestByID(idNum)
if err != nil {
log.Println(err)
- if errors.Is(err, db.ErrIdNotFound) {
+ if errors.Is(err, sqlair.ErrNoRows) {
writeError(w, http.StatusNotFound, "Not Found")
return
}
writeError(w, http.StatusInternalServerError, "Internal Error")
return
}
- certificateRequestResponse := GetCertificateRequestResponse{
- ID: cert.ID,
- CSR: cert.CSR,
- Certificate: cert.Certificate,
+ certificateRequestResponse := CertificateRequest{
+ ID: csr.ID,
+ CSR: csr.CSR,
+ CertificateChain: csr.CertificateChain,
+ Status: csr.Status,
}
w.WriteHeader(http.StatusOK)
err = writeJSON(w, certificateRequestResponse)
@@ -148,21 +130,24 @@ func GetCertificateRequest(env *HandlerConfig) http.HandlerFunc {
func DeleteCertificateRequest(env *HandlerConfig) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
id := r.PathValue("id")
- insertId, err := env.DB.DeleteCSR(id)
+ idNum, err := strconv.Atoi(id)
+ if err != nil {
+ writeError(w, http.StatusInternalServerError, "Internal Error")
+ return
+ }
+ err = env.DB.DeleteCertificateRequestByID(idNum)
if err != nil {
log.Println(err)
- if errors.Is(err, db.ErrIdNotFound) {
+ if errors.Is(err, sqlair.ErrNoRows) {
writeError(w, http.StatusNotFound, "Not Found")
return
}
writeError(w, http.StatusInternalServerError, "Internal Error")
return
}
- certificateRequestResponse := DeleteCertificateRequestResponse{
- ID: int(insertId),
- }
+ successResponse := SuccessResponse{Message: "success"}
w.WriteHeader(http.StatusAccepted)
- err = writeJSON(w, certificateRequestResponse)
+ err = writeJSON(w, successResponse)
if err != nil {
writeError(w, http.StatusInternalServerError, "internal error")
return
@@ -179,15 +164,20 @@ func CreateCertificate(env *HandlerConfig) http.HandlerFunc {
writeError(w, http.StatusBadRequest, "Invalid JSON format")
return
}
- if createCertificateParams.Certificate == "" {
+ if createCertificateParams.CertificateChain == "" {
writeError(w, http.StatusBadRequest, "certificate is missing")
return
}
id := r.PathValue("id")
- insertId, err := env.DB.UpdateCSR(id, createCertificateParams.Certificate)
+ idNum, err := strconv.Atoi(id)
+ if err != nil {
+ writeError(w, http.StatusInternalServerError, "Internal Error")
+ return
+ }
+ err = env.DB.AddCertificateChainToCertificateRequestByID(idNum, createCertificateParams.CertificateChain)
if err != nil {
log.Println(err)
- if errors.Is(err, db.ErrIdNotFound) ||
+ if errors.Is(err, sqlair.ErrNoRows) ||
err.Error() == "certificate does not match CSR" ||
strings.Contains(err.Error(), "cert validation failed") {
writeError(w, http.StatusBadRequest, "Bad Request")
@@ -196,18 +186,15 @@ func CreateCertificate(env *HandlerConfig) http.HandlerFunc {
writeError(w, http.StatusInternalServerError, "Internal Error")
return
}
- insertIdStr := strconv.FormatInt(insertId, 10)
if env.SendPebbleNotifications {
- err := SendPebbleNotification("canonical.com/notary/certificate/update", insertIdStr)
+ err := SendPebbleNotification("canonical.com/notary/certificate/update", id)
if err != nil {
log.Printf("pebble notify failed: %s. continuing silently.", err.Error())
}
}
- certificateResponse := CreateCertificateResponse{
- ID: int(insertId),
- }
+ successResponse := SuccessResponse{Message: "success"}
w.WriteHeader(http.StatusCreated)
- err = writeJSON(w, certificateResponse)
+ err = writeJSON(w, successResponse)
if err != nil {
writeError(w, http.StatusInternalServerError, "internal error")
return
@@ -218,28 +205,30 @@ func CreateCertificate(env *HandlerConfig) http.HandlerFunc {
func RejectCertificate(env *HandlerConfig) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
id := r.PathValue("id")
- insertId, err := env.DB.UpdateCSR(id, "rejected")
+ idNum, err := strconv.Atoi(id)
+ if err != nil {
+ writeError(w, http.StatusInternalServerError, "Internal Error")
+ return
+ }
+ err = env.DB.RejectCertificateRequestByID(idNum)
if err != nil {
log.Println(err)
- if errors.Is(err, db.ErrIdNotFound) {
+ if errors.Is(err, sqlair.ErrNoRows) {
writeError(w, http.StatusNotFound, "Not Found")
return
}
writeError(w, http.StatusInternalServerError, "Internal Error")
return
}
- insertIdStr := strconv.FormatInt(insertId, 10)
if env.SendPebbleNotifications {
- err := SendPebbleNotification("canonical.com/notary/certificate/update", insertIdStr)
+ err := SendPebbleNotification("canonical.com/notary/certificate/update", id)
if err != nil {
log.Printf("pebble notify failed: %s. continuing silently.", err.Error())
}
}
- certificateResponse := RejectCertificateResponse{
- ID: int(insertId),
- }
+ successResponse := SuccessResponse{Message: "success"}
w.WriteHeader(http.StatusAccepted)
- err = writeJSON(w, certificateResponse)
+ err = writeJSON(w, successResponse)
if err != nil {
writeError(w, http.StatusInternalServerError, "internal error")
return
@@ -252,28 +241,30 @@ func RejectCertificate(env *HandlerConfig) http.HandlerFunc {
func DeleteCertificate(env *HandlerConfig) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
id := r.PathValue("id")
- insertId, err := env.DB.UpdateCSR(id, "")
+ idNum, err := strconv.Atoi(id)
+ if err != nil {
+ writeError(w, http.StatusInternalServerError, "Internal Error")
+ return
+ }
+ err = env.DB.DeleteCertificateRequestByID(idNum)
if err != nil {
log.Println(err)
- if errors.Is(err, db.ErrIdNotFound) {
+ if errors.Is(err, sqlair.ErrNoRows) {
writeError(w, http.StatusBadRequest, "Bad Request")
return
}
writeError(w, http.StatusInternalServerError, "Internal Error")
return
}
- insertIdStr := strconv.FormatInt(insertId, 10)
if env.SendPebbleNotifications {
- err := SendPebbleNotification("canonical.com/notary/certificate/update", insertIdStr)
+ err := SendPebbleNotification("canonical.com/notary/certificate/update", id)
if err != nil {
log.Printf("pebble notify failed: %s. continuing silently.", err.Error())
}
}
- certificateResponse := DeleteCertificateResponse{
- ID: int(insertId),
- }
+ successResponse := SuccessResponse{Message: "success"}
w.WriteHeader(http.StatusOK)
- err = writeJSON(w, certificateResponse)
+ err = writeJSON(w, successResponse)
if err != nil {
writeError(w, http.StatusInternalServerError, "internal error")
return
diff --git a/internal/server/handlers_certificate_requests_test.go b/internal/server/handlers_certificate_requests_test.go
index 21b8311..63df6d1 100644
--- a/internal/server/handlers_certificate_requests_test.go
+++ b/internal/server/handlers_certificate_requests_test.go
@@ -9,22 +9,18 @@ import (
"path/filepath"
"strconv"
"testing"
-)
-type CertificateRequest struct {
- ID int `json:"id"`
- CSR string `json:"csr"`
- Certificate string `json:"certificate"`
-}
+ "github.com/canonical/notary/internal/server"
+)
type GetCertificateRequestResponse struct {
- Result CertificateRequest `json:"result"`
- Error string `json:"error,omitempty"`
+ Result server.CertificateRequest `json:"result"`
+ Error string `json:"error,omitempty"`
}
type ListCertificateRequestsResponse struct {
- Error string `json:"error,omitempty"`
- Result []CertificateRequest `json:"result"`
+ Error string `json:"error,omitempty"`
+ Result []server.CertificateRequest `json:"result"`
}
type CreateCertificateRequestResponse struct {
@@ -162,7 +158,10 @@ func rejectCertificate(url string, client *http.Client, adminToken string, id in
// The order of the tests is important, as some tests depend on the
// state of the server after previous tests.
func TestCertificateRequestsEndToEnd(t *testing.T) {
- ts, _, err := setupServer()
+
+ tempDir := t.TempDir()
+ db_path := filepath.Join(tempDir, "db.sqlite3")
+ ts, _, err := setupServer(db_path)
if err != nil {
t.Fatalf("couldn't create test server: %s", err)
}
@@ -257,8 +256,8 @@ func TestCertificateRequestsEndToEnd(t *testing.T) {
if getCertRequestResponse.Result.CSR == "" {
t.Fatalf("expected CSR, got empty string")
}
- if getCertRequestResponse.Result.Certificate != "" {
- t.Fatalf("expected no certificate, got %s", getCertRequestResponse.Result.Certificate)
+ if getCertRequestResponse.Result.CertificateChain != "" {
+ t.Fatalf("expected no certificate, got %s", getCertRequestResponse.Result.CertificateChain)
}
})
@@ -353,8 +352,8 @@ func TestCertificateRequestsEndToEnd(t *testing.T) {
if getCertRequestResponse.Result.CSR == "" {
t.Fatalf("expected CSR, got empty string")
}
- if getCertRequestResponse.Result.Certificate != "" {
- t.Fatalf("expected no certificate, got %s", getCertRequestResponse.Result.Certificate)
+ if getCertRequestResponse.Result.CertificateChain != "" {
+ t.Fatalf("expected no certificate, got %s", getCertRequestResponse.Result.CertificateChain)
}
})
@@ -399,7 +398,9 @@ func TestCertificateRequestsEndToEnd(t *testing.T) {
// The order of the tests is important, as some tests depend on the
// state of the server after previous tests.
func TestCertificatesEndToEnd(t *testing.T) {
- ts, _, err := setupServer()
+ tempDir := t.TempDir()
+ db_path := filepath.Join(tempDir, "db.sqlite3")
+ ts, _, err := setupServer(db_path)
if err != nil {
t.Fatalf("couldn't create test server: %s", err)
}
@@ -481,7 +482,7 @@ func TestCertificatesEndToEnd(t *testing.T) {
if getCertResponse.Error != "" {
t.Fatalf("expected no error, got %s", getCertResponse.Error)
}
- if getCertResponse.Result.Certificate == "" {
+ if getCertResponse.Result.CertificateChain == "" {
t.Fatalf("expected certificate, got empty string")
}
})
@@ -507,8 +508,8 @@ func TestCertificatesEndToEnd(t *testing.T) {
if getCertResponse.Error != "" {
t.Fatalf("expected no error, got %s", getCertResponse.Error)
}
- if getCertResponse.Result.Certificate != "rejected" {
- t.Fatalf("expected `rejected` certificate, got %s", getCertResponse.Result.Certificate)
+ if getCertResponse.Result.Status != "Rejected" {
+ t.Fatalf("expected `Rejected` status, got %s", getCertResponse.Result.CertificateChain)
}
})
diff --git a/internal/server/handlers_helpers_test.go b/internal/server/handlers_helpers_test.go
index cb309c1..1552e1e 100644
--- a/internal/server/handlers_helpers_test.go
+++ b/internal/server/handlers_helpers_test.go
@@ -10,8 +10,8 @@ import (
"github.com/canonical/notary/internal/server"
)
-func setupServer() (*httptest.Server, *server.HandlerConfig, error) {
- testdb, err := db.NewDatabase(":memory:")
+func setupServer(filepath string) (*httptest.Server, *server.HandlerConfig, error) {
+ testdb, err := db.NewDatabase(filepath)
if err != nil {
return nil, nil, err
}
diff --git a/internal/server/handlers_login.go b/internal/server/handlers_login.go
index 21b31f8..469f022 100644
--- a/internal/server/handlers_login.go
+++ b/internal/server/handlers_login.go
@@ -7,7 +7,7 @@ import (
"net/http"
"time"
- "github.com/canonical/notary/internal/db"
+ "github.com/canonical/sqlair"
"github.com/golang-jwt/jwt"
"golang.org/x/crypto/bcrypt"
)
@@ -65,17 +65,17 @@ func Login(env *HandlerConfig) http.HandlerFunc {
writeError(w, http.StatusBadRequest, "Password is required")
return
}
- userAccount, err := env.DB.RetrieveUserByUsername(loginParams.Username)
+ userAccount, err := env.DB.GetUserByUsername(loginParams.Username)
if err != nil {
log.Println(err)
- if errors.Is(err, db.ErrIdNotFound) {
+ if errors.Is(err, sqlair.ErrNoRows) {
writeError(w, http.StatusUnauthorized, "The username or password is incorrect. Try again.")
return
}
writeError(w, http.StatusInternalServerError, "Internal Error")
return
}
- if err := bcrypt.CompareHashAndPassword([]byte(userAccount.Password), []byte(loginParams.Password)); err != nil {
+ if err := bcrypt.CompareHashAndPassword([]byte(userAccount.HashedPassword), []byte(loginParams.Password)); err != nil {
writeError(w, http.StatusUnauthorized, "The username or password is incorrect. Try again.")
return
}
diff --git a/internal/server/handlers_login_test.go b/internal/server/handlers_login_test.go
index 408cb2c..ff3c0db 100644
--- a/internal/server/handlers_login_test.go
+++ b/internal/server/handlers_login_test.go
@@ -4,6 +4,7 @@ import (
"encoding/json"
"fmt"
"net/http"
+ "path/filepath"
"strings"
"testing"
@@ -46,7 +47,9 @@ func login(url string, client *http.Client, data *LoginParams) (int, *LoginRespo
}
func TestLoginEndToEnd(t *testing.T) {
- ts, config, err := setupServer()
+ tempDir := t.TempDir()
+ db_path := filepath.Join(tempDir, "db.sqlite3")
+ ts, config, err := setupServer(db_path)
if err != nil {
t.Fatalf("couldn't create test server: %s", err)
}
diff --git a/internal/server/handlers_status_test.go b/internal/server/handlers_status_test.go
index c725574..d819fd5 100644
--- a/internal/server/handlers_status_test.go
+++ b/internal/server/handlers_status_test.go
@@ -3,6 +3,7 @@ package server_test
import (
"encoding/json"
"net/http"
+ "path/filepath"
"testing"
)
@@ -35,7 +36,9 @@ func getStatus(url string, client *http.Client, adminToken string) (int, *GetSta
}
func TestStatus(t *testing.T) {
- ts, _, err := setupServer()
+ tempDir := t.TempDir()
+ db_path := filepath.Join(tempDir, "db.sqlite3")
+ ts, _, err := setupServer(db_path)
if err != nil {
t.Fatalf("couldn't create test server: %s", err)
}
diff --git a/internal/server/response.go b/internal/server/response.go
index 0b9678c..7e7daf6 100644
--- a/internal/server/response.go
+++ b/internal/server/response.go
@@ -6,6 +6,10 @@ import (
"net/http"
)
+type SuccessResponse struct {
+ Message string `json:"message"`
+}
+
// writeJSON is a helper function that writes a JSON response to the http.ResponseWriter
func writeJSON(w http.ResponseWriter, v any) error {
type response struct {
diff --git a/ui/package-lock.json b/ui/package-lock.json
index 0e09b13..c41ddf4 100644
--- a/ui/package-lock.json
+++ b/ui/package-lock.json
@@ -7962,6 +7962,21 @@
"funding": {
"url": "https://github.com/sponsors/sindresorhus"
}
+ },
+ "node_modules/@next/swc-win32-ia32-msvc": {
+ "version": "14.2.15",
+ "resolved": "https://registry.npmjs.org/@next/swc-win32-ia32-msvc/-/swc-win32-ia32-msvc-14.2.15.tgz",
+ "integrity": "sha512-fyTE8cklgkyR1p03kJa5zXEaZ9El+kDNM5A+66+8evQS5e/6v0Gk28LqA0Jet8gKSOyP+OTm/tJHzMlGdQerdQ==",
+ "cpu": [
+ "ia32"
+ ],
+ "optional": true,
+ "os": [
+ "win32"
+ ],
+ "engines": {
+ "node": ">= 10"
+ }
}
}
}
diff --git a/ui/src/app/(notary)/certificate_requests/table.test.tsx b/ui/src/app/(notary)/certificate_requests/table.test.tsx
index 5b3c076..ee5a4c2 100644
--- a/ui/src/app/(notary)/certificate_requests/table.test.tsx
+++ b/ui/src/app/(notary)/certificate_requests/table.test.tsx
@@ -2,8 +2,9 @@ import { expect, test } from 'vitest'
import { render, screen } from '@testing-library/react'
import { CertificateRequestsTable } from './table'
import { QueryClient, QueryClientProvider } from '@tanstack/react-query'
+import { CSREntry } from '@/types'
-const rows = [
+const rows: CSREntry[] = [
{
'id': 1,
'csr': `-----BEGIN CERTIFICATE REQUEST-----
@@ -23,7 +24,8 @@ Y/uPl4g3jpGqLCKTASWJDGnZLroLICOzYTVs5P3oj+VueSUwYhGK5tBnS2x5FHID
uMNMgwl0fxGMQZjrlXyCBhXBm1k6PmwcJGJF5LQ31c+5aTTMFU7SyZhlymctB8mS
y+ErBQsRpcQho6Ok+HTXQQUcx7WNcwI=
-----END CERTIFICATE REQUEST-----`,
- 'certificate': ""
+ 'certificate_chain': "",
+ 'status': "Outstanding",
},
{
'id': 2,
@@ -46,7 +48,8 @@ tK9qb8EE92MoWboo4m4bcX74y+eUo3xBev6ZZwdScy8OHLhA/MMI8EElpeYt+Hc2
WsDOAOH6qKQKQg3BO/xmRoohC6GL4CuhP7HYGi7+wziNhNZQa4GtE/k9DyIXVtJy
yuf2PnfXCKnaIWRJNoEqDCZRVMfA5BFSwTPITqyo
-----END CERTIFICATE REQUEST-----`,
- 'certificate': "rejected"
+ 'certificate_chain': "",
+ 'status': "Rejected"
},
{
'id': 3,
@@ -68,7 +71,7 @@ cAQXk3fvTWuikHiCHqqdSdjDYj/8cyiwCrQWpV245VSbOE0WesWoEnSdFXVUfE1+
RSKeTRuuJMcdGqBkDnDI22myj0bjt7q8eqBIjTiLQLnAFnQYpcCrhc8dKU9IJlv1
H9Hay4ZO9LRew3pEtlx2WrExw/gpUcWM8rTI
-----END CERTIFICATE REQUEST-----`,
- 'certificate': `-----BEGIN CERTIFICATE-----
+ 'certificate_chain': `-----BEGIN CERTIFICATE-----
MIIDrDCCApSgAwIBAgIURKr+jf7hj60SyAryIeN++9wDdtkwDQYJKoZIhvcNAQEL
BQAwOTELMAkGA1UEBhMCVVMxKjAoBgNVBAMMIXNlbGYtc2lnbmVkLWNlcnRpZmlj
YXRlcy1vcGVyYXRvcjAeFw0yNDAzMjcxMjQ4MDRaFw0yNTAzMjcxMjQ4MDRaMEcx
@@ -89,7 +92,8 @@ WyhXkzguv3dwH+n43GJFP6MQ+n9W/nPZCUQ0Iy7ueAvj0HFhGyZzAE2wxNFZdvCs
gCX3nqYpp70oZIFDrhmYwE5ij5KXlHD4/1IOfNUKCDmQDgGPLI1tVtwQLjeRq7Hg
XVelpl/LXTQawmJyvDaVT/Q9P+WqoDiMjrqF6Sy7DzNeeccWVqvqX5TVS6Ky56iS
Mvo/+PAJHkBciR5Xn+Wg2a+7vrZvT6CBoRSOTozlLSM=
------END CERTIFICATE-----`
+-----END CERTIFICATE-----`,
+ 'status': 'Active'
},
]
diff --git a/ui/src/app/(notary)/certificate_requests/table.tsx b/ui/src/app/(notary)/certificate_requests/table.tsx
index c481737..4b7ecbf 100644
--- a/ui/src/app/(notary)/certificate_requests/table.tsx
+++ b/ui/src/app/(notary)/certificate_requests/table.tsx
@@ -139,9 +139,9 @@ export function CertificateRequestsTable({ csrs: rows }: TableProps) {
};
const csrrows = rows.map((csrEntry) => {
- const { id, csr, certificate } = csrEntry;
+ const { id, csr, certificate_chain, status: csr_status } = csrEntry;
const csrObj = extractCSR(csr);
- const certs = splitBundle(certificate);
+ const certs = splitBundle(certificate_chain);
const clientCertificate = certs?.at(0);
const certObj = clientCertificate ? extractCert(clientCertificate) : null;
@@ -152,13 +152,13 @@ export function CertificateRequestsTable({ csrs: rows }: TableProps) {
sortData: {
id,
common_name: csrObj.commonName,
- csr_status: certificate === "" ? "outstanding" : (certificate === "rejected" ? "rejected" : "fulfilled"),
+ csr_status: csr_status,
cert_expiry_date: certObj?.notAfter || "",
},
columns: [
{ content: id.toString() },
{ content: csrObj.commonName || "N/A" },
- { content: certificate === "" ? "outstanding" : (certificate === "rejected" ? "rejected" : "fulfilled") },
+ { content: csr_status },
{
content: certObj?.notAfter || "",
style: { backgroundColor: getExpiryColor(certObj?.notAfter) },
@@ -185,13 +185,13 @@ export function CertificateRequestsTable({ csrs: rows }: TableProps) {