From 370b30994cf61bc8f3eeb1eb54221fca0683ddfb Mon Sep 17 00:00:00 2001 From: kayra1 Date: Tue, 19 Nov 2024 12:18:41 +0300 Subject: [PATCH 1/9] feat(db): move to sqlair --- go.mod | 1 + go.sum | 2 + internal/db/db.go | 499 ++++++++++++------ internal/db/db_test.go | 86 +-- internal/metrics/metrics.go | 7 +- internal/metrics/metrics_test.go | 9 +- internal/server/authorization_test.go | 21 +- internal/server/handlers_accounts.go | 77 ++- internal/server/handlers_accounts_test.go | 26 +- .../server/handlers_certificate_requests.go | 149 +++--- .../handlers_certificate_requests_test.go | 43 +- internal/server/handlers_helpers_test.go | 4 +- internal/server/handlers_login.go | 6 +- internal/server/handlers_login_test.go | 5 +- internal/server/handlers_status_test.go | 5 +- internal/server/response.go | 4 + 16 files changed, 572 insertions(+), 372 deletions(-) diff --git a/go.mod b/go.mod index 38fd724a..a234d34e 100644 --- a/go.mod +++ b/go.mod @@ -12,6 +12,7 @@ require ( require ( github.com/beorn7/perks v1.0.1 // indirect + github.com/canonical/sqlair v0.0.0-20241004123011-77313b5382fd github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/klauspost/compress v1.17.9 // indirect github.com/kr/text v0.2.0 // indirect diff --git a/go.sum b/go.sum index 56c00d95..4b3de951 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,7 @@ github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= +github.com/canonical/sqlair v0.0.0-20241004123011-77313b5382fd h1:lVn7391CX7QQ5WBQriNUtCB4fvfurZg6XJwH9aVsRII= +github.com/canonical/sqlair v0.0.0-20241004123011-77313b5382fd/go.mod h1:T+40I2sXshY3KRxx0QQpqqn6hCibSKJ2KHzjBvJj8T4= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= diff --git a/internal/db/db.go b/internal/db/db.go index 868a6b28..6ba5b22a 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -2,273 +2,444 @@ package db import ( + "context" "database/sql" "errors" "fmt" + "github.com/canonical/sqlair" _ "github.com/mattn/go-sqlite3" "golang.org/x/crypto/bcrypt" ) +// Database is the object used to communicate with the established repository. +type Database struct { + certificateTable string + usersTable string + conn *sqlair.DB +} + +type CertificateRequest struct { + ID int `db:"id"` + + CSR string `db:"csr"` + CertificateChain string `db:"certificate_chain"` + RequestStatus string `db:"request_status"` +} + +type User struct { + ID int `db:"id"` + + Username string `db:"username"` + HashedPassword string `db:"hashed_password"` + Permissions int `db:"permissions"` +} + const ( - certificateRequestsTableName = "CertificateRequests" + certificateRequestsTableName = "certificate_requests" usersTableName = "users" ) -const queryCreateCSRsTable = `CREATE TABLE IF NOT EXISTS %s ( - csr TEXT PRIMARY KEY UNIQUE NOT NULL, - certificate TEXT DEFAULT '' +const queryCreateCSRsTable = ` + CREATE TABLE IF NOT EXISTS %s ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + + csr TEXT NOT NULL UNIQUE, + certificate_chain TEXT DEFAULT '', + request_status TEXT DEFAULT 'Outstanding', + + CHECK (request_status IN ('Outstanding', 'Rejected', 'Revoked', 'Active')), + CHECK (NOT (certificate_chain == '' AND request_status == 'Active' )), + CHECK (NOT (certificate_chain != '' AND request_status == 'Outstanding')) + CHECK (NOT (certificate_chain != '' AND request_status == 'Rejected')) + CHECK (NOT (certificate_chain != '' AND request_status == 'Revoked')) )` -const ( - queryGetAllCSRs = "SELECT rowid, * FROM %s" - queryGetCSR = "SELECT rowid, * FROM %s WHERE rowid=?" - queryCreateCSR = "INSERT INTO %s (csr) VALUES (?)" - queryUpdateCSR = "UPDATE %s SET certificate=? WHERE rowid=?" - queryDeleteCSR = "DELETE FROM %s WHERE rowid=?" -) +const queryCreateUsersTable = ` + CREATE TABLE IF NOT EXISTS %s ( + id INTEGER PRIMARY KEY AUTOINCREMENT, -const queryCreateUsersTable = `CREATE TABLE IF NOT EXISTS %s ( - user_id INTEGER PRIMARY KEY AUTOINCREMENT, - username TEXT NOT NULL UNIQUE, - hashed_password TEXT NOT NULL, - permissions INTEGER + username TEXT NOT NULL UNIQUE, + hashed_password TEXT NOT NULL, + permissions INTEGER )` const ( - queryGetAllUsers = "SELECT * FROM %s" - queryGetUser = "SELECT * FROM %s WHERE user_id=?" - queryGetUserByUsername = "SELECT * FROM %s WHERE username=?" - queryCreateUser = "INSERT INTO %s (username, hashed_password, permissions) VALUES (?, ?, ?)" - queryUpdateUser = "UPDATE %s SET hashed_password=? WHERE user_id=?" - queryDeleteUser = "DELETE FROM %s WHERE user_id=?" - queryGetNumUsers = "SELECT COUNT(*) FROM %s" + getAllCSRsStmt = "SELECT &CertificateRequest.* FROM %s" + getCSRsStmt = "SELECT &CertificateRequest.* FROM %s WHERE id==$CertificateRequest.id or csr==$CertificateRequest.csr" + createCSRStmt = "INSERT INTO %s (csr) VALUES ($CertificateRequest.csr)" + updateCSRStmt = "UPDATE %s SET certificate_chain=$CertificateRequest.certificate_chain, request_status=$CertificateRequest.request_status WHERE id==$CertificateRequest.id or csr==$CertificateRequest.csr" + deleteCSRStmt = "DELETE FROM %s WHERE id=$CertificateRequest.id or csr=$CertificateRequest.csr" ) -// CertificateRequestRepository is the object used to communicate with the established repository. -type Database struct { - certificateTable string - usersTable string - conn *sql.DB +const ( + getAllUsersStmt = "SELECT &User.* from %s" + getUserStmt = "SELECT &User.* from %s WHERE id==$User.id or username==$User.username" + createUserStmt = "INSERT INTO %s (username, hashed_password, permissions) VALUES ($User.username, $User.hashed_password, $User.permissions)" + updateUserStmt = "UPDATE %s SET hashed_password=$User.hashed_password WHERE id==$User.id or username==$User.username" + deleteUserStmt = "DELETE FROM %s WHERE id==$User.id" + getNumUsersStmt = "SELECT COUNT(*) AS &NumUsers.count FROM %s" +) + +// RetrieveAllCSRs gets every CertificateRequest entry in the table. +func (db *Database) RetrieveAllCSRs() ([]CertificateRequest, error) { + stmt, err := sqlair.Prepare(fmt.Sprintf(getAllCSRsStmt, db.certificateTable), CertificateRequest{}) + if err != nil { + return nil, err + } + var csrs []CertificateRequest + err = db.conn.Query(context.Background(), stmt).GetAll(&csrs) + if err != nil { + if errors.Is(err, sqlair.ErrNoRows) { + return csrs, nil + } + return nil, err + } + return csrs, nil } -// A CertificateRequest struct represents an entry in the database. -// The object contains a Certificate Request, its matching Certificate if any, and the row ID. -type CertificateRequest struct { - ID int - CSR string - Certificate string +// RetrieveCSRbyID gets a CSR row from the repository from a given ID. +func (db *Database) RetrieveCSRbyID(id int) (*CertificateRequest, error) { + csr := CertificateRequest{ + ID: id, + } + stmt, err := sqlair.Prepare(fmt.Sprintf(getCSRsStmt, db.certificateTable), CertificateRequest{}) + if err != nil { + return nil, err + } + err = db.conn.Query(context.Background(), stmt, csr).Get(&csr) + if err != nil { + return nil, err + } + return &csr, nil } -type User struct { - ID int - Username string - Password string - Permissions int + +// RetrieveCSRbyCSR gets a given CSR row from the repository using the CSR text. +func (db *Database) RetrieveCSRbyCSR(csr string) (*CertificateRequest, error) { + row := CertificateRequest{ + CSR: csr, + } + stmt, err := sqlair.Prepare(fmt.Sprintf(getCSRsStmt, db.certificateTable), CertificateRequest{}) + if err != nil { + return nil, err + } + err = db.conn.Query(context.Background(), stmt, row).Get(&row) + if err != nil { + return nil, err + } + return &row, nil } -var ErrIdNotFound = errors.New("id not found") +// CreateCSR creates a new CSR entry in the repository. The string must be a valid CSR and unique. +func (db *Database) CreateCSR(csr string) error { + if err := ValidateCertificateRequest(csr); err != nil { + return errors.New("csr validation failed: " + err.Error()) + } + stmt, err := sqlair.Prepare(fmt.Sprintf(createCSRStmt, db.certificateTable), CertificateRequest{}) + if err != nil { + return err + } + row := CertificateRequest{ + CSR: csr, + } + err = db.conn.Query(context.Background(), stmt, row).Run() + if err != nil { + return err + } + return nil +} -// RetrieveAllCSRs gets every CertificateRequest entry in the table. -func (db *Database) RetrieveAllCSRs() ([]CertificateRequest, error) { - rows, err := db.conn.Query(fmt.Sprintf(queryGetAllCSRs, db.certificateTable)) +// AddCertificateChainToCSRbyCSR adds a new certificate chain to a row for a given CSR string. +func (db *Database) AddCertificateChainToCSRbyCSR(csr string, cert string) error { + err := ValidateCertificate(cert) if err != nil { - return nil, err + return errors.New("cert validation failed: " + err.Error()) + } + err = CertificateMatchesCSR(cert, csr) + if err != nil { + return errors.New("cert validation failed: " + err.Error()) + } + certBundle := sanitizeCertificateBundle(cert) + stmt, err := sqlair.Prepare(fmt.Sprintf(updateCSRStmt, db.certificateTable), CertificateRequest{}) + if err != nil { + return err + } + newRow := CertificateRequest{ + CSR: csr, + CertificateChain: certBundle, + RequestStatus: "Active", + } + err = db.conn.Query(context.Background(), stmt, newRow).Run() + if err != nil { + return err } + return nil +} - var allCsrs []CertificateRequest - defer rows.Close() - for rows.Next() { - var csr CertificateRequest - if err := rows.Scan(&csr.ID, &csr.CSR, &csr.Certificate); err != nil { - return nil, err - } - allCsrs = append(allCsrs, csr) +// AddCertificateChainToCSRbyID adds a new certificate chain to a row for a given row ID. +func (db *Database) AddCertificateToCSRbyID(id int, cert string) error { + csr, err := db.RetrieveCSRbyID(id) + if err != nil { + return err + } + err = ValidateCertificate(cert) + if err != nil { + return errors.New("cert validation failed: " + err.Error()) } - return allCsrs, nil + err = CertificateMatchesCSR(cert, csr.CSR) + if err != nil { + return errors.New("cert validation failed: " + err.Error()) + } + certBundle := sanitizeCertificateBundle(cert) + stmt, err := sqlair.Prepare(fmt.Sprintf(updateCSRStmt, db.certificateTable), CertificateRequest{}) + if err != nil { + return err + } + newRow := CertificateRequest{ + ID: id, + CertificateChain: certBundle, + RequestStatus: "Active", + } + err = db.conn.Query(context.Background(), stmt, newRow).Run() + if err != nil { + return err + } + return nil } -// RetrieveCSR gets a given CSR from the repository. -// It returns the row id and matching certificate alongside the CSR in a CertificateRequest object. -func (db *Database) RetrieveCSR(id string) (CertificateRequest, error) { - var newCSR CertificateRequest - row := db.conn.QueryRow(fmt.Sprintf(queryGetCSR, db.certificateTable), id) - if err := row.Scan(&newCSR.ID, &newCSR.CSR, &newCSR.Certificate); err != nil { - if err.Error() == "sql: no rows in result set" { - return newCSR, ErrIdNotFound - } - return newCSR, err +// RejectCSRbyCSR updates input CSR's row by setting the certificate bundle to "" and moving the row status to "Rejected". +func (db *Database) RejectCSRbyCSR(csr string) error { + oldRow, err := db.RetrieveCSRbyCSR(csr) + if err != nil { + return err } - return newCSR, nil + stmt, err := sqlair.Prepare(fmt.Sprintf(updateCSRStmt, db.certificateTable), CertificateRequest{}) + if err != nil { + return err + } + newRow := CertificateRequest{ + ID: oldRow.ID, + CSR: oldRow.CSR, + CertificateChain: "", + RequestStatus: "Rejected", + } + err = db.conn.Query(context.Background(), stmt, newRow).Run() + if err != nil { + return err + } + return nil } -// CreateCSR creates a new entry in the repository. -// The given CSR must be valid and unique -func (db *Database) CreateCSR(csr string) (int64, error) { - if err := ValidateCertificateRequest(csr); err != nil { - return 0, errors.New("csr validation failed: " + err.Error()) +// RejectCSRbyCSR updates input ID's row by setting the certificate bundle to "" and sets the row status to "Rejected". +func (db *Database) RejectCSRbyID(id int) error { + oldRow, err := db.RetrieveCSRbyID(id) + if err != nil { + return err } - result, err := db.conn.Exec(fmt.Sprintf(queryCreateCSR, db.certificateTable), csr) + stmt, err := sqlair.Prepare(fmt.Sprintf(updateCSRStmt, db.certificateTable), CertificateRequest{}) if err != nil { - return 0, err + return err } - id, err := result.LastInsertId() + newRow := CertificateRequest{ + ID: oldRow.ID, + CSR: oldRow.CSR, + CertificateChain: "", + RequestStatus: "Rejected", + } + err = db.conn.Query(context.Background(), stmt, newRow).Run() if err != nil { - return 0, err + return err } - return id, nil + return nil } -// UpdateCSR adds a new cert to the given CSR in the repository. -// The given certificate must share the public key of the CSR and must be valid. -func (db *Database) UpdateCSR(id string, cert string) (int64, error) { - csr, err := db.RetrieveCSR(id) +// RevokeCSR updates the input CSR's row by setting the certificate bundle to "" and sets the row status to "Revoked". +func (db *Database) RevokeCSR(csr string) error { + oldRow, err := db.RetrieveCSRbyCSR(csr) if err != nil { - return 0, err + return err } - if cert != "rejected" && cert != "" { - err = ValidateCertificate(cert) - if err != nil { - return 0, errors.New("cert validation failed: " + err.Error()) - } - err = CertificateMatchesCSR(cert, csr.CSR) - if err != nil { - return 0, errors.New("cert validation failed: " + err.Error()) - } - cert = sanitizeCertificateBundle(cert) + stmt, err := sqlair.Prepare(fmt.Sprintf(updateCSRStmt, db.certificateTable), CertificateRequest{}) + if err != nil { + return err + } + newRow := CertificateRequest{ + ID: oldRow.ID, + CSR: oldRow.CSR, + CertificateChain: "", + RequestStatus: "Revoked", } - _, err = db.conn.Exec(fmt.Sprintf(queryUpdateCSR, db.certificateTable), cert, csr.ID) + err = db.conn.Query(context.Background(), stmt, newRow).Run() if err != nil { - return 0, err + return err } - return int64(csr.ID), nil + return nil } -// DeleteCSR removes a CSR from the database alongside the certificate that may have been generated for it. -func (db *Database) DeleteCSR(id string) (int64, error) { - result, err := db.conn.Exec(fmt.Sprintf(queryDeleteCSR, db.certificateTable), id) +// DeleteCSRbyCSR removes a CSR from the database alongside the certificate that may have been generated for it. +func (db *Database) DeleteCSRbyCSR(csr string) error { + stmt, err := sqlair.Prepare(fmt.Sprintf(deleteCSRStmt, db.certificateTable), CertificateRequest{}) if err != nil { - return 0, err + return err + } + row := CertificateRequest{ + CSR: csr, } - deleteId, err := result.RowsAffected() + err = db.conn.Query(context.Background(), stmt, row).Run() if err != nil { - return 0, err + return err } - if deleteId == 0 { - return 0, ErrIdNotFound + return nil +} + +// DeleteCSRByID removes a CSR from the database alongside the certificate that may have been generated for it. +func (db *Database) DeleteCSRbyID(id int) error { + stmt, err := sqlair.Prepare(fmt.Sprintf(deleteCSRStmt, db.certificateTable), CertificateRequest{}) + if err != nil { + return err + } + row := CertificateRequest{ + ID: id, + } + err = db.conn.Query(context.Background(), stmt, row).Run() + if err != nil { + return err } - return deleteId, nil + return nil } // RetrieveAllUsers returns all of the users and their fields available in the database. func (db *Database) RetrieveAllUsers() ([]User, error) { - rows, err := db.conn.Query(fmt.Sprintf(queryGetAllUsers, db.usersTable)) + stmt, err := sqlair.Prepare(fmt.Sprintf(getAllUsersStmt, db.usersTable), User{}) if err != nil { return nil, err } - - var allUsers []User - defer rows.Close() - for rows.Next() { - var user User - if err := rows.Scan(&user.ID, &user.Username, &user.Password, &user.Permissions); err != nil { - return nil, err - } - allUsers = append(allUsers, user) + var users []User + err = db.conn.Query(context.Background(), stmt).GetAll(&users) + if err != nil { + return nil, err } - return allUsers, nil + return users, nil } // RetrieveUser retrieves the name, password and the permission level of a user. -func (db *Database) RetrieveUser(id string) (User, error) { - var newUser User - row := db.conn.QueryRow(fmt.Sprintf(queryGetUser, db.usersTable), id) - if err := row.Scan(&newUser.ID, &newUser.Username, &newUser.Password, &newUser.Permissions); err != nil { - if err.Error() == "sql: no rows in result set" { - return newUser, ErrIdNotFound - } - return newUser, err +func (db *Database) RetrieveUserByID(id int) (*User, error) { + row := User{ + ID: id, } - return newUser, nil + stmt, err := sqlair.Prepare(fmt.Sprintf(getUserStmt, db.usersTable), User{}) + if err != nil { + return nil, err + } + err = db.conn.Query(context.Background(), stmt, row).Get(&row) + if err != nil { + return nil, err + } + return &row, nil } // RetrieveUser retrieves the id, password and the permission level of a user. -func (db *Database) RetrieveUserByUsername(name string) (User, error) { - var newUser User - row := db.conn.QueryRow(fmt.Sprintf(queryGetUserByUsername, db.usersTable), name) - if err := row.Scan(&newUser.ID, &newUser.Username, &newUser.Password, &newUser.Permissions); err != nil { - if err.Error() == "sql: no rows in result set" { - return newUser, ErrIdNotFound - } - return newUser, err +func (db *Database) RetrieveUserByUsername(name string) (*User, error) { + row := User{ + Username: name, + } + stmt, err := sqlair.Prepare(fmt.Sprintf(getUserStmt, db.usersTable), User{}) + if err != nil { + return nil, err } - return newUser, nil + err = db.conn.Query(context.Background(), stmt, row).Get(&row) + if err != nil { + return nil, err + } + return &row, nil } // CreateUser creates a new user from a given username, password and permission level. // The permission level 1 represents an admin, and a 0 represents a regular user. // The password passed in should be in plaintext. This function handles hashing and salting the password before storing it in the database. -func (db *Database) CreateUser(username string, password string, permission int) (int64, error) { +func (db *Database) CreateUser(username string, password string, permission int) error { pw, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) if err != nil { - return 0, err + return err } - result, err := db.conn.Exec(fmt.Sprintf(queryCreateUser, db.usersTable), username, pw, permission) + stmt, err := sqlair.Prepare(fmt.Sprintf(createUserStmt, db.usersTable), User{}) if err != nil { - return 0, err + return err + } + row := User{ + Username: username, + HashedPassword: string(pw), + Permissions: permission, } - id, err := result.LastInsertId() + err = db.conn.Query(context.Background(), stmt, row).Run() if err != nil { - return 0, err + return err } - return id, nil + return nil } // UpdateUser updates the password of the given user. // Just like with CreateUser, this function handles hashing and salting the password before storage. -func (db *Database) UpdateUser(id, password string) (int64, error) { - user, err := db.RetrieveUser(id) +func (db *Database) UpdateUserPassword(id int, password string) error { + _, err := db.RetrieveUserByID(id) if err != nil { - return 0, err + return err } pw, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) if err != nil { - return 0, err + return err } - result, err := db.conn.Exec(fmt.Sprintf(queryUpdateUser, db.usersTable), pw, user.ID) + stmt, err := sqlair.Prepare(fmt.Sprintf(updateUserStmt, db.usersTable), User{}) if err != nil { - return 0, err + return err + } + row := User{ + ID: id, + HashedPassword: string(pw), } - affectedRows, err := result.RowsAffected() + err = db.conn.Query(context.Background(), stmt, row).Run() if err != nil { - return 0, err + return err } - return affectedRows, nil + return nil } -// DeleteUser removes a user from the table. -func (db *Database) DeleteUser(id string) (int64, error) { - result, err := db.conn.Exec(fmt.Sprintf(queryDeleteUser, db.usersTable), id) +// DeleteUserByID removes a user from the table. +func (db *Database) DeleteUserByID(id int) error { + _, err := db.RetrieveUserByID(id) if err != nil { - return 0, err + return err } - deleteId, err := result.RowsAffected() + stmt, err := sqlair.Prepare(fmt.Sprintf(deleteUserStmt, db.usersTable), User{}) if err != nil { - return 0, err + return err + } + row := User{ + ID: id, } - if deleteId == 0 { - return 0, ErrIdNotFound + err = db.conn.Query(context.Background(), stmt, row).Run() + if err != nil { + return err } - return deleteId, nil + return nil +} + +type NumUsers struct { + Count int `db:"count"` } // NumUsers returns the number of users in the database. func (db *Database) NumUsers() (int, error) { - var numUsers int - row := db.conn.QueryRow(fmt.Sprintf(queryGetNumUsers, db.usersTable)) - if err := row.Scan(&numUsers); err != nil { + stmt, err := sqlair.Prepare(fmt.Sprintf(getNumUsersStmt, db.usersTable), NumUsers{}) + if err != nil { + return 0, err + } + result := NumUsers{} + err = db.conn.Query(context.Background(), stmt).Get(&result) + if err != nil { return 0, err } - return numUsers, nil + return result.Count, nil } // Close closes the connection to the repository cleanly. @@ -276,7 +447,7 @@ func (db *Database) Close() error { if db.conn == nil { return nil } - if err := db.conn.Close(); err != nil { + if err := db.conn.PlainDB().Close(); err != nil { return err } return nil @@ -287,18 +458,18 @@ func (db *Database) Close() error { // The database path must be a valid file path or ":memory:". // The table will be created if it doesn't exist in the format expected by the package. func NewDatabase(databasePath string) (*Database, error) { - conn, err := sql.Open("sqlite3", databasePath) + sqlConnection, err := sql.Open("sqlite3", databasePath) if err != nil { return nil, err } - if _, err := conn.Exec(fmt.Sprintf(queryCreateCSRsTable, certificateRequestsTableName)); err != nil { + if _, err := sqlConnection.Exec(fmt.Sprintf(queryCreateCSRsTable, certificateRequestsTableName)); err != nil { return nil, err } - if _, err := conn.Exec(fmt.Sprintf(queryCreateUsersTable, usersTableName)); err != nil { + if _, err := sqlConnection.Exec(fmt.Sprintf(queryCreateUsersTable, usersTableName)); err != nil { return nil, err } db := new(Database) - db.conn = conn + db.conn = sqlair.NewDB(sqlConnection) db.certificateTable = certificateRequestsTableName db.usersTable = usersTableName return db, nil diff --git a/internal/db/db_test.go b/internal/db/db_test.go index 5d0f768c..62989ce4 100644 --- a/internal/db/db_test.go +++ b/internal/db/db_test.go @@ -3,7 +3,7 @@ package db_test import ( "fmt" "log" - "strconv" + "path/filepath" "strings" "testing" @@ -12,7 +12,8 @@ import ( ) func TestConnect(t *testing.T) { - db, err := db.NewDatabase(":memory:") + tempDir := t.TempDir() + db, err := db.NewDatabase(filepath.Join(tempDir, "db.sqlite3")) if err != nil { t.Fatalf("Can't connect to SQLite: %s", err) } @@ -20,21 +21,22 @@ func TestConnect(t *testing.T) { } func TestCSRsEndToEnd(t *testing.T) { - db, err := db.NewDatabase(":memory:") + tempDir := t.TempDir() + db, err := db.NewDatabase(filepath.Join(tempDir, "db.sqlite3")) if err != nil { t.Fatalf("Couldn't complete NewDatabase: %s", err) } defer db.Close() - id1, err := db.CreateCSR(AppleCSR) + err = db.CreateCSR(AppleCSR) if err != nil { t.Fatalf("Couldn't complete Create: %s", err) } - id2, err := db.CreateCSR(BananaCSR) + err = db.CreateCSR(BananaCSR) if err != nil { t.Fatalf("Couldn't complete Create: %s", err) } - id3, err := db.CreateCSR(StrawberryCSR) + err = db.CreateCSR(StrawberryCSR) if err != nil { t.Fatalf("Couldn't complete Create: %s", err) } @@ -46,7 +48,7 @@ func TestCSRsEndToEnd(t *testing.T) { if len(res) != 3 { t.Fatalf("One or more CSRs weren't found in DB") } - retrievedCSR, err := db.RetrieveCSR(strconv.FormatInt(id1, 10)) + retrievedCSR, err := db.RetrieveCSRbyCSR(AppleCSR) if err != nil { t.Fatalf("Couldn't complete Retrieve: %s", err) } @@ -54,7 +56,7 @@ func TestCSRsEndToEnd(t *testing.T) { t.Fatalf("The CSR from the database doesn't match the CSR that was given") } - if _, err = db.DeleteCSR(strconv.FormatInt(id1, 10)); err != nil { + if err = db.DeleteCSRbyCSR(AppleCSR); err != nil { t.Fatalf("Couldn't complete Delete: %s", err) } res, _ = db.RetrieveAllCSRs() @@ -62,24 +64,24 @@ func TestCSRsEndToEnd(t *testing.T) { t.Fatalf("CSR's weren't deleted from the DB properly") } BananaCertBundle := strings.TrimSpace(fmt.Sprintf("%s%s", BananaCert, IssuerCert)) - _, err = db.UpdateCSR(strconv.FormatInt(id2, 10), BananaCertBundle) + err = db.AddCertificateChainToCSRbyCSR(BananaCSR, BananaCertBundle) if err != nil { t.Fatalf("Couldn't complete Update: %s", err) } - retrievedCSR, _ = db.RetrieveCSR(strconv.FormatInt(id2, 10)) - if retrievedCSR.Certificate != BananaCertBundle { - t.Fatalf("The certificate that was uploaded does not match the certificate that was given.\n Retrieved: %s\nGiven: %s", retrievedCSR.Certificate, BananaCertBundle) + retrievedCSR, _ = db.RetrieveCSRbyCSR(BananaCSR) + if retrievedCSR.CertificateChain != BananaCertBundle { + t.Fatalf("The certificate that was uploaded does not match the certificate that was given.\n Retrieved: %s\nGiven: %s", retrievedCSR.CertificateChain, BananaCertBundle) } - _, err = db.UpdateCSR(strconv.FormatInt(id2, 10), "") + err = db.RevokeCSR(BananaCSR) if err != nil { - t.Fatalf("Couldn't complete Update to delete certificate: %s", err) + t.Fatalf("Couldn't complete Update to revoke certificate: %s", err) } - _, err = db.UpdateCSR(strconv.FormatInt(id3, 10), "rejected") + err = db.RejectCSRbyCSR(StrawberryCSR) if err != nil { t.Fatalf("Couldn't complete Update to reject CSR: %s", err) } - retrievedCSR, _ = db.RetrieveCSR(strconv.FormatInt(id2, 10)) - if retrievedCSR.Certificate != "" { + retrievedCSR, _ = db.RetrieveCSRbyCSR(BananaCSR) + if retrievedCSR.RequestStatus != "Revoked" { t.Fatalf("Couldn't delete certificate") } } @@ -89,12 +91,12 @@ func TestCreateFails(t *testing.T) { defer db.Close() InvalidCSR := strings.ReplaceAll(AppleCSR, "M", "i") - if _, err := db.CreateCSR(InvalidCSR); err == nil { + if err := db.CreateCSR(InvalidCSR); err == nil { t.Fatalf("Expected error due to invalid CSR") } db.CreateCSR(AppleCSR) //nolint:errcheck - if _, err := db.CreateCSR(AppleCSR); err == nil { + if err := db.CreateCSR(AppleCSR); err == nil { t.Fatalf("Expected error due to duplicate CSR") } } @@ -103,13 +105,13 @@ func TestUpdateFails(t *testing.T) { db, _ := db.NewDatabase(":memory:") defer db.Close() - id1, _ := db.CreateCSR(AppleCSR) //nolint:errcheck - id2, _ := db.CreateCSR(BananaCSR) //nolint:errcheck + db.CreateCSR(AppleCSR) //nolint:errcheck + db.CreateCSR(BananaCSR) //nolint:errcheck InvalidCert := strings.ReplaceAll(BananaCert, "/", "+") - if _, err := db.UpdateCSR(strconv.FormatInt(id2, 10), InvalidCert); err == nil { + if err := db.AddCertificateChainToCSRbyCSR(BananaCSR, InvalidCert); err == nil { t.Fatalf("Expected updating with invalid cert to fail") } - if _, err := db.UpdateCSR(strconv.FormatInt(id1, 10), BananaCert); err == nil { + if err := db.AddCertificateChainToCSRbyCSR(AppleCSR, BananaCert); err == nil { t.Fatalf("Expected updating with mismatched cert to fail") } } @@ -119,7 +121,10 @@ func TestRetrieve(t *testing.T) { defer db.Close() db.CreateCSR(AppleCSR) //nolint:errcheck - if _, err := db.RetrieveCSR("this is definitely not an id"); err == nil { + if _, err := db.RetrieveCSRbyCSR("this is definitely not an id"); err == nil { + t.Fatalf("Expected failure looking for nonexistent CSR") + } + if _, err := db.RetrieveCSRbyID(-1); err == nil { t.Fatalf("Expected failure looking for nonexistent CSR") } } @@ -131,11 +136,11 @@ func TestUsersEndToEnd(t *testing.T) { } defer db.Close() - id1, err := db.CreateUser("admin", "pw123", 1) + err = db.CreateUser("admin", "pw123", 1) if err != nil { t.Fatalf("Couldn't complete Create: %s", err) } - id2, err := db.CreateUser("norman", "pw456", 0) + err = db.CreateUser("norman", "pw456", 0) if err != nil { t.Fatalf("Couldn't complete Create: %s", err) } @@ -154,31 +159,36 @@ func TestUsersEndToEnd(t *testing.T) { if num != 2 { t.Fatalf("NumUsers didn't return the correct number of users") } - retrievedUser, err := db.RetrieveUser(strconv.FormatInt(id1, 10)) + retrievedUser, err := db.RetrieveUserByUsername("admin") + if err != nil { + t.Fatalf("Couldn't complete Retrieve: %s", err) + } + if retrievedUser.Username != "admin" { + t.Fatalf("The user from the database doesn't match the user that was given") + } + retrievedUser, err = db.RetrieveUserByID(1) if err != nil { t.Fatalf("Couldn't complete Retrieve: %s", err) } if retrievedUser.Username != "admin" { t.Fatalf("The user from the database doesn't match the user that was given") } - if err := bcrypt.CompareHashAndPassword([]byte(retrievedUser.Password), []byte("pw123")); err != nil { + if err := bcrypt.CompareHashAndPassword([]byte(retrievedUser.HashedPassword), []byte("pw123")); err != nil { t.Fatalf("The user's password doesn't match the one stored in the database") } - - if _, err = db.DeleteUser(strconv.FormatInt(id1, 10)); err != nil { + if err = db.DeleteUserByID(1); err != nil { t.Fatalf("Couldn't complete Delete: %s", err) } res, _ = db.RetrieveAllUsers() if len(res) != 1 { t.Fatalf("users weren't deleted from the DB properly") } - - _, err = db.UpdateUser(strconv.FormatInt(id2, 10), "thebestpassword") + err = db.UpdateUserPassword(2, "thebestpassword") if err != nil { t.Fatalf("Couldn't complete Update: %s", err) } - retrievedUser, _ = db.RetrieveUser(strconv.FormatInt(id2, 10)) - if err := bcrypt.CompareHashAndPassword([]byte(retrievedUser.Password), []byte("thebestpassword")); err != nil { + retrievedUser, _ = db.RetrieveUserByUsername("norman") + if err := bcrypt.CompareHashAndPassword([]byte(retrievedUser.HashedPassword), []byte("thebestpassword")); err != nil { t.Fatalf("The new password that was given does not match the password that was stored.") } } @@ -188,19 +198,19 @@ func Example() { if err != nil { log.Fatalln(err) } - _, err = db.CreateCSR(BananaCSR) + err = db.CreateCSR(BananaCSR) if err != nil { log.Fatalln(err) } - _, err = db.UpdateCSR(BananaCSR, BananaCert) + err = db.AddCertificateChainToCSRbyCSR(BananaCSR, BananaCert) if err != nil { log.Fatalln(err) } - entry, err := db.RetrieveCSR(BananaCSR) + entry, err := db.RetrieveCSRbyCSR(BananaCSR) if err != nil { log.Fatalln(err) } - if entry.Certificate != BananaCert { + if entry.CertificateChain != BananaCert { log.Fatalln("Retrieved Certificate doesn't match Stored Certificate") } err = db.Close() diff --git a/internal/metrics/metrics.go b/internal/metrics/metrics.go index 1076d615..686ce0d0 100644 --- a/internal/metrics/metrics.go +++ b/internal/metrics/metrics.go @@ -95,15 +95,15 @@ func (pm *PrometheusMetrics) GenerateMetrics(csrs []db.CertificateRequest) { var expiringIn30DaysCertCount float64 var expiringIn90DaysCertCount float64 for _, entry := range csrs { - if entry.Certificate == "" { + if entry.CertificateChain == "" { outstandingCSRCount += 1 continue } - if entry.Certificate == "rejected" { + if entry.RequestStatus == "Rejected" { continue } certCount += 1 - expiryDate := certificateExpiryDate(entry.Certificate) + expiryDate := certificateExpiryDate(entry.CertificateChain) daysRemaining := time.Until(expiryDate).Hours() / 24 if daysRemaining < 0 { expiredCertCount += 1 @@ -218,6 +218,7 @@ func requestDurationMetric() prometheus.HistogramVec { } func certificateExpiryDate(certString string) time.Time { + // TODO: Does this return the expiry date of the issuer? certBlock, _ := pem.Decode([]byte(certString)) cert, _ := x509.ParseCertificate(certBlock.Bytes) // TODO: cert.NotAfter can exist in a wrong cert. We should catch that at the db level validation diff --git a/internal/metrics/metrics_test.go b/internal/metrics/metrics_test.go index ecdd8548..27422766 100644 --- a/internal/metrics/metrics_test.go +++ b/internal/metrics/metrics_test.go @@ -21,7 +21,8 @@ import ( // TestPrometheusHandler tests that the Prometheus metrics handler responds correctly to an HTTP request. func TestPrometheusHandler(t *testing.T) { - db, err := db.NewDatabase(":memory:") + tempDir := t.TempDir() + db, err := db.NewDatabase(filepath.Join(tempDir, "db.sqlite3")) if err != nil { t.Fatal(err) } @@ -95,13 +96,13 @@ func generateCertPair(daysRemaining int) (string, string, string) { } func initializeTestDB(t *testing.T, db *db.Database) { - for i, v := range []int{5, 10, 32} { + for _, v := range []int{5, 10, 32} { csr, cert, ca := generateCertPair(v) - _, err := db.CreateCSR(csr) + err := db.CreateCSR(csr) if err != nil { t.Fatalf("couldn't create test csr: %s", err) } - _, err = db.UpdateCSR(fmt.Sprint(i+1), fmt.Sprintf("%s%s", cert, ca)) + err = db.AddCertificateChainToCSRbyCSR(csr, fmt.Sprintf("%s%s", cert, ca)) if err != nil { t.Fatalf("couldn't create test cert: %s", err) } diff --git a/internal/server/authorization_test.go b/internal/server/authorization_test.go index 164faa7e..4aaf8f99 100644 --- a/internal/server/authorization_test.go +++ b/internal/server/authorization_test.go @@ -2,12 +2,15 @@ package server_test import ( "net/http" + "path/filepath" "strings" "testing" ) func TestAuthorizationNoAuth(t *testing.T) { - ts, _, err := setupServer() + tempDir := t.TempDir() + db_path := filepath.Join(tempDir, "db.sqlite3") + ts, _, err := setupServer(db_path) if err != nil { t.Fatalf("couldn't create test server: %s", err) } @@ -47,7 +50,9 @@ func TestAuthorizationNoAuth(t *testing.T) { } func TestAuthorizationNonAdminAuthorized(t *testing.T) { - ts, _, err := setupServer() + tempDir := t.TempDir() + db_path := filepath.Join(tempDir, "db.sqlite3") + ts, _, err := setupServer(db_path) if err != nil { t.Fatalf("couldn't create test server: %s", err) } @@ -98,7 +103,9 @@ func TestAuthorizationNonAdminAuthorized(t *testing.T) { } func TestAuthorizationNonAdminUnauthorized(t *testing.T) { - ts, _, err := setupServer() + tempDir := t.TempDir() + db_path := filepath.Join(tempDir, "db.sqlite3") + ts, _, err := setupServer(db_path) if err != nil { t.Fatalf("couldn't create test server: %s", err) } @@ -156,7 +163,9 @@ func TestAuthorizationNonAdminUnauthorized(t *testing.T) { } func TestAuthorizationAdminAuthorized(t *testing.T) { - ts, _, err := setupServer() + tempDir := t.TempDir() + db_path := filepath.Join(tempDir, "db.sqlite3") + ts, _, err := setupServer(db_path) if err != nil { t.Fatalf("couldn't create test server: %s", err) } @@ -205,7 +214,9 @@ func TestAuthorizationAdminAuthorized(t *testing.T) { } func TestAuthorizationAdminUnAuthorized(t *testing.T) { - ts, _, err := setupServer() + tempDir := t.TempDir() + db_path := filepath.Join(tempDir, "db.sqlite3") + ts, _, err := setupServer(db_path) if err != nil { t.Fatalf("couldn't create test server: %s", err) } diff --git a/internal/server/handlers_accounts.go b/internal/server/handlers_accounts.go index 7b72e829..82ad61eb 100644 --- a/internal/server/handlers_accounts.go +++ b/internal/server/handlers_accounts.go @@ -10,6 +10,7 @@ import ( "strings" "github.com/canonical/notary/internal/db" + "github.com/canonical/sqlair" ) type CreateAccountParams struct { @@ -27,18 +28,6 @@ type GetAccountResponse struct { Permissions int `json:"permissions"` } -type CreateAccountResponse struct { - ID int `json:"id"` -} - -type ChangeAccountResponse struct { - ID int `json:"id"` -} - -type DeleteAccountResponse struct { - ID int `json:"id"` -} - func validatePassword(password string) bool { if len(password) < 8 { return false @@ -87,8 +76,12 @@ func ListAccounts(env *HandlerConfig) http.HandlerFunc { func GetAccount(env *HandlerConfig) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { id := r.PathValue("id") - var account db.User - var err error + idNum, err := strconv.Atoi(id) + if err != nil { + writeError(w, http.StatusInternalServerError, "Internal Error") + return + } + var account *db.User if id == "me" { claims, headerErr := getClaimsFromAuthorizationHeader(r.Header.Get("Authorization"), env.JWTSecret) if headerErr != nil { @@ -96,11 +89,11 @@ func GetAccount(env *HandlerConfig) http.HandlerFunc { } account, err = env.DB.RetrieveUserByUsername(claims.Username) } else { - account, err = env.DB.RetrieveUser(id) + account, err = env.DB.RetrieveUserByID(idNum) } if err != nil { log.Println(err) - if errors.Is(err, db.ErrIdNotFound) { + if errors.Is(err, sqlair.ErrNoRows) { writeError(w, http.StatusNotFound, "Not Found") return } @@ -150,12 +143,11 @@ func CreateAccount(env *HandlerConfig) http.HandlerFunc { writeError(w, http.StatusInternalServerError, "Failed to retrieve accounts: "+err.Error()) return } - permission := UserPermission if numUsers == 0 { permission = AdminPermission } - id, err := env.DB.CreateUser(createAccountParams.Username, createAccountParams.Password, permission) + err = env.DB.CreateUser(createAccountParams.Username, createAccountParams.Password, permission) if err != nil { if strings.Contains(err.Error(), "UNIQUE constraint failed") { writeError(w, http.StatusBadRequest, "account with given username already exists") @@ -165,11 +157,9 @@ func CreateAccount(env *HandlerConfig) http.HandlerFunc { writeError(w, http.StatusInternalServerError, "Internal Error") return } - accountResponse := CreateAccountResponse{ - ID: int(id), - } + successResponse := SuccessResponse{Message: "success"} w.WriteHeader(http.StatusCreated) - err = writeJSON(w, accountResponse) + err = writeJSON(w, successResponse) if err != nil { writeError(w, http.StatusInternalServerError, "internal error") return @@ -182,39 +172,39 @@ func CreateAccount(env *HandlerConfig) http.HandlerFunc { func DeleteAccount(env *HandlerConfig) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { id := r.PathValue("id") - idInt, err := strconv.ParseInt(id, 10, 64) + idInt, err := strconv.Atoi(id) if err != nil { log.Println(err) writeError(w, http.StatusInternalServerError, "Internal Error") return } - account, err := env.DB.RetrieveUser(id) + account, err := env.DB.RetrieveUserByID(idInt) if err != nil { - if !errors.Is(err, db.ErrIdNotFound) { - log.Println(err) - writeError(w, http.StatusInternalServerError, "Internal Error") + log.Println(err) + if errors.Is(err, sqlair.ErrNoRows) { + writeError(w, http.StatusNotFound, "Not Found") return } + writeError(w, http.StatusInternalServerError, "Internal Error") + return } if account.Permissions == 1 { writeError(w, http.StatusBadRequest, "deleting an Admin account is not allowed.") return } - _, err = env.DB.DeleteUser(id) + err = env.DB.DeleteUserByID(idInt) if err != nil { log.Println(err) - if errors.Is(err, db.ErrIdNotFound) { + if errors.Is(err, sqlair.ErrNoRows) { writeError(w, http.StatusNotFound, "Not Found") return } writeError(w, http.StatusInternalServerError, "Internal Error") return } - deleteAccountResponse := DeleteAccountResponse{ - ID: int(idInt), - } + successResponse := SuccessResponse{Message: "success"} w.WriteHeader(http.StatusAccepted) - err = writeJSON(w, deleteAccountResponse) + err = writeJSON(w, successResponse) if err != nil { writeError(w, http.StatusInternalServerError, "internal error") return @@ -225,6 +215,7 @@ func DeleteAccount(env *HandlerConfig) http.HandlerFunc { func ChangeAccountPassword(env *HandlerConfig) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { id := r.PathValue("id") + var idNum int if id == "me" { claims, err := getClaimsFromAuthorizationHeader(r.Header.Get("Authorization"), env.JWTSecret) if err != nil { @@ -238,7 +229,15 @@ func ChangeAccountPassword(env *HandlerConfig) http.HandlerFunc { writeError(w, http.StatusUnauthorized, "Unauthorized") return } - id = strconv.Itoa(account.ID) + idNum = account.ID + } else { + idInt, err := strconv.Atoi(id) + if err != nil { + log.Println(err) + writeError(w, http.StatusInternalServerError, "Internal Error") + return + } + idNum = idInt } var changeAccountParams ChangeAccountParams if err := json.NewDecoder(r.Body).Decode(&changeAccountParams); err != nil { @@ -257,21 +256,19 @@ func ChangeAccountPassword(env *HandlerConfig) http.HandlerFunc { ) return } - ret, err := env.DB.UpdateUser(id, changeAccountParams.Password) + err := env.DB.UpdateUserPassword(idNum, changeAccountParams.Password) if err != nil { log.Println(err) - if errors.Is(err, db.ErrIdNotFound) { + if errors.Is(err, sqlair.ErrNoRows) { writeError(w, http.StatusNotFound, "Not Found") return } writeError(w, http.StatusInternalServerError, "Internal Error") return } - changeAccountResponse := ChangeAccountResponse{ - ID: int(ret), - } + successResponse := SuccessResponse{Message: "success"} w.WriteHeader(http.StatusCreated) - err = writeJSON(w, changeAccountResponse) + err = writeJSON(w, successResponse) if err != nil { writeError(w, http.StatusInternalServerError, "internal error") return diff --git a/internal/server/handlers_accounts_test.go b/internal/server/handlers_accounts_test.go index 42e2e7a7..38d40a13 100644 --- a/internal/server/handlers_accounts_test.go +++ b/internal/server/handlers_accounts_test.go @@ -3,11 +3,16 @@ package server_test import ( "encoding/json" "net/http" + "path/filepath" "strconv" "strings" "testing" ) +type SuccessResponse struct { + Message string `json:"message"` +} + type GetAccountResponseResult struct { ID int `json:"id"` Username string `json:"username"` @@ -29,8 +34,8 @@ type CreateAccountResponseResult struct { } type CreateAccountResponse struct { - Result CreateAccountResponseResult `json:"result"` - Error string `json:"error,omitempty"` + Result SuccessResponse `json:"result"` + Error string `json:"error,omitempty"` } type ChangeAccountPasswordParams struct { @@ -139,7 +144,9 @@ func deleteAccount(url string, client *http.Client, adminToken string, id int) ( // The order of the tests is important, as some tests depend on // the state of the server after previous tests. func TestAccountsEndToEnd(t *testing.T) { - ts, _, err := setupServer() + tempDir := t.TempDir() + db_path := filepath.Join(tempDir, "db.sqlite3") + ts, _, err := setupServer(db_path) if err != nil { t.Fatalf("couldn't create test server: %s", err) } @@ -197,10 +204,7 @@ func TestAccountsEndToEnd(t *testing.T) { t.Fatalf("expected status %d, got %d", http.StatusCreated, statusCode) } if response.Error != "" { - t.Fatalf("expected error %q, got %q", "", response.Error) - } - if response.Result.ID != 3 { - t.Fatalf("expected ID 3, got %d", response.Result.ID) + t.Fatalf("unexpected error :%q", response.Error) } }) @@ -285,10 +289,7 @@ func TestAccountsEndToEnd(t *testing.T) { t.Fatalf("expected status %d, got %d", http.StatusCreated, statusCode) } if response.Error != "" { - t.Fatalf("expected error %q, got %q", "", response.Error) - } - if response.Result.ID != 1 { - t.Fatalf("expected ID 1, got %d", response.Result.ID) + t.Fatalf("unexpected error :%q", response.Error) } }) @@ -351,9 +352,6 @@ func TestAccountsEndToEnd(t *testing.T) { if response.Error != "" { t.Fatalf("expected error %q, got %q", "", response.Error) } - if response.Result.ID != 2 { - t.Fatalf("expected ID 2, got %d", response.Result.ID) - } }) t.Run("13. Delete account - no user", func(t *testing.T) { diff --git a/internal/server/handlers_certificate_requests.go b/internal/server/handlers_certificate_requests.go index a1a2d0fc..bcd440d1 100644 --- a/internal/server/handlers_certificate_requests.go +++ b/internal/server/handlers_certificate_requests.go @@ -8,7 +8,7 @@ import ( "strconv" "strings" - "github.com/canonical/notary/internal/db" + "github.com/canonical/sqlair" ) type CreateCertificateRequestParams struct { @@ -16,54 +16,32 @@ type CreateCertificateRequestParams struct { } type CreateCertificateParams struct { - Certificate string `json:"certificate"` + CertificateChain string `json:"certificate"` } -type GetCertificateRequestResponse struct { - ID int `json:"id"` - CSR string `json:"csr"` - Certificate string `json:"certificate"` -} - -type CreateCertificateRequestResponse struct { - ID int `json:"id"` -} - -type DeleteCertificateRequestResponse struct { - ID int `json:"id"` -} - -type RejectCertificateRequestResponse struct { - ID int `json:"id"` -} - -type CreateCertificateResponse struct { - ID int `json:"id"` -} - -type DeleteCertificateResponse struct { - ID int `json:"id"` -} - -type RejectCertificateResponse struct { - ID int `json:"id"` +type CertificateRequest struct { + ID int `json:"id"` + CSR string `json:"csr"` + CertificateChain string `json:"certificate"` + CSRStatus string `json:"csr_status"` } // ListCertificateRequests returns all of the Certificate Requests func ListCertificateRequests(env *HandlerConfig) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - certs, err := env.DB.RetrieveAllCSRs() + csrs, err := env.DB.RetrieveAllCSRs() if err != nil { log.Println(err) writeError(w, http.StatusInternalServerError, "Internal Error") return } - certificateRequestsResponse := make([]GetCertificateRequestResponse, len(certs)) - for i, cert := range certs { - certificateRequestsResponse[i] = GetCertificateRequestResponse{ - ID: cert.ID, - CSR: cert.CSR, - Certificate: cert.Certificate, + certificateRequestsResponse := make([]CertificateRequest, len(csrs)) + for i, csr := range csrs { + certificateRequestsResponse[i] = CertificateRequest{ + ID: csr.ID, + CSR: csr.CSR, + CertificateChain: csr.CertificateChain, + CSRStatus: csr.RequestStatus, } } w.WriteHeader(http.StatusOK) @@ -87,7 +65,7 @@ func CreateCertificateRequest(env *HandlerConfig) http.HandlerFunc { writeError(w, http.StatusBadRequest, "csr is missing") return } - id, err := env.DB.CreateCSR(createCertificateRequestParams.CSR) + err := env.DB.CreateCSR(createCertificateRequestParams.CSR) if err != nil { if strings.Contains(err.Error(), "UNIQUE constraint failed") { writeError(w, http.StatusBadRequest, "given csr already recorded") @@ -102,11 +80,9 @@ func CreateCertificateRequest(env *HandlerConfig) http.HandlerFunc { writeError(w, http.StatusInternalServerError, "Internal Error") return } - certificateRequestResponse := CreateCertificateRequestResponse{ - ID: int(id), - } + successResponse := SuccessResponse{Message: "success"} w.WriteHeader(http.StatusCreated) - err = writeJSON(w, certificateRequestResponse) + err = writeJSON(w, successResponse) if err != nil { writeError(w, http.StatusInternalServerError, "internal error") return @@ -119,20 +95,26 @@ func CreateCertificateRequest(env *HandlerConfig) http.HandlerFunc { func GetCertificateRequest(env *HandlerConfig) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { id := r.PathValue("id") - cert, err := env.DB.RetrieveCSR(id) + idNum, err := strconv.Atoi(id) + if err != nil { + writeError(w, http.StatusInternalServerError, "Internal Error") + return + } + csr, err := env.DB.RetrieveCSRbyID(idNum) if err != nil { log.Println(err) - if errors.Is(err, db.ErrIdNotFound) { + if errors.Is(err, sqlair.ErrNoRows) { writeError(w, http.StatusNotFound, "Not Found") return } writeError(w, http.StatusInternalServerError, "Internal Error") return } - certificateRequestResponse := GetCertificateRequestResponse{ - ID: cert.ID, - CSR: cert.CSR, - Certificate: cert.Certificate, + certificateRequestResponse := CertificateRequest{ + ID: csr.ID, + CSR: csr.CSR, + CertificateChain: csr.CertificateChain, + CSRStatus: csr.RequestStatus, } w.WriteHeader(http.StatusOK) err = writeJSON(w, certificateRequestResponse) @@ -148,21 +130,24 @@ func GetCertificateRequest(env *HandlerConfig) http.HandlerFunc { func DeleteCertificateRequest(env *HandlerConfig) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { id := r.PathValue("id") - insertId, err := env.DB.DeleteCSR(id) + idNum, err := strconv.Atoi(id) + if err != nil { + writeError(w, http.StatusInternalServerError, "Internal Error") + return + } + err = env.DB.DeleteCSRbyID(idNum) if err != nil { log.Println(err) - if errors.Is(err, db.ErrIdNotFound) { + if errors.Is(err, sqlair.ErrNoRows) { writeError(w, http.StatusNotFound, "Not Found") return } writeError(w, http.StatusInternalServerError, "Internal Error") return } - certificateRequestResponse := DeleteCertificateRequestResponse{ - ID: int(insertId), - } + successResponse := SuccessResponse{Message: "success"} w.WriteHeader(http.StatusAccepted) - err = writeJSON(w, certificateRequestResponse) + err = writeJSON(w, successResponse) if err != nil { writeError(w, http.StatusInternalServerError, "internal error") return @@ -179,15 +164,20 @@ func CreateCertificate(env *HandlerConfig) http.HandlerFunc { writeError(w, http.StatusBadRequest, "Invalid JSON format") return } - if createCertificateParams.Certificate == "" { + if createCertificateParams.CertificateChain == "" { writeError(w, http.StatusBadRequest, "certificate is missing") return } id := r.PathValue("id") - insertId, err := env.DB.UpdateCSR(id, createCertificateParams.Certificate) + idNum, err := strconv.Atoi(id) + if err != nil { + writeError(w, http.StatusInternalServerError, "Internal Error") + return + } + err = env.DB.AddCertificateToCSRbyID(idNum, createCertificateParams.CertificateChain) if err != nil { log.Println(err) - if errors.Is(err, db.ErrIdNotFound) || + if errors.Is(err, sqlair.ErrNoRows) || err.Error() == "certificate does not match CSR" || strings.Contains(err.Error(), "cert validation failed") { writeError(w, http.StatusBadRequest, "Bad Request") @@ -196,18 +186,15 @@ func CreateCertificate(env *HandlerConfig) http.HandlerFunc { writeError(w, http.StatusInternalServerError, "Internal Error") return } - insertIdStr := strconv.FormatInt(insertId, 10) if env.SendPebbleNotifications { - err := SendPebbleNotification("canonical.com/notary/certificate/update", insertIdStr) + err := SendPebbleNotification("canonical.com/notary/certificate/update", id) if err != nil { log.Printf("pebble notify failed: %s. continuing silently.", err.Error()) } } - certificateResponse := CreateCertificateResponse{ - ID: int(insertId), - } + successResponse := SuccessResponse{Message: "success"} w.WriteHeader(http.StatusCreated) - err = writeJSON(w, certificateResponse) + err = writeJSON(w, successResponse) if err != nil { writeError(w, http.StatusInternalServerError, "internal error") return @@ -218,28 +205,30 @@ func CreateCertificate(env *HandlerConfig) http.HandlerFunc { func RejectCertificate(env *HandlerConfig) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { id := r.PathValue("id") - insertId, err := env.DB.UpdateCSR(id, "rejected") + idNum, err := strconv.Atoi(id) + if err != nil { + writeError(w, http.StatusInternalServerError, "Internal Error") + return + } + err = env.DB.RejectCSRbyID(idNum) if err != nil { log.Println(err) - if errors.Is(err, db.ErrIdNotFound) { + if errors.Is(err, sqlair.ErrNoRows) { writeError(w, http.StatusNotFound, "Not Found") return } writeError(w, http.StatusInternalServerError, "Internal Error") return } - insertIdStr := strconv.FormatInt(insertId, 10) if env.SendPebbleNotifications { - err := SendPebbleNotification("canonical.com/notary/certificate/update", insertIdStr) + err := SendPebbleNotification("canonical.com/notary/certificate/update", id) if err != nil { log.Printf("pebble notify failed: %s. continuing silently.", err.Error()) } } - certificateResponse := RejectCertificateResponse{ - ID: int(insertId), - } + successResponse := SuccessResponse{Message: "success"} w.WriteHeader(http.StatusAccepted) - err = writeJSON(w, certificateResponse) + err = writeJSON(w, successResponse) if err != nil { writeError(w, http.StatusInternalServerError, "internal error") return @@ -252,28 +241,30 @@ func RejectCertificate(env *HandlerConfig) http.HandlerFunc { func DeleteCertificate(env *HandlerConfig) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { id := r.PathValue("id") - insertId, err := env.DB.UpdateCSR(id, "") + idNum, err := strconv.Atoi(id) + if err != nil { + writeError(w, http.StatusInternalServerError, "Internal Error") + return + } + err = env.DB.DeleteCSRbyID(idNum) if err != nil { log.Println(err) - if errors.Is(err, db.ErrIdNotFound) { + if errors.Is(err, sqlair.ErrNoRows) { writeError(w, http.StatusBadRequest, "Bad Request") return } writeError(w, http.StatusInternalServerError, "Internal Error") return } - insertIdStr := strconv.FormatInt(insertId, 10) if env.SendPebbleNotifications { - err := SendPebbleNotification("canonical.com/notary/certificate/update", insertIdStr) + err := SendPebbleNotification("canonical.com/notary/certificate/update", id) if err != nil { log.Printf("pebble notify failed: %s. continuing silently.", err.Error()) } } - certificateResponse := DeleteCertificateResponse{ - ID: int(insertId), - } + successResponse := SuccessResponse{Message: "success"} w.WriteHeader(http.StatusOK) - err = writeJSON(w, certificateResponse) + err = writeJSON(w, successResponse) if err != nil { writeError(w, http.StatusInternalServerError, "internal error") return diff --git a/internal/server/handlers_certificate_requests_test.go b/internal/server/handlers_certificate_requests_test.go index 21b83118..ac06dd65 100644 --- a/internal/server/handlers_certificate_requests_test.go +++ b/internal/server/handlers_certificate_requests_test.go @@ -9,22 +9,24 @@ import ( "path/filepath" "strconv" "testing" + + "github.com/canonical/notary/internal/server" ) -type CertificateRequest struct { - ID int `json:"id"` - CSR string `json:"csr"` - Certificate string `json:"certificate"` -} +// type CertificateRequest struct { +// ID int `json:"id"` +// CSR string `json:"csr"` +// Certificate string `json:"certificate"` +// } type GetCertificateRequestResponse struct { - Result CertificateRequest `json:"result"` - Error string `json:"error,omitempty"` + Result server.CertificateRequest `json:"result"` + Error string `json:"error,omitempty"` } type ListCertificateRequestsResponse struct { - Error string `json:"error,omitempty"` - Result []CertificateRequest `json:"result"` + Error string `json:"error,omitempty"` + Result []server.CertificateRequest `json:"result"` } type CreateCertificateRequestResponse struct { @@ -162,7 +164,10 @@ func rejectCertificate(url string, client *http.Client, adminToken string, id in // The order of the tests is important, as some tests depend on the // state of the server after previous tests. func TestCertificateRequestsEndToEnd(t *testing.T) { - ts, _, err := setupServer() + + tempDir := t.TempDir() + db_path := filepath.Join(tempDir, "db.sqlite3") + ts, _, err := setupServer(db_path) if err != nil { t.Fatalf("couldn't create test server: %s", err) } @@ -257,8 +262,8 @@ func TestCertificateRequestsEndToEnd(t *testing.T) { if getCertRequestResponse.Result.CSR == "" { t.Fatalf("expected CSR, got empty string") } - if getCertRequestResponse.Result.Certificate != "" { - t.Fatalf("expected no certificate, got %s", getCertRequestResponse.Result.Certificate) + if getCertRequestResponse.Result.CertificateChain != "" { + t.Fatalf("expected no certificate, got %s", getCertRequestResponse.Result.CertificateChain) } }) @@ -353,8 +358,8 @@ func TestCertificateRequestsEndToEnd(t *testing.T) { if getCertRequestResponse.Result.CSR == "" { t.Fatalf("expected CSR, got empty string") } - if getCertRequestResponse.Result.Certificate != "" { - t.Fatalf("expected no certificate, got %s", getCertRequestResponse.Result.Certificate) + if getCertRequestResponse.Result.CertificateChain != "" { + t.Fatalf("expected no certificate, got %s", getCertRequestResponse.Result.CertificateChain) } }) @@ -399,7 +404,9 @@ func TestCertificateRequestsEndToEnd(t *testing.T) { // The order of the tests is important, as some tests depend on the // state of the server after previous tests. func TestCertificatesEndToEnd(t *testing.T) { - ts, _, err := setupServer() + tempDir := t.TempDir() + db_path := filepath.Join(tempDir, "db.sqlite3") + ts, _, err := setupServer(db_path) if err != nil { t.Fatalf("couldn't create test server: %s", err) } @@ -481,7 +488,7 @@ func TestCertificatesEndToEnd(t *testing.T) { if getCertResponse.Error != "" { t.Fatalf("expected no error, got %s", getCertResponse.Error) } - if getCertResponse.Result.Certificate == "" { + if getCertResponse.Result.CertificateChain == "" { t.Fatalf("expected certificate, got empty string") } }) @@ -507,8 +514,8 @@ func TestCertificatesEndToEnd(t *testing.T) { if getCertResponse.Error != "" { t.Fatalf("expected no error, got %s", getCertResponse.Error) } - if getCertResponse.Result.Certificate != "rejected" { - t.Fatalf("expected `rejected` certificate, got %s", getCertResponse.Result.Certificate) + if getCertResponse.Result.CSRStatus != "Rejected" { + t.Fatalf("expected `Rejected` status, got %s", getCertResponse.Result.CertificateChain) } }) diff --git a/internal/server/handlers_helpers_test.go b/internal/server/handlers_helpers_test.go index cb309c1a..1552e1e9 100644 --- a/internal/server/handlers_helpers_test.go +++ b/internal/server/handlers_helpers_test.go @@ -10,8 +10,8 @@ import ( "github.com/canonical/notary/internal/server" ) -func setupServer() (*httptest.Server, *server.HandlerConfig, error) { - testdb, err := db.NewDatabase(":memory:") +func setupServer(filepath string) (*httptest.Server, *server.HandlerConfig, error) { + testdb, err := db.NewDatabase(filepath) if err != nil { return nil, nil, err } diff --git a/internal/server/handlers_login.go b/internal/server/handlers_login.go index 21b31f81..a0f702a3 100644 --- a/internal/server/handlers_login.go +++ b/internal/server/handlers_login.go @@ -7,7 +7,7 @@ import ( "net/http" "time" - "github.com/canonical/notary/internal/db" + "github.com/canonical/sqlair" "github.com/golang-jwt/jwt" "golang.org/x/crypto/bcrypt" ) @@ -68,14 +68,14 @@ func Login(env *HandlerConfig) http.HandlerFunc { userAccount, err := env.DB.RetrieveUserByUsername(loginParams.Username) if err != nil { log.Println(err) - if errors.Is(err, db.ErrIdNotFound) { + if errors.Is(err, sqlair.ErrNoRows) { writeError(w, http.StatusUnauthorized, "The username or password is incorrect. Try again.") return } writeError(w, http.StatusInternalServerError, "Internal Error") return } - if err := bcrypt.CompareHashAndPassword([]byte(userAccount.Password), []byte(loginParams.Password)); err != nil { + if err := bcrypt.CompareHashAndPassword([]byte(userAccount.HashedPassword), []byte(loginParams.Password)); err != nil { writeError(w, http.StatusUnauthorized, "The username or password is incorrect. Try again.") return } diff --git a/internal/server/handlers_login_test.go b/internal/server/handlers_login_test.go index 408cb2c6..ff3c0dbd 100644 --- a/internal/server/handlers_login_test.go +++ b/internal/server/handlers_login_test.go @@ -4,6 +4,7 @@ import ( "encoding/json" "fmt" "net/http" + "path/filepath" "strings" "testing" @@ -46,7 +47,9 @@ func login(url string, client *http.Client, data *LoginParams) (int, *LoginRespo } func TestLoginEndToEnd(t *testing.T) { - ts, config, err := setupServer() + tempDir := t.TempDir() + db_path := filepath.Join(tempDir, "db.sqlite3") + ts, config, err := setupServer(db_path) if err != nil { t.Fatalf("couldn't create test server: %s", err) } diff --git a/internal/server/handlers_status_test.go b/internal/server/handlers_status_test.go index c7255745..d819fd58 100644 --- a/internal/server/handlers_status_test.go +++ b/internal/server/handlers_status_test.go @@ -3,6 +3,7 @@ package server_test import ( "encoding/json" "net/http" + "path/filepath" "testing" ) @@ -35,7 +36,9 @@ func getStatus(url string, client *http.Client, adminToken string) (int, *GetSta } func TestStatus(t *testing.T) { - ts, _, err := setupServer() + tempDir := t.TempDir() + db_path := filepath.Join(tempDir, "db.sqlite3") + ts, _, err := setupServer(db_path) if err != nil { t.Fatalf("couldn't create test server: %s", err) } diff --git a/internal/server/response.go b/internal/server/response.go index 0b9678cd..7e7daf65 100644 --- a/internal/server/response.go +++ b/internal/server/response.go @@ -6,6 +6,10 @@ import ( "net/http" ) +type SuccessResponse struct { + Message string `json:"message"` +} + // writeJSON is a helper function that writes a JSON response to the http.ResponseWriter func writeJSON(w http.ResponseWriter, v any) error { type response struct { From c43256ac8cd6af6a1611589c792a50002915a7cb Mon Sep 17 00:00:00 2001 From: kayra1 Date: Tue, 19 Nov 2024 14:51:29 +0300 Subject: [PATCH 2/9] feat: update frontend to use new json response --- .../server/handlers_certificate_requests.go | 2 +- ui/package-lock.json | 15 ++++++++++++ .../(notary)/certificate_requests/table.tsx | 24 +++++++++---------- ui/src/types.ts | 3 ++- 4 files changed, 30 insertions(+), 14 deletions(-) diff --git a/internal/server/handlers_certificate_requests.go b/internal/server/handlers_certificate_requests.go index bcd440d1..35793694 100644 --- a/internal/server/handlers_certificate_requests.go +++ b/internal/server/handlers_certificate_requests.go @@ -22,7 +22,7 @@ type CreateCertificateParams struct { type CertificateRequest struct { ID int `json:"id"` CSR string `json:"csr"` - CertificateChain string `json:"certificate"` + CertificateChain string `json:"certificate_chain"` CSRStatus string `json:"csr_status"` } diff --git a/ui/package-lock.json b/ui/package-lock.json index 0e09b13c..c41ddf44 100644 --- a/ui/package-lock.json +++ b/ui/package-lock.json @@ -7962,6 +7962,21 @@ "funding": { "url": "https://github.com/sponsors/sindresorhus" } + }, + "node_modules/@next/swc-win32-ia32-msvc": { + "version": "14.2.15", + "resolved": "https://registry.npmjs.org/@next/swc-win32-ia32-msvc/-/swc-win32-ia32-msvc-14.2.15.tgz", + "integrity": "sha512-fyTE8cklgkyR1p03kJa5zXEaZ9El+kDNM5A+66+8evQS5e/6v0Gk28LqA0Jet8gKSOyP+OTm/tJHzMlGdQerdQ==", + "cpu": [ + "ia32" + ], + "optional": true, + "os": [ + "win32" + ], + "engines": { + "node": ">= 10" + } } } } diff --git a/ui/src/app/(notary)/certificate_requests/table.tsx b/ui/src/app/(notary)/certificate_requests/table.tsx index c4817377..8d1a17e5 100644 --- a/ui/src/app/(notary)/certificate_requests/table.tsx +++ b/ui/src/app/(notary)/certificate_requests/table.tsx @@ -139,9 +139,9 @@ export function CertificateRequestsTable({ csrs: rows }: TableProps) { }; const csrrows = rows.map((csrEntry) => { - const { id, csr, certificate } = csrEntry; + const { id, csr, certificate_chain, csr_status } = csrEntry; const csrObj = extractCSR(csr); - const certs = splitBundle(certificate); + const certs = splitBundle(certificate_chain); const clientCertificate = certs?.at(0); const certObj = clientCertificate ? extractCert(clientCertificate) : null; @@ -152,13 +152,13 @@ export function CertificateRequestsTable({ csrs: rows }: TableProps) { sortData: { id, common_name: csrObj.commonName, - csr_status: certificate === "" ? "outstanding" : (certificate === "rejected" ? "rejected" : "fulfilled"), + csr_status: csr_status, cert_expiry_date: certObj?.notAfter || "", }, columns: [ { content: id.toString() }, { content: csrObj.commonName || "N/A" }, - { content: certificate === "" ? "outstanding" : (certificate === "rejected" ? "rejected" : "fulfilled") }, + { content: csr_status }, { content: certObj?.notAfter || "", style: { backgroundColor: getExpiryColor(certObj?.notAfter) }, @@ -185,13 +185,13 @@ export function CertificateRequestsTable({ csrs: rows }: TableProps) { @@ -321,7 +321,7 @@ export function CertificateRequestsTable({ csrs: rows }: TableProps) { )} diff --git a/ui/src/types.ts b/ui/src/types.ts index 7e208e9d..9af23ff5 100644 --- a/ui/src/types.ts +++ b/ui/src/types.ts @@ -1,7 +1,8 @@ export type CSREntry = { id: number, csr: string, - certificate: string + certificate_chain: string + csr_status: "Outstanding" | "Active" | "Rejected" | "Revoked" } export type User = { From f761b038ea31847990a3a8637b7abfa28c0ebc63 Mon Sep 17 00:00:00 2001 From: kayra1 Date: Tue, 19 Nov 2024 14:59:59 +0300 Subject: [PATCH 3/9] fix: update frontend tests --- .../(notary)/certificate_requests/table.test.tsx | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/ui/src/app/(notary)/certificate_requests/table.test.tsx b/ui/src/app/(notary)/certificate_requests/table.test.tsx index 5b3c076d..15dc36c9 100644 --- a/ui/src/app/(notary)/certificate_requests/table.test.tsx +++ b/ui/src/app/(notary)/certificate_requests/table.test.tsx @@ -2,8 +2,9 @@ import { expect, test } from 'vitest' import { render, screen } from '@testing-library/react' import { CertificateRequestsTable } from './table' import { QueryClient, QueryClientProvider } from '@tanstack/react-query' +import { CSREntry } from '@/types' -const rows = [ +const rows: CSREntry[] = [ { 'id': 1, 'csr': `-----BEGIN CERTIFICATE REQUEST----- @@ -23,7 +24,8 @@ Y/uPl4g3jpGqLCKTASWJDGnZLroLICOzYTVs5P3oj+VueSUwYhGK5tBnS2x5FHID uMNMgwl0fxGMQZjrlXyCBhXBm1k6PmwcJGJF5LQ31c+5aTTMFU7SyZhlymctB8mS y+ErBQsRpcQho6Ok+HTXQQUcx7WNcwI= -----END CERTIFICATE REQUEST-----`, - 'certificate': "" + 'certificate_chain': "", + 'csr_status': "Outstanding", }, { 'id': 2, @@ -46,7 +48,8 @@ tK9qb8EE92MoWboo4m4bcX74y+eUo3xBev6ZZwdScy8OHLhA/MMI8EElpeYt+Hc2 WsDOAOH6qKQKQg3BO/xmRoohC6GL4CuhP7HYGi7+wziNhNZQa4GtE/k9DyIXVtJy yuf2PnfXCKnaIWRJNoEqDCZRVMfA5BFSwTPITqyo -----END CERTIFICATE REQUEST-----`, - 'certificate': "rejected" + 'certificate_chain': "", + 'csr_status': "Rejected" }, { 'id': 3, @@ -68,7 +71,7 @@ cAQXk3fvTWuikHiCHqqdSdjDYj/8cyiwCrQWpV245VSbOE0WesWoEnSdFXVUfE1+ RSKeTRuuJMcdGqBkDnDI22myj0bjt7q8eqBIjTiLQLnAFnQYpcCrhc8dKU9IJlv1 H9Hay4ZO9LRew3pEtlx2WrExw/gpUcWM8rTI -----END CERTIFICATE REQUEST-----`, - 'certificate': `-----BEGIN CERTIFICATE----- + 'certificate_chain': `-----BEGIN CERTIFICATE----- MIIDrDCCApSgAwIBAgIURKr+jf7hj60SyAryIeN++9wDdtkwDQYJKoZIhvcNAQEL BQAwOTELMAkGA1UEBhMCVVMxKjAoBgNVBAMMIXNlbGYtc2lnbmVkLWNlcnRpZmlj YXRlcy1vcGVyYXRvcjAeFw0yNDAzMjcxMjQ4MDRaFw0yNTAzMjcxMjQ4MDRaMEcx @@ -89,7 +92,8 @@ WyhXkzguv3dwH+n43GJFP6MQ+n9W/nPZCUQ0Iy7ueAvj0HFhGyZzAE2wxNFZdvCs gCX3nqYpp70oZIFDrhmYwE5ij5KXlHD4/1IOfNUKCDmQDgGPLI1tVtwQLjeRq7Hg XVelpl/LXTQawmJyvDaVT/Q9P+WqoDiMjrqF6Sy7DzNeeccWVqvqX5TVS6Ky56iS Mvo/+PAJHkBciR5Xn+Wg2a+7vrZvT6CBoRSOTozlLSM= ------END CERTIFICATE-----` +-----END CERTIFICATE-----`, + 'csr_status': 'Active' }, ] From 98ea13fbf33e95cc07aea42168ff6ac7f5c3a326 Mon Sep 17 00:00:00 2001 From: kayra1 Date: Tue, 19 Nov 2024 17:00:42 +0300 Subject: [PATCH 4/9] fix: simplify return nil's --- internal/db/db.go | 55 ++++++++++------------------------------------- 1 file changed, 11 insertions(+), 44 deletions(-) diff --git a/internal/db/db.go b/internal/db/db.go index 6ba5b22a..4f0bd972 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -143,10 +143,7 @@ func (db *Database) CreateCSR(csr string) error { CSR: csr, } err = db.conn.Query(context.Background(), stmt, row).Run() - if err != nil { - return err - } - return nil + return err } // AddCertificateChainToCSRbyCSR adds a new certificate chain to a row for a given CSR string. @@ -170,10 +167,7 @@ func (db *Database) AddCertificateChainToCSRbyCSR(csr string, cert string) error RequestStatus: "Active", } err = db.conn.Query(context.Background(), stmt, newRow).Run() - if err != nil { - return err - } - return nil + return err } // AddCertificateChainToCSRbyID adds a new certificate chain to a row for a given row ID. @@ -201,10 +195,7 @@ func (db *Database) AddCertificateToCSRbyID(id int, cert string) error { RequestStatus: "Active", } err = db.conn.Query(context.Background(), stmt, newRow).Run() - if err != nil { - return err - } - return nil + return err } // RejectCSRbyCSR updates input CSR's row by setting the certificate bundle to "" and moving the row status to "Rejected". @@ -224,10 +215,7 @@ func (db *Database) RejectCSRbyCSR(csr string) error { RequestStatus: "Rejected", } err = db.conn.Query(context.Background(), stmt, newRow).Run() - if err != nil { - return err - } - return nil + return err } // RejectCSRbyCSR updates input ID's row by setting the certificate bundle to "" and sets the row status to "Rejected". @@ -247,10 +235,7 @@ func (db *Database) RejectCSRbyID(id int) error { RequestStatus: "Rejected", } err = db.conn.Query(context.Background(), stmt, newRow).Run() - if err != nil { - return err - } - return nil + return err } // RevokeCSR updates the input CSR's row by setting the certificate bundle to "" and sets the row status to "Revoked". @@ -270,10 +255,7 @@ func (db *Database) RevokeCSR(csr string) error { RequestStatus: "Revoked", } err = db.conn.Query(context.Background(), stmt, newRow).Run() - if err != nil { - return err - } - return nil + return err } // DeleteCSRbyCSR removes a CSR from the database alongside the certificate that may have been generated for it. @@ -286,10 +268,7 @@ func (db *Database) DeleteCSRbyCSR(csr string) error { CSR: csr, } err = db.conn.Query(context.Background(), stmt, row).Run() - if err != nil { - return err - } - return nil + return err } // DeleteCSRByID removes a CSR from the database alongside the certificate that may have been generated for it. @@ -302,10 +281,7 @@ func (db *Database) DeleteCSRbyID(id int) error { ID: id, } err = db.conn.Query(context.Background(), stmt, row).Run() - if err != nil { - return err - } - return nil + return err } // RetrieveAllUsers returns all of the users and their fields available in the database. @@ -372,10 +348,7 @@ func (db *Database) CreateUser(username string, password string, permission int) Permissions: permission, } err = db.conn.Query(context.Background(), stmt, row).Run() - if err != nil { - return err - } - return nil + return err } // UpdateUser updates the password of the given user. @@ -398,10 +371,7 @@ func (db *Database) UpdateUserPassword(id int, password string) error { HashedPassword: string(pw), } err = db.conn.Query(context.Background(), stmt, row).Run() - if err != nil { - return err - } - return nil + return err } // DeleteUserByID removes a user from the table. @@ -418,10 +388,7 @@ func (db *Database) DeleteUserByID(id int) error { ID: id, } err = db.conn.Query(context.Background(), stmt, row).Run() - if err != nil { - return err - } - return nil + return err } type NumUsers struct { From b81b1766c1894412d0aebbf6f99502c7a6661be1 Mon Sep 17 00:00:00 2001 From: kayra1 Date: Thu, 21 Nov 2024 11:26:24 +0300 Subject: [PATCH 5/9] chore: rename table query --- internal/db/db.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/internal/db/db.go b/internal/db/db.go index 4f0bd972..f6b9bf88 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -40,7 +40,7 @@ const ( usersTableName = "users" ) -const queryCreateCSRsTable = ` +const queryCreateCertificateRequestsTable = ` CREATE TABLE IF NOT EXISTS %s ( id INTEGER PRIMARY KEY AUTOINCREMENT, @@ -429,7 +429,7 @@ func NewDatabase(databasePath string) (*Database, error) { if err != nil { return nil, err } - if _, err := sqlConnection.Exec(fmt.Sprintf(queryCreateCSRsTable, certificateRequestsTableName)); err != nil { + if _, err := sqlConnection.Exec(fmt.Sprintf(queryCreateCertificateRequestsTable, certificateRequestsTableName)); err != nil { return nil, err } if _, err := sqlConnection.Exec(fmt.Sprintf(queryCreateUsersTable, usersTableName)); err != nil { From 879f763cff29cd697f66a01bd139c5c8fb79a71c Mon Sep 17 00:00:00 2001 From: kayra1 Date: Thu, 21 Nov 2024 11:43:55 +0300 Subject: [PATCH 6/9] chore: rename variables --- internal/db/db.go | 112 +++++++++--------- internal/db/db_test.go | 62 +++++----- internal/metrics/metrics.go | 4 +- internal/metrics/metrics_test.go | 6 +- internal/server/handlers_accounts.go | 10 +- .../server/handlers_certificate_requests.go | 20 ++-- .../handlers_certificate_requests_test.go | 2 +- internal/server/handlers_login.go | 2 +- 8 files changed, 109 insertions(+), 109 deletions(-) diff --git a/internal/db/db.go b/internal/db/db.go index f6b9bf88..f7d0857c 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -24,7 +24,7 @@ type CertificateRequest struct { CSR string `db:"csr"` CertificateChain string `db:"certificate_chain"` - RequestStatus string `db:"request_status"` + Status string `db:"status"` } type User struct { @@ -46,7 +46,7 @@ const queryCreateCertificateRequestsTable = ` csr TEXT NOT NULL UNIQUE, certificate_chain TEXT DEFAULT '', - request_status TEXT DEFAULT 'Outstanding', + status TEXT DEFAULT 'Outstanding', CHECK (request_status IN ('Outstanding', 'Rejected', 'Revoked', 'Active')), CHECK (NOT (certificate_chain == '' AND request_status == 'Active' )), @@ -65,15 +65,15 @@ const queryCreateUsersTable = ` )` const ( - getAllCSRsStmt = "SELECT &CertificateRequest.* FROM %s" - getCSRsStmt = "SELECT &CertificateRequest.* FROM %s WHERE id==$CertificateRequest.id or csr==$CertificateRequest.csr" - createCSRStmt = "INSERT INTO %s (csr) VALUES ($CertificateRequest.csr)" - updateCSRStmt = "UPDATE %s SET certificate_chain=$CertificateRequest.certificate_chain, request_status=$CertificateRequest.request_status WHERE id==$CertificateRequest.id or csr==$CertificateRequest.csr" - deleteCSRStmt = "DELETE FROM %s WHERE id=$CertificateRequest.id or csr=$CertificateRequest.csr" + listCertificateRequestsStmt = "SELECT &CertificateRequest.* FROM %s" + getCertificateRequestStmt = "SELECT &CertificateRequest.* FROM %s WHERE id==$CertificateRequest.id or csr==$CertificateRequest.csr" + createCertificateRequestStmt = "INSERT INTO %s (csr) VALUES ($CertificateRequest.csr)" + updateCertificateRequestStmt = "UPDATE %s SET certificate_chain=$CertificateRequest.certificate_chain, request_status=$CertificateRequest.request_status WHERE id==$CertificateRequest.id or csr==$CertificateRequest.csr" + deleteCertificateRequestStmt = "DELETE FROM %s WHERE id=$CertificateRequest.id or csr=$CertificateRequest.csr" ) const ( - getAllUsersStmt = "SELECT &User.* from %s" + listUsersStmt = "SELECT &User.* from %s" getUserStmt = "SELECT &User.* from %s WHERE id==$User.id or username==$User.username" createUserStmt = "INSERT INTO %s (username, hashed_password, permissions) VALUES ($User.username, $User.hashed_password, $User.permissions)" updateUserStmt = "UPDATE %s SET hashed_password=$User.hashed_password WHERE id==$User.id or username==$User.username" @@ -81,9 +81,9 @@ const ( getNumUsersStmt = "SELECT COUNT(*) AS &NumUsers.count FROM %s" ) -// RetrieveAllCSRs gets every CertificateRequest entry in the table. -func (db *Database) RetrieveAllCSRs() ([]CertificateRequest, error) { - stmt, err := sqlair.Prepare(fmt.Sprintf(getAllCSRsStmt, db.certificateTable), CertificateRequest{}) +// ListCertificateRequests gets every CertificateRequest entry in the table. +func (db *Database) ListCertificateRequests() ([]CertificateRequest, error) { + stmt, err := sqlair.Prepare(fmt.Sprintf(listCertificateRequestsStmt, db.certificateTable), CertificateRequest{}) if err != nil { return nil, err } @@ -98,12 +98,12 @@ func (db *Database) RetrieveAllCSRs() ([]CertificateRequest, error) { return csrs, nil } -// RetrieveCSRbyID gets a CSR row from the repository from a given ID. -func (db *Database) RetrieveCSRbyID(id int) (*CertificateRequest, error) { +// GetCertificateRequestByID gets a CSR row from the repository from a given ID. +func (db *Database) GetCertificateRequestByID(id int) (*CertificateRequest, error) { csr := CertificateRequest{ ID: id, } - stmt, err := sqlair.Prepare(fmt.Sprintf(getCSRsStmt, db.certificateTable), CertificateRequest{}) + stmt, err := sqlair.Prepare(fmt.Sprintf(getCertificateRequestStmt, db.certificateTable), CertificateRequest{}) if err != nil { return nil, err } @@ -114,12 +114,12 @@ func (db *Database) RetrieveCSRbyID(id int) (*CertificateRequest, error) { return &csr, nil } -// RetrieveCSRbyCSR gets a given CSR row from the repository using the CSR text. -func (db *Database) RetrieveCSRbyCSR(csr string) (*CertificateRequest, error) { +// GetCertificateRequestByCSR gets a given CSR row from the repository using the CSR text. +func (db *Database) GetCertificateRequestByCSR(csr string) (*CertificateRequest, error) { row := CertificateRequest{ CSR: csr, } - stmt, err := sqlair.Prepare(fmt.Sprintf(getCSRsStmt, db.certificateTable), CertificateRequest{}) + stmt, err := sqlair.Prepare(fmt.Sprintf(getCertificateRequestStmt, db.certificateTable), CertificateRequest{}) if err != nil { return nil, err } @@ -130,12 +130,12 @@ func (db *Database) RetrieveCSRbyCSR(csr string) (*CertificateRequest, error) { return &row, nil } -// CreateCSR creates a new CSR entry in the repository. The string must be a valid CSR and unique. -func (db *Database) CreateCSR(csr string) error { +// CreateCertificateRequest creates a new CSR entry in the repository. The string must be a valid CSR and unique. +func (db *Database) CreateCertificateRequest(csr string) error { if err := ValidateCertificateRequest(csr); err != nil { return errors.New("csr validation failed: " + err.Error()) } - stmt, err := sqlair.Prepare(fmt.Sprintf(createCSRStmt, db.certificateTable), CertificateRequest{}) + stmt, err := sqlair.Prepare(fmt.Sprintf(createCertificateRequestStmt, db.certificateTable), CertificateRequest{}) if err != nil { return err } @@ -146,8 +146,8 @@ func (db *Database) CreateCSR(csr string) error { return err } -// AddCertificateChainToCSRbyCSR adds a new certificate chain to a row for a given CSR string. -func (db *Database) AddCertificateChainToCSRbyCSR(csr string, cert string) error { +// AddCertificateChainToCertificateRequestByCSR adds a new certificate chain to a row for a given CSR string. +func (db *Database) AddCertificateChainToCertificateRequestByCSR(csr string, cert string) error { err := ValidateCertificate(cert) if err != nil { return errors.New("cert validation failed: " + err.Error()) @@ -157,22 +157,22 @@ func (db *Database) AddCertificateChainToCSRbyCSR(csr string, cert string) error return errors.New("cert validation failed: " + err.Error()) } certBundle := sanitizeCertificateBundle(cert) - stmt, err := sqlair.Prepare(fmt.Sprintf(updateCSRStmt, db.certificateTable), CertificateRequest{}) + stmt, err := sqlair.Prepare(fmt.Sprintf(updateCertificateRequestStmt, db.certificateTable), CertificateRequest{}) if err != nil { return err } newRow := CertificateRequest{ CSR: csr, CertificateChain: certBundle, - RequestStatus: "Active", + Status: "Active", } err = db.conn.Query(context.Background(), stmt, newRow).Run() return err } // AddCertificateChainToCSRbyID adds a new certificate chain to a row for a given row ID. -func (db *Database) AddCertificateToCSRbyID(id int, cert string) error { - csr, err := db.RetrieveCSRbyID(id) +func (db *Database) AddCertificateChainToCertificateRequestByID(id int, cert string) error { + csr, err := db.GetCertificateRequestByID(id) if err != nil { return err } @@ -185,26 +185,26 @@ func (db *Database) AddCertificateToCSRbyID(id int, cert string) error { return errors.New("cert validation failed: " + err.Error()) } certBundle := sanitizeCertificateBundle(cert) - stmt, err := sqlair.Prepare(fmt.Sprintf(updateCSRStmt, db.certificateTable), CertificateRequest{}) + stmt, err := sqlair.Prepare(fmt.Sprintf(updateCertificateRequestStmt, db.certificateTable), CertificateRequest{}) if err != nil { return err } newRow := CertificateRequest{ ID: id, CertificateChain: certBundle, - RequestStatus: "Active", + Status: "Active", } err = db.conn.Query(context.Background(), stmt, newRow).Run() return err } -// RejectCSRbyCSR updates input CSR's row by setting the certificate bundle to "" and moving the row status to "Rejected". -func (db *Database) RejectCSRbyCSR(csr string) error { - oldRow, err := db.RetrieveCSRbyCSR(csr) +// RejectCertificateRequestByCSR updates input CSR's row by setting the certificate bundle to "" and moving the row status to "Rejected". +func (db *Database) RejectCertificateRequestByCSR(csr string) error { + oldRow, err := db.GetCertificateRequestByCSR(csr) if err != nil { return err } - stmt, err := sqlair.Prepare(fmt.Sprintf(updateCSRStmt, db.certificateTable), CertificateRequest{}) + stmt, err := sqlair.Prepare(fmt.Sprintf(updateCertificateRequestStmt, db.certificateTable), CertificateRequest{}) if err != nil { return err } @@ -212,19 +212,19 @@ func (db *Database) RejectCSRbyCSR(csr string) error { ID: oldRow.ID, CSR: oldRow.CSR, CertificateChain: "", - RequestStatus: "Rejected", + Status: "Rejected", } err = db.conn.Query(context.Background(), stmt, newRow).Run() return err } // RejectCSRbyCSR updates input ID's row by setting the certificate bundle to "" and sets the row status to "Rejected". -func (db *Database) RejectCSRbyID(id int) error { - oldRow, err := db.RetrieveCSRbyID(id) +func (db *Database) RejectCertificateRequestByID(id int) error { + oldRow, err := db.GetCertificateRequestByID(id) if err != nil { return err } - stmt, err := sqlair.Prepare(fmt.Sprintf(updateCSRStmt, db.certificateTable), CertificateRequest{}) + stmt, err := sqlair.Prepare(fmt.Sprintf(updateCertificateRequestStmt, db.certificateTable), CertificateRequest{}) if err != nil { return err } @@ -232,19 +232,19 @@ func (db *Database) RejectCSRbyID(id int) error { ID: oldRow.ID, CSR: oldRow.CSR, CertificateChain: "", - RequestStatus: "Rejected", + Status: "Rejected", } err = db.conn.Query(context.Background(), stmt, newRow).Run() return err } -// RevokeCSR updates the input CSR's row by setting the certificate bundle to "" and sets the row status to "Revoked". -func (db *Database) RevokeCSR(csr string) error { - oldRow, err := db.RetrieveCSRbyCSR(csr) +// RevokeCertificateByCSR updates the input CSR's row by setting the certificate bundle to "" and sets the row status to "Revoked". +func (db *Database) RevokeCertificateByCSR(csr string) error { + oldRow, err := db.GetCertificateRequestByCSR(csr) if err != nil { return err } - stmt, err := sqlair.Prepare(fmt.Sprintf(updateCSRStmt, db.certificateTable), CertificateRequest{}) + stmt, err := sqlair.Prepare(fmt.Sprintf(updateCertificateRequestStmt, db.certificateTable), CertificateRequest{}) if err != nil { return err } @@ -252,15 +252,15 @@ func (db *Database) RevokeCSR(csr string) error { ID: oldRow.ID, CSR: oldRow.CSR, CertificateChain: "", - RequestStatus: "Revoked", + Status: "Revoked", } err = db.conn.Query(context.Background(), stmt, newRow).Run() return err } -// DeleteCSRbyCSR removes a CSR from the database alongside the certificate that may have been generated for it. -func (db *Database) DeleteCSRbyCSR(csr string) error { - stmt, err := sqlair.Prepare(fmt.Sprintf(deleteCSRStmt, db.certificateTable), CertificateRequest{}) +// DeleteCertificateRequestByCSR removes a CSR from the database alongside the certificate that may have been generated for it. +func (db *Database) DeleteCertificateRequestByCSR(csr string) error { + stmt, err := sqlair.Prepare(fmt.Sprintf(deleteCertificateRequestStmt, db.certificateTable), CertificateRequest{}) if err != nil { return err } @@ -272,8 +272,8 @@ func (db *Database) DeleteCSRbyCSR(csr string) error { } // DeleteCSRByID removes a CSR from the database alongside the certificate that may have been generated for it. -func (db *Database) DeleteCSRbyID(id int) error { - stmt, err := sqlair.Prepare(fmt.Sprintf(deleteCSRStmt, db.certificateTable), CertificateRequest{}) +func (db *Database) DeleteCertificateRequestByID(id int) error { + stmt, err := sqlair.Prepare(fmt.Sprintf(deleteCertificateRequestStmt, db.certificateTable), CertificateRequest{}) if err != nil { return err } @@ -284,9 +284,9 @@ func (db *Database) DeleteCSRbyID(id int) error { return err } -// RetrieveAllUsers returns all of the users and their fields available in the database. -func (db *Database) RetrieveAllUsers() ([]User, error) { - stmt, err := sqlair.Prepare(fmt.Sprintf(getAllUsersStmt, db.usersTable), User{}) +// ListUsers returns all of the users and their fields available in the database. +func (db *Database) ListUsers() ([]User, error) { + stmt, err := sqlair.Prepare(fmt.Sprintf(listUsersStmt, db.usersTable), User{}) if err != nil { return nil, err } @@ -298,8 +298,8 @@ func (db *Database) RetrieveAllUsers() ([]User, error) { return users, nil } -// RetrieveUser retrieves the name, password and the permission level of a user. -func (db *Database) RetrieveUserByID(id int) (*User, error) { +// GetUserByID retrieves the name, password and the permission level of a user. +func (db *Database) GetUserByID(id int) (*User, error) { row := User{ ID: id, } @@ -314,8 +314,8 @@ func (db *Database) RetrieveUserByID(id int) (*User, error) { return &row, nil } -// RetrieveUser retrieves the id, password and the permission level of a user. -func (db *Database) RetrieveUserByUsername(name string) (*User, error) { +// GetUserByUsername retrieves the id, password and the permission level of a user. +func (db *Database) GetUserByUsername(name string) (*User, error) { row := User{ Username: name, } @@ -354,7 +354,7 @@ func (db *Database) CreateUser(username string, password string, permission int) // UpdateUser updates the password of the given user. // Just like with CreateUser, this function handles hashing and salting the password before storage. func (db *Database) UpdateUserPassword(id int, password string) error { - _, err := db.RetrieveUserByID(id) + _, err := db.GetUserByID(id) if err != nil { return err } @@ -376,7 +376,7 @@ func (db *Database) UpdateUserPassword(id int, password string) error { // DeleteUserByID removes a user from the table. func (db *Database) DeleteUserByID(id int) error { - _, err := db.RetrieveUserByID(id) + _, err := db.GetUserByID(id) if err != nil { return err } diff --git a/internal/db/db_test.go b/internal/db/db_test.go index 62989ce4..19cb6475 100644 --- a/internal/db/db_test.go +++ b/internal/db/db_test.go @@ -28,27 +28,27 @@ func TestCSRsEndToEnd(t *testing.T) { } defer db.Close() - err = db.CreateCSR(AppleCSR) + err = db.CreateCertificateRequest(AppleCSR) if err != nil { t.Fatalf("Couldn't complete Create: %s", err) } - err = db.CreateCSR(BananaCSR) + err = db.CreateCertificateRequest(BananaCSR) if err != nil { t.Fatalf("Couldn't complete Create: %s", err) } - err = db.CreateCSR(StrawberryCSR) + err = db.CreateCertificateRequest(StrawberryCSR) if err != nil { t.Fatalf("Couldn't complete Create: %s", err) } - res, err := db.RetrieveAllCSRs() + res, err := db.ListCertificateRequests() if err != nil { t.Fatalf("Couldn't complete RetrieveAll: %s", err) } if len(res) != 3 { t.Fatalf("One or more CSRs weren't found in DB") } - retrievedCSR, err := db.RetrieveCSRbyCSR(AppleCSR) + retrievedCSR, err := db.GetCertificateRequestByCSR(AppleCSR) if err != nil { t.Fatalf("Couldn't complete Retrieve: %s", err) } @@ -56,32 +56,32 @@ func TestCSRsEndToEnd(t *testing.T) { t.Fatalf("The CSR from the database doesn't match the CSR that was given") } - if err = db.DeleteCSRbyCSR(AppleCSR); err != nil { + if err = db.DeleteCertificateRequestByCSR(AppleCSR); err != nil { t.Fatalf("Couldn't complete Delete: %s", err) } - res, _ = db.RetrieveAllCSRs() + res, _ = db.ListCertificateRequests() if len(res) != 2 { t.Fatalf("CSR's weren't deleted from the DB properly") } BananaCertBundle := strings.TrimSpace(fmt.Sprintf("%s%s", BananaCert, IssuerCert)) - err = db.AddCertificateChainToCSRbyCSR(BananaCSR, BananaCertBundle) + err = db.AddCertificateChainToCertificateRequestByCSR(BananaCSR, BananaCertBundle) if err != nil { t.Fatalf("Couldn't complete Update: %s", err) } - retrievedCSR, _ = db.RetrieveCSRbyCSR(BananaCSR) + retrievedCSR, _ = db.GetCertificateRequestByCSR(BananaCSR) if retrievedCSR.CertificateChain != BananaCertBundle { t.Fatalf("The certificate that was uploaded does not match the certificate that was given.\n Retrieved: %s\nGiven: %s", retrievedCSR.CertificateChain, BananaCertBundle) } - err = db.RevokeCSR(BananaCSR) + err = db.RevokeCertificateByCSR(BananaCSR) if err != nil { t.Fatalf("Couldn't complete Update to revoke certificate: %s", err) } - err = db.RejectCSRbyCSR(StrawberryCSR) + err = db.RejectCertificateRequestByCSR(StrawberryCSR) if err != nil { t.Fatalf("Couldn't complete Update to reject CSR: %s", err) } - retrievedCSR, _ = db.RetrieveCSRbyCSR(BananaCSR) - if retrievedCSR.RequestStatus != "Revoked" { + retrievedCSR, _ = db.GetCertificateRequestByCSR(BananaCSR) + if retrievedCSR.Status != "Revoked" { t.Fatalf("Couldn't delete certificate") } } @@ -91,12 +91,12 @@ func TestCreateFails(t *testing.T) { defer db.Close() InvalidCSR := strings.ReplaceAll(AppleCSR, "M", "i") - if err := db.CreateCSR(InvalidCSR); err == nil { + if err := db.CreateCertificateRequest(InvalidCSR); err == nil { t.Fatalf("Expected error due to invalid CSR") } - db.CreateCSR(AppleCSR) //nolint:errcheck - if err := db.CreateCSR(AppleCSR); err == nil { + db.CreateCertificateRequest(AppleCSR) //nolint:errcheck + if err := db.CreateCertificateRequest(AppleCSR); err == nil { t.Fatalf("Expected error due to duplicate CSR") } } @@ -105,13 +105,13 @@ func TestUpdateFails(t *testing.T) { db, _ := db.NewDatabase(":memory:") defer db.Close() - db.CreateCSR(AppleCSR) //nolint:errcheck - db.CreateCSR(BananaCSR) //nolint:errcheck + db.CreateCertificateRequest(AppleCSR) //nolint:errcheck + db.CreateCertificateRequest(BananaCSR) //nolint:errcheck InvalidCert := strings.ReplaceAll(BananaCert, "/", "+") - if err := db.AddCertificateChainToCSRbyCSR(BananaCSR, InvalidCert); err == nil { + if err := db.AddCertificateChainToCertificateRequestByCSR(BananaCSR, InvalidCert); err == nil { t.Fatalf("Expected updating with invalid cert to fail") } - if err := db.AddCertificateChainToCSRbyCSR(AppleCSR, BananaCert); err == nil { + if err := db.AddCertificateChainToCertificateRequestByCSR(AppleCSR, BananaCert); err == nil { t.Fatalf("Expected updating with mismatched cert to fail") } } @@ -120,11 +120,11 @@ func TestRetrieve(t *testing.T) { db, _ := db.NewDatabase(":memory:") //nolint:errcheck defer db.Close() - db.CreateCSR(AppleCSR) //nolint:errcheck - if _, err := db.RetrieveCSRbyCSR("this is definitely not an id"); err == nil { + db.CreateCertificateRequest(AppleCSR) //nolint:errcheck + if _, err := db.GetCertificateRequestByCSR("this is definitely not an id"); err == nil { t.Fatalf("Expected failure looking for nonexistent CSR") } - if _, err := db.RetrieveCSRbyID(-1); err == nil { + if _, err := db.GetCertificateRequestByID(-1); err == nil { t.Fatalf("Expected failure looking for nonexistent CSR") } } @@ -145,7 +145,7 @@ func TestUsersEndToEnd(t *testing.T) { t.Fatalf("Couldn't complete Create: %s", err) } - res, err := db.RetrieveAllUsers() + res, err := db.ListUsers() if err != nil { t.Fatalf("Couldn't complete RetrieveAll: %s", err) } @@ -159,14 +159,14 @@ func TestUsersEndToEnd(t *testing.T) { if num != 2 { t.Fatalf("NumUsers didn't return the correct number of users") } - retrievedUser, err := db.RetrieveUserByUsername("admin") + retrievedUser, err := db.GetUserByUsername("admin") if err != nil { t.Fatalf("Couldn't complete Retrieve: %s", err) } if retrievedUser.Username != "admin" { t.Fatalf("The user from the database doesn't match the user that was given") } - retrievedUser, err = db.RetrieveUserByID(1) + retrievedUser, err = db.GetUserByID(1) if err != nil { t.Fatalf("Couldn't complete Retrieve: %s", err) } @@ -179,7 +179,7 @@ func TestUsersEndToEnd(t *testing.T) { if err = db.DeleteUserByID(1); err != nil { t.Fatalf("Couldn't complete Delete: %s", err) } - res, _ = db.RetrieveAllUsers() + res, _ = db.ListUsers() if len(res) != 1 { t.Fatalf("users weren't deleted from the DB properly") } @@ -187,7 +187,7 @@ func TestUsersEndToEnd(t *testing.T) { if err != nil { t.Fatalf("Couldn't complete Update: %s", err) } - retrievedUser, _ = db.RetrieveUserByUsername("norman") + retrievedUser, _ = db.GetUserByUsername("norman") if err := bcrypt.CompareHashAndPassword([]byte(retrievedUser.HashedPassword), []byte("thebestpassword")); err != nil { t.Fatalf("The new password that was given does not match the password that was stored.") } @@ -198,15 +198,15 @@ func Example() { if err != nil { log.Fatalln(err) } - err = db.CreateCSR(BananaCSR) + err = db.CreateCertificateRequest(BananaCSR) if err != nil { log.Fatalln(err) } - err = db.AddCertificateChainToCSRbyCSR(BananaCSR, BananaCert) + err = db.AddCertificateChainToCertificateRequestByCSR(BananaCSR, BananaCert) if err != nil { log.Fatalln(err) } - entry, err := db.RetrieveCSRbyCSR(BananaCSR) + entry, err := db.GetCertificateRequestByCSR(BananaCSR) if err != nil { log.Fatalln(err) } diff --git a/internal/metrics/metrics.go b/internal/metrics/metrics.go index 686ce0d0..1b8b9fc5 100644 --- a/internal/metrics/metrics.go +++ b/internal/metrics/metrics.go @@ -37,7 +37,7 @@ func NewMetricsSubsystem(db *db.Database) *PrometheusMetrics { ticker := time.NewTicker(120 * time.Second) go func() { for ; ; <-ticker.C { - csrs, err := db.RetrieveAllCSRs() + csrs, err := db.ListCertificateRequests() if err != nil { log.Println(errors.Join(errors.New("error generating metrics repository: "), err)) panic(1) @@ -99,7 +99,7 @@ func (pm *PrometheusMetrics) GenerateMetrics(csrs []db.CertificateRequest) { outstandingCSRCount += 1 continue } - if entry.RequestStatus == "Rejected" { + if entry.Status == "Rejected" { continue } certCount += 1 diff --git a/internal/metrics/metrics_test.go b/internal/metrics/metrics_test.go index 27422766..0775f283 100644 --- a/internal/metrics/metrics_test.go +++ b/internal/metrics/metrics_test.go @@ -98,11 +98,11 @@ func generateCertPair(daysRemaining int) (string, string, string) { func initializeTestDB(t *testing.T, db *db.Database) { for _, v := range []int{5, 10, 32} { csr, cert, ca := generateCertPair(v) - err := db.CreateCSR(csr) + err := db.CreateCertificateRequest(csr) if err != nil { t.Fatalf("couldn't create test csr: %s", err) } - err = db.AddCertificateChainToCSRbyCSR(csr, fmt.Sprintf("%s%s", cert, ca)) + err = db.AddCertificateChainToCertificateRequestByCSR(csr, fmt.Sprintf("%s%s", cert, ca)) if err != nil { t.Fatalf("couldn't create test cert: %s", err) } @@ -118,7 +118,7 @@ func TestMetrics(t *testing.T) { } initializeTestDB(t, db) m := metrics.NewMetricsSubsystem(db) - csrs, _ := db.RetrieveAllCSRs() + csrs, _ := db.ListCertificateRequests() m.GenerateMetrics(csrs) request, _ := http.NewRequest("GET", "/", nil) diff --git a/internal/server/handlers_accounts.go b/internal/server/handlers_accounts.go index 82ad61eb..3e8beb30 100644 --- a/internal/server/handlers_accounts.go +++ b/internal/server/handlers_accounts.go @@ -48,7 +48,7 @@ func validatePassword(password string) bool { // ListAccounts returns all accounts from the database func ListAccounts(env *HandlerConfig) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - accounts, err := env.DB.RetrieveAllUsers() + accounts, err := env.DB.ListUsers() if err != nil { log.Println(err) writeError(w, http.StatusInternalServerError, "Internal Error") @@ -87,9 +87,9 @@ func GetAccount(env *HandlerConfig) http.HandlerFunc { if headerErr != nil { writeError(w, http.StatusUnauthorized, "Unauthorized") } - account, err = env.DB.RetrieveUserByUsername(claims.Username) + account, err = env.DB.GetUserByUsername(claims.Username) } else { - account, err = env.DB.RetrieveUserByID(idNum) + account, err = env.DB.GetUserByID(idNum) } if err != nil { log.Println(err) @@ -178,7 +178,7 @@ func DeleteAccount(env *HandlerConfig) http.HandlerFunc { writeError(w, http.StatusInternalServerError, "Internal Error") return } - account, err := env.DB.RetrieveUserByID(idInt) + account, err := env.DB.GetUserByID(idInt) if err != nil { log.Println(err) if errors.Is(err, sqlair.ErrNoRows) { @@ -223,7 +223,7 @@ func ChangeAccountPassword(env *HandlerConfig) http.HandlerFunc { writeError(w, http.StatusUnauthorized, "Unauthorized") return } - account, err := env.DB.RetrieveUserByUsername(claims.Username) + account, err := env.DB.GetUserByUsername(claims.Username) if err != nil { log.Println(err) writeError(w, http.StatusUnauthorized, "Unauthorized") diff --git a/internal/server/handlers_certificate_requests.go b/internal/server/handlers_certificate_requests.go index 35793694..a914c909 100644 --- a/internal/server/handlers_certificate_requests.go +++ b/internal/server/handlers_certificate_requests.go @@ -23,13 +23,13 @@ type CertificateRequest struct { ID int `json:"id"` CSR string `json:"csr"` CertificateChain string `json:"certificate_chain"` - CSRStatus string `json:"csr_status"` + Status string `json:"status"` } // ListCertificateRequests returns all of the Certificate Requests func ListCertificateRequests(env *HandlerConfig) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - csrs, err := env.DB.RetrieveAllCSRs() + csrs, err := env.DB.ListCertificateRequests() if err != nil { log.Println(err) writeError(w, http.StatusInternalServerError, "Internal Error") @@ -41,7 +41,7 @@ func ListCertificateRequests(env *HandlerConfig) http.HandlerFunc { ID: csr.ID, CSR: csr.CSR, CertificateChain: csr.CertificateChain, - CSRStatus: csr.RequestStatus, + Status: csr.Status, } } w.WriteHeader(http.StatusOK) @@ -65,7 +65,7 @@ func CreateCertificateRequest(env *HandlerConfig) http.HandlerFunc { writeError(w, http.StatusBadRequest, "csr is missing") return } - err := env.DB.CreateCSR(createCertificateRequestParams.CSR) + err := env.DB.CreateCertificateRequest(createCertificateRequestParams.CSR) if err != nil { if strings.Contains(err.Error(), "UNIQUE constraint failed") { writeError(w, http.StatusBadRequest, "given csr already recorded") @@ -100,7 +100,7 @@ func GetCertificateRequest(env *HandlerConfig) http.HandlerFunc { writeError(w, http.StatusInternalServerError, "Internal Error") return } - csr, err := env.DB.RetrieveCSRbyID(idNum) + csr, err := env.DB.GetCertificateRequestByID(idNum) if err != nil { log.Println(err) if errors.Is(err, sqlair.ErrNoRows) { @@ -114,7 +114,7 @@ func GetCertificateRequest(env *HandlerConfig) http.HandlerFunc { ID: csr.ID, CSR: csr.CSR, CertificateChain: csr.CertificateChain, - CSRStatus: csr.RequestStatus, + Status: csr.Status, } w.WriteHeader(http.StatusOK) err = writeJSON(w, certificateRequestResponse) @@ -135,7 +135,7 @@ func DeleteCertificateRequest(env *HandlerConfig) http.HandlerFunc { writeError(w, http.StatusInternalServerError, "Internal Error") return } - err = env.DB.DeleteCSRbyID(idNum) + err = env.DB.DeleteCertificateRequestByID(idNum) if err != nil { log.Println(err) if errors.Is(err, sqlair.ErrNoRows) { @@ -174,7 +174,7 @@ func CreateCertificate(env *HandlerConfig) http.HandlerFunc { writeError(w, http.StatusInternalServerError, "Internal Error") return } - err = env.DB.AddCertificateToCSRbyID(idNum, createCertificateParams.CertificateChain) + err = env.DB.AddCertificateChainToCertificateRequestByID(idNum, createCertificateParams.CertificateChain) if err != nil { log.Println(err) if errors.Is(err, sqlair.ErrNoRows) || @@ -210,7 +210,7 @@ func RejectCertificate(env *HandlerConfig) http.HandlerFunc { writeError(w, http.StatusInternalServerError, "Internal Error") return } - err = env.DB.RejectCSRbyID(idNum) + err = env.DB.RejectCertificateRequestByID(idNum) if err != nil { log.Println(err) if errors.Is(err, sqlair.ErrNoRows) { @@ -246,7 +246,7 @@ func DeleteCertificate(env *HandlerConfig) http.HandlerFunc { writeError(w, http.StatusInternalServerError, "Internal Error") return } - err = env.DB.DeleteCSRbyID(idNum) + err = env.DB.DeleteCertificateRequestByID(idNum) if err != nil { log.Println(err) if errors.Is(err, sqlair.ErrNoRows) { diff --git a/internal/server/handlers_certificate_requests_test.go b/internal/server/handlers_certificate_requests_test.go index ac06dd65..b8afcdb5 100644 --- a/internal/server/handlers_certificate_requests_test.go +++ b/internal/server/handlers_certificate_requests_test.go @@ -514,7 +514,7 @@ func TestCertificatesEndToEnd(t *testing.T) { if getCertResponse.Error != "" { t.Fatalf("expected no error, got %s", getCertResponse.Error) } - if getCertResponse.Result.CSRStatus != "Rejected" { + if getCertResponse.Result.Status != "Rejected" { t.Fatalf("expected `Rejected` status, got %s", getCertResponse.Result.CertificateChain) } }) diff --git a/internal/server/handlers_login.go b/internal/server/handlers_login.go index a0f702a3..469f022b 100644 --- a/internal/server/handlers_login.go +++ b/internal/server/handlers_login.go @@ -65,7 +65,7 @@ func Login(env *HandlerConfig) http.HandlerFunc { writeError(w, http.StatusBadRequest, "Password is required") return } - userAccount, err := env.DB.RetrieveUserByUsername(loginParams.Username) + userAccount, err := env.DB.GetUserByUsername(loginParams.Username) if err != nil { log.Println(err) if errors.Is(err, sqlair.ErrNoRows) { From c387671706fb52b17de25c39dcf9994ad153dd45 Mon Sep 17 00:00:00 2001 From: kayra1 Date: Thu, 21 Nov 2024 12:09:06 +0300 Subject: [PATCH 7/9] chore: rename frontend too --- internal/db/db.go | 12 ++++++------ .../app/(notary)/certificate_requests/table.test.tsx | 6 +++--- ui/src/app/(notary)/certificate_requests/table.tsx | 2 +- ui/src/types.ts | 2 +- 4 files changed, 11 insertions(+), 11 deletions(-) diff --git a/internal/db/db.go b/internal/db/db.go index f7d0857c..67ba0db9 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -48,11 +48,11 @@ const queryCreateCertificateRequestsTable = ` certificate_chain TEXT DEFAULT '', status TEXT DEFAULT 'Outstanding', - CHECK (request_status IN ('Outstanding', 'Rejected', 'Revoked', 'Active')), - CHECK (NOT (certificate_chain == '' AND request_status == 'Active' )), - CHECK (NOT (certificate_chain != '' AND request_status == 'Outstanding')) - CHECK (NOT (certificate_chain != '' AND request_status == 'Rejected')) - CHECK (NOT (certificate_chain != '' AND request_status == 'Revoked')) + CHECK (status IN ('Outstanding', 'Rejected', 'Revoked', 'Active')), + CHECK (NOT (certificate_chain == '' AND status == 'Active' )), + CHECK (NOT (certificate_chain != '' AND status == 'Outstanding')) + CHECK (NOT (certificate_chain != '' AND status == 'Rejected')) + CHECK (NOT (certificate_chain != '' AND status == 'Revoked')) )` const queryCreateUsersTable = ` @@ -68,7 +68,7 @@ const ( listCertificateRequestsStmt = "SELECT &CertificateRequest.* FROM %s" getCertificateRequestStmt = "SELECT &CertificateRequest.* FROM %s WHERE id==$CertificateRequest.id or csr==$CertificateRequest.csr" createCertificateRequestStmt = "INSERT INTO %s (csr) VALUES ($CertificateRequest.csr)" - updateCertificateRequestStmt = "UPDATE %s SET certificate_chain=$CertificateRequest.certificate_chain, request_status=$CertificateRequest.request_status WHERE id==$CertificateRequest.id or csr==$CertificateRequest.csr" + updateCertificateRequestStmt = "UPDATE %s SET certificate_chain=$CertificateRequest.certificate_chain, status=$CertificateRequest.status WHERE id==$CertificateRequest.id or csr==$CertificateRequest.csr" deleteCertificateRequestStmt = "DELETE FROM %s WHERE id=$CertificateRequest.id or csr=$CertificateRequest.csr" ) diff --git a/ui/src/app/(notary)/certificate_requests/table.test.tsx b/ui/src/app/(notary)/certificate_requests/table.test.tsx index 15dc36c9..ee5a4c28 100644 --- a/ui/src/app/(notary)/certificate_requests/table.test.tsx +++ b/ui/src/app/(notary)/certificate_requests/table.test.tsx @@ -25,7 +25,7 @@ uMNMgwl0fxGMQZjrlXyCBhXBm1k6PmwcJGJF5LQ31c+5aTTMFU7SyZhlymctB8mS y+ErBQsRpcQho6Ok+HTXQQUcx7WNcwI= -----END CERTIFICATE REQUEST-----`, 'certificate_chain': "", - 'csr_status': "Outstanding", + 'status': "Outstanding", }, { 'id': 2, @@ -49,7 +49,7 @@ WsDOAOH6qKQKQg3BO/xmRoohC6GL4CuhP7HYGi7+wziNhNZQa4GtE/k9DyIXVtJy yuf2PnfXCKnaIWRJNoEqDCZRVMfA5BFSwTPITqyo -----END CERTIFICATE REQUEST-----`, 'certificate_chain': "", - 'csr_status': "Rejected" + 'status': "Rejected" }, { 'id': 3, @@ -93,7 +93,7 @@ gCX3nqYpp70oZIFDrhmYwE5ij5KXlHD4/1IOfNUKCDmQDgGPLI1tVtwQLjeRq7Hg XVelpl/LXTQawmJyvDaVT/Q9P+WqoDiMjrqF6Sy7DzNeeccWVqvqX5TVS6Ky56iS Mvo/+PAJHkBciR5Xn+Wg2a+7vrZvT6CBoRSOTozlLSM= -----END CERTIFICATE-----`, - 'csr_status': 'Active' + 'status': 'Active' }, ] diff --git a/ui/src/app/(notary)/certificate_requests/table.tsx b/ui/src/app/(notary)/certificate_requests/table.tsx index 8d1a17e5..4b7ecbf3 100644 --- a/ui/src/app/(notary)/certificate_requests/table.tsx +++ b/ui/src/app/(notary)/certificate_requests/table.tsx @@ -139,7 +139,7 @@ export function CertificateRequestsTable({ csrs: rows }: TableProps) { }; const csrrows = rows.map((csrEntry) => { - const { id, csr, certificate_chain, csr_status } = csrEntry; + const { id, csr, certificate_chain, status: csr_status } = csrEntry; const csrObj = extractCSR(csr); const certs = splitBundle(certificate_chain); const clientCertificate = certs?.at(0); diff --git a/ui/src/types.ts b/ui/src/types.ts index 9af23ff5..3cce885d 100644 --- a/ui/src/types.ts +++ b/ui/src/types.ts @@ -2,7 +2,7 @@ export type CSREntry = { id: number, csr: string, certificate_chain: string - csr_status: "Outstanding" | "Active" | "Rejected" | "Revoked" + status: "Outstanding" | "Active" | "Rejected" | "Revoked" } export type User = { From bce2930f33e13bb05447e1d3b743e4eb4437af23 Mon Sep 17 00:00:00 2001 From: kayra1 Date: Thu, 21 Nov 2024 12:12:25 +0300 Subject: [PATCH 8/9] chore: remove todo --- internal/metrics/metrics.go | 1 - 1 file changed, 1 deletion(-) diff --git a/internal/metrics/metrics.go b/internal/metrics/metrics.go index 1b8b9fc5..725891f1 100644 --- a/internal/metrics/metrics.go +++ b/internal/metrics/metrics.go @@ -218,7 +218,6 @@ func requestDurationMetric() prometheus.HistogramVec { } func certificateExpiryDate(certString string) time.Time { - // TODO: Does this return the expiry date of the issuer? certBlock, _ := pem.Decode([]byte(certString)) cert, _ := x509.ParseCertificate(certBlock.Bytes) // TODO: cert.NotAfter can exist in a wrong cert. We should catch that at the db level validation From 5bd7a3ad5a76e33138dbf45eb3ee790ddd7c164b Mon Sep 17 00:00:00 2001 From: kayra1 Date: Thu, 21 Nov 2024 15:50:01 +0300 Subject: [PATCH 9/9] chore: remove unnecessary comment --- internal/server/handlers_certificate_requests_test.go | 6 ------ 1 file changed, 6 deletions(-) diff --git a/internal/server/handlers_certificate_requests_test.go b/internal/server/handlers_certificate_requests_test.go index b8afcdb5..63df6d1d 100644 --- a/internal/server/handlers_certificate_requests_test.go +++ b/internal/server/handlers_certificate_requests_test.go @@ -13,12 +13,6 @@ import ( "github.com/canonical/notary/internal/server" ) -// type CertificateRequest struct { -// ID int `json:"id"` -// CSR string `json:"csr"` -// Certificate string `json:"certificate"` -// } - type GetCertificateRequestResponse struct { Result server.CertificateRequest `json:"result"` Error string `json:"error,omitempty"`