Skip to content

Commit

Permalink
Replace rouille with warp
Browse files Browse the repository at this point in the history
Co-authored-by: Fedor Sakharov <[email protected]>
  • Loading branch information
Saruniks and montekki committed Aug 8, 2024
1 parent 0e66562 commit 621bec5
Show file tree
Hide file tree
Showing 17 changed files with 1,978 additions and 1,272 deletions.
1,075 changes: 582 additions & 493 deletions Cargo.lock

Large diffs are not rendered by default.

12 changes: 7 additions & 5 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ regex = "1.10.3"
reqsign = { version = "0.16.0", optional = true }
reqwest = { version = "0.12", features = [
"json",
"blocking",
"native-tls",
"stream",
"rustls-tls",
"rustls-tls-native-roots",
Expand Down Expand Up @@ -107,6 +107,7 @@ zstd = "0.13"

# dist-server only
memmap2 = "0.9.4"
native-tls = "0.2.8"
nix = { version = "0.28.0", optional = true, features = [
"mount",
"user",
Expand All @@ -115,11 +116,10 @@ nix = { version = "0.28.0", optional = true, features = [
"process",
] }
object = "0.32"
rouille = { version = "3.6", optional = true, default-features = false, features = [
"ssl",
] }
syslog = { version = "6", optional = true }
thiserror = { version = "1.0.30", optional = true }
version-compare = { version = "0.1.1", optional = true }
warp = { version = "0.3.2", optional = true, features = ["tls"] }

[dev-dependencies]
assert_cmd = "2.0.13"
Expand Down Expand Up @@ -190,15 +190,17 @@ dist-client = [
]
# Enables the sccache-dist binary
dist-server = [
"reqwest/blocking",
"jwt",
"flate2",
"libmount",
"nix",
"openssl",
"reqwest",
"rouille",
"syslog",
"version-compare",
"warp",
"thiserror",
]
# Enables dist tests with external requirements
dist-tests = ["dist-client", "dist-server"]
Expand Down
2 changes: 1 addition & 1 deletion src/bin/sccache-dist/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ impl OverlayBuilder {
for (tc, _) in entries {
warn!("Removing old un-compressed toolchain: {:?}", tc);
assert!(toolchain_dir_map.remove(tc).is_some());
fs::remove_dir_all(&self.dir.join("toolchains").join(&tc.archive_id))
fs::remove_dir_all(self.dir.join("toolchains").join(&tc.archive_id))
.context("Failed to remove old toolchain directory")?;
}
}
Expand Down
49 changes: 30 additions & 19 deletions src/bin/sccache-dist/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
extern crate log;

use anyhow::{bail, Context, Error, Result};
use async_trait::async_trait;
use base64::Engine;
use cmdline::{AuthSubcommand, Command};
use rand::{rngs::OsRng, RngCore};
use sccache::config::{
scheduler as scheduler_config, server as server_config, INSECURE_DIST_CLIENT_TOKEN,
Expand All @@ -22,17 +24,16 @@ use std::env;
use std::io;
use std::path::Path;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Mutex, MutexGuard};
use std::sync::{Arc, Mutex, MutexGuard};
use std::time::{Duration, Instant};
use tokio::runtime::Runtime;

#[cfg_attr(target_os = "freebsd", path = "build_freebsd.rs")]
mod build;

mod cmdline;
mod token_check;

use cmdline::{AuthSubcommand, Command};

pub const INSECURE_DIST_SERVER_TOKEN: &str = "dangerously_insecure_server";

// Only supported on x86_64 Linux machines and on FreeBSD
Expand Down Expand Up @@ -184,10 +185,10 @@ fn run(command: Command) -> Result<i32> {
scheduler_config::ServerAuth::Insecure => {
warn!("Scheduler starting with DANGEROUSLY_INSECURE server authentication");
let token = INSECURE_DIST_SERVER_TOKEN;
Box::new(move |server_token| check_server_token(server_token, token))
Arc::new(move |server_token| check_server_token(server_token, token))
}
scheduler_config::ServerAuth::Token { token } => {
Box::new(move |server_token| check_server_token(server_token, &token))
Arc::new(move |server_token| check_server_token(server_token, &token))
}
scheduler_config::ServerAuth::JwtHS256 { secret_key } => {
let secret_key = BASE64_URL_SAFE_ENGINE
Expand All @@ -203,7 +204,7 @@ fn run(command: Command) -> Result<i32> {
validation.validate_nbf = false;
validation
};
Box::new(move |server_token| {
Arc::new(move |server_token| {
check_jwt_server_token(server_token, &secret_key, &validation)
})
}
Expand All @@ -217,7 +218,10 @@ fn run(command: Command) -> Result<i32> {
check_client_auth,
check_server_auth,
);
http_scheduler.start()?;

// Create runtime after daemonize because Tokio doesn't work well with daemonize
let runtime = Runtime::new().context("Failed to create Tokio runtime")?;
runtime.block_on(async { http_scheduler.start().await })?;
unreachable!();
}

Expand Down Expand Up @@ -294,7 +298,8 @@ fn run(command: Command) -> Result<i32> {
server,
)
.context("Failed to create sccache HTTP server instance")?;
http_server.start()?;
let runtime = Runtime::new().context("Failed to create Tokio runtime")?;
runtime.block_on(async { http_server.start().await })?;
unreachable!();
}
}
Expand Down Expand Up @@ -399,8 +404,9 @@ impl Default for Scheduler {
}
}

#[async_trait]
impl SchedulerIncoming for Scheduler {
fn handle_alloc_job(
async fn handle_alloc_job(
&self,
requester: &dyn SchedulerOutgoing,
tc: Toolchain,
Expand Down Expand Up @@ -499,6 +505,7 @@ impl SchedulerIncoming for Scheduler {
need_toolchain,
} = requester
.do_assign_job(server_id, job_id, tc, auth.clone())
.await
.with_context(|| {
// LOCKS
let mut servers = self.servers.lock().unwrap();
Expand Down Expand Up @@ -717,7 +724,7 @@ impl SchedulerIncoming for Scheduler {
pub struct Server {
builder: Box<dyn BuilderIncoming>,
cache: Mutex<TcCache>,
job_toolchains: Mutex<HashMap<JobId, Toolchain>>,
job_toolchains: tokio::sync::Mutex<HashMap<JobId, Toolchain>>,
}

impl Server {
Expand All @@ -731,18 +738,19 @@ impl Server {
Ok(Server {
builder,
cache: Mutex::new(cache),
job_toolchains: Mutex::new(HashMap::new()),
job_toolchains: tokio::sync::Mutex::new(HashMap::new()),
})
}
}

#[async_trait]
impl ServerIncoming for Server {
fn handle_assign_job(&self, job_id: JobId, tc: Toolchain) -> Result<AssignJobResult> {
async fn handle_assign_job(&self, job_id: JobId, tc: Toolchain) -> Result<AssignJobResult> {
let need_toolchain = !self.cache.lock().unwrap().contains_toolchain(&tc);
assert!(self
.job_toolchains
.lock()
.unwrap()
.await
.insert(job_id, tc)
.is_none());
let state = if need_toolchain {
Expand All @@ -756,18 +764,19 @@ impl ServerIncoming for Server {
need_toolchain,
})
}
fn handle_submit_toolchain(
async fn handle_submit_toolchain(
&self,
requester: &dyn ServerOutgoing,
job_id: JobId,
tc_rdr: ToolchainReader,
tc_rdr: ToolchainReader<'_>,
) -> Result<SubmitToolchainResult> {
requester
.do_update_job_state(job_id, JobState::Ready)
.await
.context("Updating job state failed")?;
// TODO: need to lock the toolchain until the container has started
// TODO: can start prepping container
let tc = match self.job_toolchains.lock().unwrap().get(&job_id).cloned() {
let tc = match self.job_toolchains.lock().await.get(&job_id).cloned() {
Some(tc) => tc,
None => return Ok(SubmitToolchainResult::JobNotFound),
};
Expand All @@ -783,18 +792,19 @@ impl ServerIncoming for Server {
.map(|_| SubmitToolchainResult::Success)
.unwrap_or(SubmitToolchainResult::CannotCache))
}
fn handle_run_job(
async fn handle_run_job(
&self,
requester: &dyn ServerOutgoing,
job_id: JobId,
command: CompileCommand,
outputs: Vec<String>,
inputs_rdr: InputsReader,
inputs_rdr: InputsReader<'_>,
) -> Result<RunJobResult> {
requester
.do_update_job_state(job_id, JobState::Started)
.await
.context("Updating job state failed")?;
let tc = self.job_toolchains.lock().unwrap().remove(&job_id);
let tc = self.job_toolchains.lock().await.remove(&job_id);
let res = match tc {
None => Ok(RunJobResult::JobNotFound),
Some(tc) => {
Expand All @@ -812,6 +822,7 @@ impl ServerIncoming for Server {
};
requester
.do_update_job_state(job_id, JobState::Complete)
.await
.context("Updating job state failed")?;
res
}
Expand Down
41 changes: 25 additions & 16 deletions src/bin/sccache-dist/token_check.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
use anyhow::{bail, Context, Result};
use async_trait::async_trait;
use base64::Engine;
use sccache::dist::http::{ClientAuthCheck, ClientVisibleMsg};
use sccache::util::{new_reqwest_blocking_client, BASE64_URL_SAFE_ENGINE};
use sccache::util::new_reqwest_client;
use sccache::util::BASE64_URL_SAFE_ENGINE;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::result::Result as StdResult;
Expand Down Expand Up @@ -54,8 +56,9 @@ pub struct EqCheck {
s: String,
}

#[async_trait]
impl ClientAuthCheck for EqCheck {
fn check(&self, token: &str) -> StdResult<(), ClientVisibleMsg> {
async fn check(&self, token: &str) -> StdResult<(), ClientVisibleMsg> {
if self.s == token {
Ok(())
} else {
Expand All @@ -80,14 +83,15 @@ const MOZ_USERINFO_ENDPOINT: &str = "https://auth.mozilla.auth0.com/userinfo";
/// Mozilla-specific check by forwarding the token onto the auth0 userinfo endpoint
pub struct MozillaCheck {
// token, token_expiry
auth_cache: Mutex<HashMap<String, Instant>>,
client: reqwest::blocking::Client,
auth_cache: tokio::sync::Mutex<HashMap<String, Instant>>,
client: reqwest::Client,
required_groups: Vec<String>,
}

#[async_trait]
impl ClientAuthCheck for MozillaCheck {
fn check(&self, token: &str) -> StdResult<(), ClientVisibleMsg> {
self.check_mozilla(token).map_err(|e| {
async fn check(&self, token: &str) -> StdResult<(), ClientVisibleMsg> {
self.check_mozilla(token).await.map_err(|e| {
warn!("Mozilla token validation failed: {}", e);
ClientVisibleMsg::from_nonsensitive(
"Failed to validate Mozilla OAuth token, run sccache --dist-auth".to_owned(),
Expand All @@ -99,13 +103,13 @@ impl ClientAuthCheck for MozillaCheck {
impl MozillaCheck {
pub fn new(required_groups: Vec<String>) -> Self {
Self {
auth_cache: Mutex::new(HashMap::new()),
client: new_reqwest_blocking_client(),
auth_cache: tokio::sync::Mutex::new(HashMap::new()),
client: new_reqwest_client(),
required_groups,
}
}

fn check_mozilla(&self, token: &str) -> Result<()> {
async fn check_mozilla(&self, token: &str) -> Result<()> {
// azp == client_id
// {
// "iss": "https://auth.mozilla.auth0.com/",
Expand Down Expand Up @@ -139,7 +143,7 @@ impl MozillaCheck {
}

// If the token is cached and not expired, return it
let mut auth_cache = self.auth_cache.lock().unwrap();
let mut auth_cache = self.auth_cache.lock().await;
if let Some(cached_at) = auth_cache.get(token) {
if cached_at.elapsed() < MOZ_SESSION_TIMEOUT {
return Ok(());
Expand All @@ -158,10 +162,12 @@ impl MozillaCheck {
.get(url.clone())
.bearer_auth(token)
.send()
.await
.context("Failed to make request to mozilla userinfo")?;
let status = res.status();
let res_text = res
.text()
.await
.context("Failed to interpret response from mozilla userinfo as string")?;
if !status.is_success() {
bail!("JWT forwarded to {} returned {}: {}", url, status, res_text)
Expand Down Expand Up @@ -245,14 +251,15 @@ fn test_auth_verify_check_mozilla_profile() {
// Don't check a token is valid (it may not even be a JWT) just forward it to
// an API and check for success
pub struct ProxyTokenCheck {
client: reqwest::blocking::Client,
client: reqwest::Client,
maybe_auth_cache: Option<Mutex<(HashMap<String, Instant>, Duration)>>,
url: String,
}

#[async_trait]
impl ClientAuthCheck for ProxyTokenCheck {
fn check(&self, token: &str) -> StdResult<(), ClientVisibleMsg> {
match self.check_token_with_forwarding(token) {
async fn check(&self, token: &str) -> StdResult<(), ClientVisibleMsg> {
match self.check_token_with_forwarding(token).await {
Ok(()) => Ok(()),
Err(e) => {
warn!("Proxying token validation failed: {}", e);
Expand All @@ -269,13 +276,13 @@ impl ProxyTokenCheck {
let maybe_auth_cache: Option<Mutex<(HashMap<String, Instant>, Duration)>> =
cache_secs.map(|secs| Mutex::new((HashMap::new(), Duration::from_secs(secs))));
Self {
client: new_reqwest_blocking_client(),
client: new_reqwest_client(),
maybe_auth_cache,
url,
}
}

fn check_token_with_forwarding(&self, token: &str) -> Result<()> {
async fn check_token_with_forwarding(&self, token: &str) -> Result<()> {
trace!("Validating token by forwarding to {}", self.url);
// If the token is cached and not cache has not expired, return it
if let Some(ref auth_cache) = self.maybe_auth_cache {
Expand All @@ -294,6 +301,7 @@ impl ProxyTokenCheck {
.get(&self.url)
.bearer_auth(token)
.send()
.await
.context("Failed to make request to proxying url")?;
if !res.status().is_success() {
bail!("Token forwarded to {} returned {}", self.url, res.status());
Expand All @@ -315,8 +323,9 @@ pub struct ValidJWTCheck {
kid_to_pkcs1: HashMap<String, Vec<u8>>,
}

#[async_trait]
impl ClientAuthCheck for ValidJWTCheck {
fn check(&self, token: &str) -> StdResult<(), ClientVisibleMsg> {
async fn check(&self, token: &str) -> StdResult<(), ClientVisibleMsg> {
match self.check_jwt_validity(token) {
Ok(()) => Ok(()),
Err(e) => {
Expand Down
1 change: 0 additions & 1 deletion src/compiler/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1460,7 +1460,6 @@ mod test {
use std::io::{Cursor, Write};
use std::sync::Arc;
use std::time::Duration;
use std::u64;
use test_case::test_case;
use tokio::runtime::Runtime;

Expand Down
1 change: 0 additions & 1 deletion src/compiler/rust.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ use fs_err as fs;
use log::Level::Trace;
use once_cell::sync::Lazy;
#[cfg(feature = "dist-client")]
#[cfg(feature = "dist-client")]
use std::borrow::Borrow;
use std::borrow::Cow;
#[cfg(feature = "dist-client")]
Expand Down
Loading

0 comments on commit 621bec5

Please sign in to comment.