Skip to content

Commit

Permalink
Implement token-based authentication
Browse files Browse the repository at this point in the history
  • Loading branch information
svenrademakers committed Sep 25, 2023
1 parent 23fab75 commit 4e046db
Show file tree
Hide file tree
Showing 10 changed files with 1,030 additions and 39 deletions.
354 changes: 329 additions & 25 deletions Cargo.lock

Large diffs are not rendered by default.

7 changes: 6 additions & 1 deletion bmcd/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ serde_yaml = "0.9.25"
tpi_rs = { path = "../tpi_rs" }
clap = { version = "4.4.2", features = ["cargo"] }
openssl = "0.10.57"
rand = "0.8.5"
pwhash = "1.0.0"

anyhow.workspace = true
log.workspace = true
Expand All @@ -23,4 +25,7 @@ tokio.workspace = true
tokio-util.workspace = true
futures.workspace = true
serde.workspace = true
rand = "0.8.5"
jsonwebtoken = "8.3.0"

[dev-dependencies]
mockall = "0.11.4"
315 changes: 315 additions & 0 deletions bmcd/src/authentication/authentication_context.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,315 @@
use super::authentication_errors::AuthenticationError;
use super::authentication_errors::TokenError;
use super::passwd_validator::BcryptValidator;
use super::passwd_validator::PasswordValidator;
use super::token_generator::Hs512Generator;
use super::token_generator::TokenGenerator;
use actix_web::dev::ServiceRequest;
use actix_web::http::header;
use actix_web::http::header::HeaderValue;
use futures::StreamExt;
use jsonwebtoken::EncodingKey;
use serde::Deserialize;
use serde::Serialize;
use std::collections::HashMap;
use std::marker::PhantomData;
use tokio::{
sync::RwLock,
time::{Duration, Instant},
};

pub const EXPIRATION: Duration = Duration::from_secs(60 * 60);

pub struct AuthenticationContext<G, P>
where
G: TokenGenerator,
P: PasswordValidator + 'static,
{
token_store: RwLock<HashMap<String, Instant>>,
passwds: HashMap<String, String>,
token_generator: G,
password_validator: PhantomData<P>,
}

impl<G, P> AuthenticationContext<G, P>
where
G: TokenGenerator,
P: PasswordValidator,
{
pub fn with_hs_512_generator(
password_entries: impl Iterator<Item = (String, String)>,
encoding_key: EncodingKey,
) -> AuthenticationContext<Hs512Generator, BcryptValidator> {
AuthenticationContext::<Hs512Generator, BcryptValidator> {
token_store: RwLock::new(HashMap::new()),
passwds: HashMap::from_iter(password_entries),
token_generator: Hs512Generator::new(encoding_key),
password_validator: PhantomData::<BcryptValidator>,
}
}

fn parse_http_token(auth_string: &str) -> Result<&str, TokenError> {
auth_string
.split_once(' ')
.ok_or(TokenError::HttpParseError(auth_string.to_string()))
.and_then(|(_, token)| {
(!token.is_empty())
.then_some(token)
.ok_or(TokenError::Empty)
})
}

async fn verify_token(&self, auth_string: &str) -> Result<(), TokenError> {
let token = Self::parse_http_token(auth_string)?;

let store = self.token_store.read().await;
let Some(expiration) = store.get(token).copied() else {
return Err(TokenError::NoMatch(token.to_string()));
};

let duration = Instant::now().saturating_duration_since(expiration);
if duration > EXPIRATION {
// upgrade to write lock
drop(store);
self.token_store.write().await.remove(token);
return Err(TokenError::Expired(expiration));
}

Ok(())
}

pub async fn authorize_request(&self, request: &ServiceRequest) -> Result<(), TokenError> {
let auth = request
.headers()
.get(header::AUTHORIZATION)
.map(HeaderValue::to_str)
.ok_or(TokenError::Empty)?;

self.verify_token(auth.unwrap_or_default()).await
}

pub async fn authenticate_request(
&self,
request: &mut ServiceRequest,
) -> Result<(String, Instant), AuthenticationError> {
let mut buffer = Vec::new();
while let Some(Ok(bytes)) = request.parts_mut().1.next().await {
buffer.extend_from_slice(&bytes);
}

let credentials = serde_json::from_slice::<Login>(&buffer)?;
let Some(pass) = self.passwds.get(&credentials.username) else {
return Err(AuthenticationError::IncorrectCredentials);
};

P::validate(pass, &credentials.password)?;

let token = self.token_generator.generate_token(&buffer)?;
let expires = Instant::now() + EXPIRATION;
self.token_store
.write()
.await
.insert(token.clone(), expires);
Ok((token, expires))
}
}

#[derive(Debug, Deserialize, Serialize)]
struct Login {
username: String,
password: String,
}

