Skip to content

Commit

Permalink
continue update oauth
Browse files Browse the repository at this point in the history
  • Loading branch information
xou816 committed Nov 26, 2024
1 parent 2447341 commit 08385d6
Show file tree
Hide file tree
Showing 10 changed files with 1,182 additions and 829 deletions.
1,760 changes: 1,072 additions & 688 deletions Cargo.lock

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions src/api/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use serde_json::from_str;
use std::convert::Into;
use std::marker::PhantomData;
use std::str::FromStr;
use std::sync::{Arc, Mutex};
use std::sync::Arc;
use thiserror::Error;

use crate::player::TokenStore;
Expand Down Expand Up @@ -66,7 +66,7 @@ where
}

fn authenticated(mut self) -> Result<Self, SpotifyApiError> {
let token = self.client.token_store.get_cached();
let token = self.client.token_store.get_cached_blocking();
let token = token.as_ref().ok_or(SpotifyApiError::NoToken)?;
self.request = self
.request
Expand Down Expand Up @@ -200,7 +200,7 @@ impl SpotifyClient {
}

pub(crate) fn has_token(&self) -> bool {
self.token_store.get_cached().is_some()
self.token_store.get_cached_blocking().is_some()
}

fn parse_cache_control(cache_control: &str) -> Option<u64> {
Expand Down
1 change: 0 additions & 1 deletion src/app/components/login/login.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ use gtk::CompositeTemplate;
use std::rc::Rc;

use crate::app::components::EventListener;
use crate::app::credentials::{self, Credentials};
use crate::app::state::LoginEvent;
use crate::app::AppEvent;

Expand Down
8 changes: 2 additions & 6 deletions src/app/components/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use std::collections::HashSet;
use std::future::Future;

use crate::api::SpotifyApiError;
use crate::app::{state::LoginAction, ActionDispatcher, AppAction, AppEvent};
use crate::app::{ActionDispatcher, AppAction, AppEvent};

mod navigation;
pub use navigation::*;
Expand Down Expand Up @@ -121,11 +121,7 @@ impl dyn ActionDispatcher {
match result {
Ok(actions) => actions,
Err(SpotifyApiError::NoToken) => vec![],
Err(SpotifyApiError::InvalidToken) => {
let mut retried = call().await.unwrap_or_else(|_| Vec::new());
retried.insert(0, LoginAction::RefreshToken.into());
retried
}
Err(SpotifyApiError::InvalidToken) => call().await.unwrap_or_else(|_| Vec::new()),
Err(err) => {
error!("Spotify API error: {}", err);
vec![AppAction::ShowNotification(gettext(
Expand Down
1 change: 0 additions & 1 deletion src/app/state/login_state.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use gettextrs::*;
use std::borrow::Cow;

use crate::app::credentials::Credentials;
use crate::app::models::PlaylistSummary;
use crate::app::state::{AppAction, AppEvent, UpdatableState};

Expand Down
1 change: 1 addition & 0 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ extern crate glib;
extern crate lazy_static;
#[macro_use]
extern crate log;
extern crate gettextrs;

use app::state::ScreenName;
use futures::channel::mpsc::UnboundedSender;
Expand Down
2 changes: 0 additions & 2 deletions src/player/mod.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
use futures::channel::mpsc::{unbounded, UnboundedReceiver, UnboundedSender};
use librespot::core::spotify_id::SpotifyId;
use log::Log;
use std::cell::RefCell;
use std::rc::Rc;
use std::sync::Arc;
use tokio::task;

use crate::app::credentials::Credentials;
use crate::app::state::{LoginAction, PlaybackAction};
use crate::app::AppAction;
#[allow(clippy::module_inception)]
Expand Down
195 changes: 82 additions & 113 deletions src/player/oauth2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,24 +11,21 @@
//! is appropriate for headless systems.
use crate::app::credentials::Credentials;
use librespot::protocol::credentials;

use log::{error, info, trace};
use oauth2::reqwest::async_http_client;
use oauth2::{
basic::BasicClient, AuthUrl, AuthorizationCode, ClientId, CsrfToken, PkceCodeChallenge,
RedirectUrl, Scope, TokenResponse, TokenUrl,
};
use oauth2::{RefreshToken, RequestTokenError};
use std::alloc::System;
use std::collections::HashMap;
use std::io;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::{Duration, SystemTime};
use std::{
io::{BufRead, BufReader, Write},
net::{SocketAddr, TcpListener},
};
use thiserror::Error;
use tokio::task::JoinHandle;
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use tokio::time;
use url::Url;

Expand Down Expand Up @@ -87,24 +84,26 @@ impl SpotOauthClient {
.into_iter()
.map(|s| Scope::new(s.into()))
.collect();

let (auth_url, csrf_token) = self
.client
.authorize_url(CsrfToken::new_random)
.add_scopes(request_scopes)
.set_pkce_challenge(pkce_challenge)
.url();

println!("Browse to: {}", auth_url);
if let Err(err) = open::that(auth_url.to_string()) {
error!("An error occurred when opening '{auth_url}': {err}")
}

let addr = get_socket_address(REDIRECT_URI).expect("Invalid redirect uri");
let code = get_authcode_listener(addr, csrf_token)?;
let res = wait_for_authcode().await?;
if *csrf_token.secret() != *res.csrf_token.secret() {
return Err(OAuthError::InvalidState);
}

let token = self
.client
.exchange_code(code)
.exchange_code(res.code)
.set_pkce_verifier(pkce_verifier)
.request_async(async_http_client)
.await
Expand Down Expand Up @@ -138,16 +137,12 @@ impl SpotOauthClient {
),
};

self.token_store.set_async(token.clone()).await;
self.token_store.set(token.clone()).await;
Ok(token)
}

pub async fn get_refreshed_token(&self) -> Result<Credentials, OAuthError> {
let token = self
.token_store
.get_async()
.await
.ok_or(OAuthError::LoggedOut)?;
pub async fn get_valid_token(&self) -> Result<Credentials, OAuthError> {
let token = self.token_store.get().await.ok_or(OAuthError::LoggedOut)?;
if token.token_expired() {
self.refresh_token(token).await
} else {
Expand All @@ -170,7 +165,7 @@ impl SpotOauthClient {
}
})
else {
self.token_store.async_clear().await;
self.token_store.clear().await;
return Err(OAuthError::NoRefreshToken);
};

Expand All @@ -191,50 +186,45 @@ impl SpotOauthClient {
),
};

self.token_store.set_async(new_token.clone()).await;
self.token_store.set(new_token.clone()).await;
Ok(new_token)
}

pub async fn continuously_refresh(&self) {
let mut token = self.token_store.get_async().await;
pub async fn refresh_token_at_expiry(&self) -> Result<Credentials, OAuthError> {
let Some(old_token) = self.token_store.get_cached().await.take() else {
return Err(OAuthError::NoRefreshToken);
};

loop {
let Some(old_token) = token.take() else {
break;
};
let duration = old_token
.token_expiry_time
.and_then(|d| d.duration_since(SystemTime::now()).ok())
.unwrap_or(Duration::from_secs(120));

let duration = old_token
.token_expiry_time
.and_then(|d| d.duration_since(SystemTime::now()).ok())
.unwrap_or(Duration::from_secs(120));
time::sleep(duration.saturating_sub(Duration::from_secs(10))).await;
info!(
"Refreshing token in approx {}min",
duration.as_secs().div_euclid(60)
);
time::sleep(duration.saturating_sub(Duration::from_secs(10))).await;

info!("Refreshing token...");
token = self.refresh_token(old_token).await.ok();
}
info!("Refreshing token...");
self.refresh_token(old_token).await
}
}

#[derive(Debug, Error)]
pub enum OAuthError {
#[error("Unable to parse redirect URI {uri} ({e})")]
AuthCodeBadUri { uri: String, e: url::ParseError },

#[error("Auth code param not found in URI {uri}")]
AuthCodeNotFound { uri: String },
#[error("Auth code param not found in URI")]
AuthCodeNotFound,

#[error("CSRF token param not found in URI {uri}")]
CsrfTokenNotFound { uri: String },
#[error("CSRF token param not found in URI")]
CsrfTokenNotFound,

#[error("Failed to bind server to {addr} ({e})")]
AuthCodeListenerBind { addr: SocketAddr, e: io::Error },

#[error("Listener terminated without accepting a connection")]
AuthCodeListenerTerminated,

#[error("Failed to read redirect URI from HTTP request")]
AuthCodeListenerRead,

#[error("Failed to parse redirect URI from HTTP request")]
AuthCodeListenerParse,

Expand All @@ -249,82 +239,28 @@ pub enum OAuthError {

#[error("No saved token")]
LoggedOut,
}

/// Return state query-string parameter from the redirect URI (CSRF token).
fn get_state(redirect_url: &str) -> Result<String, OAuthError> {
let url = Url::parse(redirect_url).map_err(|e| OAuthError::AuthCodeBadUri {
uri: redirect_url.to_string(),
e,
})?;
let code = url
.query_pairs()
.find(|(key, _)| key == "state")
.map(|(_, state)| state.into_owned())
.ok_or(OAuthError::CsrfTokenNotFound {
uri: redirect_url.to_string(),
})?;

Ok(code)
#[error("Mismatched state during auth code exchange")]
InvalidState,
}

/// Return code query-string parameter from the redirect URI.
fn get_code(redirect_url: &str) -> Result<AuthorizationCode, OAuthError> {
let url = Url::parse(redirect_url).map_err(|e| OAuthError::AuthCodeBadUri {
uri: redirect_url.to_string(),
e,
})?;
let code = url
.query_pairs()
.find(|(key, _)| key == "code")
.map(|(_, code)| AuthorizationCode::new(code.into_owned()))
.ok_or(OAuthError::AuthCodeNotFound {
uri: redirect_url.to_string(),
})?;

Ok(code)
struct OAuthResult {
csrf_token: CsrfToken,
code: AuthorizationCode,
}

/// Spawn HTTP server at provided socket address to accept OAuth callback and return auth code.
fn get_authcode_listener(
socket_address: SocketAddr,
csrf_token: CsrfToken,
) -> Result<AuthorizationCode, OAuthError> {
let listener =
TcpListener::bind(socket_address).map_err(|e| OAuthError::AuthCodeListenerBind {
addr: socket_address,
e,
})?;
info!("OAuth server listening on {:?}", socket_address);

// The server will terminate itself after collecting the first code.
let mut stream = listener
.incoming()
.flatten()
.next()
.ok_or(OAuthError::AuthCodeListenerTerminated)?;
let mut reader = BufReader::new(&stream);
let mut request_line = String::new();
reader
.read_line(&mut request_line)
.map_err(|_| OAuthError::AuthCodeListenerRead)?;
async fn wait_for_authcode() -> Result<OAuthResult, OAuthError> {
let addr = get_socket_address(REDIRECT_URI).expect("Invalid redirect uri");

let redirect_url = request_line
.split_whitespace()
.nth(1)
.ok_or(OAuthError::AuthCodeListenerParse)?;
let listener = tokio::net::TcpListener::bind(addr)
.await
.map_err(|e| OAuthError::AuthCodeListenerBind { addr, e })?;

let token = get_state(&("http://localhost".to_string() + redirect_url));
if token.is_err() {
return Err(token.err().unwrap());
}
let token = token.ok().unwrap();
if !token.eq(csrf_token.secret()) {
return Err(OAuthError::CsrfTokenNotFound {
uri: redirect_url.to_string(),
});
}
let code = get_code(&("http://localhost".to_string() + redirect_url));
let (mut stream, _) = listener
.accept()
.await
.map_err(|_| OAuthError::AuthCodeListenerTerminated)?;

let message = include_str!("./login.html");
let response = format!(
Expand All @@ -334,9 +270,42 @@ fn get_authcode_listener(
);
stream
.write_all(response.as_bytes())
.await
.map_err(|_| OAuthError::AuthCodeListenerWrite)?;

code
let mut request_line = String::new();
let mut reader = BufReader::new(stream);
reader
.read_line(&mut request_line)
.await
.map_err(|_| OAuthError::AuthCodeListenerParse)?;

parse_query(&request_line)
}

fn parse_query(request_line: &str) -> Result<OAuthResult, OAuthError> {
let query = request_line
.split_whitespace()
.nth(1)
.ok_or(OAuthError::AuthCodeListenerParse)?
.split("?")
.nth(1)
.ok_or(OAuthError::AuthCodeListenerParse)?;

let mut query_params: HashMap<String, String> = url::form_urlencoded::parse(query.as_bytes())
.into_owned()
.collect();

let csrf_token = query_params
.remove("state")
.map(CsrfToken::new)
.ok_or(OAuthError::CsrfTokenNotFound)?;
let code = query_params
.remove("code")
.map(AuthorizationCode::new)
.ok_or(OAuthError::AuthCodeNotFound)?;

Ok(OAuthResult { csrf_token, code })
}

// If the specified `redirect_uri` is HTTP, loopback, and contains a port,
Expand Down
Loading

0 comments on commit 08385d6

Please sign in to comment.