Skip to content

Commit

Permalink
feat: ✨ split generic auth action
Browse files Browse the repository at this point in the history
  • Loading branch information
holmofy committed Oct 7, 2024
1 parent 4e9328d commit 03626b9
Show file tree
Hide file tree
Showing 11 changed files with 247 additions and 205 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ description = "just for oauth2 login"
edition = "2021"
license = "MIT"
name = "just-auth"
version = "0.1.2"
version = "0.1.3"

[dependencies]
async-trait = "0.1"
Expand Down
58 changes: 32 additions & 26 deletions src/baidu.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
//! https://openauth.baidu.com/doc/doc.html
use crate::error::Result;
use crate::{auth_server_builder, AuthAction, AuthConfig, AuthUrlProvider, AuthUser};
use crate::{
auth_server_builder, AuthAction, AuthConfig, AuthUrlProvider, AuthUser, GenericAuthAction,
};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use serde_json::Value;
Expand Down Expand Up @@ -46,6 +48,33 @@ impl AuthAction for AuthorizationServer {
type AuthToken = TokenResponse;
type AuthUser = UserInfoResponse;

async fn get_access_token(&self, callback: Self::AuthCallback) -> Result<Self::AuthToken> {
let AuthConfig {
client_id,
client_secret,
redirect_uri,
..
} = &self.config;
let access_token_url = Self::access_token_url(GetTokenRequest {
client_id: client_id.to_string(),
client_secret: client_secret.clone().expect("client_secret is empty"),
code: callback.code,
redirect_uri: redirect_uri.to_string(),
})?;
Ok(reqwest::get(access_token_url).await?.json().await?)
}

async fn get_user_info(&self, token: Self::AuthToken) -> Result<Self::AuthUser> {
let user_info_url = Self::user_info_url(GetUserInfoRequest {
access_token: token.access_token,
get_unionid: Some(1),
})?;
Ok(reqwest::get(user_info_url).await?.json().await?)
}
}