#[cfg(test)]
mod tests {
use std::ops::Sub;

use super::*;
use crate::authentication::{
authentication_context::EXPIRATION, token_generator::MockTokenGenerator,
};
use actix_web::{
http::header::AUTHORIZATION,
test::{self},
};

struct DummyValidator {}
impl PasswordValidator for DummyValidator {
fn validate(
_: &str,
_: &str,
) -> Result<(), crate::authentication::authentication_errors::AuthenticationError> {
Ok(())
}
}

struct FalseValidator {}
impl PasswordValidator for FalseValidator {
fn validate(
_: &str,
_: &str,
) -> Result<(), crate::authentication::authentication_errors::AuthenticationError> {
Err(AuthenticationError::IncorrectCredentials)
}
}

fn build_test_context(
token_data: impl Iterator<Item = (String, Instant)>,
user_data: impl Iterator<Item = (String, String)>,
) -> AuthenticationContext<MockTokenGenerator, DummyValidator> {
let mut mock = MockTokenGenerator::default();
mock.expect_generate_token()
.return_once(move |_: &Vec<u8>| Ok("token".to_string()));
AuthenticationContext {
token_store: RwLock::new(HashMap::from_iter(token_data)),
passwds: HashMap::from_iter(user_data),
token_generator: mock,
password_validator: PhantomData::<DummyValidator>,
}
}

fn build_false_context(
token_data: impl Iterator<Item = (String, Instant)>,
user_data: impl Iterator<Item = (String, String)>,
) -> AuthenticationContext<MockTokenGenerator, FalseValidator> {
let mock = MockTokenGenerator::default();
AuthenticationContext {
token_store: RwLock::new(HashMap::from_iter(token_data)),
passwds: HashMap::from_iter(user_data),
token_generator: mock,
password_validator: PhantomData::<FalseValidator>,
}
}

#[actix_web::test]
async fn test_wrongly_encoded_token() {
let context = build_test_context(Vec::new().into_iter(), Vec::new().into_iter());
assert!(matches!(
context.verify_token("").await.unwrap_err(),
TokenError::HttpParseError(_),
));
assert!(matches!(
context.verify_token("Bearernospace").await.unwrap_err(),
TokenError::HttpParseError(_),
));
assert_eq!(
TokenError::Empty,
context.verify_token("Bearer ").await.unwrap_err()
);

let req = test::TestRequest::default().to_srv_request();
assert_eq!(
TokenError::Empty,
context.authorize_request(&req).await.unwrap_err()
);
}

#[actix_web::test]
async fn test_token_failures() {
let context = build_test_context(
[
("123".to_string(), Instant::now()),
("2".to_string(), Instant::now().sub(EXPIRATION)),
]
.into_iter(),
Vec::new().into_iter(),
);

let req = test::TestRequest::default()
.insert_header((AUTHORIZATION, "Bearer 1234"))
.to_srv_request();

assert_eq!(
context.authorize_request(&req).await.unwrap_err(),
TokenError::NoMatch("1234".to_string())
);

let req = test::TestRequest::default()
.insert_header((AUTHORIZATION, "Bearer 2"))
.to_srv_request();
assert!(matches!(
context.authorize_request(&req).await.unwrap_err(),
TokenError::Expired(x) if Instant::now().saturating_duration_since(x) > EXPIRATION
));
}

#[actix_web::test]
async fn test_happy_flow() {
let context = build_test_context(
[
("123".to_string(), Instant::now()),
("2".to_string(), Instant::now().sub(EXPIRATION)),
]
.into_iter(),
Vec::new().into_iter(),
);
let req = test::TestRequest::default()
.insert_header((AUTHORIZATION, "Bearer 123"))
.to_srv_request();
assert_eq!(Ok(()), context.authorize_request(&req).await);
}

#[actix_web::test]
async fn authentication_errors() {
let context = build_test_context(
Vec::new().into_iter(),
[("test_user".to_string(), "password".to_string())].into_iter(),
);

let mut req = test::TestRequest::default()
.set_payload("{not a valid json")
.to_srv_request();
assert!(matches!(
context.authenticate_request(&mut req).await.unwrap_err(),
AuthenticationError::ParseError(_)
));

let mut req = test::TestRequest::default()
.set_json(Login {
username: "John".to_string(),
password: "1234".to_string(),
})
.to_srv_request();
assert!(matches!(
context.authenticate_request(&mut req).await.unwrap_err(),
AuthenticationError::IncorrectCredentials
));
}

#[actix_web::test]
async fn invalid_password() {
let context = build_false_context(
Vec::new().into_iter(),
[("test_user".to_string(), "password".to_string())].into_iter(),
);
let mut req = test::TestRequest::default()
.set_json(Login {
username: "test_user".to_string(),
password: "1234".to_string(),
})
.to_srv_request();
assert!(matches!(
context.authenticate_request(&mut req).await.unwrap_err(),
AuthenticationError::IncorrectCredentials
));
}

#[actix_web::test]
async fn pass_authentication() {
let context = build_test_context(
Vec::new().into_iter(),
[("test_user".to_string(), "password".to_string())].into_iter(),
);
let mut req = test::TestRequest::default()
.set_json(Login {
username: "test_user".to_string(),
password: "password".to_string(),
})
.to_srv_request();

assert!(matches!(
context.authenticate_request(&mut req).await.unwrap(),
(token, _) if token == "token"
));
}
}
Loading

0 comments on commit 4e046db

Please sign in to comment.