diff --git a/src/authentication.rs b/src/authentication.rs
index 5da227b..2f28555 100644
--- a/src/authentication.rs
+++ b/src/authentication.rs
@@ -14,5 +14,6 @@
pub mod authentication_context;
pub mod authentication_errors;
pub mod authentication_service;
+pub mod ban_patrol;
pub mod linux_authenticator;
pub mod passwd_validator;
diff --git a/src/authentication/authentication_context.rs b/src/authentication/authentication_context.rs
index 6d475ab..c9e07b7 100644
--- a/src/authentication/authentication_context.rs
+++ b/src/authentication/authentication_context.rs
@@ -13,6 +13,7 @@
// limitations under the License.
use super::authentication_errors::AuthenticationError;
use super::authentication_errors::SchemedAuthError;
+use super::ban_patrol::BanPatrol;
use super::passwd_validator::PasswordValidator;
use super::passwd_validator::UnixValidator;
use base64::{engine::general_purpose, Engine as _};
@@ -23,18 +24,17 @@ use serde::Deserialize;
use serde::Serialize;
use std::collections::HashMap;
use std::marker::PhantomData;
-use tokio::sync::Mutex;
use tokio::time::{Duration, Instant};
pub struct AuthenticationContext
where
P: PasswordValidator + 'static,
{
- token_store: Mutex>,
+ token_store: HashMap,
passwds: HashMap,
password_validator: PhantomData,
expire_timeout: Duration,
- authentication_attempts: usize,
+ ban_patrol: BanPatrol,
}
impl
AuthenticationContext
@@ -47,11 +47,11 @@ where
authentication_attempts: usize,
) -> AuthenticationContext {
AuthenticationContext:: {
- token_store: Mutex::new(HashMap::new()),
+ token_store: HashMap::new(),
passwds: HashMap::from_iter(password_entries),
password_validator: PhantomData::,
expire_timeout,
- authentication_attempts,
+ ban_patrol: BanPatrol::new(authentication_attempts),
}
}
@@ -59,48 +59,75 @@ where
/// request. This imposes a small penalty on each request. Its deemed not
/// significant enough to justify optimization given the expected volume
/// of incoming authentication requests.
- async fn new_and_remove_expired_tokens(&self, key: String) {
- let mut store = self.token_store.lock().await;
-
- store.retain(|_, last_access| {
+ async fn new_and_remove_expired_tokens(&mut self, key: String) {
+ self.token_store.retain(|_, last_access| {
let duration = Instant::now().saturating_duration_since(*last_access);
duration <= self.expire_timeout
});
- store.insert(key, Instant::now());
+ self.token_store.insert(key, Instant::now());
}
- async fn authorize_bearer(&self, token: &str) -> Result<(), AuthenticationError> {
- let mut store = self.token_store.lock().await;
- let Some(last_access) = store.get_mut(token) else {
- return Err(AuthenticationError::NoMatch(token.to_string()));
+ async fn authorize_bearer(
+ &mut self,
+ peer: &str,
+ token: &str,
+ ) -> Result<(), AuthenticationError> {
+ self.ban_patrol.patrole_ban(peer.to_string())?;
+
+ let Some(last_access) = self.token_store.get_mut(token) else {
+ return Err(self
+ .ban_patrol
+ .penalize(peer.to_string())
+ .err()
+ .unwrap_or(AuthenticationError::NoMatch(token.to_string())));
};
let instant = *last_access;
let duration = Instant::now().saturating_duration_since(instant);
if duration < self.expire_timeout {
*last_access = Instant::now();
+ self.ban_patrol.clear_penalties(peer);
return Ok(());
}
- store.remove(token);
+ self.token_store.remove(token);
Err(AuthenticationError::TokenExpired(instant))
}
fn validate_credentials(
- &self,
+ &mut self,
+ peer: &str,
username: &str,
password: &str,
) -> Result<(), AuthenticationError> {
- let Some(pass) = self.passwds.get(username) else {
- log::debug!("user {} not in database", username);
- return Err(AuthenticationError::IncorrectCredentials);
- };
-
- P::validate(pass, password)
+ self.ban_patrol.patrole_ban(peer.to_string())?;
+
+ match self
+ .passwds
+ .get(username)
+ .ok_or(AuthenticationError::IncorrectCredentials)
+ .and_then(|pass| P::validate(pass, password))
+ {
+ Ok(_) => {
+ log::debug!("{username} validated successfully");
+ self.ban_patrol.clear_penalties(peer);
+ Ok(())
+ }
+ Err(AuthenticationError::IncorrectCredentials) => Err(self
+ .ban_patrol
+ .penalize(peer.to_string())
+ .err()
+ .unwrap_or(AuthenticationError::IncorrectCredentials)),
+ Err(err) => Err(err),
+ }
}
- async fn authorize_basic(&self, credentials: &str) -> Result<(), AuthenticationError> {
+ async fn authorize_basic(
+ &mut self,
+ peer: &str,
+ credentials: &str,
+ ) -> Result<(), AuthenticationError> {
let decoded = general_purpose::STANDARD.decode(credentials)?;
let utf8 = std::str::from_utf8(&decoded)?;
let Some((user, pass)) = utf8.split_once(':') else {
@@ -109,20 +136,21 @@ where
));
};
- self.validate_credentials(user, pass)
+ self.validate_credentials(peer, user, pass)
}
pub async fn authorize_request(
- &self,
+ &mut self,
+ peer: &str,
http_authorization_line: &str,
) -> Result<(), SchemedAuthError> {
match http_authorization_line.split_once(' ') {
Some(("Bearer", token)) => self
- .authorize_bearer(token)
+ .authorize_bearer(peer, token)
.await
.map_err(AuthenticationError::into_bearer_error),
Some(("Basic", credentials)) => self
- .authorize_basic(credentials)
+ .authorize_basic(peer, credentials)
.await
.map_err(AuthenticationError::into_basic_error),
Some((auth, _)) => {
@@ -135,10 +163,14 @@ where
}
}
- pub async fn authenticate_request(&self, body: &[u8]) -> Result {
+ pub async fn authenticate_request(
+ &mut self,
+ peer: &str,
+ body: &[u8],
+ ) -> Result {
let credentials = serde_json::from_slice::(body)?;
- self.validate_credentials(&credentials.username, &credentials.password)?;
+ self.validate_credentials(peer, &credentials.username, &credentials.password)?;
let token: String = thread_rng()
.sample_iter(&Alphanumeric)
@@ -202,10 +234,11 @@ pub mod tests {
user_data: impl IntoIterator- ,
) -> AuthenticationContext {
AuthenticationContext {
- token_store: Mutex::new(HashMap::from_iter(token_data)),
+ token_store: HashMap::from_iter(token_data),
passwds: HashMap::from_iter(user_data),
password_validator: PhantomData::,
expire_timeout: Duration::from_secs(20),
+ ban_patrol: BanPatrol::new(10),
}
}
@@ -213,14 +246,14 @@ pub mod tests {
async fn test_token_failures() {
let now = Instant::now();
let twenty_sec_ago = now.sub(Duration::from_secs(20));
- let context = build_test_context(
+ let mut context = build_test_context(
[("123".to_string(), now), ("2".to_string(), twenty_sec_ago)],
Vec::new(),
);
assert_eq!(
context
- .authorize_request("Bearer 1234")
+ .authorize_request("peer", "Bearer 1234")
.await
.unwrap_err()
.1,
@@ -228,40 +261,51 @@ pub mod tests {
);
assert_eq!(
- context.authorize_request("Bearer 2").await.unwrap_err().1,
+ context
+ .authorize_request("peer", "Bearer 2")
+ .await
+ .unwrap_err()
+ .1,
AuthenticationError::TokenExpired(twenty_sec_ago)
);
// After expired error, the token gets removed. Subsequent calls for that token will
// therefore return "NoMatch"
assert_eq!(
- context.authorize_request("Bearer 2").await.unwrap_err().1,
+ context
+ .authorize_request("peer", "Bearer 2")
+ .await
+ .unwrap_err()
+ .1,
AuthenticationError::NoMatch("2".to_string())
);
}
#[actix_web::test]
async fn test_happy_flow() {
- let context = build_test_context(
+ let mut context = build_test_context(
[
("123".to_string(), Instant::now()),
("2".to_string(), Instant::now().sub(Duration::from_secs(20))),
],
Vec::new(),
);
- assert_eq!(Ok(()), context.authorize_request("Bearer 123").await);
+ assert_eq!(
+ Ok(()),
+ context.authorize_request("peer1", "Bearer 123").await
+ );
}
#[actix_web::test]
async fn authentication_errors() {
- let context = build_test_context(
+ let mut context = build_test_context(
Vec::new(),
[("test_user".to_string(), "password".to_string())],
);
assert!(matches!(
context
- .authenticate_request(b"{not a valid json")
+ .authenticate_request("peer1", b"{not a valid json")
.await
.unwrap_err(),
AuthenticationError::ParseError(_)
@@ -274,7 +318,10 @@ pub mod tests {
.unwrap();
assert_eq!(
- context.authenticate_request(&json).await.unwrap_err(),
+ context
+ .authenticate_request("peer", &json)
+ .await
+ .unwrap_err(),
AuthenticationError::IncorrectCredentials
);
let json = serde_json::to_vec(&Login {
@@ -284,14 +331,17 @@ pub mod tests {
.unwrap();
assert_eq!(
- context.authenticate_request(&json).await.unwrap_err(),
+ context
+ .authenticate_request("peer", &json)
+ .await
+ .unwrap_err(),
AuthenticationError::IncorrectCredentials
);
}
#[actix_web::test]
async fn pass_authentication() {
- let context = build_test_context(
+ let mut context = build_test_context(
Vec::new(),
[("test_user".to_string(), "password".to_string())],
);
@@ -302,7 +352,10 @@ pub mod tests {
.unwrap();
assert_eq!(
- context.authenticate_request(&json).await.unwrap_err(),
+ context
+ .authenticate_request("peer", &json)
+ .await
+ .unwrap_err(),
AuthenticationError::IncorrectCredentials
);
}
diff --git a/src/authentication/authentication_service.rs b/src/authentication/authentication_service.rs
index aee910f..aafb700 100644
--- a/src/authentication/authentication_service.rs
+++ b/src/authentication/authentication_service.rs
@@ -30,6 +30,7 @@ use std::{
rc::Rc,
sync::Arc,
};
+use tokio::sync::Mutex;
const LOCALHOSTV4: IpAddr = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
const LOCALHOSTV6: IpAddr = IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1));
@@ -41,7 +42,7 @@ const LOCALHOSTV6: IpAddr = IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1));
#[derive(Clone)]
pub struct AuthenticationService
{
service: Rc,
- context: Arc>,
+ context: Arc>>,
authentication_path: &'static str,
realm: &'static str,
}
@@ -49,7 +50,7 @@ pub struct AuthenticationService {
impl AuthenticationService {
pub fn new(
service: Rc,
- context: Arc>,
+ context: Arc>>,
authentication_path: &'static str,
realm: &'static str,
) -> Self {
@@ -96,41 +97,26 @@ where
let realm = self.realm;
Box::pin(async move {
+ let peer = request
+ .connection_info()
+ .peer_addr()
+ .unwrap_or_default()
+ .to_string();
+ let mut context = context.lock().await;
+
+ // handle authentication requests and return
if request.request().uri().path() == auth_path {
- log::debug!("authentication request");
- let mut buffer = Vec::new();
- while let Some(Ok(bytes)) = request.parts_mut().1.next().await {
- buffer.extend_from_slice(&bytes);
- }
-
- let response = match context.authenticate_request(&buffer).await {
- Ok(session) => {
- authenticated_response(request.request(), session.id.clone(), session)
- }
- Err(error) => forbidden_response(request.request(), error),
- };
-
- return response;
+ return authentication_request(&mut request, &peer, &mut context).await;
}
- log::debug!("authorize request");
- let parse_result = request
- .headers()
- .get(header::AUTHORIZATION)
- .ok_or(AuthenticationError::Empty)
- .and_then(|auth| {
- auth.to_str()
- .map_err(|e| AuthenticationError::HttpParseError(e.to_string()))
- });
-
- let auth = match parse_result {
+ let auth = match parse_authorization_header(&request) {
Ok(p) => p,
Err(e) => {
return unauthorized_response(request.request(), e.into_basic_error(), realm)
}
};
- if let Err(e) = context.authorize_request(auth).await {
+ if let Err(e) = context.authorize_request(&peer, auth).await {
unauthorized_response(request.request(), e, realm)
} else {
service
@@ -142,6 +128,37 @@ where
}
}
+async fn authentication_request(
+ request: &mut ServiceRequest,
+ peer: &str,
+ context: &mut AuthenticationContext,
+) -> Result>, Error> {
+ log::debug!("authentication request");
+ let mut buffer = Vec::new();
+ while let Some(Ok(bytes)) = request.parts_mut().1.next().await {
+ buffer.extend_from_slice(&bytes);
+ }
+
+ let response = match context.authenticate_request(peer, &buffer).await {
+ Ok(session) => authenticated_response(request.request(), session.id.clone(), session),
+ Err(error) => forbidden_response(request.request(), error),
+ };
+
+ response
+}
+
+fn parse_authorization_header(request: &ServiceRequest) -> Result<&str, AuthenticationError> {
+ log::debug!("authorize request");
+ request
+ .headers()
+ .get(header::AUTHORIZATION)
+ .ok_or(AuthenticationError::Empty)
+ .and_then(|auth| {
+ auth.to_str()
+ .map_err(|e| AuthenticationError::HttpParseError(e.to_string()))
+ })
+}
+
fn forbidden_response(
request: &HttpRequest,
response_text: E,
diff --git a/src/authentication/ban_patrol.rs b/src/authentication/ban_patrol.rs
new file mode 100644
index 0000000..6957301
--- /dev/null
+++ b/src/authentication/ban_patrol.rs
@@ -0,0 +1,156 @@
+// Copyright 2023 Turing Machines
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+use super::authentication_errors::AuthenticationError;
+use std::{collections::HashMap, ops::Add, time::Duration};
+use tokio::time::Instant;
+
+const BAN_LEVELS: usize = 10;
+const BAN_DURATION: Duration = Duration::from_secs(60);
+
+/// Book-keeps how many consecutive failed attempts a given peer has made to
+/// authenticate itself. When a given threshold is exceeded, every consecutive
+/// failed attempt exponentially increases the cool down period in which the
+/// peer is blocked from authenticate itself up to a limit declared in
+/// [`BAN_LEVELS`].
+pub struct BanPatrol {
+ penalties: HashMap,
+ max_authentication_attempts: usize,
+}
+
+impl BanPatrol {
+ /// construct an instance. `max_authentication_attempts` is not allowed to
+ /// be 0.
+ pub fn new(max_authentication_attempts: usize) -> Self {
+ assert!(
+ max_authentication_attempts != 0,
+ "chosing a value of 0 for the maximum permitted \
+ authentication attempts is not allowed."
+ );
+
+ Self {
+ penalties: HashMap::new(),
+ max_authentication_attempts,
+ }
+ }
+
+ /// Verifies if a given peer is banned and therefore is denied to authenticate
+ pub fn patrole_ban(&mut self, peer: String) -> Result<(), AuthenticationError> {
+ let Some((attempts, start_time)) = self.penalties.get(&peer) else {
+ return Ok(());
+ };
+
+ self.patrole_peer(attempts, start_time)
+ }
+
+ pub fn clear_penalties(&mut self, peer: &str) {
+ self.penalties.remove(peer);
+ }
+
+ pub fn penalize(&mut self, peer: String) -> Result<(), AuthenticationError> {
+ let (attempts, start_time) = self.penalties.entry(peer).or_insert((0, Instant::now()));
+ *attempts += 1;
+
+ // TODO: cannot re-alias &mut _ to &_?
+ let attempts = *attempts;
+ let time = *start_time;
+ self.patrole_peer(&attempts, &time)
+ }
+
+ fn patrole_peer(
+ &self,
+ attempts: &usize,
+ start_time: &Instant,
+ ) -> Result<(), AuthenticationError> {
+ if attempts >= &self.max_authentication_attempts {
+ let serverity = 1 << (attempts - self.max_authentication_attempts).min(BAN_LEVELS);
+ let deadline = start_time.add(BAN_DURATION * serverity);
+
+ if deadline > Instant::now() {
+ return Err(AuthenticationError::ExceededAllowedAttempts(deadline));
+ }
+ }
+
+ Ok(())
+ }
+}
+
+#[cfg(test)]
+mod test {
+ use super::*;
+
+ #[test]
+ fn test_parole_consequtive() {
+ let mut patrol = BanPatrol::new(1);
+ assert!(patrol.patrole_ban("peer".to_string()).is_ok());
+ assert!(patrol.patrole_ban("peer2".to_string()).is_ok());
+ assert!(patrol.patrole_ban("peer".to_string()).is_ok());
+ }
+
+ #[test]
+ fn test_penalties() {
+ let mut patrol = BanPatrol::new(1);
+ assert!(patrol.patrole_ban("peer".to_string()).is_ok());
+ assert!(patrol.patrole_ban("peer1".to_string()).is_ok());
+
+ let expected = Instant::now().add(BAN_DURATION);
+ let denied = patrol.penalize("peer1".to_string()).unwrap_err();
+ assert!(
+ matches!(denied, AuthenticationError::ExceededAllowedAttempts(expiry) if expiry > expected)
+ );
+ let denied = patrol.patrole_ban("peer1".to_string()).unwrap_err();
+ assert!(
+ matches!(denied, AuthenticationError::ExceededAllowedAttempts(expiry) if expiry > expected)
+ );
+ assert!(patrol.patrole_ban("peer".to_string()).is_ok());
+ let denied = patrol.penalize("peer".to_string()).unwrap_err();
+ assert!(
+ matches!(denied, AuthenticationError::ExceededAllowedAttempts(expiry) if expiry > expected)
+ );
+ let denied = patrol.patrole_ban("peer".to_string()).unwrap_err();
+ assert!(
+ matches!(denied, AuthenticationError::ExceededAllowedAttempts(expiry) if expiry > expected)
+ );
+
+ // clear peer1
+ patrol.clear_penalties("peer1");
+ assert!(patrol.patrole_ban("peer1".to_string()).is_ok());
+ let denied = patrol.patrole_ban("peer".to_string()).unwrap_err();
+ assert!(
+ matches!(denied, AuthenticationError::ExceededAllowedAttempts(expiry) if expiry > expected)
+ );
+ }
+
+ #[test]
+ fn test_penalty_expiry_cap() {
+ let mut patrol = BanPatrol::new(1);
+ let now = Instant::now();
+ (0..BAN_LEVELS).for_each(|i|{
+ let expected = now.add((1 < expected)
+ );
+ });
+
+ let start_time = patrol.penalties.get("peer").unwrap().1;
+ assert!(patrol.penalize("peer".to_string()).is_err());
+ assert!(patrol.penalize("peer".to_string()).is_err());
+ let expected = start_time.add((1 << BAN_LEVELS) * BAN_DURATION);
+ let denied = patrol.patrole_ban("peer".to_string()).unwrap_err();
+ assert!(
+ matches!(denied, AuthenticationError::ExceededAllowedAttempts(expiry) if expiry == expected)
+ );
+ }
+}
diff --git a/src/authentication/linux_authenticator.rs b/src/authentication/linux_authenticator.rs
index cc7ad48..35d1394 100644
--- a/src/authentication/linux_authenticator.rs
+++ b/src/authentication/linux_authenticator.rs
@@ -29,12 +29,13 @@ use std::{rc::Rc, sync::Arc};
use tokio::{
fs::OpenOptions,
io::{AsyncBufReadExt, BufReader},
+ sync::Mutex,
};
type LinuxContext = AuthenticationContext;
pub struct LinuxAuthenticator {
- context: Arc,
+ context: Arc>,
authentication_path: &'static str,
realm: &'static str,
}
@@ -48,11 +49,11 @@ impl LinuxAuthenticator {
) -> io::Result {
let password_entries = Self::parse_shadow_file().await?;
Ok(Self {
- context: Arc::new(LinuxContext::with_unix_validator(
+ context: Arc::new(Mutex::new(LinuxContext::with_unix_validator(
password_entries,
authentication_token_duration,
authentication_attemps,
- )),
+ ))),
authentication_path,
realm,
})
diff --git a/src/firmware_update/rockusb_fwudate.rs b/src/firmware_update/rockusb_fwudate.rs
index 3a4d455..f8c1dbe 100644
--- a/src/firmware_update/rockusb_fwudate.rs
+++ b/src/firmware_update/rockusb_fwudate.rs
@@ -89,10 +89,10 @@ fn parse_boot_header_entry(
let boot_entry = parse_boot_entry(blob, &range);
let name = String::from_utf16(boot_entry.name.as_slice()).unwrap_or_default();
log::debug!(
- "Found boot entry [{:x}] {} {} KiB",
+ "Found boot entry [{:x}] {} {}",
entry_type,
name,
- boot_entry.data_size / 1024,
+ humansize::format_size(boot_entry.data_size, humansize::DECIMAL)
);
if boot_entry.size == 0 {