From c74617bce3fc16e363be349e0ca239e2b99e0bfc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Martin=20Pomp=C3=A9ry?= Date: Thu, 21 Dec 2023 22:23:14 +0100 Subject: [PATCH] feat: turn key pair into server parameter (rather than generation upon each start) --- endpoint/src/auth.rs | 2 +- endpoint/src/main.rs | 31 ++++++++++++++++++------------- 2 files changed, 19 insertions(+), 14 deletions(-) diff --git a/endpoint/src/auth.rs b/endpoint/src/auth.rs index 1a19f34..38ec0f6 100644 --- a/endpoint/src/auth.rs +++ b/endpoint/src/auth.rs @@ -35,7 +35,7 @@ use rsa::{pkcs8::LineEnding, RsaPrivateKey, RsaPublicKey}; const KEY_BITS: usize = 2048; -//#[derive(Debug)] +#[derive(Clone)] pub struct KeyPair { pub pub_key: RsaPublicKey, enc_key: EncodingKey, diff --git a/endpoint/src/main.rs b/endpoint/src/main.rs index ff808d0..10beb79 100644 --- a/endpoint/src/main.rs +++ b/endpoint/src/main.rs @@ -451,7 +451,7 @@ fn default_handler() -> error::AccessDenied { const OPENAPI_PATH: &str = "../openapi.json"; -fn create_server() -> rocket::Rocket { +fn create_server(key_pair: KeyPair) -> rocket::Rocket { let settings = OpenApiSettings::default(); let (mut openapi_routes, openapi_spec) = openapi_get_routes_spec![settings: get_pcf, get_footprints, post_event]; @@ -489,13 +489,13 @@ fn create_server() -> rocket::Rocket { ..Default::default() }), ) - .manage(generate_keys()) + .manage(key_pair) .register("/", catchers![bad_request, default_handler]) } #[rocket::main] async fn main() -> Result<(), LambdaError> { - let rocket = create_server(); + let rocket = create_server(generate_keys()); if is_running_on_lambda() { // Launch on AWS Lambda launch_rocket_on_lambda(rocket).await?; @@ -509,6 +509,11 @@ async fn main() -> Result<(), LambdaError> { #[cfg(test)] const EXAMPLE_HOST: &str = "api.pathfinder.sine.dev"; +#[cfg(test)] +lazy_static! { + static ref TEST_KEYPAIR: KeyPair = generate_keys(); +} + // tests the /v2/auth/token endpoint #[test] fn post_auth_action_test() { @@ -516,7 +521,7 @@ fn post_auth_action_test() { let auth_uri = "/2/auth/token"; - let client = &Client::tracked(create_server()).unwrap(); + let client = &Client::tracked(create_server(TEST_KEYPAIR.clone())).unwrap(); // invalid credentials { @@ -581,7 +586,7 @@ fn verify_token_signature_test() { use jsonwebtoken::{decode, Algorithm, DecodingKey, Validation}; use std::collections::HashSet; - let client = &Client::tracked(create_server()).unwrap(); + let client = &Client::tracked(create_server(TEST_KEYPAIR.clone())).unwrap(); let token = UserToken { username: "hello".to_string(), @@ -614,7 +619,7 @@ fn verify_token_signature_test() { #[test] fn get_list_test() { - let client = &Client::tracked(create_server()).unwrap(); + let client = &Client::tracked(create_server(TEST_KEYPAIR.clone())).unwrap(); let token = UserToken { username: "hello".to_string(), @@ -656,7 +661,7 @@ fn get_list_test() { #[test] fn get_list_with_filter_eq_test() { - let client = &Client::tracked(create_server()).unwrap(); + let client = &Client::tracked(create_server(TEST_KEYPAIR.clone())).unwrap(); let token = UserToken { username: "hello".to_string(), @@ -682,7 +687,7 @@ fn get_list_with_filter_eq_test() { #[test] fn get_list_with_filter_lt_test() { - let client = &Client::tracked(create_server()).unwrap(); + let client = &Client::tracked(create_server(TEST_KEYPAIR.clone())).unwrap(); let token = UserToken { username: "hello".to_string(), @@ -709,7 +714,7 @@ fn get_list_with_filter_lt_test() { #[test] fn get_list_with_filter_eq_and_lt_test() { - let client = &Client::tracked(create_server()).unwrap(); + let client = &Client::tracked(create_server(TEST_KEYPAIR.clone())).unwrap(); let token = UserToken { username: "hello".to_string(), @@ -735,7 +740,7 @@ fn get_list_with_filter_eq_and_lt_test() { #[test] fn get_list_with_filter_any_test() { - let client = &Client::tracked(create_server()).unwrap(); + let client = &Client::tracked(create_server(TEST_KEYPAIR.clone())).unwrap(); let token = UserToken { username: "hello".to_string(), @@ -778,7 +783,7 @@ fn get_list_with_filter_any_test() { #[test] fn get_list_with_limit_test() { - let client = &Client::tracked(create_server()).unwrap(); + let client = &Client::tracked(create_server(TEST_KEYPAIR.clone())).unwrap(); let token = UserToken { username: "hello".to_string(), @@ -849,7 +854,7 @@ fn get_list_with_limit_test() { #[test] fn post_events_test() { - let client = &Client::tracked(create_server()).unwrap(); + let client = &Client::tracked(create_server(TEST_KEYPAIR.clone())).unwrap(); let token = UserToken { username: "hello".to_string(), @@ -911,7 +916,7 @@ fn post_events_test() { #[test] fn get_pcf_test() { - let client = &Client::tracked(create_server()).unwrap(); + let client = &Client::tracked(create_server(TEST_KEYPAIR.clone())).unwrap(); let token = UserToken { username: "hello".to_string(),