From d940c3fb76df1b89409720c22dc9f0df5402f172 Mon Sep 17 00:00:00 2001 From: Coenen Benjamin Date: Tue, 30 Jul 2024 10:14:58 +0200 Subject: [PATCH] Create the invalidation endpoint for entity caching (#5614) Signed-off-by: Benjamin Coenen <5719034+bnjjj@users.noreply.github.com> Co-authored-by: Geoffroy Couprie --- apollo-router/src/configuration/schema.rs | 8 +- ...nfiguration__tests__schema_generation.snap | 47 +- ...nfiguration@entity_cache_preview.yaml.snap | 3 + .../testdata/metrics/entities.router.yaml | 3 + .../migrations/entity_cache_preview.yaml | 3 + apollo-router/src/notification.rs | 4 +- apollo-router/src/plugins/cache/entity.rs | 102 +++- .../src/plugins/cache/invalidation.rs | 176 ++++-- .../plugins/cache/invalidation_endpoint.rs | 569 ++++++++++++++++++ apollo-router/src/plugins/cache/mod.rs | 1 + apollo-router/src/plugins/cache/tests.rs | 6 +- apollo-router/src/plugins/subscription.rs | 2 +- apollo-router/src/router_factory.rs | 1 - .../uplink/testdata/restricted.router.yaml | 3 + apollo-router/tests/integration/redis.rs | 12 + .../configuration.yaml | 7 + .../invalidation-subgraph-type/skipped.json | 38 +- .../invalidation-subgraph/configuration.yaml | 3 + apollo-router/tests/samples_tests.rs | 52 +- 19 files changed, 944 insertions(+), 96 deletions(-) create mode 100644 apollo-router/src/plugins/cache/invalidation_endpoint.rs diff --git a/apollo-router/src/configuration/schema.rs b/apollo-router/src/configuration/schema.rs index a78015ab63..4d05b786ef 100644 --- a/apollo-router/src/configuration/schema.rs +++ b/apollo-router/src/configuration/schema.rs @@ -161,8 +161,12 @@ pub(crate) fn validate_yaml_configuration( let offset = start_marker .line() .saturating_sub(NUMBER_OF_PREVIOUS_LINES_TO_DISPLAY); - - let lines = yaml_split_by_lines[offset..end_marker.line()] + let end = if end_marker.line() > yaml_split_by_lines.len() { + yaml_split_by_lines.len() + } else { + end_marker.line() + }; + let lines = yaml_split_by_lines[offset..end] .iter() .map(|line| format!(" {line}")) .join("\n"); diff --git a/apollo-router/src/configuration/snapshots/apollo_router__configuration__tests__schema_generation.snap b/apollo-router/src/configuration/snapshots/apollo_router__configuration__tests__schema_generation.snap index 1f66e536ba..a702a933cd 100644 --- a/apollo-router/src/configuration/snapshots/apollo_router__configuration__tests__schema_generation.snap +++ b/apollo-router/src/configuration/snapshots/apollo_router__configuration__tests__schema_generation.snap @@ -1606,6 +1606,11 @@ expression: "&schema" "description": "Enable or disable the entity caching feature", "type": "boolean" }, + "invalidation": { + "$ref": "#/definitions/InvalidationEndpointConfig", + "description": "#/definitions/InvalidationEndpointConfig", + "nullable": true + }, "metrics": { "$ref": "#/definitions/Metrics", "description": "#/definitions/Metrics" @@ -3518,6 +3523,24 @@ expression: "&schema" }, "type": "object" }, + "InvalidationEndpointConfig": { + "additionalProperties": false, + "properties": { + "listen": { + "$ref": "#/definitions/ListenAddr", + "description": "#/definitions/ListenAddr" + }, + "path": { + "description": "Specify on which path you want to listen for invalidation endpoint.", + "type": "string" + } + }, + "required": [ + "listen", + "path" + ], + "type": "object" + }, "JWTConf": { "additionalProperties": false, "properties": { @@ -5571,11 +5594,17 @@ expression: "&schema" "description": "Per subgraph configuration for entity caching", "properties": { "enabled": { + "default": true, "description": "activates caching for this subgraph, overrides the global configuration", - "nullable": true, "type": "boolean" }, + "invalidation": { + "$ref": "#/definitions/SubgraphInvalidationConfig", + "description": "#/definitions/SubgraphInvalidationConfig", + "nullable": true + }, "private_id": { + "default": null, "description": "Context key used to separate cache sections per user", "nullable": true, "type": "string" @@ -5779,6 +5808,22 @@ expression: "&schema" }, "type": "object" }, + "SubgraphInvalidationConfig": { + "additionalProperties": false, + "properties": { + "enabled": { + "default": false, + "description": "Enable the invalidation", + "type": "boolean" + }, + "shared_key": { + "default": "", + "description": "Shared key needed to request the invalidation endpoint", + "type": "string" + } + }, + "type": "object" + }, "SubgraphPassthroughMode": { "additionalProperties": false, "properties": { diff --git a/apollo-router/src/configuration/snapshots/apollo_router__configuration__tests__upgrade_old_configuration@entity_cache_preview.yaml.snap b/apollo-router/src/configuration/snapshots/apollo_router__configuration__tests__upgrade_old_configuration@entity_cache_preview.yaml.snap index 08bc3e55b9..5544788d20 100644 --- a/apollo-router/src/configuration/snapshots/apollo_router__configuration__tests__upgrade_old_configuration@entity_cache_preview.yaml.snap +++ b/apollo-router/src/configuration/snapshots/apollo_router__configuration__tests__upgrade_old_configuration@entity_cache_preview.yaml.snap @@ -10,6 +10,9 @@ preview_entity_cache: timeout: 5ms ttl: 60s enabled: true + invalidation: + listen: "127.0.0.1:4000" + path: /invalidation subgraph: subgraphs: accounts: diff --git a/apollo-router/src/configuration/testdata/metrics/entities.router.yaml b/apollo-router/src/configuration/testdata/metrics/entities.router.yaml index 8c810effa7..0c886c2d64 100644 --- a/apollo-router/src/configuration/testdata/metrics/entities.router.yaml +++ b/apollo-router/src/configuration/testdata/metrics/entities.router.yaml @@ -4,6 +4,9 @@ preview_entity_cache: urls: [ "redis://localhost:6379" ] timeout: 5ms ttl: 60s + invalidation: + listen: 127.0.0.1:4000 + path: /invalidation subgraph: all: enabled: true diff --git a/apollo-router/src/configuration/testdata/migrations/entity_cache_preview.yaml b/apollo-router/src/configuration/testdata/migrations/entity_cache_preview.yaml index 2539a571ce..c210551098 100644 --- a/apollo-router/src/configuration/testdata/migrations/entity_cache_preview.yaml +++ b/apollo-router/src/configuration/testdata/migrations/entity_cache_preview.yaml @@ -4,6 +4,9 @@ preview_entity_cache: timeout: 5ms ttl: 60s enabled: true + invalidation: + listen: 127.0.0.1:4000 + path: /invalidation subgraphs: accounts: enabled: false diff --git a/apollo-router/src/notification.rs b/apollo-router/src/notification.rs index 77aff5db43..7cfba87e7a 100644 --- a/apollo-router/src/notification.rs +++ b/apollo-router/src/notification.rs @@ -807,6 +807,7 @@ where } #[allow(clippy::collapsible_if)] if topic_to_delete { + tracing::trace!("deleting subscription from unsubscribe"); if self.subscriptions.remove(&topic).is_some() { i64_up_down_counter!( "apollo_router_opened_subscriptions", @@ -880,6 +881,7 @@ where // Send error message to all killed connections for (_subscriber_id, subscription) in closed_subs { + tracing::trace!("deleting subscription from kill_dead_topics"); i64_up_down_counter!( "apollo_router_opened_subscriptions", "Number of opened subscriptions", @@ -907,7 +909,7 @@ where } fn force_delete(&mut self, topic: K) { - tracing::trace!("deleting subscription"); + tracing::trace!("deleting subscription from force_delete"); let sub = self.subscriptions.remove(&topic); if let Some(sub) = sub { i64_up_down_counter!( diff --git a/apollo-router/src/plugins/cache/entity.rs b/apollo-router/src/plugins/cache/entity.rs index 2375d4fde4..7332ca0d10 100644 --- a/apollo-router/src/plugins/cache/entity.rs +++ b/apollo-router/src/plugins/cache/entity.rs @@ -7,6 +7,7 @@ use std::time::Duration; use http::header; use http::header::CACHE_CONTROL; +use multimap::MultiMap; use schemars::JsonSchema; use serde::Deserialize; use serde::Serialize; @@ -26,6 +27,9 @@ use tracing::Level; use super::cache_control::CacheControl; use super::invalidation::Invalidation; use super::invalidation::InvalidationOrigin; +use super::invalidation_endpoint::InvalidationEndpointConfig; +use super::invalidation_endpoint::InvalidationService; +use super::invalidation_endpoint::SubgraphInvalidationConfig; use super::metrics::CacheMetricContextKey; use super::metrics::CacheMetricsService; use crate::batching::BatchQuery; @@ -49,6 +53,8 @@ use crate::services::subgraph; use crate::services::supergraph; use crate::spec::TYPENAME; use crate::Context; +use crate::Endpoint; +use crate::ListenAddr; /// Change this key if you introduce a breaking change in entity caching algorithm to make sure it won't take the previous entries pub(crate) const ENTITY_CACHE_VERSION: &str = "1.0"; @@ -61,6 +67,7 @@ register_plugin!("apollo", "preview_entity_cache", EntityCache); #[derive(Clone)] pub(crate) struct EntityCache { storage: Option, + endpoint_config: Option>, subgraphs: Arc>, entity_type: Option, enabled: bool, @@ -78,25 +85,43 @@ pub(crate) struct Config { #[serde(default)] enabled: bool, + /// Configure invalidation per subgraph subgraph: SubgraphConfiguration, + /// Global invalidation configuration + invalidation: Option, + /// Entity caching evaluation metrics #[serde(default)] metrics: Metrics, } /// Per subgraph configuration for entity caching -#[derive(Clone, Debug, Default, JsonSchema, Deserialize, Serialize)] -#[serde(rename_all = "snake_case", deny_unknown_fields)] +#[derive(Clone, Debug, JsonSchema, Deserialize, Serialize)] +#[serde(rename_all = "snake_case", deny_unknown_fields, default)] pub(crate) struct Subgraph { /// expiration for all keys for this subgraph, unless overriden by the `Cache-Control` header in subgraph responses pub(crate) ttl: Option, /// activates caching for this subgraph, overrides the global configuration - pub(crate) enabled: Option, + pub(crate) enabled: bool, /// Context key used to separate cache sections per user pub(crate) private_id: Option, + + /// Invalidation configuration + pub(crate) invalidation: Option, +} + +impl Default for Subgraph { + fn default() -> Self { + Self { + enabled: true, + ttl: Default::default(), + private_id: Default::default(), + invalidation: Default::default(), + } + } } /// Per subgraph configuration for entity caching @@ -179,12 +204,29 @@ impl Plugin for EntityCache { .into()); } + if init + .config + .subgraph + .all + .invalidation + .as_ref() + .map(|i| i.shared_key.is_empty()) + .unwrap_or_default() + { + return Err( + "you must set a default shared_key invalidation for all subgraphs" + .to_string() + .into(), + ); + } + let invalidation = Invalidation::new(storage.clone()).await?; Ok(Self { storage, entity_type, enabled: init.config.enabled, + endpoint_config: init.config.invalidation.clone().map(Arc::new), subgraphs: Arc::new(init.config.subgraph), metrics: init.config.metrics, private_queries: Arc::new(RwLock::new(HashSet::new())), @@ -240,13 +282,8 @@ impl Plugin for EntityCache { .clone() .map(|t| t.0) .or_else(|| storage.ttl()); - let subgraph_enabled = self.enabled - && self - .subgraphs - .get(name) - .enabled - // if the top level `enabled` is true but there is no other configuration, caching is enabled for this plugin - .unwrap_or(true); + let subgraph_enabled = + self.enabled && (self.subgraphs.all.enabled || self.subgraphs.get(name).enabled); let private_id = self.subgraphs.get(name).private_id.clone(); let name = name.to_string(); @@ -300,6 +337,40 @@ impl Plugin for EntityCache { .boxed() } } + + fn web_endpoints(&self) -> MultiMap { + let mut map = MultiMap::new(); + if self.enabled + && self + .subgraphs + .all + .invalidation + .as_ref() + .map(|i| i.enabled) + .unwrap_or_default() + { + match &self.endpoint_config { + Some(endpoint_config) => { + let endpoint = Endpoint::from_router_service( + endpoint_config.path.clone(), + InvalidationService::new(self.subgraphs.clone(), self.invalidation.clone()) + .boxed(), + ); + tracing::info!( + "Entity caching invalidation endpoint listening on: {}{}", + endpoint_config.listen, + endpoint_config.path + ); + map.insert(endpoint_config.listen.clone(), endpoint); + } + None => { + tracing::warn!("Cannot start entity caching invalidation endpoint because the listen address and endpoint is not configured"); + } + } + } + + map + } } impl EntityCache { @@ -311,6 +382,10 @@ impl EntityCache { where Self: Sized, { + use std::net::IpAddr; + use std::net::Ipv4Addr; + use std::net::SocketAddr; + let invalidation = Invalidation::new(Some(storage.clone())).await?; Ok(Self { storage: Some(storage), @@ -322,6 +397,13 @@ impl EntityCache { }), metrics: Metrics::default(), private_queries: Default::default(), + endpoint_config: Some(Arc::new(InvalidationEndpointConfig { + path: String::from("/invalidation"), + listen: ListenAddr::SocketAddr(SocketAddr::new( + IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), + 4000, + )), + })), invalidation, }) } diff --git a/apollo-router/src/plugins/cache/invalidation.rs b/apollo-router/src/plugins/cache/invalidation.rs index 96c863e437..4e8e5d5204 100644 --- a/apollo-router/src/plugins/cache/invalidation.rs +++ b/apollo-router/src/plugins/cache/invalidation.rs @@ -1,11 +1,15 @@ use std::time::Instant; +use fred::error::RedisError; use fred::types::Scanner; use futures::SinkExt; use futures::StreamExt; +use itertools::Itertools; use serde::Deserialize; use serde::Serialize; use serde_json_bytes::Value; +use thiserror::Error; +use tokio::sync::broadcast; use tower::BoxError; use tracing::Instrument; @@ -19,15 +23,48 @@ use crate::Notify; #[derive(Clone)] pub(crate) struct Invalidation { - enabled: bool, - handle: Handle)>, + pub(super) enabled: bool, + #[allow(clippy::type_complexity)] + pub(super) handle: Handle< + InvalidationTopic, + ( + Vec, + InvalidationOrigin, + broadcast::Sender>, + ), + >, } +#[derive(Error, Debug, Clone)] +pub(crate) enum InvalidationError { + #[error("redis error")] + RedisError(#[from] RedisError), + #[error("several errors")] + Errors(#[from] InvalidationErrors), + #[cfg(test)] + #[error("custom error: {0}")] + Custom(String), +} + +#[derive(Debug, Clone)] +pub(crate) struct InvalidationErrors(Vec); + +impl std::fmt::Display for InvalidationErrors { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "invalidation errors: [{}]", + self.0.iter().map(|e| e.to_string()).join("; ") + ) + } +} + +impl std::error::Error for InvalidationErrors {} + #[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)] pub(crate) struct InvalidationTopic; -#[derive(Clone, Debug)] -#[allow(dead_code)] +#[derive(Clone, Debug, PartialEq)] pub(crate) enum InvalidationOrigin { Endpoint, Extensions, @@ -38,10 +75,12 @@ impl Invalidation { let mut notify = Notify::new(None, None, None); let (handle, _b) = notify.create_or_subscribe(InvalidationTopic, false).await?; let enabled = storage.is_some(); - if let Some(storage) = storage { + if let Some(storage) = storage.clone() { let h = handle.clone(); - tokio::task::spawn(async move { start(storage, h.into_stream()).await }); + tokio::task::spawn(async move { + start(storage, h.into_stream()).await; + }); } Ok(Self { enabled, handle }) } @@ -50,21 +89,46 @@ impl Invalidation { &mut self, origin: InvalidationOrigin, requests: Vec, - ) -> Result<(), BoxError> { + ) -> Result { if self.enabled { let mut sink = self.handle.clone().into_sink(); - sink.send((origin, requests)).await.map_err(|e| e.message)?; - } + let (response_tx, mut response_rx) = broadcast::channel(2); + sink.send((requests, origin, response_tx.clone())) + .await + .map_err(|e| format!("cannot send invalidation request: {}", e.message))?; + + let result = response_rx + .recv() + .await + .map_err(|err| { + format!( + "cannot receive response for invalidation request: {:?}", + err + ) + })? + .map_err(|err| format!("received an invalidation error: {:?}", err))?; - Ok(()) + Ok(result) + } else { + Ok(0) + } } } +// TODO refactor +#[allow(clippy::type_complexity)] async fn start( storage: RedisCacheStorage, - mut handle: HandleStream)>, + mut handle: HandleStream< + InvalidationTopic, + ( + Vec, + InvalidationOrigin, + broadcast::Sender>, + ), + >, ) { - while let Some((origin, requests)) = handle.next().await { + while let Some((requests, origin, response_tx)) = handle.next().await { let origin = match origin { InvalidationOrigin::Endpoint => "endpoint", InvalidationOrigin::Extensions => "extensions", @@ -75,30 +139,16 @@ async fn start( 1u64, "origin" = origin ); - handle_request_batch(&storage, origin, requests) - .instrument(tracing::info_span!( - "cache.invalidation.batch", - "origin" = origin - )) - .await - } -} - -async fn handle_request_batch( - storage: &RedisCacheStorage, - origin: &'static str, - requests: Vec, -) { - for request in requests { - let start = Instant::now(); - handle_request(storage, origin, &request) - .instrument(tracing::info_span!("cache.invalidation.request")) - .await; - f64_histogram!( - "apollo.router.cache.invalidation.duration", - "Duration of the invalidation event execution.", - start.elapsed().as_secs_f64() - ); + if let Err(err) = response_tx.send( + handle_request_batch(&storage, origin, requests) + .instrument(tracing::info_span!( + "cache.invalidation.batch", + "origin" = origin + )) + .await, + ) { + ::tracing::error!("cannot send answer to invalidation request in the channel: {err}"); + } } } @@ -106,9 +156,9 @@ async fn handle_request( storage: &RedisCacheStorage, origin: &'static str, request: &InvalidationRequest, -) { +) -> Result { let key_prefix = request.key_prefix(); - let subgraph = request.subgraph(); + let subgraph = request.subgraph_name(); tracing::debug!( "got invalidation request: {request:?}, will scan for: {}", key_prefix @@ -117,6 +167,7 @@ async fn handle_request( // FIXME: configurable batch size let mut stream = storage.scan(key_prefix.clone(), Some(10)); let mut count = 0u64; + let mut error = None; while let Some(res) = stream.next().await { match res { @@ -126,6 +177,7 @@ async fn handle_request( error = %e, message = "error scanning for key", ); + error = Some(e); break; } Ok(scan_res) => { @@ -158,9 +210,46 @@ async fn handle_request( "Number of invalidated keys.", count ); + + match error { + Some(err) => Err(err.into()), + None => Ok(count), + } +} + +async fn handle_request_batch( + storage: &RedisCacheStorage, + origin: &'static str, + requests: Vec, +) -> Result { + let mut count = 0; + let mut errors = Vec::new(); + for request in requests { + let start = Instant::now(); + match handle_request(storage, origin, &request) + .instrument(tracing::info_span!("cache.invalidation.request")) + .await + { + Ok(c) => count += c, + Err(err) => { + errors.push(err); + } + } + f64_histogram!( + "apollo.router.cache.invalidation.duration", + "Duration of the invalidation event execution.", + start.elapsed().as_secs_f64() + ); + } + + if !errors.is_empty() { + Err(InvalidationErrors(errors).into()) + } else { + Ok(count) + } } -#[derive(Clone, Debug, Serialize, Deserialize)] +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] #[serde(tag = "kind", rename_all = "lowercase")] pub(crate) enum InvalidationRequest { Subgraph { @@ -197,12 +286,11 @@ impl InvalidationRequest { } } - fn subgraph(&self) -> String { + pub(super) fn subgraph_name(&self) -> &String { match self { - InvalidationRequest::Subgraph { subgraph } => subgraph.clone(), - _ => { - todo!() - } + InvalidationRequest::Subgraph { subgraph } + | InvalidationRequest::Type { subgraph, .. } + | InvalidationRequest::Entity { subgraph, .. } => subgraph, } } } diff --git a/apollo-router/src/plugins/cache/invalidation_endpoint.rs b/apollo-router/src/plugins/cache/invalidation_endpoint.rs new file mode 100644 index 0000000000..424751c830 --- /dev/null +++ b/apollo-router/src/plugins/cache/invalidation_endpoint.rs @@ -0,0 +1,569 @@ +use std::sync::Arc; +use std::task::Poll; + +use bytes::Buf; +use futures::future::BoxFuture; +use http::header::AUTHORIZATION; +use http::Method; +use http::StatusCode; +use schemars::JsonSchema; +use serde::Deserialize; +use serde::Serialize; +use serde_json_bytes::json; +use tower::BoxError; +use tower::Service; +use tracing_futures::Instrument; + +use super::entity::Subgraph; +use super::invalidation::Invalidation; +use super::invalidation::InvalidationOrigin; +use crate::configuration::subgraph::SubgraphConfiguration; +use crate::plugins::cache::invalidation::InvalidationRequest; +use crate::services::router; +use crate::services::router::body::RouterBody; +use crate::ListenAddr; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema, Default)] +#[serde(rename_all = "snake_case", deny_unknown_fields, default)] +pub(crate) struct SubgraphInvalidationConfig { + /// Enable the invalidation + pub(crate) enabled: bool, + /// Shared key needed to request the invalidation endpoint + pub(crate) shared_key: String, +} + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)] +#[serde(rename_all = "snake_case", deny_unknown_fields)] +pub(crate) struct InvalidationEndpointConfig { + /// Specify on which path you want to listen for invalidation endpoint. + pub(crate) path: String, + /// Listen address on which the invalidation endpoint must listen. + pub(crate) listen: ListenAddr, +} + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)] +#[serde(rename_all = "camelCase")] +pub(crate) enum InvalidationType { + EntityType, +} + +#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)] +#[serde(rename_all = "camelCase")] +pub(crate) struct InvalidationKey { + pub(crate) id: String, + pub(crate) field: String, +} + +#[derive(Clone)] +pub(crate) struct InvalidationService { + config: Arc>, + invalidation: Invalidation, +} + +impl InvalidationService { + pub(crate) fn new( + config: Arc>, + invalidation: Invalidation, + ) -> Self { + Self { + config, + invalidation, + } + } +} + +impl Service for InvalidationService { + type Response = router::Response; + type Error = BoxError; + type Future = BoxFuture<'static, Result>; + + fn poll_ready(&mut self, _: &mut std::task::Context<'_>) -> Poll> { + Ok(()).into() + } + + fn call(&mut self, req: router::Request) -> Self::Future { + let mut invalidation = self.invalidation.clone(); + let config = self.config.clone(); + Box::pin( + async move { + let (parts, body) = req.router_request.into_parts(); + if !parts.headers.contains_key(AUTHORIZATION) { + return Ok(router::Response { + response: http::Response::builder() + .status(StatusCode::UNAUTHORIZED) + .body("Missing authorization header".into()) + .map_err(BoxError::from)?, + context: req.context, + }); + } + match parts.method { + Method::POST => { + let body = Into::::into(body) + .to_bytes() + .await + .map_err(|e| format!("failed to get the request body: {e}")) + .and_then(|bytes| { + serde_json::from_reader::<_, Vec>( + bytes.reader(), + ) + .map_err(|err| { + format!( + "failed to deserialize the request body into JSON: {err}" + ) + }) + }); + let shared_key = parts + .headers + .get(AUTHORIZATION) + .ok_or("cannot find authorization header")? + .to_str()?; + match body { + Ok(body) => { + let valid_shared_key = + body.iter().map(|b| b.subgraph_name()).any(|subgraph_name| { + valid_shared_key(&config, shared_key, subgraph_name) + }); + if !valid_shared_key { + return Ok(router::Response { + response: http::Response::builder() + .status(StatusCode::UNAUTHORIZED) + .body("Invalid authorization header".into()) + .map_err(BoxError::from)?, + context: req.context, + }); + } + match invalidation + .invalidate(InvalidationOrigin::Endpoint, body) + .await + { + Ok(count) => Ok(router::Response { + response: http::Response::builder() + .status(StatusCode::ACCEPTED) + .body( + serde_json::to_string(&json!({ + "count": count + }))? + .into(), + ) + .map_err(BoxError::from)?, + context: req.context, + }), + Err(err) => Ok(router::Response { + response: http::Response::builder() + .status(StatusCode::BAD_REQUEST) + .body(err.to_string().into()) + .map_err(BoxError::from)?, + context: req.context, + }), + } + } + Err(err) => Ok(router::Response { + response: http::Response::builder() + .status(StatusCode::BAD_REQUEST) + .body(err.into()) + .map_err(BoxError::from)?, + context: req.context, + }), + } + } + _ => Ok(router::Response { + response: http::Response::builder() + .status(StatusCode::METHOD_NOT_ALLOWED) + .body("".into()) + .map_err(BoxError::from)?, + context: req.context, + }), + } + } + .instrument(tracing::info_span!("invalidation_endpoint")), + ) + } +} + +fn valid_shared_key( + config: &SubgraphConfiguration, + shared_key: &str, + subgraph_name: &str, +) -> bool { + config + .all + .invalidation + .as_ref() + .map(|i| i.shared_key == shared_key) + .unwrap_or_default() + || config + .subgraphs + .get(subgraph_name) + .and_then(|s| s.invalidation.as_ref()) + .map(|i| i.shared_key == shared_key) + .unwrap_or_default() +} + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + + use tokio::sync::broadcast::Sender; + use tokio_stream::StreamExt; + use tower::ServiceExt; + + use super::*; + use crate::plugins::cache::invalidation::InvalidationError; + use crate::plugins::cache::invalidation::InvalidationTopic; + use crate::Notify; + + #[tokio::test] + async fn test_invalidation_service_bad_shared_key() { + #[allow(clippy::type_complexity)] + let mut notify: Notify< + InvalidationTopic, + ( + Vec, + InvalidationOrigin, + Sender>, + ), + > = Notify::new(None, None, None); + let (handle, _b) = notify + .create_or_subscribe(InvalidationTopic, false) + .await + .unwrap(); + let invalidation = Invalidation { + enabled: true, + handle, + }; + let config = Arc::new(SubgraphConfiguration { + all: Subgraph { + ttl: None, + enabled: true, + private_id: None, + invalidation: Some(SubgraphInvalidationConfig { + enabled: true, + shared_key: String::from("test"), + }), + }, + subgraphs: HashMap::new(), + }); + let service = InvalidationService::new(config, invalidation); + let req = router::Request::fake_builder() + .method(http::Method::POST) + .header(AUTHORIZATION, "testttt") + .body( + serde_json::to_vec(&[ + InvalidationRequest::Subgraph { + subgraph: String::from("test"), + }, + InvalidationRequest::Type { + subgraph: String::from("test"), + r#type: String::from("Test"), + }, + ]) + .unwrap(), + ) + .build() + .unwrap(); + let res = service.oneshot(req).await.unwrap(); + assert_eq!(res.response.status(), StatusCode::UNAUTHORIZED); + } + + #[tokio::test] + async fn test_invalidation_service_good_sub_shared_key() { + #[allow(clippy::type_complexity)] + let mut notify: Notify< + InvalidationTopic, + ( + Vec, + InvalidationOrigin, + Sender>, + ), + > = Notify::new(None, None, None); + let (handle, _b) = notify + .create_or_subscribe(InvalidationTopic, false) + .await + .unwrap(); + let h = handle.clone(); + + tokio::task::spawn(async move { + let mut handle = h.into_stream(); + let mut called = false; + while let Some((requests, origin, response_tx)) = handle.next().await { + called = true; + if requests + != [ + InvalidationRequest::Subgraph { + subgraph: String::from("test"), + }, + InvalidationRequest::Type { + subgraph: String::from("test"), + r#type: String::from("Test"), + }, + ] + { + response_tx + .send(Err(InvalidationError::Custom(format!( + "it's not the right invalidation requests : {requests:?}" + )))) + .unwrap(); + return; + } + if origin != InvalidationOrigin::Endpoint { + response_tx + .send(Err(InvalidationError::Custom(format!( + "it's not the right invalidation origin : {origin:?}" + )))) + .unwrap(); + return; + } + response_tx.send(Ok(0)).unwrap(); + } + assert!(called); + }); + + let invalidation = Invalidation { + enabled: true, + handle: handle.clone(), + }; + let config = Arc::new(SubgraphConfiguration { + all: Subgraph { + ttl: None, + enabled: true, + private_id: None, + invalidation: Some(SubgraphInvalidationConfig { + enabled: true, + shared_key: String::from("test"), + }), + }, + subgraphs: [( + String::from("test"), + Subgraph { + ttl: None, + enabled: true, + private_id: None, + invalidation: Some(SubgraphInvalidationConfig { + enabled: true, + shared_key: String::from("test_test"), + }), + }, + )] + .into_iter() + .collect(), + }); + let service = InvalidationService::new(config, invalidation); + let req = router::Request::fake_builder() + .method(http::Method::POST) + .header(AUTHORIZATION, "test_test") + .body( + serde_json::to_vec(&[ + InvalidationRequest::Subgraph { + subgraph: String::from("test"), + }, + InvalidationRequest::Type { + subgraph: String::from("test"), + r#type: String::from("Test"), + }, + ]) + .unwrap(), + ) + .build() + .unwrap(); + let res = service.oneshot(req).await.unwrap(); + assert_eq!(res.response.status(), StatusCode::ACCEPTED); + let h = handle.clone(); + + tokio::task::spawn(async move { + let mut handle = h.into_stream(); + let mut called = false; + while let Some((requests, origin, response_tx)) = handle.next().await { + called = true; + if requests + != [ + InvalidationRequest::Subgraph { + subgraph: String::from("test"), + }, + InvalidationRequest::Type { + subgraph: String::from("test"), + r#type: String::from("Test"), + }, + ] + { + response_tx + .send(Err(InvalidationError::Custom(format!( + "it's not the right invalidation requests : {requests:?}" + )))) + .unwrap(); + return; + } + if origin != InvalidationOrigin::Endpoint { + response_tx + .send(Err(InvalidationError::Custom(format!( + "it's not the right invalidation origin : {origin:?}" + )))) + .unwrap(); + return; + } + response_tx.send(Ok(0)).unwrap(); + } + assert!(called); + }); + } + + #[tokio::test] + async fn test_invalidation_service_bad_shared_key_subgraph() { + #[allow(clippy::type_complexity)] + let mut notify: Notify< + InvalidationTopic, + ( + Vec, + InvalidationOrigin, + Sender>, + ), + > = Notify::new(None, None, None); + let (handle, _b) = notify + .create_or_subscribe(InvalidationTopic, false) + .await + .unwrap(); + let invalidation = Invalidation { + enabled: true, + handle, + }; + let config = Arc::new(SubgraphConfiguration { + all: Subgraph { + ttl: None, + enabled: true, + private_id: None, + invalidation: Some(SubgraphInvalidationConfig { + enabled: true, + shared_key: String::from("test"), + }), + }, + subgraphs: [( + String::from("test"), + Subgraph { + ttl: None, + enabled: true, + private_id: None, + invalidation: Some(SubgraphInvalidationConfig { + enabled: true, + shared_key: String::from("test_test"), + }), + }, + )] + .into_iter() + .collect(), + }); + // Trying to invalidation with shared_key on subgraph test for a subgraph foo + let service = InvalidationService::new(config, invalidation); + let req = router::Request::fake_builder() + .method(http::Method::POST) + .header(AUTHORIZATION, "test_test") + .body( + serde_json::to_vec(&[InvalidationRequest::Subgraph { + subgraph: String::from("foo"), + }]) + .unwrap(), + ) + .build() + .unwrap(); + let res = service.oneshot(req).await.unwrap(); + assert_eq!(res.response.status(), StatusCode::UNAUTHORIZED); + } + + #[tokio::test] + async fn test_invalidation_service() { + #[allow(clippy::type_complexity)] + let mut notify: Notify< + InvalidationTopic, + ( + Vec, + InvalidationOrigin, + Sender>, + ), + > = Notify::new(None, None, None); + let (handle, _b) = notify + .create_or_subscribe(InvalidationTopic, false) + .await + .unwrap(); + let h = handle.clone(); + + tokio::task::spawn(async move { + let mut handle = h.into_stream(); + let mut called = false; + while let Some((requests, origin, response_tx)) = handle.next().await { + called = true; + if requests + != [ + InvalidationRequest::Subgraph { + subgraph: String::from("test"), + }, + InvalidationRequest::Type { + subgraph: String::from("test"), + r#type: String::from("Test"), + }, + ] + { + response_tx + .send(Err(InvalidationError::Custom(format!( + "it's not the right invalidation requests : {requests:?}" + )))) + .unwrap(); + return; + } + if origin != InvalidationOrigin::Endpoint { + response_tx + .send(Err(InvalidationError::Custom(format!( + "it's not the right invalidation origin : {origin:?}" + )))) + .unwrap(); + return; + } + response_tx.send(Ok(2)).unwrap(); + } + assert!(called); + }); + + let invalidation = Invalidation { + enabled: true, + handle, + }; + let config = Arc::new(SubgraphConfiguration { + all: Subgraph { + ttl: None, + enabled: true, + private_id: None, + invalidation: Some(SubgraphInvalidationConfig { + enabled: true, + shared_key: String::from("test"), + }), + }, + subgraphs: HashMap::new(), + }); + let service = InvalidationService::new(config, invalidation); + let req = router::Request::fake_builder() + .method(http::Method::POST) + .header(AUTHORIZATION, "test") + .body( + serde_json::to_vec(&[ + InvalidationRequest::Subgraph { + subgraph: String::from("test"), + }, + InvalidationRequest::Type { + subgraph: String::from("test"), + r#type: String::from("Test"), + }, + ]) + .unwrap(), + ) + .build() + .unwrap(); + let res = service.oneshot(req).await.unwrap(); + assert_eq!(res.response.status(), StatusCode::ACCEPTED); + assert_eq!( + serde_json::from_slice::( + &hyper::body::to_bytes(res.response.into_body()) + .await + .unwrap() + ) + .unwrap(), + serde_json::json!({"count": 2}) + ); + } +} diff --git a/apollo-router/src/plugins/cache/mod.rs b/apollo-router/src/plugins/cache/mod.rs index dded2f9586..c45265a3d3 100644 --- a/apollo-router/src/plugins/cache/mod.rs +++ b/apollo-router/src/plugins/cache/mod.rs @@ -1,6 +1,7 @@ pub(crate) mod cache_control; pub(crate) mod entity; pub(crate) mod invalidation; +pub(crate) mod invalidation_endpoint; pub(crate) mod metrics; #[cfg(test)] pub(crate) mod tests; diff --git a/apollo-router/src/plugins/cache/tests.rs b/apollo-router/src/plugins/cache/tests.rs index 3d0bb21169..8af136c0c9 100644 --- a/apollo-router/src/plugins/cache/tests.rs +++ b/apollo-router/src/plugins/cache/tests.rs @@ -399,16 +399,18 @@ async fn private() { "user".to_string(), Subgraph { private_id: Some("sub".to_string()), - enabled: Some(true), + enabled: true, ttl: None, + ..Default::default() }, ), ( "orga".to_string(), Subgraph { private_id: Some("sub".to_string()), - enabled: Some(true), + enabled: true, ttl: None, + ..Default::default() }, ), ] diff --git a/apollo-router/src/plugins/subscription.rs b/apollo-router/src/plugins/subscription.rs index 4ca4d56201..50d5e78ead 100644 --- a/apollo-router/src/plugins/subscription.rs +++ b/apollo-router/src/plugins/subscription.rs @@ -229,7 +229,7 @@ fn default_path() -> String { String::from("/callback") } -fn default_listen_addr() -> ListenAddr { +pub(crate) fn default_listen_addr() -> ListenAddr { ListenAddr::SocketAddr("127.0.0.1:4000".parse().expect("valid ListenAddr")) } diff --git a/apollo-router/src/router_factory.rs b/apollo-router/src/router_factory.rs index ca1dd4cb7b..e2d20593a9 100644 --- a/apollo-router/src/router_factory.rs +++ b/apollo-router/src/router_factory.rs @@ -602,7 +602,6 @@ pub(crate) async fn create_plugins( ($name: literal, $opt_plugin_config: expr) => {{ let name = concat!("apollo.", $name); let span = tracing::info_span!(concat!("plugin: ", "apollo.", $name)); - async { let factory = apollo_plugin_factories .remove(name) diff --git a/apollo-router/src/uplink/testdata/restricted.router.yaml b/apollo-router/src/uplink/testdata/restricted.router.yaml index 14aa7bd994..278e6134c8 100644 --- a/apollo-router/src/uplink/testdata/restricted.router.yaml +++ b/apollo-router/src/uplink/testdata/restricted.router.yaml @@ -54,6 +54,9 @@ plugins: preview_entity_cache: enabled: true + invalidation: + listen: 127.0.0.1:4000 + path: /invalidation redis: urls: - https://example.com diff --git a/apollo-router/tests/integration/redis.rs b/apollo-router/tests/integration/redis.rs index b7bfe2ea52..e0cfc0037e 100644 --- a/apollo-router/tests/integration/redis.rs +++ b/apollo-router/tests/integration/redis.rs @@ -364,6 +364,10 @@ async fn entity_cache() -> Result<(), BoxError> { "urls": ["redis://127.0.0.1:6379"], "ttl": "2s" }, + "invalidation": { + "listen": "127.0.0.1:4000", + "path": "/invalidation" + }, "subgraph": { "all": { "enabled": false @@ -474,6 +478,10 @@ async fn entity_cache() -> Result<(), BoxError> { "urls": ["redis://127.0.0.1:6379"], "ttl": "2s" }, + "invalidation": { + "listen": "127.0.0.1:4000", + "path": "/invalidation" + }, "subgraph": { "all": { "enabled": false, @@ -677,6 +685,10 @@ async fn entity_cache_authorization() -> Result<(), BoxError> { "urls": ["redis://127.0.0.1:6379"], "ttl": "2s" }, + "invalidation": { + "listen": "127.0.0.1:4000", + "path": "/invalidation" + }, "subgraph": { "all": { "enabled": false, diff --git a/apollo-router/tests/samples/enterprise/entity-cache/invalidation-subgraph-type/configuration.yaml b/apollo-router/tests/samples/enterprise/entity-cache/invalidation-subgraph-type/configuration.yaml index b297fee443..55728b841b 100644 --- a/apollo-router/tests/samples/enterprise/entity-cache/invalidation-subgraph-type/configuration.yaml +++ b/apollo-router/tests/samples/enterprise/entity-cache/invalidation-subgraph-type/configuration.yaml @@ -8,9 +8,16 @@ preview_entity_cache: redis: urls: ["redis://localhost:6379",] + invalidation: + # FIXME: right now we cannot configure it to use the same port used for the GraphQL endpoint if it is chosen at random + listen: 127.0.0.1:12345 + path: /invalidation-sample-subgraph-type subgraph: all: enabled: true + invalidation: + enabled: true + shared_key: "1234" subgraphs: reviews: ttl: 120s diff --git a/apollo-router/tests/samples/enterprise/entity-cache/invalidation-subgraph-type/skipped.json b/apollo-router/tests/samples/enterprise/entity-cache/invalidation-subgraph-type/skipped.json index f6996f21b8..89e90f1be9 100644 --- a/apollo-router/tests/samples/enterprise/entity-cache/invalidation-subgraph-type/skipped.json +++ b/apollo-router/tests/samples/enterprise/entity-cache/invalidation-subgraph-type/skipped.json @@ -43,28 +43,7 @@ "type": "ReloadSubgraphs", "subgraphs": { "accounts": { - "requests": [ - { - "request": { - "body": {"query":"mutation{updateMyAccount{name}}"} - }, - "response": { - "headers": { - "Content-Type": "application/json" - }, - "body": { - "data": { "updateMyAccount": { "name": "invalidation-subgraph-type2" } }, - "extensions": { - "invalidation": [{ - "kind": "type", - "subgraph": "accounts", - "type": "Query" - }] - } - } - } - } - ] + "requests": [] } } }, @@ -83,15 +62,14 @@ } }, { - "type": "Request", + "type": "EndpointRequest", + "url": "http://127.0.0.1:12345/invalidation-sample-subgraph-type", "request": { - "query": "mutation { updateMyAccount { name } }" - }, - "expected_response": { - "data":{ - "updateMyAccount":{ - "name":"invalidation-subgraph-type2" - } + "method": "POST", + "body": { + "kind": "type", + "subgraph": "accounts", + "type": "Query" } } }, diff --git a/apollo-router/tests/samples/enterprise/entity-cache/invalidation-subgraph/configuration.yaml b/apollo-router/tests/samples/enterprise/entity-cache/invalidation-subgraph/configuration.yaml index b297fee443..a54c33f25d 100644 --- a/apollo-router/tests/samples/enterprise/entity-cache/invalidation-subgraph/configuration.yaml +++ b/apollo-router/tests/samples/enterprise/entity-cache/invalidation-subgraph/configuration.yaml @@ -5,6 +5,9 @@ include_subgraph_errors: preview_entity_cache: enabled: true + invalidation: + listen: 127.0.0.1:4000 + path: /invalidation redis: urls: ["redis://localhost:6379",] diff --git a/apollo-router/tests/samples_tests.rs b/apollo-router/tests/samples_tests.rs index 4507089a66..b6f2f902f1 100644 --- a/apollo-router/tests/samples_tests.rs +++ b/apollo-router/tests/samples_tests.rs @@ -176,6 +176,9 @@ impl TestExecution { ) .await } + Action::EndpointRequest { url, request } => { + self.endpoint_request(url, request.clone(), out).await + } Action::Stop => self.stop(out).await, } } @@ -479,6 +482,43 @@ impl TestExecution { Ok(()) } + + async fn endpoint_request( + &mut self, + url: &url::Url, + request: HttpRequest, + out: &mut String, + ) -> Result<(), Failed> { + let client = reqwest::Client::new(); + + let mut builder = client.request( + request + .method + .as_deref() + .unwrap_or("POST") + .try_into() + .unwrap(), + url.clone(), + ); + for (name, value) in request.headers { + builder = builder.header(name, value); + } + + let request = builder.json(&request.body).build().unwrap(); + let response = client.execute(request).await.map_err(|e| { + writeln!( + out, + "could not send request to Router endpoint at {url}: {e}" + ) + .unwrap(); + let f: Failed = out.clone().into(); + f + })?; + + writeln!(out, "Endpoint returned: {response:?}").unwrap(); + + Ok(()) + } } fn open_file(path: &Path, out: &mut String) -> Result { @@ -537,6 +577,10 @@ enum Action { query_path: Option, expected_response: Value, }, + EndpointRequest { + url: url::Url, + request: HttpRequest, + }, Stop, } @@ -547,12 +591,12 @@ struct Subgraph { #[derive(Clone, Debug, Deserialize)] struct SubgraphRequestMock { - request: SubgraphRequest, - response: SubgraphResponse, + request: HttpRequest, + response: HttpResponse, } #[derive(Clone, Debug, Deserialize)] -struct SubgraphRequest { +struct HttpRequest { method: Option, path: Option, #[serde(default)] @@ -561,7 +605,7 @@ struct SubgraphRequest { } #[derive(Clone, Debug, Deserialize)] -struct SubgraphResponse { +struct HttpResponse { status: Option, #[serde(default)] headers: HashMap,