Skip to content

Commit

Permalink
auth: ban consecutive failed requests
Browse files Browse the repository at this point in the history
Consecutive unauthenticated requests that exceed a given threshold will
be banned for a given period before they are allowed to make any new
request. Every consecutive failed attempt exponentially increases the
cool down period in which the peer is blocked from authenticate itself
up to a upper limit.
  • Loading branch information
svenrademakers committed Nov 2, 2023
1 parent 47c2335 commit e5f60ac
Show file tree
Hide file tree
Showing 6 changed files with 303 additions and 75 deletions.
1 change: 1 addition & 0 deletions src/authentication.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,6 @@
pub mod authentication_context;
pub mod authentication_errors;
pub mod authentication_service;
pub mod ban_patrol;
pub mod linux_authenticator;
pub mod passwd_validator;
137 changes: 95 additions & 42 deletions src/authentication/authentication_context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// limitations under the License.
use super::authentication_errors::AuthenticationError;
use super::authentication_errors::SchemedAuthError;
use super::ban_patrol::BanPatrol;
use super::passwd_validator::PasswordValidator;
use super::passwd_validator::UnixValidator;
use base64::{engine::general_purpose, Engine as _};
Expand All @@ -23,18 +24,17 @@ use serde::Deserialize;
use serde::Serialize;
use std::collections::HashMap;
use std::marker::PhantomData;
use tokio::sync::Mutex;
use tokio::time::{Duration, Instant};

pub struct AuthenticationContext<P>
where
P: PasswordValidator + 'static,
{
token_store: Mutex<HashMap<String, Instant>>,
token_store: HashMap<String, Instant>,
passwds: HashMap<String, String>,
password_validator: PhantomData<P>,
expire_timeout: Duration,
authentication_attempts: usize,
ban_patrol: BanPatrol,
}

impl<P> AuthenticationContext<P>
Expand All @@ -47,60 +47,87 @@ where
authentication_attempts: usize,
) -> AuthenticationContext<UnixValidator> {
AuthenticationContext::<UnixValidator> {
token_store: Mutex::new(HashMap::new()),
token_store: HashMap::new(),
passwds: HashMap::from_iter(password_entries),
password_validator: PhantomData::<UnixValidator>,
expire_timeout,
authentication_attempts,
ban_patrol: BanPatrol::new(authentication_attempts),
}
}

