From cf37e8a75fb8e217fda2fc0bc126c76db1ce16df Mon Sep 17 00:00:00 2001 From: Flix Date: Sat, 24 Aug 2024 14:42:10 +0200 Subject: [PATCH] feat: Re-use the same correlation ID in paging --- crates/fhir-sdk/src/client/fhir/crud.rs | 47 +++++++++++++++++------ crates/fhir-sdk/src/client/fhir/paging.rs | 35 ++++++++++++----- crates/fhir-sdk/src/client/misc.rs | 24 +++++++++++- crates/fhir-sdk/src/client/mod.rs | 31 +++++++-------- crates/fhir-sdk/src/client/request.rs | 7 +--- 5 files changed, 100 insertions(+), 44 deletions(-) diff --git a/crates/fhir-sdk/src/client/fhir/crud.rs b/crates/fhir-sdk/src/client/fhir/crud.rs index 0ef06517..cefca2b9 100644 --- a/crates/fhir-sdk/src/client/fhir/crud.rs +++ b/crates/fhir-sdk/src/client/fhir/crud.rs @@ -15,6 +15,7 @@ use super::{ Client, Error, SearchParameters, }; use crate::{ + client::misc::{extract_header, make_uuid_header_value}, extensions::{AnyResource, GenericResource, ReferenceExt}, version::FhirVersion, }; @@ -42,8 +43,12 @@ where pub(crate) async fn read_generic( &self, url: Url, + correlation_id: Option, ) -> Result, Error> { - let request = self.0.client.get(url).header(header::ACCEPT, V::MIME_TYPE); + let mut request = self.0.client.get(url).header(header::ACCEPT, V::MIME_TYPE); + if let Some(correlation_id) = correlation_id { + request = request.header("X-Correlation-Id", correlation_id); + } let response = self.run_request(request).await?; if response.status().is_success() { @@ -62,7 +67,7 @@ where id: &str, ) -> Result, Error> { let url = self.url(&[R::TYPE_STR, id]); - self.read_generic(url).await + self.read_generic(url, None).await } /// Read a specific version of a specific FHIR resource. @@ -72,7 +77,7 @@ where version_id: &str, ) -> Result, Error> { let url = self.url(&[R::TYPE_STR, id, "_history", version_id]); - self.read_generic(url).await + self.read_generic(url, None).await } /// Read the resource that is targeted in the reference. @@ -93,7 +98,7 @@ where }; let resource: V::Resource = self - .read_generic(url.clone()) + .read_generic(url.clone(), None) .await? .ok_or_else(|| Error::ResourceNotFound(url.to_string()))?; if let Some(resource_type) = reference.r#type() { @@ -114,6 +119,8 @@ where R: AnyResource + TryFrom + 'static, for<'a> &'a R: TryFrom<&'a V::Resource>, { + let correlation_id = make_uuid_header_value(); + let url = { if let Some(id) = id { self.url(&[R::TYPE_STR, id, "_history"]) @@ -121,12 +128,17 @@ where self.url(&[R::TYPE_STR, "_history"]) } }; - let request = self.0.client.get(url).header(header::ACCEPT, V::MIME_TYPE); + let request = self + .0 + .client + .get(url) + .header(header::ACCEPT, V::MIME_TYPE) + .header("X-Correlation-Id", correlation_id.clone()); let response = self.run_request(request).await?; if response.status().is_success() { let bundle: V::Bundle = response.json().await?; - Ok(Page::new(self.clone(), bundle)) + Ok(Page::new(self.clone(), bundle, correlation_id)) } else { Err(Error::from_response::(response).await) } @@ -233,18 +245,21 @@ where ) -> Result, Error> { // TODO: Use POST for long queries? + let correlation_id = make_uuid_header_value(); + let url = self.url(&[]); let request = self .0 .client .get(url) .query(&queries.into_queries()) - .header(header::ACCEPT, V::MIME_TYPE); + .header(header::ACCEPT, V::MIME_TYPE) + .header("X-Correlation-Id", correlation_id.clone()); let response = self.run_request(request).await?; if response.status().is_success() { let bundle: V::Bundle = response.json().await?; - Ok(Page::new(self.clone(), bundle)) + Ok(Page::new(self.clone(), bundle, correlation_id)) } else { Err(Error::from_response::(response).await) } @@ -258,18 +273,21 @@ where { // TODO: Use POST for long queries? + let correlation_id = make_uuid_header_value(); + let url = self.url(&[R::TYPE_STR]); let request = self .0 .client .get(url) .query(&queries.into_queries()) - .header(header::ACCEPT, V::MIME_TYPE); + .header(header::ACCEPT, V::MIME_TYPE) + .header("X-Correlation-Id", correlation_id.clone()); let response = self.run_request(request).await?; if response.status().is_success() { let bundle: V::Bundle = response.json().await?; - Ok(Page::new(self.clone(), bundle)) + Ok(Page::new(self.clone(), bundle, correlation_id)) } else { Err(Error::from_response::(response).await) } @@ -295,10 +313,15 @@ where R: TryFrom + Send + Sync + 'static, for<'a> &'a R: TryFrom<&'a V::Resource>, { - let response = self.send_custom_request(make_request).await?; + let request = (make_request)(&self.0.client); + let (mut request, correlation_id) = extract_header(request, "X-Correlation-Id")?; + let correlation_id = correlation_id.unwrap_or_else(make_uuid_header_value); + request = request.header("X-Correlation-Id", correlation_id.clone()); + + let response = self.run_request(request).await?; if response.status().is_success() { let bundle: V::Bundle = response.json().await?; - Ok(Page::new(self.clone(), bundle)) + Ok(Page::new(self.clone(), bundle, correlation_id)) } else { Err(Error::from_response::(response).await) } diff --git a/crates/fhir-sdk/src/client/fhir/paging.rs b/crates/fhir-sdk/src/client/fhir/paging.rs index c6ae3eda..012ae499 100644 --- a/crates/fhir-sdk/src/client/fhir/paging.rs +++ b/crates/fhir-sdk/src/client/fhir/paging.rs @@ -3,7 +3,7 @@ use std::{any::type_name, fmt::Debug, marker::PhantomData}; use futures::{stream, Stream, StreamExt, TryStreamExt}; -use reqwest::{StatusCode, Url}; +use reqwest::{header::HeaderValue, StatusCode, Url}; use super::{Client, Error}; use crate::{ @@ -21,6 +21,8 @@ pub struct Page { client: Client, /// The inner Bundle result. bundle: V::Bundle, + /// The correlation ID to send when fetching all further pages. + correlation_id: HeaderValue, /// The resource type to return in matches. _resource_type: PhantomData, @@ -33,8 +35,12 @@ where for<'a> &'a R: TryFrom<&'a V::Resource>, { /// Create a new `Page` result from a `Bundle` and client. - pub(crate) const fn new(client: Client, bundle: V::Bundle) -> Self { - Self { client, bundle, _resource_type: PhantomData } + pub(crate) const fn new( + client: Client, + bundle: V::Bundle, + correlation_id: HeaderValue, + ) -> Self { + Self { client, bundle, correlation_id, _resource_type: PhantomData } } /// Get the next page URL, if there is one. @@ -51,13 +57,17 @@ where }; tracing::debug!("Fetching next page from URL: {next_page_url}"); - let next_bundle = match self.client.read_generic::(url).await { + let next_bundle = match self + .client + .read_generic::(url, Some(self.correlation_id.clone())) + .await + { Ok(Some(bundle)) => bundle, Ok(None) => return Some(Err(Error::ResourceNotFound(next_page_url.clone()))), Err(err) => return Some(Err(err)), }; - Some(Ok(Self::new(self.client.clone(), next_bundle))) + Some(Ok(Self::new(self.client.clone(), next_bundle, self.correlation_id.clone()))) } /// Get the `total` field, indicating the total number of results. @@ -104,8 +114,10 @@ where &mut self, ) -> impl Stream> + Send + 'static { let client = self.client.clone(); - stream::iter(self.take_entries().into_iter().flatten()) - .filter_map(move |entry| resolve_bundle_entry(entry, client.clone())) + let correlation_id = self.correlation_id.clone(); + stream::iter(self.take_entries().into_iter().flatten()).filter_map(move |entry| { + resolve_bundle_entry(entry, client.clone(), correlation_id.clone()) + }) } /// Get the matches of this page, where the `fullUrl` is automatically resolved whenever there @@ -116,13 +128,16 @@ where /// Consumes the entries, leaving the page empty. pub fn matches_owned(&mut self) -> impl Stream> + Send + 'static { let client = self.client.clone(); + let correlation_id = self.correlation_id.clone(); stream::iter( self.take_entries() .into_iter() .flatten() .filter(|entry| entry.search_mode().is_some_and(SearchEntryModeExt::is_match)), ) - .filter_map(move |entry| resolve_bundle_entry(entry, client.clone())) + .filter_map(move |entry| { + resolve_bundle_entry(entry, client.clone(), correlation_id.clone()) + }) .try_filter_map(|resource| std::future::ready(Ok(resource.try_into().ok()))) } @@ -163,6 +178,7 @@ where async fn resolve_bundle_entry( entry: BundleEntry, client: Client, + correlation_id: HeaderValue, ) -> Option> where (StatusCode, V::OperationOutcome): Into, @@ -184,7 +200,7 @@ where }; let result = client - .read_generic::(url) + .read_generic::(url, Some(correlation_id)) .await .and_then(|opt| opt.ok_or_else(|| Error::ResourceNotFound(full_url.clone()))); Some(result) @@ -195,6 +211,7 @@ impl Clone for Page { Self { client: self.client.clone(), bundle: self.bundle.clone(), + correlation_id: self.correlation_id.clone(), _resource_type: self._resource_type, } } diff --git a/crates/fhir-sdk/src/client/misc.rs b/crates/fhir-sdk/src/client/misc.rs index 732599e4..32357b0e 100644 --- a/crates/fhir-sdk/src/client/misc.rs +++ b/crates/fhir-sdk/src/client/misc.rs @@ -1,6 +1,8 @@ //! Miscellaneous helpers. -use reqwest::header::{self, HeaderMap}; +use ::reqwest::{header::AsHeaderName, RequestBuilder}; +use ::uuid::Uuid; +use reqwest::header::{self, HeaderMap, HeaderValue}; use super::Error; @@ -62,6 +64,26 @@ pub fn escape_search_value(value: &str) -> String { value.replace('\\', "\\\\").replace('|', "\\|").replace('$', "\\$").replace(',', "\\,") } +/// Make a [HeaderValue] containing a new UUID. +pub fn make_uuid_header_value() -> HeaderValue { + #[allow(clippy::expect_used)] // Will not fail. + HeaderValue::from_str(&Uuid::new_v4().to_string()).expect("UUIDs are valid header values") +} + +/// Get a cloned header value from a request builder without cloning the whole request. +pub fn extract_header( + request_builder: RequestBuilder, + header: K, +) -> Result<(RequestBuilder, Option), reqwest::Error> +where + K: AsHeaderName, +{ + let (client, request_result) = request_builder.build_split(); + let request = request_result?; + let value = request.headers().get(header).cloned(); + Ok((RequestBuilder::from_parts(client, request), value)) +} + #[cfg(test)] mod tests { #![allow(clippy::expect_used)] // Allowed for tests diff --git a/crates/fhir-sdk/src/client/mod.rs b/crates/fhir-sdk/src/client/mod.rs index 8ba0db0a..3626d6cb 100644 --- a/crates/fhir-sdk/src/client/mod.rs +++ b/crates/fhir-sdk/src/client/mod.rs @@ -12,15 +12,14 @@ mod search; use std::{marker::PhantomData, sync::Arc}; use ::std::any::type_name; -use ::uuid::Uuid; use misc::parse_major_fhir_version; use reqwest::{header, StatusCode, Url}; -use self::auth::AuthCallback; pub use self::{ aliases::*, auth::LoginManager, builder::ClientBuilder, error::Error, fhir::*, request::RequestSettings, search::SearchParameters, }; +use self::{auth::AuthCallback, misc::make_uuid_header_value}; use crate::version::{DefaultVersion, FhirR4B, FhirR5, FhirStu3, FhirVersion}; /// FHIR REST Client. @@ -141,7 +140,10 @@ impl Client { &self, mut request: reqwest::RequestBuilder, ) -> Result { - let info_request = request.try_clone().ok_or(Error::RequestNotClone)?.build()?; + let (client, info_request_result) = request.build_split(); + let info_request = info_request_result?; + let req_method = info_request.method().clone(); + let req_url = info_request.url().clone(); // Check the URL origin if configured to ensure equality. if self.0.error_on_origin_mismatch { @@ -153,24 +155,19 @@ impl Client { // Generate a new correlation ID for this request/transaction across login, if there was // none. - let x_correlation_id = if let Some(value) = info_request.headers().get("X-Correlation-Id") { - value.to_str().ok().map(ToOwned::to_owned) - } else { - let id_str = Uuid::new_v4().to_string(); - #[allow(clippy::expect_used)] // Will not fail. - let id_value = header::HeaderValue::from_str(&id_str).expect("UUIDs are valid header values"); - request = request.header("X-Correlation-Id", id_value); - Some(id_str) - }; + let correlation_id = info_request + .headers() + .get("X-Correlation-Id") + .cloned() + .unwrap_or_else(make_uuid_header_value); + let x_correlation_id = correlation_id.to_str().ok().map(ToOwned::to_owned); + request = reqwest::RequestBuilder::from_parts(client, info_request) + .header("X-Correlation-Id", correlation_id); tracing::Span::current().record("x_correlation_id", x_correlation_id); // Try running the request let mut request_settings = self.request_settings(); - tracing::info!( - "Sending {} request to {} (potentially with retries)", - info_request.method(), - info_request.url() - ); + tracing::info!("Sending {req_method} request to {req_url} (potentially with retries)"); let mut response = request_settings .make_request(request.try_clone().ok_or(Error::RequestNotClone)?) .await?; diff --git a/crates/fhir-sdk/src/client/request.rs b/crates/fhir-sdk/src/client/request.rs index e51d119c..f1017b3c 100644 --- a/crates/fhir-sdk/src/client/request.rs +++ b/crates/fhir-sdk/src/client/request.rs @@ -2,14 +2,13 @@ use std::time::Duration; -use ::uuid::Uuid; use reqwest::header::{HeaderMap, HeaderName, HeaderValue}; use tokio_retry::{ strategy::{ExponentialBackoff, FixedInterval}, RetryIf, }; -use super::error::Error; +use super::{error::Error, misc::make_uuid_header_value}; /// Settings for the HTTP Requests. /// @@ -105,9 +104,7 @@ impl RequestSettings { *request.headers_mut() = headers; // Add `X-Request-Id` and `X-Correlation-Id` header if not already set. - #[allow(clippy::expect_used)] // Will not fail. - let id_value = HeaderValue::from_str(&Uuid::new_v4().to_string()) - .expect("UUIDs are valid header values"); + let id_value = make_uuid_header_value(); request.headers_mut().entry("X-Correlation-Id").or_insert_with(|| id_value.clone()); request.headers_mut().entry("X-Request-Id").or_insert(id_value);