diff --git a/src/authentication.rs b/src/authentication.rs index 5da227b..2f28555 100644 --- a/src/authentication.rs +++ b/src/authentication.rs @@ -14,5 +14,6 @@ pub mod authentication_context; pub mod authentication_errors; pub mod authentication_service; +pub mod ban_patrol; pub mod linux_authenticator; pub mod passwd_validator; diff --git a/src/authentication/authentication_context.rs b/src/authentication/authentication_context.rs index 6d475ab..c9e07b7 100644 --- a/src/authentication/authentication_context.rs +++ b/src/authentication/authentication_context.rs @@ -13,6 +13,7 @@ // limitations under the License. use super::authentication_errors::AuthenticationError; use super::authentication_errors::SchemedAuthError; +use super::ban_patrol::BanPatrol; use super::passwd_validator::PasswordValidator; use super::passwd_validator::UnixValidator; use base64::{engine::general_purpose, Engine as _}; @@ -23,18 +24,17 @@ use serde::Deserialize; use serde::Serialize; use std::collections::HashMap; use std::marker::PhantomData; -use tokio::sync::Mutex; use tokio::time::{Duration, Instant}; pub struct AuthenticationContext

where P: PasswordValidator + 'static, { - token_store: Mutex>, + token_store: HashMap, passwds: HashMap, password_validator: PhantomData

, expire_timeout: Duration, - authentication_attempts: usize, + ban_patrol: BanPatrol, } impl

AuthenticationContext

