Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(ee): implement auth claims #932

Merged
merged 3 commits into from
Dec 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 41 additions & 1 deletion crates/juniper-axum/src/extract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
async_trait,
body::Body,
extract::{FromRequest, FromRequestParts, Query},
http::{HeaderValue, Method, Request, StatusCode},
http::{request::Parts, HeaderValue, Method, Request, StatusCode},
response::{IntoResponse as _, Response},
Json, RequestExt as _,
};
Expand All @@ -16,6 +16,46 @@
};
use serde::Deserialize;

#[derive(Debug, PartialEq, Eq, Clone)]
pub struct AuthBearer(pub Option<String>);

pub type Rejection = (StatusCode, &'static str);

#[async_trait]
impl<B> FromRequestParts<B> for AuthBearer
where
B: Send + Sync,
{
type Rejection = Rejection;

async fn from_request_parts(req: &mut Parts, _: &B) -> Result<Self, Self::Rejection> {
// Get authorization header
let authorization = req
.headers
.get("authorization")
.map(HeaderValue::to_str)
.transpose()
.map_err(|_| {
(
StatusCode::BAD_REQUEST,
"authorization contains invalid characters",
)
})?;

let Some(authorization) = authorization else {
return Ok(Self(None));
};

// Check that its a well-formed bearer and return
let split = authorization.split_once(' ');
match split {
// Found proper bearer
Some((name, contents)) if name == "Bearer" => Ok(Self(Some(contents.to_owned()))),

Check warning on line 53 in crates/juniper-axum/src/extract.rs

View workflow job for this annotation

GitHub Actions / autofix

redundant guard
_ => Ok(Self(None)),
}
}
}

#[derive(Debug, PartialEq)]
pub struct JuniperRequest<S = DefaultScalarValue>(pub GraphQLBatchRequest<S>)
where
Expand Down
16 changes: 12 additions & 4 deletions crates/juniper-axum/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,26 +1,34 @@
pub mod extract;
pub mod response;

use std::{future, sync::Arc};
use std::{future};

use axum::{
extract::{Extension, State},
response::{Html, IntoResponse},
};
use extract::AuthBearer;
use juniper_graphql_ws::Schema;

use self::{extract::JuniperRequest, response::JuniperResponse};

pub trait FromAuth<S> {
fn build(state: S, bearer: Option<String>) -> Self;
}

#[cfg_attr(text, axum::debug_handler)]
pub async fn graphql<S>(
State(state): State<Arc<S::Context>>,
pub async fn graphql<S, C>(
State(state): State<C>,
Extension(schema): Extension<S>,
AuthBearer(bearer): AuthBearer,
JuniperRequest(req): JuniperRequest<S::ScalarValue>,
) -> impl IntoResponse
where
S: Schema, // TODO: Refactor in the way we don't depend on `juniper_graphql_ws::Schema` here.
S::Context: FromAuth<C>,
{
JuniperResponse(req.execute(schema.root_node(), &state).await).into_response()
let ctx = S::Context::build(state, bearer);
JuniperResponse(req.execute(schema.root_node(), &ctx).await).into_response()
}

/// Creates a [`Handler`] that replies with an HTML page containing [GraphiQL].
Expand Down
10 changes: 1 addition & 9 deletions ee/tabby-webserver/graphql/schema.graphql
Original file line number Diff line number Diff line change
@@ -1,12 +1,6 @@
type RegisterResponse {
accessToken: String!
refreshToken: String!
errors: [AuthError!]!
}

type AuthError {
message: String!
code: String!
}

enum WorkerKind {
Expand All @@ -15,7 +9,7 @@ enum WorkerKind {
}

type Mutation {
resetRegistrationToken(token: String): String!
resetRegistrationToken: String!
register(email: String!, password1: String!, password2: String!): RegisterResponse!
tokenAuth(email: String!, password: String!): TokenAuthResponse!
verifyToken(token: String!): VerifyTokenResponse!
Expand All @@ -27,7 +21,6 @@ type UserInfo {
}

type VerifyTokenResponse {
errors: [AuthError!]!
claims: Claims!
}

Expand Down Expand Up @@ -56,7 +49,6 @@ type Worker {
type TokenAuthResponse {
accessToken: String!
refreshToken: String!
errors: [AuthError!]!
}

schema {
Expand Down
2 changes: 1 addition & 1 deletion ee/tabby-webserver/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ pub async fn attach_webserver(
.layer(from_fn_with_state(ctx.clone(), distributed_tabby_layer))
.route(
"/graphql",
routing::post(graphql::<Arc<Schema>>).with_state(ctx.clone()),
routing::post(graphql::<Arc<Schema>, Arc<ServerContext>>).with_state(ctx.clone()),
)
.route("/graphql", routing::get(playground("/graphql", None)))
.layer(Extension(schema))
Expand Down
53 changes: 33 additions & 20 deletions ee/tabby-webserver/src/schema.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
pub mod auth;

use std::sync::Arc;


use juniper::{
graphql_object, graphql_value, EmptySubscription, FieldError, FieldResult, RootNode,
};
use juniper_axum::FromAuth;

use crate::{
api::Worker,
Expand All @@ -13,36 +17,45 @@ use crate::{
},
};

pub struct Context {
claims: Option<auth::Claims>,
server: Arc<ServerContext>,
}

impl FromAuth<Arc<ServerContext>> for Context {
fn build(server: Arc<ServerContext>, bearer: Option<String>) -> Self {
let claims = bearer.and_then(|token| validate_jwt(&token).ok());
Self { claims, server }
}
}

// To make our context usable by Juniper, we have to implement a marker trait.
impl juniper::Context for ServerContext {}
impl juniper::Context for Context {}

#[derive(Default)]
pub struct Query;

#[graphql_object(context = ServerContext)]
#[graphql_object(context = Context)]
impl Query {
async fn workers(ctx: &ServerContext) -> Vec<Worker> {
ctx.list_workers().await
async fn workers(ctx: &Context) -> Vec<Worker> {
ctx.server.list_workers().await
}

async fn registration_token(ctx: &ServerContext) -> FieldResult<String> {
let token = ctx.read_registration_token().await?;
async fn registration_token(ctx: &Context) -> FieldResult<String> {
let token = ctx.server.read_registration_token().await?;
Ok(token)
}
}

#[derive(Default)]
pub struct Mutation;

#[graphql_object(context = ServerContext)]
#[graphql_object(context = Context)]
impl Mutation {
async fn reset_registration_token(
ctx: &ServerContext,
token: Option<String>,
) -> FieldResult<String> {
if let Some(Ok(claims)) = token.map(|t| validate_jwt(&t)) {
async fn reset_registration_token(ctx: &Context) -> FieldResult<String> {
if let Some(claims) = &ctx.claims {
if claims.user_info().is_admin() {
let reg_token = ctx.reset_registration_token().await?;
let reg_token = ctx.server.reset_registration_token().await?;
return Ok(reg_token);
}
}
Expand All @@ -53,7 +66,7 @@ impl Mutation {
}

async fn register(
ctx: &ServerContext,
ctx: &Context,
email: String,
password1: String,
password2: String,
Expand All @@ -63,24 +76,24 @@ impl Mutation {
password1,
password2,
};
ctx.auth().register(input).await
ctx.server.auth().register(input).await
}

async fn token_auth(
ctx: &ServerContext,
ctx: &Context,
email: String,
password: String,
) -> FieldResult<TokenAuthResponse> {
let input = TokenAuthInput { email, password };
ctx.auth().token_auth(input).await
ctx.server.auth().token_auth(input).await
}

async fn verify_token(ctx: &ServerContext, token: String) -> FieldResult<VerifyTokenResponse> {
ctx.auth().verify_token(token).await
async fn verify_token(ctx: &Context, token: String) -> FieldResult<VerifyTokenResponse> {
ctx.server.auth().verify_token(token).await
}
}

pub type Schema = RootNode<'static, Query, Mutation, EmptySubscription<ServerContext>>;
pub type Schema = RootNode<'static, Query, Mutation, EmptySubscription<Context>>;

pub fn create_schema() -> Schema {
Schema::new(Query, Mutation, EmptySubscription::new())
Expand Down
4 changes: 2 additions & 2 deletions ee/tabby-webserver/src/schema/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ impl Claims {
}
}

pub fn user_info(self) -> UserInfo {
self.user
pub fn user_info(&self) -> &UserInfo {
&self.user
}
}
5 changes: 4 additions & 1 deletion ee/tabby-webserver/src/server/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@
})?;

// check if email exists
if let Some(_) = self.get_user_by_email(&input.email).await? {

Check warning on line 141 in ee/tabby-webserver/src/server/auth.rs

View workflow job for this annotation

GitHub Actions / autofix

redundant pattern matching, consider using `is_some()`
return Err("Email already exists".into());
}

Expand Down Expand Up @@ -268,6 +268,9 @@
let claims = Claims::new(UserInfo::new("test".to_string(), false));
let token = generate_jwt(claims).unwrap();
let claims = validate_jwt(&token).unwrap();
assert_eq!(claims.user_info(), UserInfo::new("test".to_string(), false));
assert_eq!(
claims.user_info(),
&UserInfo::new("test".to_string(), false)
);
}
}
Loading