Skip to content

Commit

Permalink
Implement token-based authentication
Browse files Browse the repository at this point in the history
  • Loading branch information
svenrademakers committed Sep 26, 2023
1 parent 23fab75 commit 565d1af
Show file tree
Hide file tree
Showing 9 changed files with 939 additions and 48 deletions.
297 changes: 269 additions & 28 deletions Cargo.lock

Large diffs are not rendered by default.

6 changes: 5 additions & 1 deletion bmcd/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ serde_yaml = "0.9.25"
tpi_rs = { path = "../tpi_rs" }
clap = { version = "4.4.2", features = ["cargo"] }
openssl = "0.10.57"
rand = "0.8.5"
pwhash = "1.0.0"

anyhow.workspace = true
log.workspace = true
Expand All @@ -23,4 +25,6 @@ tokio.workspace = true
tokio-util.workspace = true
futures.workspace = true
serde.workspace = true
rand = "0.8.5"

[dev-dependencies]
mockall = "0.11.4"
317 changes: 317 additions & 0 deletions bmcd/src/authentication/authentication_context.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,317 @@
use super::authentication_errors::AuthenticationError;
use super::authentication_errors::TokenError;
use super::passwd_validator::PasswordValidator;
use super::passwd_validator::UnixValidator;
use actix_web::dev::ServiceRequest;
use actix_web::http::header;
use actix_web::http::header::HeaderValue;
use futures::StreamExt;
use rand::distributions::Alphanumeric;
use rand::thread_rng;
use rand::Rng;
use serde::Deserialize;
use serde::Serialize;
use std::collections::HashMap;
use std::marker::PhantomData;
use tokio::{
sync::RwLock,
time::{Duration, Instant},
};

pub struct AuthenticationContext<P>
where
P: PasswordValidator + 'static,
{
token_store: RwLock<HashMap<String, Instant>>,
passwds: HashMap<String, String>,
password_validator: PhantomData<P>,
expiration_timeout: Duration,
}

impl<P> AuthenticationContext<P>
where
P: PasswordValidator + 'static,
{
pub fn with_unix_validator(
password_entries: impl Iterator<Item = (String, String)>,
expiration_timeout: Duration,
) -> AuthenticationContext<UnixValidator> {
AuthenticationContext::<UnixValidator> {
token_store: RwLock::new(HashMap::new()),
passwds: HashMap::from_iter(password_entries),
password_validator: PhantomData::<UnixValidator>,
expiration_timeout,
}
}

fn parse_http_token(auth_string: &str) -> Result<&str, TokenError> {
auth_string
.split_once(' ')
.ok_or(TokenError::HttpParseError(auth_string.to_string()))
.and_then(|(_, token)| {
(!token.is_empty())
.then_some(token)
.ok_or(TokenError::Empty)
})
}

async fn verify_token(&self, auth_string: &str) -> Result<(), TokenError> {
let token = Self::parse_http_token(auth_string)?;

let store = self.token_store.read().await;
let Some(expiration) = store.get(token).copied() else {
return Err(TokenError::NoMatch(token.to_string()));
};

let duration = Instant::now().saturating_duration_since(expiration);
if duration > self.expiration_timeout {
// upgrade to write lock
drop(store);
self.token_store.write().await.remove(token);
return Err(TokenError::Expired(expiration));
}

Ok(())
}

pub async fn authorize_request(&self, request: &ServiceRequest) -> Result<(), TokenError> {
let auth = request
.headers()
.get(header::AUTHORIZATION)
.map(HeaderValue::to_str)
.ok_or(TokenError::Empty)?;

self.verify_token(auth.unwrap_or_default()).await
}

pub async fn authenticate_request(
&self,
request: &mut ServiceRequest,
) -> Result<(String, Instant), AuthenticationError> {
let mut buffer = Vec::new();
while let Some(Ok(bytes)) = request.parts_mut().1.next().await {
buffer.extend_from_slice(&bytes);
}

let credentials = serde_json::from_slice::<Login>(&buffer)?;
let Some(pass) = self.passwds.get(&credentials.username) else {
log::debug!("user {} not in database",&credentials.username);
return Err(AuthenticationError::IncorrectCredentials);
};

P::validate(pass, &credentials.password)?;

let token: String = thread_rng()
.sample_iter(&Alphanumeric)
.take(64)
.map(char::from)
.collect();
let expires = Instant::now() + self.expiration_timeout;
self.token_store
.write()
.await
.insert(token.clone(), expires);
Ok((token, expires))
}
}

#[derive(Debug, Deserialize, Serialize)]
struct Login {
username: String,
password: String,
}

