Skip to content

Commit

Permalink
Implement token-based authentication
Browse files Browse the repository at this point in the history
  • Loading branch information
svenrademakers committed Sep 27, 2023
1 parent 23fab75 commit 2785d62
Show file tree
Hide file tree
Showing 9 changed files with 815 additions and 48 deletions.
198 changes: 170 additions & 28 deletions Cargo.lock

Large diffs are not rendered by default.

4 changes: 3 additions & 1 deletion bmcd/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ serde_yaml = "0.9.25"
tpi_rs = { path = "../tpi_rs" }
clap = { version = "4.4.2", features = ["cargo"] }
openssl = "0.10.57"
rand = "0.8.5"
pwhash = "1.0.0"
base64 = "0.21.4"

anyhow.workspace = true
log.workspace = true
Expand All @@ -23,4 +26,3 @@ tokio.workspace = true
tokio-util.workspace = true
futures.workspace = true
serde.workspace = true
rand = "0.8.5"
277 changes: 277 additions & 0 deletions bmcd/src/authentication/authentication_context.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,277 @@
use super::authentication_errors::AuthenticationError;
use super::passwd_validator::PasswordValidator;
use super::passwd_validator::UnixValidator;
use base64::{engine::general_purpose, Engine as _};
use rand::distributions::Alphanumeric;
use rand::thread_rng;
use rand::Rng;
use serde::Deserialize;
use serde::Serialize;
use std::collections::HashMap;
use std::marker::PhantomData;
use tokio::sync::Mutex;
use tokio::time::{Duration, Instant};

pub struct AuthenticationContext<P>
where
P: PasswordValidator + 'static,
{
token_store: Mutex<HashMap<String, Instant>>,
passwds: HashMap<String, String>,
password_validator: PhantomData<P>,
expire_timeout: Duration,
}

impl<P> AuthenticationContext<P>
where
P: PasswordValidator + 'static,
{
pub fn with_unix_validator(
password_entries: impl Iterator<Item = (String, String)>,
expire_timeout: Duration,
) -> AuthenticationContext<UnixValidator> {
AuthenticationContext::<UnixValidator> {
token_store: Mutex::new(HashMap::new()),
passwds: HashMap::from_iter(password_entries),
password_validator: PhantomData::<UnixValidator>,
expire_timeout,
}
}

/// This function piggy-backs removes of expired tokens on an authentication
/// request. This imposes a small penalty on each request. Its deemed not
/// significant enough to justify optimalization given the expected volume
/// of incoming requests.
async fn new_and_remove_expired_tokens(&self, key: String) {
let mut store = self.token_store.lock().await;

store.retain(|_, last_access| {
let duration = Instant::now().saturating_duration_since(*last_access);
duration <= self.expire_timeout
});

store.insert(key, Instant::now());
}

async fn authorize_bearer(&self, token: &str) -> Result<(), AuthenticationError> {
let mut store = self.token_store.lock().await;
let Some(last_access) = store.get_mut(token) else {
return Err(AuthenticationError::NoMatch(token.to_string()));
};

let instant = *last_access;
let duration = Instant::now().saturating_duration_since(instant);
if duration < self.expire_timeout {
*last_access = Instant::now();
return Ok(());
}

store.remove(token);
Err(AuthenticationError::TokenExpired(instant))
}

fn validate_credentials(
&self,
username: &str,
password: &str,
) -> Result<(), AuthenticationError> {
let Some(pass) = self.passwds.get(username) else {
log::debug!("user {} not in database", username);
return Err(AuthenticationError::IncorrectCredentials);
};

P::validate(pass, password)
}

async fn authorize_basic(&self, credentials: &str) -> Result<(), AuthenticationError> {
let decoded = general_purpose::STANDARD.decode(credentials)?;
let utf8 = std::str::from_utf8(&decoded)?;
let Some((user, pass)) = utf8.split_once(':') else {
return Err(AuthenticationError::IncorrectCredentials);
};

self.validate_credentials(user, pass)
}

pub async fn authorize_request(
&self,
http_authorization_line: &str,
) -> Result<(), AuthenticationError> {
match http_authorization_line.split_once(' ') {
Some(("Bearer", token)) => self.authorize_bearer(token).await,
Some(("Basic", credentials)) => self.authorize_basic(credentials).await,
Some((auth, _)) => Err(AuthenticationError::SchemeNotSupported(auth.to_string())),
None => Err(AuthenticationError::HttpParseError(
http_authorization_line.to_string(),
)),
}
}

pub async fn authenticate_request(&self, body: &[u8]) -> Result<Session, AuthenticationError> {
let credentials = serde_json::from_slice::<Login>(body)?;

self.validate_credentials(&credentials.username, &credentials.password)?;

let token: String = thread_rng()
.sample_iter(&Alphanumeric)
.take(64)
.map(char::from)
.collect();
self.new_and_remove_expired_tokens(token.clone()).await;

Ok(Session {
id: token, // according Redfish spec, id refers to the session id.
// which is not equal to the access-token. for now use
// the token.
name: "User Session".to_string(),
description: "User Session".to_string(),
username: credentials.username,
})
}
}

#[derive(Debug, Deserialize, Serialize)]
pub struct Session {
pub id: String,
name: String,
description: String,
username: String,
}

