Skip to content

Commit

Permalink
Implement token-based authentication
Browse files Browse the repository at this point in the history
  • Loading branch information
svenrademakers committed Sep 26, 2023
1 parent 23fab75 commit 7e8a40d
Show file tree
Hide file tree
Showing 9 changed files with 955 additions and 44 deletions.
290 changes: 266 additions & 24 deletions Cargo.lock

Large diffs are not rendered by default.

7 changes: 6 additions & 1 deletion bmcd/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ serde_yaml = "0.9.25"
tpi_rs = { path = "../tpi_rs" }
clap = { version = "4.4.2", features = ["cargo"] }
openssl = "0.10.57"
rand = "0.8.5"
pwhash = "1.0.0"

anyhow.workspace = true
log.workspace = true
Expand All @@ -23,4 +25,7 @@ tokio.workspace = true
tokio-util.workspace = true
futures.workspace = true
serde.workspace = true
rand = "0.8.5"
base64 = "0.21.4"

[dev-dependencies]
mockall = "0.11.4"
320 changes: 320 additions & 0 deletions bmcd/src/authentication/authentication_context.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,320 @@
use super::authentication_errors::AuthenticationError;
use super::authentication_errors::TokenError;
use super::passwd_validator::PasswordValidator;
use super::passwd_validator::UnixValidator;
use actix_web::dev::ServiceRequest;
use actix_web::http::header;
use actix_web::http::header::HeaderValue;
use base64::{engine::general_purpose, Engine as _};
use futures::StreamExt;
use rand::thread_rng;
use rand::Rng;
use serde::Deserialize;
use serde::Serialize;
use std::collections::HashMap;
use std::marker::PhantomData;
use tokio::{
sync::RwLock,
time::{Duration, Instant},
};

pub struct AuthenticationContext<P>
where
P: PasswordValidator + 'static,
{
token_store: RwLock<HashMap<u128, Instant>>,
passwds: HashMap<String, String>,
password_validator: PhantomData<P>,
expiration_timeout: Duration,
}

impl<P> AuthenticationContext<P>
where
P: PasswordValidator + 'static,
{
pub fn with_hs_512_generator(
password_entries: impl Iterator<Item = (String, String)>,
expiration_timeout: Duration,
) -> AuthenticationContext<UnixValidator> {
AuthenticationContext::<UnixValidator> {
token_store: RwLock::new(HashMap::new()),
passwds: HashMap::from_iter(password_entries),
password_validator: PhantomData::<UnixValidator>,
expiration_timeout,
}
}

fn parse_http_token(auth_string: &str) -> Result<u128, TokenError> {
auth_string
.split_once(' ')
.ok_or(TokenError::HttpParseError(auth_string.to_string()))
.and_then(|(_, token)| {
(!token.is_empty())
.then_some(Self::decode_token(token)?)
.ok_or(TokenError::Empty)
})
}

fn encode_token(token: u128) -> String {
general_purpose::STANDARD.encode(token.to_ne_bytes())
}

fn decode_token(text: &str) -> Result<u128, TokenError> {
let mut buffer = [0u8; 16];
general_purpose::STANDARD.decode_slice(text, &mut buffer)?;
Ok(u128::from_ne_bytes(buffer))
}

async fn verify_token(&self, auth_string: &str) -> Result<(), TokenError> {
let token = Self::parse_http_token(auth_string)?;

let store = self.token_store.read().await;
let Some(expiration) = store.get(&token).copied() else {
return Err(TokenError::NoMatch(token.to_string()));
};

let duration = Instant::now().saturating_duration_since(expiration);
if duration > self.expiration_timeout {
// upgrade to write lock
drop(store);
self.token_store.write().await.remove(&token);
return Err(TokenError::Expired(expiration));
}

Ok(())
}

pub async fn authorize_request(&self, request: &ServiceRequest) -> Result<(), TokenError> {
let auth = request
.headers()
.get(header::AUTHORIZATION)
.map(HeaderValue::to_str)
.ok_or(TokenError::Empty)?;

self.verify_token(auth.unwrap_or_default()).await
}

pub async fn authenticate_request(
&self,
request: &mut ServiceRequest,
) -> Result<(String, Instant), AuthenticationError> {
let mut buffer = Vec::new();
while let Some(Ok(bytes)) = request.parts_mut().1.next().await {
buffer.extend_from_slice(&bytes);
}

let credentials = serde_json::from_slice::<Login>(&buffer)?;
let Some(pass) = self.passwds.get(&credentials.username) else {
log::debug!("user {} not in database",&credentials.username);
return Err(AuthenticationError::IncorrectCredentials);
};

P::validate(pass, &credentials.password)?;

let token: u128 = thread_rng().gen();
let expires = Instant::now() + self.expiration_timeout;
self.token_store.write().await.insert(token, expires);
Ok((Self::encode_token(token), expires))
}
}

#[derive(Debug, Deserialize, Serialize)]
struct Login {
username: String,
password: String,
}

