Skip to content

Commit

Permalink
Implement token-based authentication
Browse files Browse the repository at this point in the history
  • Loading branch information
svenrademakers committed Sep 25, 2023
1 parent 23fab75 commit 7a15dbe
Show file tree
Hide file tree
Showing 10 changed files with 1,003 additions and 39 deletions.
354 changes: 329 additions & 25 deletions Cargo.lock

Large diffs are not rendered by default.

7 changes: 6 additions & 1 deletion bmcd/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ serde_yaml = "0.9.25"
tpi_rs = { path = "../tpi_rs" }
clap = { version = "4.4.2", features = ["cargo"] }
openssl = "0.10.57"
rand = "0.8.5"
pwhash = "1.0.0"

anyhow.workspace = true
log.workspace = true
Expand All @@ -23,4 +25,7 @@ tokio.workspace = true
tokio-util.workspace = true
futures.workspace = true
serde.workspace = true
rand = "0.8.5"
jsonwebtoken = "8.3.0"

[dev-dependencies]
mockall = "0.11.4"
294 changes: 294 additions & 0 deletions bmcd/src/authentication/authentication_context.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,294 @@
use super::authentication_errors::AuthenticationError;
use super::authentication_errors::TokenError;
use super::passwd_validator::BcryptValidator;
use super::passwd_validator::PasswordValidator;
use super::token_generator::Hs512Generator;
use super::token_generator::TokenGenerator;
use actix_web::dev::ServiceRequest;
use actix_web::http::header;
use actix_web::http::header::HeaderValue;
use futures::StreamExt;
use jsonwebtoken::EncodingKey;
use serde::Deserialize;
use serde::Serialize;
use std::collections::HashMap;
use std::marker::PhantomData;
use tokio::{
sync::RwLock,
time::{Duration, Instant},
};

pub const EXPIRATION: Duration = Duration::from_secs(60 * 60);

pub struct AuthenticationContext<G, P>
where
G: TokenGenerator,
P: PasswordValidator + 'static,
{
token_store: RwLock<HashMap<String, Instant>>,
passwds: HashMap<String, String>,
token_generator: G,
password_validator: PhantomData<P>,
}

impl<G, P> AuthenticationContext<G, P>
where
G: TokenGenerator,
P: PasswordValidator,
{
pub fn with_hs_512_generator(
password_entries: impl Iterator<Item = (String, String)>,
encoding_key: EncodingKey,
) -> AuthenticationContext<Hs512Generator, BcryptValidator> {
AuthenticationContext::<Hs512Generator, BcryptValidator> {
token_store: RwLock::new(HashMap::new()),
passwds: HashMap::from_iter(password_entries),
token_generator: Hs512Generator::new(encoding_key),
password_validator: PhantomData::<BcryptValidator>,
}
}

fn parse_http_token(auth_string: &str) -> Result<&str, TokenError> {
auth_string
.split_once(' ')
.ok_or(TokenError::HttpParseError(auth_string.to_string()))
.and_then(|(_, token)| {
(!token.is_empty())
.then_some(token)
.ok_or(TokenError::Empty)
})
}

async fn verify_token(&self, auth_string: &str) -> Result<(), TokenError> {
let token = Self::parse_http_token(auth_string)?;

let store = self.token_store.read().await;
let Some(expiration) = store.get(token).copied() else {
return Err(TokenError::NoMatch(token.to_string()));
};

let duration = Instant::now().saturating_duration_since(expiration);
if duration > EXPIRATION {
// upgrade to write lock
drop(store);
self.token_store.write().await.remove(token);
return Err(TokenError::Expired(expiration));
}

Ok(())
}

pub async fn authorize_request(&self, request: &ServiceRequest) -> Result<(), TokenError> {
let auth = request
.headers()
.get(header::AUTHORIZATION)
.map(HeaderValue::to_str)
.ok_or(TokenError::Empty)?;

self.verify_token(auth.unwrap_or_default()).await
}

pub async fn authenticate_request(
&self,
request: &mut ServiceRequest,
) -> Result<(String, Instant), AuthenticationError> {
let mut buffer = Vec::new();
while let Some(Ok(bytes)) = request.parts_mut().1.next().await {
buffer.extend_from_slice(&bytes);
}

let credentials = serde_json::from_slice::<Login>(&buffer)?;
let Some(pass) = self.passwds.get(&credentials.username) else {
return Err(AuthenticationError::IncorrectCredentials);
};

P::validate(pass, &credentials.password)?;

let token = self.token_generator.generate_token(&buffer)?;
let expires = Instant::now() + EXPIRATION;
self.token_store
.write()
.await
.insert(token.clone(), expires);
Ok((token, expires))
}
}

