Skip to content

Commit

Permalink
Add configurable cors origin header to attestation endpoint (#263)
Browse files Browse the repository at this point in the history
* Add attestation cors to feature context

* Add cors headers to attestation response

* Pass feature context

* Clone feature context before passing into async block'

* Fix test

* Fix typo
  • Loading branch information
donaltuohy authored Jan 9, 2025
1 parent 660ecb0 commit a71df7c
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 4 deletions.
14 changes: 13 additions & 1 deletion data-plane/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,11 @@ impl From<ProvisionerContext> for EnclaveContext {
}
}

#[derive(Clone, Deserialize, Debug)]
pub struct AttestationCors {
pub origin: String,
}

#[derive(Clone, Deserialize, Debug)]
pub struct FeatureContext {
pub api_key_auth: bool,
Expand All @@ -157,6 +162,7 @@ pub struct FeatureContext {
pub trx_logging_enabled: bool,
pub forward_proxy_protocol: bool,
pub trusted_headers: Vec<String>,
pub attestation_cors: Option<AttestationCors>,
#[cfg(feature = "network_egress")]
pub egress: EgressConfig,
}
Expand Down Expand Up @@ -196,14 +202,19 @@ mod test {
#[cfg(not(feature = "network_egress"))]
#[test]
fn test_config_deserialization_without_proxy_protocol() {
let raw_feature_context = r#"{ "api_key_auth": true, "trx_logging_enabled": false, "forward_proxy_protocol": false, "trusted_headers": [] }"#;
let raw_feature_context = r#"{ "api_key_auth": true, "attestation_cors": { "origin": "test.com" }, "trx_logging_enabled": false, "forward_proxy_protocol": false, "trusted_headers": [] }"#;
let parsed = serde_json::from_str(raw_feature_context);
assert!(parsed.is_ok());
let feature_context: FeatureContext = parsed.unwrap();
assert_eq!(feature_context.api_key_auth, true);
assert_eq!(feature_context.trx_logging_enabled, false);
assert_eq!(feature_context.forward_proxy_protocol, false);
assert!(feature_context.healthcheck.is_none());
assert!(feature_context.attestation_cors.is_some());
assert_eq!(
feature_context.attestation_cors.unwrap().origin,
"test.com".to_string()
);
}

#[cfg(not(feature = "network_egress"))]
Expand All @@ -217,6 +228,7 @@ mod test {
assert_eq!(feature_context.trx_logging_enabled, false);
assert_eq!(feature_context.forward_proxy_protocol, false);
assert_eq!(feature_context.healthcheck, Some("/health".into()));
assert!(feature_context.attestation_cors.is_none());
}

#[cfg(feature = "network_egress")]
Expand Down
29 changes: 27 additions & 2 deletions data-plane/src/server/layers/attest.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,34 @@ use hyper::Body;
use serde::{Deserialize, Serialize};
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use tower::{Layer, Service};

use crate::cache::ATTESTATION_DOC;
use crate::crypto::attest;
use crate::server::http::build_internal_error_response;
use crate::server::tls::TRUSTED_PUB_CERT;
use crate::FeatureContext;

#[derive(Clone)]
pub struct AttestLayer;
pub struct AttestLayer {
feature_context: Arc<FeatureContext>,
}

impl AttestLayer {
pub fn new(feature_context: Arc<FeatureContext>) -> Self {
Self { feature_context }
}
}

impl<S> Layer<S> for AttestLayer {
type Service = AttestService<S>;

fn layer(&self, inner: S) -> Self::Service {
AttestService { inner }
AttestService {
feature_context: self.feature_context.clone(),
inner,
}
}
}

Expand All @@ -29,6 +42,7 @@ struct AttestationResponse {

#[derive(Clone)]
pub struct AttestService<S> {
feature_context: Arc<FeatureContext>,
inner: S,
}

Expand Down Expand Up @@ -59,6 +73,8 @@ where
return Box::pin(inner.call(req));
}

let feature_context = self.feature_context.clone();

Box::pin(async move {
let attestation_doc_key: String = "attestation_doc".to_string();
let mut cache = ATTESTATION_DOC.lock().await;
Expand Down Expand Up @@ -86,10 +102,19 @@ where
};

let response_payload = serde_json::to_string(&response).expect("Infallible");
let cors_origin = feature_context
.attestation_cors
.as_ref()
.map_or("*", |cors| cors.origin.as_str());

let attestation_response = Response::builder()
.status(200)
.header(hyper::http::header::CONTENT_TYPE, "application/json")
.header(hyper::http::header::CONTENT_LENGTH, response_payload.len())
.header(
hyper::http::header::ACCESS_CONTROL_ALLOW_ORIGIN,
cors_origin,
)
.body(Body::from(response_payload))
.unwrap_or_else(|e| build_internal_error_response(Some(e.to_string())));

Expand Down
2 changes: 1 addition & 1 deletion data-plane/src/server/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ where

// Only apply attestation layer in enclave mode
#[cfg(feature = "enclave")]
let service_builder = service_builder.layer(AttestLayer);
let service_builder = service_builder.layer(AttestLayer::new(feature_context.clone()));

// layers are invoked in the order that they're registered to the service
let service = service_builder
Expand Down

0 comments on commit a71df7c

Please sign in to comment.