Skip to content

Commit

Permalink
Merge branch 'main' into feat/rama-http-core
Browse files Browse the repository at this point in the history
  • Loading branch information
GlenDC committed Nov 7, 2024
2 parents f249d37 + 03cb4eb commit 058c40c
Show file tree
Hide file tree
Showing 5 changed files with 231 additions and 15 deletions.
38 changes: 38 additions & 0 deletions rama-core/src/combinators/either.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,44 @@ macro_rules! define_either {
}
}
}


impl<$($param),+, Output> $id<$($param),+>
where
$($param: std::future::Future<Output = Output>),+
{
/// Convert `Pin<&mut Either<A, B>>` to `Either<Pin<&mut A>, Pin<&mut B>>`,
/// pinned projections of the inner variants.
fn as_pin_mut(self: Pin<&mut Self>) -> $id<$(Pin<&mut $param>),+> {
// SAFETY: `get_unchecked_mut` is fine because we don't move anything.
// We can use `new_unchecked` because the `inner` parts are guaranteed
// to be pinned, as they come from `self` which is pinned, and we never
// offer an unpinned `&mut A` or `&mut B` through `Pin<&mut Self>`. We
// also don't have an implementation of `Drop`, nor manual `Unpin`.
unsafe {
match self.get_unchecked_mut() {
$(
Self::$param(inner) => $id::$param(Pin::new_unchecked(inner)),
)+
}
}
}
}

impl<$($param),+, Output> std::future::Future for $id<$($param),+>
where
$($param: std::future::Future<Output = Output> + Unpin),+
{
type Output = Output;

fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
match self.as_pin_mut() {
$(
$id::$param(fut) => fut.poll(cx),
)+
}
}
}
};
}

Expand Down
157 changes: 145 additions & 12 deletions rama-http/src/layer/auth/require_authorization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,39 @@ use crate::{
};
use rama_core::Context;

use rama_net::user::UserId;

const BASE64: base64::engine::GeneralPurpose = base64::engine::general_purpose::STANDARD;