@@ -47,11 +47,11 @@ where authentication_attempts: usize, ) -> AuthenticationContext { AuthenticationContext:: { - token_store: Mutex::new(HashMap::new()), + token_store: HashMap::new(), passwds: HashMap::from_iter(password_entries), password_validator: PhantomData::, expire_timeout, - authentication_attempts, + ban_patrol: BanPatrol::new(authentication_attempts), } } @@ -59,48 +59,75 @@ where /// request. This imposes a small penalty on each request. Its deemed not /// significant enough to justify optimization given the expected volume /// of incoming authentication requests. - async fn new_and_remove_expired_tokens(&self, key: String) { - let mut store = self.token_store.lock().await; - - store.retain(|_, last_access| { + async fn new_and_remove_expired_tokens(&mut self, key: String) { + self.token_store.retain(|_, last_access| { let duration = Instant::now().saturating_duration_since(*last_access); duration <= self.expire_timeout }); - store.insert(key, Instant::now()); + self.token_store.insert(key, Instant::now()); } - async fn authorize_bearer(&self, token: &str) -> Result<(), AuthenticationError> { - let mut store = self.token_store.lock().await; - let Some(last_access) = store.get_mut(token) else { - return Err(AuthenticationError::NoMatch(token.to_string())); + async fn authorize_bearer( + &mut self, + peer: &str, + token: &str, + ) -> Result<(), AuthenticationError> { + self.ban_patrol.patrole_ban(peer.to_string())?; + + let Some(last_access) = self.token_store.get_mut(token) else { + return Err(self + .ban_patrol + .penalize(peer.to_string()) + .err() + .unwrap_or(AuthenticationError::NoMatch(token.to_string()))); }; let instant = *last_access; let duration = Instant::now().saturating_duration_since(instant); if duration < self.expire_timeout { *last_access = Instant::now(); + self.ban_patrol.clear_penalties(peer); return Ok(()); } - store.remove(token); + self.token_store.remove(token); Err(AuthenticationError::TokenExpired(instant)) } fn validate_credentials( - &self, + &mut self, + peer: &str, username: &str, password: &str, ) -> Result<(), AuthenticationError> { - let Some(pass) = self.passwds.get(username) else { - log::debug!("user {} not in database", username); - return Err(AuthenticationError::IncorrectCredentials); - }; - - P::validate(pass, password) + self.ban_patrol.patrole_ban(peer.to_string())?; + + match self + .passwds + .get(username) + .ok_or(AuthenticationError::IncorrectCredentials) + .and_then(|pass| P::validate(pass, password)) + { + Ok(_) => { + log::debug!("{username} validated successfully"); + self.ban_patrol.clear_penalties(peer); + Ok(()) + } + Err(AuthenticationError::IncorrectCredentials) => Err(self + .ban_patrol + .penalize(peer.to_string()) + .err() + .unwrap_or(AuthenticationError::IncorrectCredentials)), + Err(err) => Err(err), + } } - async fn authorize_basic(&self, credentials: &str) -> Result<(), AuthenticationError> { + async fn authorize_basic( + &mut self, + peer: &str, + credentials: &str, + ) -> Result<(), AuthenticationError> { let decoded = general_purpose::STANDARD.decode(credentials)?; let utf8 = std::str::from_utf8(&decoded)?; let Some((user, pass)) = utf8.split_once(':') else { @@ -109,20 +136,21 @@ where )); }; - self.validate_credentials(user, pass) + self.validate_credentials(peer, user, pass) } pub async fn authorize_request( - &self, + &mut self, + peer: &str, http_authorization_line: &str, ) -> Result<(), SchemedAuthError> { match http_authorization_line.split_once(' ') { Some(("Bearer", token)) => self - .authorize_bearer(token) + .authorize_bearer(peer, token) .await .map_err(AuthenticationError::into_bearer_error), Some(("Basic", credentials)) => self - .authorize_basic(credentials) + .authorize_basic(peer, credentials) .await .map_err(AuthenticationError::into_basic_error), Some((auth, _)) => { @@ -135,10 +163,14 @@ where } } - pub async fn authenticate_request(&self, body: &[u8]) -> Result { + pub async fn authenticate_request( + &mut self, + peer: &str, + body: &[u8], + ) -> Result { let credentials = serde_json::from_slice::(body)?; - self.validate_credentials(&credentials.username, &credentials.password)?; + self.validate_credentials(peer, &credentials.username, &credentials.password)?; let token: String = thread_rng() .sample_iter(&Alphanumeric) @@ -202,10 +234,11 @@ pub mod tests { user_data: impl IntoIterator, ) -> AuthenticationContext { AuthenticationContext { - token_store: Mutex::new(HashMap::from_iter(token_data)), + token_store: HashMap::from_iter(token_data), passwds: HashMap::from_iter(user_data), password_validator: PhantomData::, expire_timeout: Duration::from_secs(20), + ban_patrol: BanPatrol::new(10), } } @@ -213,14 +246,14 @@ pub mod tests { async fn test_token_failures() { let now = Instant::now(); let twenty_sec_ago = now.sub(Duration::from_secs(20)); - let context = build_test_context( + let mut context = build_test_context( [("123".to_string(), now), ("2".to_string(), twenty_sec_ago)], Vec::new(), ); assert_eq!( context - .authorize_request("Bearer 1234") + .authorize_request("peer", "Bearer 1234") .await .unwrap_err() .1, @@ -228,40 +261,51 @@ pub mod tests { ); assert_eq!( - context.authorize_request("Bearer 2").await.unwrap_err().1, + context + .authorize_request("peer", "Bearer 2") + .await + .unwrap_err() + .1, AuthenticationError::TokenExpired(twenty_sec_ago) ); // After expired error, the token gets removed. Subsequent calls for that token will // therefore return "NoMatch" assert_eq!( - context.authorize_request("Bearer 2").await.unwrap_err().1, + context + .authorize_request("peer", "Bearer 2") + .await + .unwrap_err() + .1, AuthenticationError::NoMatch("2".to_string()) ); } #[actix_web::test] async fn test_happy_flow() { - let context = build_test_context( + let mut context = build_test_context( [ ("123".to_string(), Instant::now()), ("2".to_string(), Instant::now().sub(Duration::from_secs(20))), ], Vec::new(), ); - assert_eq!(Ok(()), context.authorize_request("Bearer 123").await); + assert_eq!( + Ok(()), + context.authorize_request("peer1", "Bearer 123").await + ); } #[actix_web::test] async fn authentication_errors() { - let context = build_test_context( + let mut context = build_test_context( Vec::new(), [("test_user".to_string(), "password".to_string())], ); assert!(matches!( context - .authenticate_request(b"{not a valid json") + .authenticate_request("peer1", b"{not a valid json") .await .unwrap_err(), AuthenticationError::ParseError(_) @@ -274,7 +318,10 @@ pub mod tests { .unwrap(); assert_eq!( - context.authenticate_request(&json).await.unwrap_err(), + context + .authenticate_request("peer", &json) + .await + .unwrap_err(), AuthenticationError::IncorrectCredentials ); let json = serde_json::to_vec(&Login { @@ -284,14 +331,17 @@ pub mod tests { .unwrap(); assert_eq!( - context.authenticate_request(&json).await.unwrap_err(), + context + .authenticate_request("peer", &json) + .await + .unwrap_err(), AuthenticationError::IncorrectCredentials ); } #[actix_web::test] async fn pass_authentication() { - let context = build_test_context( + let mut context = build_test_context( Vec::new(), [("test_user".to_string(), "password".to_string())], ); @@ -302,7 +352,10 @@ pub mod tests { .unwrap(); assert_eq!( - context.authenticate_request(&json).await.unwrap_err(), + context + .authenticate_request("peer", &json) + .await + .unwrap_err(), AuthenticationError::IncorrectCredentials ); } diff --git a/src/authentication/authentication_service.rs b/src/authentication/authentication_service.rs index aee910f..aafb700 100644 --- a/src/authentication/authentication_service.rs +++ b/src/authentication/authentication_service.rs @@ -30,6 +30,7 @@ use std::{ rc::Rc, sync::Arc, }; +use tokio::sync::Mutex; const LOCALHOSTV4: IpAddr = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)); const LOCALHOSTV6: IpAddr = IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)); @@ -41,7 +42,7 @@ const LOCALHOSTV6: IpAddr = IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)); #[derive(Clone)] pub struct AuthenticationService { service: Rc, - context: Arc>, + context: Arc>>, authentication_path: &'static str, realm: &'static str, } @@ -49,7 +50,7 @@ pub struct AuthenticationService { impl AuthenticationService { pub fn new( service: Rc, - context: Arc>, + context: Arc>>, authentication_path: &'static str, realm: &'static str, ) -> Self { @@ -96,41 +97,26 @@ where let realm = self.realm; Box::pin(async move { + let peer = request + .connection_info() + .peer_addr() + .unwrap_or_default() + .to_string(); + let mut context = context.lock().await; + + // handle authentication requests and return if request.request().uri().path() == auth_path { - log::debug!("authentication request"); - let mut buffer = Vec::new(); - while let Some(Ok(bytes)) = request.parts_mut().1.next().await { - buffer.extend_from_slice(&bytes); - } - - let response = match context.authenticate_request(&buffer).await { - Ok(session) => { - authenticated_response(request.request(), session.id.clone(), session) - } - Err(error) => forbidden_response(request.request(), error), - }; - - return response; + return authentication_request(&mut request, &peer, &mut context).await; } - log::debug!("authorize request"); - let parse_result = request - .headers() - .get(header::AUTHORIZATION) - .ok_or(AuthenticationError::Empty) - .and_then(|auth| { - auth.to_str() - .map_err(|e| AuthenticationError::HttpParseError(e.to_string())) - }); - - let auth = match parse_result { + let auth = match parse_authorization_header(&request) { Ok(p) => p, Err(e) => { return unauthorized_response(request.request(), e.into_basic_error(), realm) } }; - if let Err(e) = context.authorize_request(auth).await { + if let Err(e) = context.authorize_request(&peer, auth).await { unauthorized_response(request.request(), e, realm) } else { service @@ -142,6 +128,37 @@ where } } +async fn authentication_request( + request: &mut ServiceRequest, + peer: &str, + context: &mut AuthenticationContext, +) -> Result>, Error> { + log::debug!("authentication request"); + let mut buffer = Vec::new(); + while let Some(Ok(bytes)) = request.parts_mut().1.next().await { + buffer.extend_from_slice(&bytes); + } + + let response = match context.authenticate_request(peer, &buffer).await { + Ok(session) => authenticated_response(request.request(), session.id.clone(), session), + Err(error) => forbidden_response(request.request(), error), + }; + + response +} + +fn parse_authorization_header(request: &ServiceRequest) -> Result<&str, AuthenticationError> { + log::debug!("authorize request"); + request + .headers() + .get(header::AUTHORIZATION) + .ok_or(AuthenticationError::Empty) + .and_then(|auth| { + auth.to_str() + .map_err(|e| AuthenticationError::HttpParseError(e.to_string())) + }) +} + fn forbidden_response( request: &HttpRequest, response_text: E, diff --git a/src/authentication/ban_patrol.rs b/src/authentication/ban_patrol.rs new file mode 100644 index 0000000..6957301 --- /dev/null +++ b/src/authentication/ban_patrol.rs @@ -0,0 +1,156 @@ +// Copyright 2023 Turing Machines +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +use super::authentication_errors::AuthenticationError; +use std::{collections::HashMap, ops::Add, time::Duration}; +use tokio::time::Instant; + +const BAN_LEVELS: usize = 10; +const BAN_DURATION: Duration = Duration::from_secs(60); + +/// Book-keeps how many consecutive failed attempts a given peer has made to +/// authenticate itself. When a given threshold is exceeded, every consecutive +/// failed attempt exponentially increases the cool down period in which the +/// peer is blocked from authenticate itself up to a limit declared in +/// [`BAN_LEVELS`]. +pub struct BanPatrol { + penalties: HashMap, + max_authentication_attempts: usize, +} + +impl BanPatrol { + /// construct an instance. `max_authentication_attempts` is not allowed to + /// be 0. + pub fn new(max_authentication_attempts: usize) -> Self { + assert!( + max_authentication_attempts != 0, + "chosing a value of 0 for the maximum permitted \ + authentication attempts is not allowed." + ); + + Self { + penalties: HashMap::new(), + max_authentication_attempts, + } + } + + /// Verifies if a given peer is banned and therefore is denied to authenticate + pub fn patrole_ban(&mut self, peer: String) -> Result<(), AuthenticationError> { + let Some((attempts, start_time)) = self.penalties.get(&peer) else { + return Ok(()); + }; + + self.patrole_peer(attempts, start_time) + } + + pub fn clear_penalties(&mut self, peer: &str) { + self.penalties.remove(peer); + } + + pub fn penalize(&mut self, peer: String) -> Result<(), AuthenticationError> { + let (attempts, start_time) = self.penalties.entry(peer).or_insert((0, Instant::now())); + *attempts += 1; + + // TODO: cannot re-alias &mut _ to &_? + let attempts = *attempts; + let time = *start_time; + self.patrole_peer(&attempts, &time) + } + + fn patrole_peer( + &self, + attempts: &usize, + start_time: &Instant, + ) -> Result<(), AuthenticationError> { + if attempts >= &self.max_authentication_attempts { + let serverity = 1 << (attempts - self.max_authentication_attempts).min(BAN_LEVELS); + let deadline = start_time.add(BAN_DURATION * serverity); + + if deadline > Instant::now() { + return Err(AuthenticationError::ExceededAllowedAttempts(deadline)); + } + } + + Ok(()) + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_parole_consequtive() { + let mut patrol = BanPatrol::new(1); + assert!(patrol.patrole_ban("peer".to_string()).is_ok()); + assert!(patrol.patrole_ban("peer2".to_string()).is_ok()); + assert!(patrol.patrole_ban("peer".to_string()).is_ok()); + } + + #[test] + fn test_penalties() { + let mut patrol = BanPatrol::new(1); + assert!(patrol.patrole_ban("peer".to_string()).is_ok()); + assert!(patrol.patrole_ban("peer1".to_string()).is_ok()); + + let expected = Instant::now().add(BAN_DURATION); + let denied = patrol.penalize("peer1".to_string()).unwrap_err(); + assert!( + matches!(denied, AuthenticationError::ExceededAllowedAttempts(expiry) if expiry > expected) + ); + let denied = patrol.patrole_ban("peer1".to_string()).unwrap_err(); + assert!( + matches!(denied, AuthenticationError::ExceededAllowedAttempts(expiry) if expiry > expected) + ); + assert!(patrol.patrole_ban("peer".to_string()).is_ok()); + let denied = patrol.penalize("peer".to_string()).unwrap_err(); + assert!( + matches!(denied, AuthenticationError::ExceededAllowedAttempts(expiry) if expiry > expected) + ); + let denied = patrol.patrole_ban("peer".to_string()).unwrap_err(); + assert!( + matches!(denied, AuthenticationError::ExceededAllowedAttempts(expiry) if expiry > expected) + ); + + // clear peer1 + patrol.clear_penalties("peer1"); + assert!(patrol.patrole_ban("peer1".to_string()).is_ok()); + let denied = patrol.patrole_ban("peer".to_string()).unwrap_err(); + assert!( + matches!(denied, AuthenticationError::ExceededAllowedAttempts(expiry) if expiry > expected) + ); + } + + #[test] + fn test_penalty_expiry_cap() { + let mut patrol = BanPatrol::new(1); + let now = Instant::now(); + (0..BAN_LEVELS).for_each(|i|{ + let expected = now.add((1 < expected) + ); + }); + + let start_time = patrol.penalties.get("peer").unwrap().1; + assert!(patrol.penalize("peer".to_string()).is_err()); + assert!(patrol.penalize("peer".to_string()).is_err()); + let expected = start_time.add((1 << BAN_LEVELS) * BAN_DURATION); + let denied = patrol.patrole_ban("peer".to_string()).unwrap_err(); + assert!( + matches!(denied, AuthenticationError::ExceededAllowedAttempts(expiry) if expiry == expected) + ); + } +} diff --git a/src/authentication/linux_authenticator.rs b/src/authentication/linux_authenticator.rs index cc7ad48..35d1394 100644 --- a/src/authentication/linux_authenticator.rs +++ b/src/authentication/linux_authenticator.rs @@ -29,12 +29,13 @@ use std::{rc::Rc, sync::Arc}; use tokio::{ fs::OpenOptions, io::{AsyncBufReadExt, BufReader}, + sync::Mutex, }; type LinuxContext = AuthenticationContext; pub struct LinuxAuthenticator { - context: Arc, + context: Arc>, authentication_path: &'static str, realm: &'static str, } @@ -48,11 +49,11 @@ impl LinuxAuthenticator { ) -> io::Result { let password_entries = Self::parse_shadow_file().await?; Ok(Self { - context: Arc::new(LinuxContext::with_unix_validator( + context: Arc::new(Mutex::new(LinuxContext::with_unix_validator( password_entries, authentication_token_duration, authentication_attemps, - )), + ))), authentication_path, realm, }) diff --git a/src/firmware_update/rockusb_fwudate.rs b/src/firmware_update/rockusb_fwudate.rs index 3a4d455..f8c1dbe 100644 --- a/src/firmware_update/rockusb_fwudate.rs +++ b/src/firmware_update/rockusb_fwudate.rs @@ -89,10 +89,10 @@ fn parse_boot_header_entry( let boot_entry = parse_boot_entry(blob, &range); let name = String::from_utf16(boot_entry.name.as_slice()).unwrap_or_default(); log::debug!( - "Found boot entry [{:x}] {} {} KiB", + "Found boot entry [{:x}] {} {}", entry_type, name, - boot_entry.data_size / 1024, + humansize::format_size(boot_entry.data_size, humansize::DECIMAL) ); if boot_entry.size == 0 {