#[cfg(test)]
mod tests {
use std::ops::Sub;

use super::*;
use actix_web::{
http::header::AUTHORIZATION,
test::{self},
};

struct DummyValidator {}
impl PasswordValidator for DummyValidator {
fn validate(
_: &str,
_: &str,
) -> Result<(), crate::authentication::authentication_errors::AuthenticationError> {
Ok(())
}
}

struct FalseValidator {}
impl PasswordValidator for FalseValidator {
fn validate(
_: &str,
_: &str,
) -> Result<(), crate::authentication::authentication_errors::AuthenticationError> {
Err(AuthenticationError::IncorrectCredentials)
}
}

fn build_test_context(
token_data: impl IntoIterator<Item = (String, Instant)>,
user_data: impl IntoIterator<Item = (String, String)>,
) -> AuthenticationContext<DummyValidator> {
AuthenticationContext {
token_store: RwLock::new(HashMap::from_iter(token_data)),
passwds: HashMap::from_iter(user_data),
password_validator: PhantomData::<DummyValidator>,
expiration_timeout: Duration::from_secs(20),
}
}

fn build_false_context(
token_data: impl IntoIterator<Item = (String, Instant)>,
user_data: impl IntoIterator<Item = (String, String)>,
) -> AuthenticationContext<FalseValidator> {
AuthenticationContext {
token_store: RwLock::new(HashMap::from_iter(token_data)),
passwds: HashMap::from_iter(user_data),
password_validator: PhantomData::<FalseValidator>,
expiration_timeout: Duration::from_secs(20),
}
}

#[actix_web::test]
async fn test_wrongly_encoded_token() {
let context = build_test_context(Vec::new(), Vec::new());
assert!(matches!(
context.verify_token("").await.unwrap_err(),
TokenError::HttpParseError(_),
));
assert!(matches!(
context.verify_token("Bearernospace").await.unwrap_err(),
TokenError::HttpParseError(_),
));
assert_eq!(
TokenError::Empty,
context.verify_token("Bearer ").await.unwrap_err()
);

let req = test::TestRequest::default().to_srv_request();
assert_eq!(
TokenError::Empty,
context.authorize_request(&req).await.unwrap_err()
);
}

#[actix_web::test]
async fn test_token_failures() {
let context = build_test_context(
[
("123".to_string(), Instant::now()),
("2".to_string(), Instant::now().sub(Duration::from_secs(20))),
],
Vec::new(),
);

let req = test::TestRequest::default()
.insert_header((AUTHORIZATION, "Bearer 1234"))
.to_srv_request();

assert_eq!(
context.authorize_request(&req).await.unwrap_err(),
TokenError::NoMatch("1234".to_string())
);

let req = test::TestRequest::default()
.insert_header((AUTHORIZATION, "Bearer 2"))
.to_srv_request();
assert!(matches!(
context.authorize_request(&req).await.unwrap_err(),
TokenError::Expired(x) if Instant::now().saturating_duration_since(x) > Duration::from_secs(20)
));

// After expired error, the token gets removed. Subsequent calls for that token will
// therefore return "NoMatch"
let req = test::TestRequest::default()
.insert_header((AUTHORIZATION, "Bearer 2"))
.to_srv_request();
assert_eq!(
context.authorize_request(&req).await.unwrap_err(),
TokenError::NoMatch("2".to_string())
);
}

#[actix_web::test]
async fn test_happy_flow() {
let context = build_test_context(
[
("123".to_string(), Instant::now()),
("2".to_string(), Instant::now().sub(Duration::from_secs(20))),
],
Vec::new(),
);
let req = test::TestRequest::default()
.insert_header((AUTHORIZATION, "Bearer 123"))
.to_srv_request();
assert_eq!(Ok(()), context.authorize_request(&req).await);
}

#[actix_web::test]
async fn authentication_errors() {
let context = build_test_context(
Vec::new(),
[("test_user".to_string(), "password".to_string())],
);

let mut req = test::TestRequest::default()
.set_payload("{not a valid json")
.to_srv_request();
assert!(matches!(
context.authenticate_request(&mut req).await.unwrap_err(),
AuthenticationError::ParseError(_)
));

let mut req = test::TestRequest::default()
.set_json(Login {
username: "John".to_string(),
password: "1234".to_string(),
})
.to_srv_request();
assert!(matches!(
context.authenticate_request(&mut req).await.unwrap_err(),
AuthenticationError::IncorrectCredentials
));
}

#[actix_web::test]
async fn invalid_password() {
let context = build_false_context(
Vec::new(),
[("test_user".to_string(), "password".to_string())],
);
let mut req = test::TestRequest::default()
.set_json(Login {
username: "test_user".to_string(),
password: "1234".to_string(),
})
.to_srv_request();
assert!(matches!(
context.authenticate_request(&mut req).await.unwrap_err(),
AuthenticationError::IncorrectCredentials
));
}

#[actix_web::test]
async fn pass_authentication() {
let context = build_test_context(
Vec::new(),
[("test_user".to_string(), "password".to_string())],
);
let mut req = test::TestRequest::default()
.set_json(Login {
username: "test_user".to_string(),
password: "password".to_string(),
})
.to_srv_request();

assert!(matches!(
context.authenticate_request(&mut req).await.unwrap(),
(token, _) if !token.is_empty()
));
}
}
60 changes: 60 additions & 0 deletions bmcd/src/authentication/authentication_errors.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
use std::fmt::Display;
use tokio::time::Instant;

#[derive(Debug)]
pub enum AuthenticationError {
ParseError(serde_json::Error),
IncorrectCredentials,
InvalidToken(TokenError),
}

impl Display for AuthenticationError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
AuthenticationError::ParseError(e) => {
write!(f, "error trying to parse credentials: {}", e)
}
AuthenticationError::IncorrectCredentials => write!(f, "credentials incorrect"),
AuthenticationError::InvalidToken(token_error) => token_error.fmt(f),
}
}
}

impl std::error::Error for AuthenticationError {}

impl From<serde_json::Error> for AuthenticationError {
fn from(value: serde_json::Error) -> Self {
Self::ParseError(value)
}
}

impl From<TokenError> for AuthenticationError {
fn from(value: TokenError) -> Self {
Self::InvalidToken(value)
}
}

#[derive(Debug, PartialEq)]
pub enum TokenError {
Expired(Instant),
NoMatch(String),
HttpParseError(String),
Empty,
}

impl Display for TokenError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
TokenError::Expired(instant) => write!(
f,
"token expired {}s ago",
Instant::now().duration_since(*instant).as_secs()
),
TokenError::NoMatch(token) => write!(f, "token {} is not registerd", token),
TokenError::Empty => write!(f, "no authorization token provided"),
TokenError::HttpParseError(token) => write!(f, "cannot parse token {}", token),
}
}
}

impl std::error::Error for TokenError {}
Loading

0 comments on commit 565d1af

Please sign in to comment.