#[cfg(test)]
mod tests {
use std::ops::Sub;

use super::*;
use actix_web::{
http::header::AUTHORIZATION,
test::{self},
};

struct DummyValidator {}
impl PasswordValidator for DummyValidator {
fn validate(
_: &str,
_: &str,
) -> Result<(), crate::authentication::authentication_errors::AuthenticationError> {
Ok(())
}
}

struct FalseValidator {}
impl PasswordValidator for FalseValidator {
fn validate(
_: &str,
_: &str,
) -> Result<(), crate::authentication::authentication_errors::AuthenticationError> {
Err(AuthenticationError::IncorrectCredentials)
}
}

fn build_test_context(
token_data: impl IntoIterator<Item = (String, Instant)>,
user_data: impl IntoIterator<Item = (String, String)>,
) -> AuthenticationContext<DummyValidator> {
AuthenticationContext {
token_store: RwLock::new(HashMap::from_iter(token_data)),
passwds: HashMap::from_iter(user_data),
password_validator: PhantomData::<DummyValidator>,
expiration_timeout: Duration::from_secs(20),
}
}

fn build_false_context(
token_data: impl IntoIterator<Item = (String, Instant)>,
user_data: impl IntoIterator<Item = (String, String)>,
) -> AuthenticationContext<FalseValidator> {
AuthenticationContext {
token_store: RwLock::new(HashMap::from_iter(token_data)),
passwds: HashMap::from_iter(user_data),
password_validator: PhantomData::<FalseValidator>,
expiration_timeout: Duration::from_secs(20),
}
}

#[actix_web::test]
async fn test_wrongly_encoded_token() {
let context = build_test_context(Vec::new(), Vec::new());
assert!(matches!(
context.verify_token("").await.unwrap_err(),
TokenError::HttpParseError(_),
));
assert!(matches!(
context.verify_token("Bearernospace").await.unwrap_err(),
TokenError::HttpParseError(_),
));
assert_eq!(
TokenError::Empty,
context.verify_token("Bearer ").await.unwrap_err()
);

let req = test::TestRequest::default().to_srv_request();
assert_eq!(
TokenError::Empty,
context.authorize_request(&req).await.unwrap_err()
);
}

#[actix_web::test]
async fn test_token_failures() {
let context = build_test_context(
[
("123".to_string(), Instant::now()),
("2".to_string(), Instant::now().sub(Duration::from_secs(20))),
],
Vec::new(),
);

let req = test::TestRequest::default()
.insert_header((AUTHORIZATION, "Bearer 1234"))
.to_srv_request();

assert_eq!(
context.authorize_request(&req).await.unwrap_err(),
TokenError::NoMatch("1234".to_string())
);

let req = test::TestRequest::default()
.insert_header((AUTHORIZATION, "Bearer 2"))
.to_srv_request();
assert!(matches!(
context.authorize_request(&req).await.unwrap_err(),
TokenError::Expired(x) if Instant::now().saturating_duration_since(x) > Duration::from_secs(20)
));

// After expired error, the token gets removed. Subsequent calls for that token will
// therefore return "NoMatch"
let req = test::TestRequest::default()
.insert_header((AUTHORIZATION, "Bearer 2"))
.to_srv_request();
assert_eq!(
context.authorize_request(&req).await.unwrap_err(),
TokenError::NoMatch("2".to_string())
);
}

#[actix_web::test]
async fn test_happy_flow() {
let context = build_test_context(
[
("123".to_string(), Instant::now()),
("2".to_string(), Instant::now().sub(Duration::from_secs(20))),
],
Vec::new(),
);
let req = test::TestRequest::default()
.insert_header((AUTHORIZATION, "Bearer 123"))
.to_srv_request();
assert_eq!(Ok(()), context.authorize_request(&req).await);
}

#[actix_web::test]
async fn authentication_errors() {
let context = build_test_context(
Vec::new(),
[("test_user".to_string(), "password".to_string())],
);

let mut req = test::TestRequest::default()
.set_payload("{not a valid json")
.to_srv_request();
assert!(matches!(
context.authenticate_request(&mut req).await.unwrap_err(),
AuthenticationError::ParseError(_)
));

let mut req = test::TestRequest::default()
.set_json(Login {
username: "John".to_string(),
password: "1234".to_string(),
})
.to_srv_request();
assert!(matches!(
context.authenticate_request(&mut req).await.unwrap_err(),
AuthenticationError::IncorrectCredentials
));
}

#[actix_web::test]
async fn invalid_password() {
let context = build_false_context(
Vec::new(),
[("test_user".to_string(), "password".to_string())],
);
let mut req = test::TestRequest::default()
.set_json(Login {
username: "test_user".to_string(),
password: "1234".to_string(),
})
.to_srv_request();
assert!(matches!(
context.authenticate_request(&mut req).await.unwrap_err(),
AuthenticationError::IncorrectCredentials
));
}

#[actix_web::test]
async fn pass_authentication() {
let context = build_test_context(
Vec::new(),
[("test_user".to_string(), "password".to_string())],
);
let mut req = test::TestRequest::default()
.set_json(Login {
username: "test_user".to_string(),
password: "password".to_string(),
})
.to_srv_request();

assert!(matches!(
context.authenticate_request(&mut req).await.unwrap(),
(token, _) if token == "token"
));
}
}
Loading

0 comments on commit 7e8a40d

Please sign in to comment.