#[derive(Debug, Deserialize, Serialize)]
struct Login {
username: String,
password: String,
}

#[cfg(test)]
mod tests {
use std::ops::Sub;

use super::*;
use crate::authentication::{
authentication_context::EXPIRATION, token_generator::MockTokenGenerator,
};
use actix_web::{
http::header::AUTHORIZATION,
test::{self},
};

struct DummyValidator {}
impl PasswordValidator for DummyValidator {
fn validate(
_: &str,
_: &str,
) -> Result<(), crate::authentication::authentication_errors::AuthenticationError> {
Ok(())
}
}

struct FalseValidator {}
impl PasswordValidator for FalseValidator {
fn validate(
_: &str,
_: &str,
) -> Result<(), crate::authentication::authentication_errors::AuthenticationError> {
Err(AuthenticationError::IncorrectCredentials)
}
}

fn build_test_context(
token_data: impl Iterator<Item = (String, Instant)>,
user_data: impl Iterator<Item = (String, String)>,
) -> AuthenticationContext<MockTokenGenerator, DummyValidator> {
let mock = MockTokenGenerator::default();
AuthenticationContext {
token_store: RwLock::new(HashMap::from_iter(token_data)),
passwds: HashMap::from_iter(user_data),
token_generator: mock,
password_validator: PhantomData::<DummyValidator>,
}
}

fn build_false_context(
token_data: impl Iterator<Item = (String, Instant)>,
user_data: impl Iterator<Item = (String, String)>,
) -> AuthenticationContext<MockTokenGenerator, FalseValidator> {
let mock = MockTokenGenerator::default();
AuthenticationContext {
token_store: RwLock::new(HashMap::from_iter(token_data)),
passwds: HashMap::from_iter(user_data),
token_generator: mock,
password_validator: PhantomData::<FalseValidator>,
}
}

#[actix_web::test]
async fn test_wrongly_encoded_token() {
let context = build_test_context(Vec::new().into_iter(), Vec::new().into_iter());
assert!(matches!(
context.verify_token("").await.unwrap_err(),
TokenError::HttpParseError(_),
));
assert!(matches!(
context.verify_token("Bearernospace").await.unwrap_err(),
TokenError::HttpParseError(_),
));
assert_eq!(
TokenError::Empty,
context.verify_token("Bearer ").await.unwrap_err()
);

let req = test::TestRequest::default().to_srv_request();
assert_eq!(
TokenError::Empty,
context.authorize_request(&req).await.unwrap_err()
);
}

#[actix_web::test]
async fn test_token_failures() {
let context = build_test_context(
[
("123".to_string(), Instant::now()),
("2".to_string(), Instant::now().sub(EXPIRATION)),
]
.into_iter(),
Vec::new().into_iter(),
);

let req = test::TestRequest::default()
.insert_header((AUTHORIZATION, "Bearer 1234"))
.to_srv_request();

assert_eq!(
context.authorize_request(&req).await.unwrap_err(),
TokenError::NoMatch("1234".to_string())
);

let req = test::TestRequest::default()
.insert_header((AUTHORIZATION, "Bearer 2"))
.to_srv_request();
assert!(matches!(
context.authorize_request(&req).await.unwrap_err(),
TokenError::Expired(x) if Instant::now().saturating_duration_since(x) > EXPIRATION
));
}

#[actix_web::test]
async fn test_happy_flow() {
let context = build_test_context(
[
("123".to_string(), Instant::now()),
("2".to_string(), Instant::now().sub(EXPIRATION)),
]
.into_iter(),
Vec::new().into_iter(),
);
let req = test::TestRequest::default()
.insert_header((AUTHORIZATION, "Bearer 123"))
.to_srv_request();
assert_eq!(Ok(()), context.authorize_request(&req).await);
}

#[actix_web::test]
async fn authentication() {
let context = build_test_context(
Vec::new().into_iter(),
[("test_user".to_string(), "password".to_string())].into_iter(),
);

let mut req = test::TestRequest::default()
.set_payload("{not a valid json")
.to_srv_request();
assert!(matches!(
context.authenticate_request(&mut req).await.unwrap_err(),
AuthenticationError::ParseError(_)
));

let mut req = test::TestRequest::default()
.set_json(Login {
username: "John".to_string(),
password: "1234".to_string(),
})
.to_srv_request();
assert!(matches!(
context.authenticate_request(&mut req).await.unwrap_err(),
AuthenticationError::IncorrectCredentials
));
}

#[actix_web::test]
async fn invalid_password() {
let context = build_false_context(
Vec::new().into_iter(),
[("test_user".to_string(), "password".to_string())].into_iter(),
);
let mut req = test::TestRequest::default()
.set_json(Login {
username: "test_user".to_string(),
password: "1234".to_string(),
})
.to_srv_request();
assert!(matches!(
context.authenticate_request(&mut req).await.unwrap_err(),
AuthenticationError::IncorrectCredentials
));
}
}
70 changes: 70 additions & 0 deletions bmcd/src/authentication/authentication_errors.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
use std::fmt::Display;
use tokio::time::Instant;

