diff --git a/rama-http/src/layer/auth/require_authorization.rs b/rama-http/src/layer/auth/require_authorization.rs index 77379b55..e82fd7c8 100644 --- a/rama-http/src/layer/auth/require_authorization.rs +++ b/rama-http/src/layer/auth/require_authorization.rs @@ -64,9 +64,39 @@ use crate::{ }; use rama_core::Context; +use rama_net::user::UserId; + const BASE64: base64::engine::GeneralPurpose = base64::engine::general_purpose::STANDARD; -impl ValidateRequestHeader> { +impl ValidateRequestHeaderLayer> { + /// Allow anonymous requests. + pub fn set_allow_anonymous(&mut self, allow_anonymous: bool) -> &mut Self { + self.validate.allow_anonymous = allow_anonymous; + self + } + + /// Allow anonymous requests. + pub fn with_allow_anonymous(mut self, allow_anonymous: bool) -> Self { + self.validate.allow_anonymous = allow_anonymous; + self + } +} + +impl ValidateRequestHeader> { + /// Allow anonymous requests. + pub fn set_allow_anonymous(&mut self, allow_anonymous: bool) -> &mut Self { + self.validate.allow_anonymous = allow_anonymous; + self + } + + /// Allow anonymous requests. + pub fn with_allow_anonymous(mut self, allow_anonymous: bool) -> Self { + self.validate.allow_anonymous = allow_anonymous; + self + } +} + +impl ValidateRequestHeader>> { /// Authorize requests using a username and password pair. /// /// The `Authorization` header is required to be `Basic {credentials}` where `credentials` is @@ -78,11 +108,11 @@ impl ValidateRequestHeader> { where ResBody: Default, { - Self::custom(inner, Basic::new(username, value)) + Self::custom(inner, AuthorizeContext::new(Basic::new(username, value))) } } -impl ValidateRequestHeaderLayer> { +impl ValidateRequestHeaderLayer>> { /// Authorize requests using a username and password pair. /// /// The `Authorization` header is required to be `Basic {credentials}` where `credentials` is @@ -94,11 +124,11 @@ impl ValidateRequestHeaderLayer> { where ResBody: Default, { - Self::custom(Basic::new(username, password)) + Self::custom(AuthorizeContext::new(Basic::new(username, password))) } } -impl ValidateRequestHeader> { +impl ValidateRequestHeader>> { /// Authorize requests using a "bearer token". Commonly used for OAuth 2. /// /// The `Authorization` header is required to be `Bearer {token}`. @@ -110,11 +140,11 @@ impl ValidateRequestHeader> { where ResBody: Default, { - Self::custom(inner, Bearer::new(token)) + Self::custom(inner, AuthorizeContext::new(Bearer::new(token))) } } -impl ValidateRequestHeaderLayer> { +impl ValidateRequestHeaderLayer>> { /// Authorize requests using a "bearer token". Commonly used for OAuth 2. /// /// The `Authorization` header is required to be `Bearer {token}`. @@ -126,7 +156,7 @@ impl ValidateRequestHeaderLayer> { where ResBody: Default, { - Self::custom(Bearer::new(token)) + Self::custom(AuthorizeContext::new(Bearer::new(token))) } } @@ -169,7 +199,7 @@ impl fmt::Debug for Bearer { } } -impl ValidateRequest for Bearer +impl ValidateRequest for AuthorizeContext> where ResBody: Default + Send + 'static, B: Send + 'static, @@ -183,7 +213,12 @@ where request: Request, ) -> Result<(Context, Request), Response> { match request.headers().get(header::AUTHORIZATION) { - Some(actual) if actual == self.header_value => Ok((ctx, request)), + Some(actual) if actual == self.credential.header_value => Ok((ctx, request)), + None if self.allow_anonymous => { + let mut ctx = ctx; + ctx.insert(UserId::Anonymous); + Ok((ctx, request)) + } _ => { let mut res = Response::new(ResBody::default()); *res.status_mut() = StatusCode::UNAUTHORIZED; @@ -232,7 +267,7 @@ impl fmt::Debug for Basic { } } -impl ValidateRequest for Basic +impl ValidateRequest for AuthorizeContext> where ResBody: Default + Send + 'static, B: Send + 'static, @@ -246,7 +281,12 @@ where request: Request, ) -> Result<(Context, Request), Response> { match request.headers().get(header::AUTHORIZATION) { - Some(actual) if actual == self.header_value => Ok((ctx, request)), + Some(actual) if actual == self.credential.header_value => Ok((ctx, request)), + None if self.allow_anonymous => { + let mut ctx = ctx; + ctx.insert(UserId::Anonymous); + Ok((ctx, request)) + } _ => { let mut res = Response::new(ResBody::default()); *res.status_mut() = StatusCode::UNAUTHORIZED; @@ -258,6 +298,38 @@ where } } +pub struct AuthorizeContext { + credential: C, + allow_anonymous: bool, +} + +impl AuthorizeContext { + pub(crate) fn new(credential: C) -> Self { + Self { + credential, + allow_anonymous: false, + } + } +} + +impl Clone for AuthorizeContext { + fn clone(&self) -> Self { + Self { + credential: self.credential.clone(), + allow_anonymous: self.allow_anonymous, + } + } +} + +impl fmt::Debug for AuthorizeContext { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("AuthorizeContext") + .field("credential", &self.credential) + .field("allow_anonymous", &self.allow_anonymous) + .finish() + } +} + #[cfg(test)] mod tests { #[allow(unused_imports)] @@ -399,4 +471,65 @@ mod tests { async fn echo(req: Request) -> Result, BoxError> { Ok(Response::new(req.into_body())) } + + #[tokio::test] + async fn basic_allows_anonymous_if_header_is_missing() { + let service = ValidateRequestHeaderLayer::basic("foo", "bar") + .with_allow_anonymous(true) + .layer(service_fn(echo)); + + let request = Request::get("/").body(Body::empty()).unwrap(); + + let res = service.serve(Context::default(), request).await.unwrap(); + + assert_eq!(res.status(), StatusCode::OK); + } + + #[tokio::test] + async fn basic_fails_if_allow_anonymous_and_credentials_are_invalid() { + let service = ValidateRequestHeaderLayer::basic("foo", "bar") + .with_allow_anonymous(true) + .layer(service_fn(echo)); + + let request = Request::get("/") + .header( + header::AUTHORIZATION, + format!("Basic {}", BASE64.encode("wrong:credentials")), + ) + .body(Body::empty()) + .unwrap(); + + let res = service.serve(Context::default(), request).await.unwrap(); + + assert_eq!(res.status(), StatusCode::UNAUTHORIZED); + } + + #[tokio::test] + async fn bearer_allows_anonymous_if_header_is_missing() { + let service = ValidateRequestHeaderLayer::bearer("foobar") + .with_allow_anonymous(true) + .layer(service_fn(echo)); + + let request = Request::get("/").body(Body::empty()).unwrap(); + + let res = service.serve(Context::default(), request).await.unwrap(); + + assert_eq!(res.status(), StatusCode::OK); + } + + #[tokio::test] + async fn bearer_fails_if_allow_anonymous_and_credentials_are_invalid() { + let service = ValidateRequestHeaderLayer::bearer("foobar") + .with_allow_anonymous(true) + .layer(service_fn(echo)); + + let request = Request::get("/") + .header(header::AUTHORIZATION, "Bearer wrong") + .body(Body::empty()) + .unwrap(); + + let res = service.serve(Context::default(), request).await.unwrap(); + + assert_eq!(res.status(), StatusCode::UNAUTHORIZED); + } } diff --git a/rama-http/src/layer/proxy_auth.rs b/rama-http/src/layer/proxy_auth.rs index c2d8869d..bb45ab87 100644 --- a/rama-http/src/layer/proxy_auth.rs +++ b/rama-http/src/layer/proxy_auth.rs @@ -6,7 +6,7 @@ use crate::header::PROXY_AUTHENTICATE; use crate::headers::{authorization::Credentials, HeaderMapExt, ProxyAuthorization}; use crate::{Request, Response, StatusCode}; use rama_core::{Context, Layer, Service}; -use rama_net::user::auth::Authority; +use rama_net::user::{auth::Authority, UserId}; use rama_utils::macros::define_inner_service_accessors; use std::fmt; use std::marker::PhantomData; @@ -16,6 +16,7 @@ use std::marker::PhantomData; /// See the [module docs](super) for an example. pub struct ProxyAuthLayer { proxy_auth: A, + allow_anonymous: bool, _phantom: PhantomData ()>, } @@ -35,6 +36,7 @@ impl Clone for ProxyAuthLayer { fn clone(&self) -> Self { Self { proxy_auth: self.proxy_auth.clone(), + allow_anonymous: self.allow_anonymous, _phantom: PhantomData, } } @@ -45,9 +47,22 @@ impl ProxyAuthLayer { pub const fn new(proxy_auth: A) -> Self { ProxyAuthLayer { proxy_auth, + allow_anonymous: false, _phantom: PhantomData, } } + + /// Allow anonymous requests. + pub fn set_allow_anonymous(&mut self, allow_anonymous: bool) -> &mut Self { + self.allow_anonymous = allow_anonymous; + self + } + + /// Allow anonymous requests. + pub fn with_allow_anonymous(mut self, allow_anonymous: bool) -> Self { + self.allow_anonymous = allow_anonymous; + self + } } impl ProxyAuthLayer { @@ -63,6 +78,7 @@ impl ProxyAuthLayer { pub fn with_labels(self) -> ProxyAuthLayer { ProxyAuthLayer { proxy_auth: self.proxy_auth, + allow_anonymous: self.allow_anonymous, _phantom: PhantomData, } } @@ -83,10 +99,13 @@ where /// Middleware that validates if a request has the appropriate Proxy Authorisation. /// /// If the request is not authorized a `407 Proxy Authentication Required` response will be sent. +/// If `allow_anonymous` is set to `true` then requests without a Proxy Authorization header will be +/// allowed and the user will be authoized as [`UserId::Anonymous`]. /// /// See the [module docs](self) for an example. pub struct ProxyAuthService { proxy_auth: A, + allow_anonymous: bool, inner: S, _phantom: PhantomData ()>, } @@ -96,11 +115,24 @@ impl ProxyAuthService { pub const fn new(proxy_auth: A, inner: S) -> Self { Self { proxy_auth, + allow_anonymous: false, inner, _phantom: PhantomData, } } + /// Allow anonymous requests. + pub fn set_allow_anonymous(&mut self, allow_anonymous: bool) -> &mut Self { + self.allow_anonymous = allow_anonymous; + self + } + + /// Allow anonymous requests. + pub fn with_allow_anonymous(mut self, allow_anonymous: bool) -> Self { + self.allow_anonymous = allow_anonymous; + self + } + define_inner_service_accessors!(); } @@ -108,6 +140,7 @@ impl fmt::Debug for ProxyAuthService) -> fmt::Result { f.debug_struct("ProxyAuthService") .field("proxy_auth", &self.proxy_auth) + .field("allow_anonymous", &self.allow_anonymous) .field("inner", &self.inner) .field( "_phantom", @@ -121,6 +154,7 @@ impl Clone for ProxyAuthService { fn clone(&self) -> Self { ProxyAuthService { proxy_auth: self.proxy_auth.clone(), + allow_anonymous: self.allow_anonymous, inner: self.inner.clone(), _phantom: PhantomData, } @@ -162,6 +196,9 @@ where .body(Default::default()) .unwrap()) } + } else if self.allow_anonymous { + ctx.insert(UserId::Anonymous); + self.inner.serve(ctx, req).await } else { Ok(Response::builder() .status(StatusCode::PROXY_AUTHENTICATION_REQUIRED) diff --git a/rama-http/src/layer/validate_request/validate_request_header.rs b/rama-http/src/layer/validate_request/validate_request_header.rs index 081f9dd3..2128f26c 100644 --- a/rama-http/src/layer/validate_request/validate_request_header.rs +++ b/rama-http/src/layer/validate_request/validate_request_header.rs @@ -8,7 +8,7 @@ use std::fmt; /// /// See the [module docs](crate::layer::validate_request) for an example. pub struct ValidateRequestHeaderLayer { - validate: T, + pub(crate) validate: T, } impl fmt::Debug for ValidateRequestHeaderLayer { @@ -90,7 +90,7 @@ where /// See the [module docs](crate::layer::validate_request) for an example. pub struct ValidateRequestHeader { inner: S, - validate: T, + pub(crate) validate: T, } impl fmt::Debug for ValidateRequestHeader { diff --git a/rama-net/src/user/id.rs b/rama-net/src/user/id.rs index 7061361d..b0240069 100644 --- a/rama-net/src/user/id.rs +++ b/rama-net/src/user/id.rs @@ -11,6 +11,10 @@ pub enum UserId { /// /// E.g. the token of a Bearer Auth user. Token(Vec), + /// User remains anonymous. + /// + /// E.g. the user is not authenticated via any credentials. + Anonymous, } impl PartialEq for UserId { @@ -21,6 +25,7 @@ impl PartialEq for UserId { let other = other.as_bytes(); token == other } + UserId::Anonymous => false, } } } @@ -39,6 +44,7 @@ impl PartialEq<[u8]> for UserId { username_bytes == other } UserId::Token(token) => token == other, + UserId::Anonymous => false, } } } @@ -57,6 +63,7 @@ impl PartialEq for UserId { let other = other.as_bytes(); token == other } + UserId::Anonymous => false, } } } @@ -75,6 +82,7 @@ impl PartialEq> for UserId { username_bytes == other } UserId::Token(token) => token == other, + UserId::Anonymous => false, } } }