#[derive(Debug, Deserialize, Serialize)]
struct Login {
username: String,
password: String,
}

#[cfg(test)]
pub mod tests {
use super::*;
use std::ops::Sub;

pub struct DummyValidator {}
impl PasswordValidator for DummyValidator {
fn validate(
_: &str,
_: &str,
) -> Result<(), crate::authentication::authentication_errors::AuthenticationError> {
Ok(())
}
}

pub struct FalseValidator {}
impl PasswordValidator for FalseValidator {
fn validate(
_: &str,
_: &str,
) -> Result<(), crate::authentication::authentication_errors::AuthenticationError> {
Err(AuthenticationError::IncorrectCredentials)
}
}

pub fn build_test_context(
token_data: impl IntoIterator<Item = (String, Instant)>,
user_data: impl IntoIterator<Item = (String, String)>,
) -> AuthenticationContext<UnixValidator> {
AuthenticationContext {
token_store: Mutex::new(HashMap::from_iter(token_data)),
passwds: HashMap::from_iter(user_data),
password_validator: PhantomData::<UnixValidator>,
expire_timeout: Duration::from_secs(20),
}
}

#[actix_web::test]
async fn test_token_failures() {
let now = Instant::now();
let twenty_sec_ago = now.sub(Duration::from_secs(20));
let context = build_test_context(
[("123".to_string(), now), ("2".to_string(), twenty_sec_ago)],
Vec::new(),
);

assert_eq!(
context.authorize_request("Bearer 1234").await.unwrap_err(),
AuthenticationError::NoMatch("1234".to_string())
);

assert_eq!(
context.authorize_request("Bearer 2").await.unwrap_err(),
AuthenticationError::TokenExpired(twenty_sec_ago)
);

// After expired error, the token gets removed. Subsequent calls for that token will
// therefore return "NoMatch"
assert_eq!(
context.authorize_request("Bearer 2").await.unwrap_err(),
AuthenticationError::NoMatch("2".to_string())
);
}

#[actix_web::test]
async fn test_happy_flow() {
let context = build_test_context(
[
("123".to_string(), Instant::now()),
("2".to_string(), Instant::now().sub(Duration::from_secs(20))),
],
Vec::new(),
);
assert_eq!(Ok(()), context.authorize_request("Bearer 123").await);
}

#[actix_web::test]
async fn authentication_errors() {
let context = build_test_context(
Vec::new(),
[("test_user".to_string(), "password".to_string())],
);

assert!(matches!(
context
.authenticate_request(b"{not a valid json")
.await
.unwrap_err(),
AuthenticationError::ParseError(_)
));

let json = serde_json::to_vec(&Login {
username: "John".to_string(),
password: "1234".to_string(),
})
.unwrap();

assert_eq!(
context.authenticate_request(&json).await.unwrap_err(),
AuthenticationError::IncorrectCredentials
);
let json = serde_json::to_vec(&Login {
username: "test_user".to_string(),
password: "1234".to_string(),
})
.unwrap();

assert_eq!(
context.authenticate_request(&json).await.unwrap_err(),
AuthenticationError::IncorrectCredentials
);
}

#[actix_web::test]
async fn pass_authentication() {
let context = build_test_context(
Vec::new(),
[("test_user".to_string(), "password".to_string())],
);
let json = serde_json::to_vec(&Login {
username: "test_user".to_string(),
password: "password".to_string(),
})
.unwrap();

assert_eq!(
context.authenticate_request(&json).await.unwrap_err(),
AuthenticationError::IncorrectCredentials
);
}
}
57 changes: 57 additions & 0 deletions bmcd/src/authentication/authentication_errors.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
use std::{fmt::Display, str::Utf8Error};
use tokio::time::Instant;

#[derive(Debug, PartialEq)]
pub enum AuthenticationError {
ParseError(String),
IncorrectCredentials,
TokenExpired(Instant),
NoMatch(String),
HttpParseError(String),
SchemeNotSupported(String),
Empty,
}

impl Display for AuthenticationError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
AuthenticationError::ParseError(e) => {
write!(f, "error trying to parse credentials: {}", e)
}
AuthenticationError::IncorrectCredentials => write!(f, "credentials incorrect"),
AuthenticationError::TokenExpired(instant) => write!(
f,
"token expired {}s ago",
Instant::now().duration_since(*instant).as_secs()
),
AuthenticationError::NoMatch(token) => write!(f, "token {} is not registered", token),
AuthenticationError::Empty => write!(f, "no authorization header provided"),
AuthenticationError::HttpParseError(token) => {
write!(f, "cannot parse authorization header: {}", token)
}
AuthenticationError::SchemeNotSupported(scheme) => {
write!(f, "{} authentication not supported", scheme)
}
}
}
}

impl std::error::Error for AuthenticationError {}

impl From<serde_json::Error> for AuthenticationError {
fn from(value: serde_json::Error) -> Self {
Self::ParseError(value.to_string())
}
}

impl From<base64::DecodeError> for AuthenticationError {
fn from(value: base64::DecodeError) -> Self {
Self::ParseError(value.to_string())
}
}

impl From<Utf8Error> for AuthenticationError {
fn from(value: Utf8Error) -> Self {
Self::ParseError(value.to_string())
}
}
Loading

0 comments on commit 2785d62

Please sign in to comment.