Skip to content

Commit

Permalink
feat: turn key pair into server parameter (rather than generation upo…
Browse files Browse the repository at this point in the history
…n each start)
  • Loading branch information
zeitgeist committed Dec 21, 2023
1 parent f8aa96d commit c74617b
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 14 deletions.
2 changes: 1 addition & 1 deletion endpoint/src/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ use rsa::{pkcs8::LineEnding, RsaPrivateKey, RsaPublicKey};

const KEY_BITS: usize = 2048;

//#[derive(Debug)]
#[derive(Clone)]
pub struct KeyPair {
pub pub_key: RsaPublicKey,
enc_key: EncodingKey,
Expand Down
31 changes: 18 additions & 13 deletions endpoint/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,7 @@ fn default_handler() -> error::AccessDenied {

const OPENAPI_PATH: &str = "../openapi.json";

fn create_server() -> rocket::Rocket<rocket::Build> {
fn create_server(key_pair: KeyPair) -> rocket::Rocket<rocket::Build> {
let settings = OpenApiSettings::default();
let (mut openapi_routes, openapi_spec) =
openapi_get_routes_spec![settings: get_pcf, get_footprints, post_event];
Expand Down Expand Up @@ -489,13 +489,13 @@ fn create_server() -> rocket::Rocket<rocket::Build> {
..Default::default()
}),
)
.manage(generate_keys())
.manage(key_pair)
.register("/", catchers![bad_request, default_handler])
}

#[rocket::main]
async fn main() -> Result<(), LambdaError> {
let rocket = create_server();
let rocket = create_server(generate_keys());
if is_running_on_lambda() {
// Launch on AWS Lambda
launch_rocket_on_lambda(rocket).await?;
Expand All @@ -509,14 +509,19 @@ async fn main() -> Result<(), LambdaError> {
#[cfg(test)]
const EXAMPLE_HOST: &str = "api.pathfinder.sine.dev";

#[cfg(test)]
lazy_static! {
static ref TEST_KEYPAIR: KeyPair = generate_keys();
}

// tests the /v2/auth/token endpoint
#[test]
fn post_auth_action_test() {
use std::collections::HashMap;

let auth_uri = "/2/auth/token";

let client = &Client::tracked(create_server()).unwrap();
let client = &Client::tracked(create_server(TEST_KEYPAIR.clone())).unwrap();

// invalid credentials
{
Expand Down Expand Up @@ -581,7 +586,7 @@ fn verify_token_signature_test() {
use jsonwebtoken::{decode, Algorithm, DecodingKey, Validation};
use std::collections::HashSet;

let client = &Client::tracked(create_server()).unwrap();
let client = &Client::tracked(create_server(TEST_KEYPAIR.clone())).unwrap();

let token = UserToken {
username: "hello".to_string(),
Expand Down Expand Up @@ -614,7 +619,7 @@ fn verify_token_signature_test() {

#[test]
fn get_list_test() {
let client = &Client::tracked(create_server()).unwrap();
let client = &Client::tracked(create_server(TEST_KEYPAIR.clone())).unwrap();

let token = UserToken {
username: "hello".to_string(),
Expand Down Expand Up @@ -656,7 +661,7 @@ fn get_list_test() {

#[test]
fn get_list_with_filter_eq_test() {
let client = &Client::tracked(create_server()).unwrap();
let client = &Client::tracked(create_server(TEST_KEYPAIR.clone())).unwrap();

let token = UserToken {
username: "hello".to_string(),
Expand All @@ -682,7 +687,7 @@ fn get_list_with_filter_eq_test() {

#[test]
fn get_list_with_filter_lt_test() {
let client = &Client::tracked(create_server()).unwrap();
let client = &Client::tracked(create_server(TEST_KEYPAIR.clone())).unwrap();

let token = UserToken {
username: "hello".to_string(),
Expand All @@ -709,7 +714,7 @@ fn get_list_with_filter_lt_test() {

#[test]
fn get_list_with_filter_eq_and_lt_test() {
let client = &Client::tracked(create_server()).unwrap();
let client = &Client::tracked(create_server(TEST_KEYPAIR.clone())).unwrap();

let token = UserToken {
username: "hello".to_string(),
Expand All @@ -735,7 +740,7 @@ fn get_list_with_filter_eq_and_lt_test() {

#[test]
fn get_list_with_filter_any_test() {
let client = &Client::tracked(create_server()).unwrap();
let client = &Client::tracked(create_server(TEST_KEYPAIR.clone())).unwrap();

let token = UserToken {
username: "hello".to_string(),
Expand Down Expand Up @@ -778,7 +783,7 @@ fn get_list_with_filter_any_test() {

#[test]
fn get_list_with_limit_test() {
let client = &Client::tracked(create_server()).unwrap();
let client = &Client::tracked(create_server(TEST_KEYPAIR.clone())).unwrap();

let token = UserToken {
username: "hello".to_string(),
Expand Down Expand Up @@ -849,7 +854,7 @@ fn get_list_with_limit_test() {

#[test]
fn post_events_test() {
let client = &Client::tracked(create_server()).unwrap();
let client = &Client::tracked(create_server(TEST_KEYPAIR.clone())).unwrap();

let token = UserToken {
username: "hello".to_string(),
Expand Down Expand Up @@ -911,7 +916,7 @@ fn post_events_test() {

#[test]
fn get_pcf_test() {
let client = &Client::tracked(create_server()).unwrap();
let client = &Client::tracked(create_server(TEST_KEYPAIR.clone())).unwrap();

let token = UserToken {
username: "hello".to_string(),
Expand Down

0 comments on commit c74617b

Please sign in to comment.