Skip to content

Commit

Permalink
feat(ee): implement auth claims (#932)
Browse files Browse the repository at this point in the history
* feat(ee): implement auth claims

* fix test

* [autofix.ci] apply automated fixes

---------

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
  • Loading branch information
wsxiaoys and autofix-ci[bot] authored Dec 1, 2023
1 parent 8d3be2e commit 1a9cbdc
Show file tree
Hide file tree
Showing 7 changed files with 94 additions and 38 deletions.
42 changes: 41 additions & 1 deletion crates/juniper-axum/src/extract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use axum::{
async_trait,
body::Body,
extract::{FromRequest, FromRequestParts, Query},
http::{HeaderValue, Method, Request, StatusCode},
http::{request::Parts, HeaderValue, Method, Request, StatusCode},
response::{IntoResponse as _, Response},
Json, RequestExt as _,
};
Expand All @@ -16,6 +16,46 @@ use juniper::{
};
use serde::Deserialize;

#[derive(Debug, PartialEq, Eq, Clone)]
pub struct AuthBearer(pub Option<String>);

pub type Rejection = (StatusCode, &'static str);

#[async_trait]
impl<B> FromRequestParts<B> for AuthBearer
where
B: Send + Sync,
{
type Rejection = Rejection;

async fn from_request_parts(req: &mut Parts, _: &B) -> Result<Self, Self::Rejection> {
// Get authorization header
let authorization = req
.headers
.get("authorization")
.map(HeaderValue::to_str)
.transpose()
.map_err(|_| {
(
StatusCode::BAD_REQUEST,
"authorization contains invalid characters",
)
})?;

let Some(authorization) = authorization else {
return Ok(Self(None));
};

// Check that its a well-formed bearer and return
let split = authorization.split_once(' ');
match split {
// Found proper bearer
Some((name, contents)) if name == "Bearer" => Ok(Self(Some(contents.to_owned()))),
_ => Ok(Self(None)),
}
}
}

#[derive(Debug, PartialEq)]
pub struct JuniperRequest<S = DefaultScalarValue>(pub GraphQLBatchRequest<S>)
where
Expand Down
16 changes: 12 additions & 4 deletions crates/juniper-axum/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,26 +1,34 @@
pub mod extract;
pub mod response;

use std::{future, sync::Arc};
use std::{future};

use axum::{
extract::{Extension, State},
response::{Html, IntoResponse},
};
use extract::AuthBearer;
use juniper_graphql_ws::Schema;

use self::{extract::JuniperRequest, response::JuniperResponse};

pub trait FromAuth<S> {
fn build(state: S, bearer: Option<String>) -> Self;
}

#[cfg_attr(text, axum::debug_handler)]
pub async fn graphql<S>(
State(state): State<Arc<S::Context>>,
pub async fn graphql<S, C>(
State(state): State<C>,
Extension(schema): Extension<S>,
AuthBearer(bearer): AuthBearer,
JuniperRequest(req): JuniperRequest<S::ScalarValue>,
) -> impl IntoResponse
where
S: Schema, // TODO: Refactor in the way we don't depend on `juniper_graphql_ws::Schema` here.
S::Context: FromAuth<C>,
{
JuniperResponse(req.execute(schema.root_node(), &state).await).into_response()
let ctx = S::Context::build(state, bearer);
JuniperResponse(req.execute(schema.root_node(), &ctx).await).into_response()
}

/// Creates a [`Handler`] that replies with an HTML page containing [GraphiQL].
Expand Down
10 changes: 1 addition & 9 deletions ee/tabby-webserver/graphql/schema.graphql
Original file line number Diff line number Diff line change
@@ -1,12 +1,6 @@
type RegisterResponse {
accessToken: String!
refreshToken: String!
errors: [AuthError!]!
}

type AuthError {
message: String!
code: String!
}

enum WorkerKind {
Expand All @@ -15,7 +9,7 @@ enum WorkerKind {
}

type Mutation {
resetRegistrationToken(token: String): String!
resetRegistrationToken: String!
register(email: String!, password1: String!, password2: String!): RegisterResponse!
tokenAuth(email: String!, password: String!): TokenAuthResponse!
verifyToken(token: String!): VerifyTokenResponse!
Expand All @@ -27,7 +21,6 @@ type UserInfo {
}

type VerifyTokenResponse {
errors: [AuthError!]!
claims: Claims!
}

Expand Down Expand Up @@ -56,7 +49,6 @@ type Worker {
type TokenAuthResponse {
accessToken: String!
refreshToken: String!
errors: [AuthError!]!
}

schema {
Expand Down
2 changes: 1 addition & 1 deletion ee/tabby-webserver/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ pub async fn attach_webserver(
.layer(from_fn_with_state(ctx.clone(), distributed_tabby_layer))
.route(
"/graphql",
routing::post(graphql::<Arc<Schema>>).with_state(ctx.clone()),
routing::post(graphql::<Arc<Schema>, Arc<ServerContext>>).with_state(ctx.clone()),
)
.route("/graphql", routing::get(playground("/graphql", None)))
.layer(Extension(schema))
Expand Down
53 changes: 33 additions & 20 deletions ee/tabby-webserver/src/schema.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
pub mod auth;

use std::sync::Arc;


use juniper::{
graphql_object, graphql_value, EmptySubscription, FieldError, FieldResult, RootNode,
};
use juniper_axum::FromAuth;

use crate::{
api::Worker,
Expand All @@ -13,36 +17,45 @@ use crate::{
},
};

pub struct Context {
claims: Option<auth::Claims>,
server: Arc<ServerContext>,
}

impl FromAuth<Arc<ServerContext>> for Context {
fn build(server: Arc<ServerContext>, bearer: Option<String>) -> Self {
let claims = bearer.and_then(|token| validate_jwt(&token).ok());
Self { claims, server }
}
}

// To make our context usable by Juniper, we have to implement a marker trait.
impl juniper::Context for ServerContext {}
impl juniper::Context for Context {}

#[derive(Default)]
pub struct Query;

#[graphql_object(context = ServerContext)]
#[graphql_object(context = Context)]
impl Query {
async fn workers(ctx: &ServerContext) -> Vec<Worker> {
ctx.list_workers().await
async fn workers(ctx: &Context) -> Vec<Worker> {
ctx.server.list_workers().await
}

async fn registration_token(ctx: &ServerContext) -> FieldResult<String> {
let token = ctx.read_registration_token().await?;
async fn registration_token(ctx: &Context) -> FieldResult<String> {
let token = ctx.server.read_registration_token().await?;
Ok(token)
}
}

#[derive(Default)]
pub struct Mutation;

#[graphql_object(context = ServerContext)]
#[graphql_object(context = Context)]
impl Mutation {
async fn reset_registration_token(
ctx: &ServerContext,
token: Option<String>,
) -> FieldResult<String> {
if let Some(Ok(claims)) = token.map(|t| validate_jwt(&t)) {
async fn reset_registration_token(ctx: &Context) -> FieldResult<String> {
if let Some(claims) = &ctx.claims {
if claims.user_info().is_admin() {
let reg_token = ctx.reset_registration_token().await?;
let reg_token = ctx.server.reset_registration_token().await?;
return Ok(reg_token);
}
}
Expand All @@ -53,7 +66,7 @@ impl Mutation {
}

async fn register(
ctx: &ServerContext,
ctx: &Context,
email: String,
password1: String,
password2: String,
Expand All @@ -63,24 +76,24 @@ impl Mutation {
password1,
password2,
};
ctx.auth().register(input).await
ctx.server.auth().register(input).await
}

async fn token_auth(
ctx: &ServerContext,
ctx: &Context,
email: String,
password: String,
) -> FieldResult<TokenAuthResponse> {
let input = TokenAuthInput { email, password };
ctx.auth().token_auth(input).await
ctx.server.auth().token_auth(input).await
}

async fn verify_token(ctx: &ServerContext, token: String) -> FieldResult<VerifyTokenResponse> {
ctx.auth().verify_token(token).await
async fn verify_token(ctx: &Context, token: String) -> FieldResult<VerifyTokenResponse> {
ctx.server.auth().verify_token(token).await
}
}

pub type Schema = RootNode<'static, Query, Mutation, EmptySubscription<ServerContext>>;
pub type Schema = RootNode<'static, Query, Mutation, EmptySubscription<Context>>;

pub fn create_schema() -> Schema {
Schema::new(Query, Mutation, EmptySubscription::new())
Expand Down
4 changes: 2 additions & 2 deletions ee/tabby-webserver/src/schema/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ impl Claims {
}
}

pub fn user_info(self) -> UserInfo {
self.user
pub fn user_info(&self) -> &UserInfo {
&self.user
}
}
5 changes: 4 additions & 1 deletion ee/tabby-webserver/src/server/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,9 @@ mod tests {
let claims = Claims::new(UserInfo::new("test".to_string(), false));
let token = generate_jwt(claims).unwrap();
let claims = validate_jwt(&token).unwrap();
assert_eq!(claims.user_info(), UserInfo::new("test".to_string(), false));
assert_eq!(
claims.user_info(),
&UserInfo::new("test".to_string(), false)
);
}
}

0 comments on commit 1a9cbdc

Please sign in to comment.