/// This function piggy-backs removes of expired tokens on an authentication
/// request. This imposes a small penalty on each request. Its deemed not
/// significant enough to justify optimization given the expected volume
/// of incoming authentication requests.
async fn new_and_remove_expired_tokens(&self, key: String) {
let mut store = self.token_store.lock().await;

store.retain(|_, last_access| {
async fn new_and_remove_expired_tokens(&mut self, key: String) {
self.token_store.retain(|_, last_access| {
let duration = Instant::now().saturating_duration_since(*last_access);
duration <= self.expire_timeout
});

store.insert(key, Instant::now());
self.token_store.insert(key, Instant::now());
}

async fn authorize_bearer(&self, token: &str) -> Result<(), AuthenticationError> {
let mut store = self.token_store.lock().await;
let Some(last_access) = store.get_mut(token) else {
return Err(AuthenticationError::NoMatch(token.to_string()));
async fn authorize_bearer(
&mut self,
peer: &str,
token: &str,
) -> Result<(), AuthenticationError> {
self.ban_patrol.patrole_ban(peer.to_string())?;

let Some(last_access) = self.token_store.get_mut(token) else {
return Err(self
.ban_patrol
.penalize(peer.to_string())
.err()
.unwrap_or(AuthenticationError::NoMatch(token.to_string())));
};

let instant = *last_access;
let duration = Instant::now().saturating_duration_since(instant);
if duration < self.expire_timeout {
*last_access = Instant::now();
self.ban_patrol.clear_penalties(peer);
return Ok(());
}

store.remove(token);
self.token_store.remove(token);
Err(AuthenticationError::TokenExpired(instant))
}

fn validate_credentials(
&self,
&mut self,
peer: &str,
username: &str,
password: &str,
) -> Result<(), AuthenticationError> {
let Some(pass) = self.passwds.get(username) else {
log::debug!("user {} not in database", username);
return Err(AuthenticationError::IncorrectCredentials);
};

P::validate(pass, password)
self.ban_patrol.patrole_ban(peer.to_string())?;

match self
.passwds
.get(username)
.ok_or(AuthenticationError::IncorrectCredentials)
.and_then(|pass| P::validate(pass, password))
{
Ok(_) => {
log::debug!("{username} validated successfully");
self.ban_patrol.clear_penalties(peer);
Ok(())
}
Err(AuthenticationError::IncorrectCredentials) => Err(self
.ban_patrol
.penalize(peer.to_string())
.err()
.unwrap_or(AuthenticationError::IncorrectCredentials)),
Err(err) => Err(err),
}
}

async fn authorize_basic(&self, credentials: &str) -> Result<(), AuthenticationError> {
async fn authorize_basic(
&mut self,
peer: &str,
credentials: &str,
) -> Result<(), AuthenticationError> {
let decoded = general_purpose::STANDARD.decode(credentials)?;
let utf8 = std::str::from_utf8(&decoded)?;
let Some((user, pass)) = utf8.split_once(':') else {
Expand All @@ -109,20 +136,21 @@ where
));
};

self.validate_credentials(user, pass)
self.validate_credentials(peer, user, pass)
}

pub async fn authorize_request(
&self,
&mut self,
peer: &str,
http_authorization_line: &str,
) -> Result<(), SchemedAuthError> {
match http_authorization_line.split_once(' ') {
Some(("Bearer", token)) => self
.authorize_bearer(token)
.authorize_bearer(peer, token)
.await
.map_err(AuthenticationError::into_bearer_error),
Some(("Basic", credentials)) => self
.authorize_basic(credentials)
.authorize_basic(peer, credentials)
.await
.map_err(AuthenticationError::into_basic_error),
Some((auth, _)) => {
Expand All @@ -135,10 +163,14 @@ where
}
}

pub async fn authenticate_request(&self, body: &[u8]) -> Result<Session, AuthenticationError> {
pub async fn authenticate_request(
&mut self,
peer: &str,
body: &[u8],
) -> Result<Session, AuthenticationError> {
let credentials = serde_json::from_slice::<Login>(body)?;

self.validate_credentials(&credentials.username, &credentials.password)?;
self.validate_credentials(peer, &credentials.username, &credentials.password)?;

let token: String = thread_rng()
.sample_iter(&Alphanumeric)
Expand Down Expand Up @@ -202,66 +234,78 @@ pub mod tests {
user_data: impl IntoIterator<Item = (String, String)>,
) -> AuthenticationContext<UnixValidator> {
AuthenticationContext {
token_store: Mutex::new(HashMap::from_iter(token_data)),
token_store: HashMap::from_iter(token_data),
passwds: HashMap::from_iter(user_data),
password_validator: PhantomData::<UnixValidator>,
expire_timeout: Duration::from_secs(20),
ban_patrol: BanPatrol::new(10),
}
}

#[actix_web::test]
async fn test_token_failures() {
let now = Instant::now();
let twenty_sec_ago = now.sub(Duration::from_secs(20));
let context = build_test_context(
let mut context = build_test_context(
[("123".to_string(), now), ("2".to_string(), twenty_sec_ago)],
Vec::new(),
);

assert_eq!(
context
.authorize_request("Bearer 1234")
.authorize_request("peer", "Bearer 1234")
.await
.unwrap_err()
.1,
AuthenticationError::NoMatch("1234".to_string())
);

assert_eq!(
context.authorize_request("Bearer 2").await.unwrap_err().1,
context
.authorize_request("peer", "Bearer 2")
.await
.unwrap_err()
.1,
AuthenticationError::TokenExpired(twenty_sec_ago)
);

// After expired error, the token gets removed. Subsequent calls for that token will
// therefore return "NoMatch"
assert_eq!(
context.authorize_request("Bearer 2").await.unwrap_err().1,
context
.authorize_request("peer", "Bearer 2")
.await
.unwrap_err()
.1,
AuthenticationError::NoMatch("2".to_string())
);
}

#[actix_web::test]
async fn test_happy_flow() {
let context = build_test_context(
let mut context = build_test_context(
[
("123".to_string(), Instant::now()),
("2".to_string(), Instant::now().sub(Duration::from_secs(20))),
],
Vec::new(),
);
assert_eq!(Ok(()), context.authorize_request("Bearer 123").await);
assert_eq!(
Ok(()),
context.authorize_request("peer1", "Bearer 123").await
);
}

#[actix_web::test]
async fn authentication_errors() {
let context = build_test_context(
let mut context = build_test_context(
Vec::new(),
[("test_user".to_string(), "password".to_string())],
);

assert!(matches!(
context
.authenticate_request(b"{not a valid json")
.authenticate_request("peer1", b"{not a valid json")
.await
.unwrap_err(),
AuthenticationError::ParseError(_)
Expand All @@ -274,7 +318,10 @@ pub mod tests {
.unwrap();

assert_eq!(
context.authenticate_request(&json).await.unwrap_err(),
context
.authenticate_request("peer", &json)
.await
.unwrap_err(),
AuthenticationError::IncorrectCredentials
);
let json = serde_json::to_vec(&Login {
Expand All @@ -284,14 +331,17 @@ pub mod tests {
.unwrap();

assert_eq!(
context.authenticate_request(&json).await.unwrap_err(),
context
.authenticate_request("peer", &json)
.await
.unwrap_err(),
AuthenticationError::IncorrectCredentials
);
}

#[actix_web::test]
async fn pass_authentication() {
let context = build_test_context(
let mut context = build_test_context(
Vec::new(),
[("test_user".to_string(), "password".to_string())],
);
Expand All @@ -302,7 +352,10 @@ pub mod tests {
.unwrap();

assert_eq!(
context.authenticate_request(&json).await.unwrap_err(),
context
.authenticate_request("peer", &json)
.await
.unwrap_err(),
AuthenticationError::IncorrectCredentials
);
}
Expand Down
Loading

0 comments on commit e5f60ac

Please sign in to comment.