#[async_trait]
impl GenericAuthAction for AuthorizationServer {
async fn authorize<S: Into<String> + Send>(&self, state: S) -> Result<String> {
let AuthConfig {
client_id,
Expand All @@ -62,7 +91,8 @@ impl AuthAction for AuthorizationServer {
})
}

async fn login(&self, callback: Self::AuthCallback) -> Result<AuthUser> {
async fn login<S: Into<String> + Send>(&self, callback: S) -> Result<AuthUser> {
let callback: AuthCallback = serde_urlencoded::from_str(&callback.into())?;
let token = self.get_access_token(callback).await?;
let user = self.get_user_info(token.clone()).await?;
Ok(AuthUser {
Expand All @@ -74,30 +104,6 @@ impl AuthAction for AuthorizationServer {
extra: user.extra,
})
}

async fn get_access_token(&self, callback: Self::AuthCallback) -> Result<Self::AuthToken> {
let AuthConfig {
client_id,
client_secret,
redirect_uri,
..
} = &self.config;
let access_token_url = Self::access_token_url(GetTokenRequest {
client_id: client_id.to_string(),
client_secret: client_secret.clone().expect("client_secret is empty"),
code: callback.code,
redirect_uri: redirect_uri.to_string(),
})?;
Ok(reqwest::get(access_token_url).await?.json().await?)
}

async fn get_user_info(&self, token: Self::AuthToken) -> Result<Self::AuthUser> {
let user_info_url = Self::user_info_url(GetUserInfoRequest {
access_token: token.access_token,
get_unionid: Some(1),
})?;
Ok(reqwest::get(user_info_url).await?.json().await?)
}
}

#[serde_as]
Expand Down
3 changes: 3 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ pub enum AuthError {
#[error(transparent)]
UrlEncodedSerializeErr(#[from] serde_urlencoded::ser::Error),

#[error(transparent)]
UrlEncodedDeserializeErr(#[from] serde_urlencoded::de::Error),

#[error(transparent)]
JsonParseErr(#[from] serde_json::Error),

Expand Down
53 changes: 29 additions & 24 deletions src/facebook.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
//! https://developers.facebook.com/docs/facebook-login/guides/advanced/manual-flow
use crate::{
auth_server_builder, error::Result, AuthAction, AuthConfig, AuthUrlProvider, AuthUser,
GenericAuthAction,
};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
Expand Down Expand Up @@ -48,6 +49,32 @@ impl AuthAction for AuthorizationServer {
type AuthToken = TokenResponse;
type AuthUser = UserInfoResponse;

async fn get_access_token(&self, callback: Self::AuthCallback) -> Result<Self::AuthToken> {
let AuthConfig {
client_id,
client_secret,
redirect_uri,
..
} = &self.config;
let access_token_url = Self::access_token_url(GetTokenRequest {
client_id: client_id.to_string(),
client_secret: client_secret.clone().expect("client_secret is empty"),
code: callback.code,
redirect_uri: redirect_uri.to_string(),
})?;
Ok(reqwest::get(access_token_url).await?.json().await?)
}

async fn get_user_info(&self, token: Self::AuthToken) -> Result<Self::AuthUser> {
let user_info_url = Self::user_info_url(GetUserInfoRequest {
access_token: token.access_token,
})?;
Ok(reqwest::get(user_info_url).await?.json().await?)
}
}

#[async_trait]
impl GenericAuthAction for AuthorizationServer {
async fn authorize<S: Into<String> + Send>(&self, state: S) -> Result<String> {
let AuthConfig {
client_id,
Expand All @@ -64,7 +91,8 @@ impl AuthAction for AuthorizationServer {
})
}

async fn login(&self, callback: Self::AuthCallback) -> Result<AuthUser> {
async fn login<S: Into<String> + Send>(&self, callback: S) -> Result<AuthUser> {
let callback: AuthCallback = serde_urlencoded::from_str(&callback.into())?;
let token = self.get_access_token(callback).await?;
let user = self.get_user_info(token.clone()).await?;
Ok(AuthUser {
Expand All @@ -76,29 +104,6 @@ impl AuthAction for AuthorizationServer {
extra: user.extra,
})
}

async fn get_access_token(&self, callback: Self::AuthCallback) -> Result<Self::AuthToken> {
let AuthConfig {
client_id,
client_secret,
redirect_uri,
..
} = &self.config;
let access_token_url = Self::access_token_url(GetTokenRequest {
client_id: client_id.to_string(),
client_secret: client_secret.clone().expect("client_secret is empty"),
code: callback.code,
redirect_uri: redirect_uri.to_string(),
})?;
Ok(reqwest::get(access_token_url).await?.json().await?)
}

async fn get_user_info(&self, token: Self::AuthToken) -> Result<Self::AuthUser> {
let user_info_url = Self::user_info_url(GetUserInfoRequest {
access_token: token.access_token,
})?;
Ok(reqwest::get(user_info_url).await?.json().await?)
}
}

#[serde_as]
Expand Down
72 changes: 39 additions & 33 deletions src/github.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
//! https://docs.github.com/zh/apps/oauth-apps/building-oauth-apps/authorizing-oauth-apps
use crate::error::Result;
use crate::{auth_server_builder, AuthAction, AuthConfig, AuthUrlProvider, AuthUser};
use crate::{
auth_server_builder, AuthAction, AuthConfig, AuthUrlProvider, AuthUser, GenericAuthAction,
};
use async_trait::async_trait;
use reqwest::header::ACCEPT;
use serde::{Deserialize, Serialize};
Expand Down Expand Up @@ -42,38 +44,6 @@ impl AuthAction for AuthorizationServer {
type AuthToken = TokenResponse;
type AuthUser = UserInfoResponse;

async fn authorize<S: Into<String> + Send>(&self, state: S) -> Result<String> {
let AuthConfig {
client_id,
redirect_uri,
scope,
..
} = &self.config;
Self::authorize_url(AuthRequest {
client_id: client_id.to_string(),
redirect_uri: redirect_uri.to_string(),
state: state.into(),
scope: scope
.clone()
.or_else(|| Some(vec!["read:user".into(), "user:email".into()]))
.expect("scope is empty"),
..Default::default()
})
}

async fn login(&self, callback: Self::AuthCallback) -> Result<AuthUser> {
let token = self.get_access_token(callback).await?;
let user = self.get_user_info(token.clone()).await?;
Ok(AuthUser {
user_id: user.id.to_string(),
name: user.name,
access_token: token.access_token,
refresh_token: token.token_type,
expires_in: i64::MAX,
extra: user.extra,
})
}

async fn get_access_token(&self, callback: Self::AuthCallback) -> Result<Self::AuthToken> {
let AuthConfig {
client_id,
Expand Down Expand Up @@ -108,6 +78,42 @@ impl AuthAction for AuthorizationServer {
}
}

#[async_trait]
impl GenericAuthAction for AuthorizationServer {
async fn authorize<S: Into<String> + Send>(&self, state: S) -> Result<String> {
let AuthConfig {
client_id,
redirect_uri,
scope,
..
} = &self.config;
Self::authorize_url(AuthRequest {
client_id: client_id.to_string(),
redirect_uri: redirect_uri.to_string(),
state: state.into(),
scope: scope
.clone()
.or_else(|| Some(vec!["read:user".into(), "user:email".into()]))
.expect("scope is empty"),
..Default::default()
})
}

async fn login<S: Into<String> + Send>(&self, callback: S) -> Result<AuthUser> {
let callback: AuthCallback = serde_urlencoded::from_str(&callback.into())?;
let token = self.get_access_token(callback).await?;
let user = self.get_user_info(token.clone()).await?;
Ok(AuthUser {
user_id: user.id.to_string(),
name: user.name,
access_token: token.access_token,
refresh_token: token.token_type,
expires_in: i64::MAX,
extra: user.extra,
})
}
}

#[serde_as]
#[derive(Debug, Default, Serialize, Deserialize)]
pub struct AuthRequest {
Expand Down
11 changes: 7 additions & 4 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,15 +97,18 @@ pub trait AuthAction {
type AuthToken: Send;
type AuthUser;

async fn authorize<S: Into<String> + Send>(&self, state: S) -> Result<String>;

async fn login(&self, callback: Self::AuthCallback) -> Result<AuthUser>;

async fn get_access_token(&self, callback: Self::AuthCallback) -> Result<Self::AuthToken>;

async fn get_user_info(&self, token: Self::AuthToken) -> Result<Self::AuthUser>;
}

#[async_trait]
pub trait GenericAuthAction {
async fn authorize<S: Into<String> + Send>(&self, state: S) -> Result<String>;

async fn login<S: Into<String> + Send>(&self, callback_raw_query: S) -> Result<AuthUser>;
}

pub struct AuthUser {
pub user_id: String,
pub name: String,
Expand Down
Loading

0 comments on commit 03626b9

Please sign in to comment.