impl<S, ResBody> ValidateRequestHeader<S, Basic<ResBody>> {
impl<C> ValidateRequestHeaderLayer<AuthorizeContext<C>> {
/// Allow anonymous requests.
pub fn set_allow_anonymous(&mut self, allow_anonymous: bool) -> &mut Self {
self.validate.allow_anonymous = allow_anonymous;
self
}

/// Allow anonymous requests.
pub fn with_allow_anonymous(mut self, allow_anonymous: bool) -> Self {
self.validate.allow_anonymous = allow_anonymous;
self
}
}

impl<S, C> ValidateRequestHeader<S, AuthorizeContext<C>> {
/// Allow anonymous requests.
pub fn set_allow_anonymous(&mut self, allow_anonymous: bool) -> &mut Self {
self.validate.allow_anonymous = allow_anonymous;
self
}

/// Allow anonymous requests.
pub fn with_allow_anonymous(mut self, allow_anonymous: bool) -> Self {
self.validate.allow_anonymous = allow_anonymous;
self
}
}

impl<S, ResBody> ValidateRequestHeader<S, AuthorizeContext<Basic<ResBody>>> {
/// Authorize requests using a username and password pair.
///
/// The `Authorization` header is required to be `Basic {credentials}` where `credentials` is
Expand All @@ -78,11 +108,11 @@ impl<S, ResBody> ValidateRequestHeader<S, Basic<ResBody>> {
where
ResBody: Default,
{
Self::custom(inner, Basic::new(username, value))
Self::custom(inner, AuthorizeContext::new(Basic::new(username, value)))
}
}

impl<ResBody> ValidateRequestHeaderLayer<Basic<ResBody>> {
impl<ResBody> ValidateRequestHeaderLayer<AuthorizeContext<Basic<ResBody>>> {
/// Authorize requests using a username and password pair.
///
/// The `Authorization` header is required to be `Basic {credentials}` where `credentials` is
Expand All @@ -94,11 +124,11 @@ impl<ResBody> ValidateRequestHeaderLayer<Basic<ResBody>> {
where
ResBody: Default,
{
Self::custom(Basic::new(username, password))
Self::custom(AuthorizeContext::new(Basic::new(username, password)))
}
}

impl<S, ResBody> ValidateRequestHeader<S, Bearer<ResBody>> {
impl<S, ResBody> ValidateRequestHeader<S, AuthorizeContext<Bearer<ResBody>>> {
/// Authorize requests using a "bearer token". Commonly used for OAuth 2.
///
/// The `Authorization` header is required to be `Bearer {token}`.
Expand All @@ -110,11 +140,11 @@ impl<S, ResBody> ValidateRequestHeader<S, Bearer<ResBody>> {
where
ResBody: Default,
{
Self::custom(inner, Bearer::new(token))
Self::custom(inner, AuthorizeContext::new(Bearer::new(token)))
}
}

impl<ResBody> ValidateRequestHeaderLayer<Bearer<ResBody>> {
impl<ResBody> ValidateRequestHeaderLayer<AuthorizeContext<Bearer<ResBody>>> {
/// Authorize requests using a "bearer token". Commonly used for OAuth 2.
///
/// The `Authorization` header is required to be `Bearer {token}`.
Expand All @@ -126,7 +156,7 @@ impl<ResBody> ValidateRequestHeaderLayer<Bearer<ResBody>> {
where
ResBody: Default,
{
Self::custom(Bearer::new(token))
Self::custom(AuthorizeContext::new(Bearer::new(token)))
}
}

Expand Down Expand Up @@ -169,7 +199,7 @@ impl<ResBody> fmt::Debug for Bearer<ResBody> {
}
}

impl<S, B, ResBody> ValidateRequest<S, B> for Bearer<ResBody>
impl<S, B, ResBody> ValidateRequest<S, B> for AuthorizeContext<Bearer<ResBody>>
where
ResBody: Default + Send + 'static,
B: Send + 'static,
Expand All @@ -183,7 +213,12 @@ where
request: Request<B>,
) -> Result<(Context<S>, Request<B>), Response<Self::ResponseBody>> {
match request.headers().get(header::AUTHORIZATION) {
Some(actual) if actual == self.header_value => Ok((ctx, request)),
Some(actual) if actual == self.credential.header_value => Ok((ctx, request)),
None if self.allow_anonymous => {
let mut ctx = ctx;
ctx.insert(UserId::Anonymous);
Ok((ctx, request))
}
_ => {
let mut res = Response::new(ResBody::default());
*res.status_mut() = StatusCode::UNAUTHORIZED;
Expand Down Expand Up @@ -232,7 +267,7 @@ impl<ResBody> fmt::Debug for Basic<ResBody> {
}
}

impl<S, B, ResBody> ValidateRequest<S, B> for Basic<ResBody>
impl<S, B, ResBody> ValidateRequest<S, B> for AuthorizeContext<Basic<ResBody>>
where
ResBody: Default + Send + 'static,
B: Send + 'static,
Expand All @@ -246,7 +281,12 @@ where
request: Request<B>,
) -> Result<(Context<S>, Request<B>), Response<Self::ResponseBody>> {
match request.headers().get(header::AUTHORIZATION) {
Some(actual) if actual == self.header_value => Ok((ctx, request)),
Some(actual) if actual == self.credential.header_value => Ok((ctx, request)),
None if self.allow_anonymous => {
let mut ctx = ctx;
ctx.insert(UserId::Anonymous);
Ok((ctx, request))
}
_ => {
let mut res = Response::new(ResBody::default());
*res.status_mut() = StatusCode::UNAUTHORIZED;
Expand All @@ -258,6 +298,38 @@ where
}
}

pub struct AuthorizeContext<C> {
credential: C,
allow_anonymous: bool,
}

impl<C> AuthorizeContext<C> {
pub(crate) fn new(credential: C) -> Self {
Self {
credential,
allow_anonymous: false,
}
}
}

impl<C: Clone> Clone for AuthorizeContext<C> {
fn clone(&self) -> Self {
Self {
credential: self.credential.clone(),
allow_anonymous: self.allow_anonymous,
}
}
}

impl<C: fmt::Debug> fmt::Debug for AuthorizeContext<C> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("AuthorizeContext")
.field("credential", &self.credential)
.field("allow_anonymous", &self.allow_anonymous)
.finish()
}
}

#[cfg(test)]
mod tests {
#[allow(unused_imports)]
Expand Down Expand Up @@ -399,4 +471,65 @@ mod tests {
async fn echo<Body>(req: Request<Body>) -> Result<Response<Body>, BoxError> {
Ok(Response::new(req.into_body()))
}

#[tokio::test]
async fn basic_allows_anonymous_if_header_is_missing() {
let service = ValidateRequestHeaderLayer::basic("foo", "bar")
.with_allow_anonymous(true)
.layer(service_fn(echo));

let request = Request::get("/").body(Body::empty()).unwrap();

let res = service.serve(Context::default(), request).await.unwrap();

assert_eq!(res.status(), StatusCode::OK);
}

#[tokio::test]
async fn basic_fails_if_allow_anonymous_and_credentials_are_invalid() {
let service = ValidateRequestHeaderLayer::basic("foo", "bar")
.with_allow_anonymous(true)
.layer(service_fn(echo));

let request = Request::get("/")
.header(
header::AUTHORIZATION,
format!("Basic {}", BASE64.encode("wrong:credentials")),
)
.body(Body::empty())
.unwrap();

let res = service.serve(Context::default(), request).await.unwrap();

assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
}

#[tokio::test]
async fn bearer_allows_anonymous_if_header_is_missing() {
let service = ValidateRequestHeaderLayer::bearer("foobar")
.with_allow_anonymous(true)
.layer(service_fn(echo));

let request = Request::get("/").body(Body::empty()).unwrap();

let res = service.serve(Context::default(), request).await.unwrap();

assert_eq!(res.status(), StatusCode::OK);
}

#[tokio::test]
async fn bearer_fails_if_allow_anonymous_and_credentials_are_invalid() {
let service = ValidateRequestHeaderLayer::bearer("foobar")
.with_allow_anonymous(true)
.layer(service_fn(echo));

let request = Request::get("/")
.header(header::AUTHORIZATION, "Bearer wrong")
.body(Body::empty())
.unwrap();

let res = service.serve(Context::default(), request).await.unwrap();

assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
}
}
Loading

0 comments on commit 058c40c

Please sign in to comment.