#[derive(Debug)]
pub enum AuthenticationError {
ParseError(serde_json::Error),
IncorrectCredentials,
InvalidToken(TokenError),
}

impl Display for AuthenticationError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
AuthenticationError::ParseError(e) => {
write!(f, "error trying to parse credetials: {}", e)
}
AuthenticationError::IncorrectCredentials => write!(f, "credetials incorrect"),
AuthenticationError::InvalidToken(token_error) => token_error.fmt(f),
}
}
}

impl std::error::Error for AuthenticationError {}

impl From<serde_json::Error> for AuthenticationError {
fn from(value: serde_json::Error) -> Self {
Self::ParseError(value)
}
}

impl From<TokenError> for AuthenticationError {
fn from(value: TokenError) -> Self {
Self::InvalidToken(value)
}
}

#[derive(Debug, PartialEq)]
pub enum TokenError {
Expired(Instant),
NoMatch(String),
HttpParseError(String),
TokenGenerationError(jsonwebtoken::errors::Error),
Empty,
}

impl Display for TokenError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
TokenError::Expired(instant) => write!(
f,
"token expired {}s ago",
Instant::now().duration_since(*instant).as_secs()
),
TokenError::NoMatch(token) => write!(f, "token {} is not registerd", token),
TokenError::Empty => write!(f, "empty token provided"),
TokenError::HttpParseError(token) => write!(f, "cannot parse token {}", token),
TokenError::TokenGenerationError(e) => {
write!(f, "could not generate access-token: {}", e)
}
}
}
}

impl std::error::Error for TokenError {}

impl From<jsonwebtoken::errors::Error> for TokenError {
fn from(value: jsonwebtoken::errors::Error) -> Self {
Self::TokenGenerationError(value)
}
}
Loading

0 comments on commit 7a15dbe

Please sign in to comment.