Skip to content

Commit

Permalink
feat: Re-use the same correlation ID in paging
Browse files Browse the repository at this point in the history
  • Loading branch information
Flix committed Aug 24, 2024
1 parent 59f7b57 commit cf37e8a
Show file tree
Hide file tree
Showing 5 changed files with 100 additions and 44 deletions.
47 changes: 35 additions & 12 deletions crates/fhir-sdk/src/client/fhir/crud.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ use super::{
Client, Error, SearchParameters,
};
use crate::{
client::misc::{extract_header, make_uuid_header_value},
extensions::{AnyResource, GenericResource, ReferenceExt},
version::FhirVersion,
};
Expand Down Expand Up @@ -42,8 +43,12 @@ where
pub(crate) async fn read_generic<R: DeserializeOwned>(
&self,
url: Url,
correlation_id: Option<HeaderValue>,
) -> Result<Option<R>, Error> {
let request = self.0.client.get(url).header(header::ACCEPT, V::MIME_TYPE);
let mut request = self.0.client.get(url).header(header::ACCEPT, V::MIME_TYPE);
if let Some(correlation_id) = correlation_id {
request = request.header("X-Correlation-Id", correlation_id);
}

let response = self.run_request(request).await?;
if response.status().is_success() {
Expand All @@ -62,7 +67,7 @@ where
id: &str,
) -> Result<Option<R>, Error> {
let url = self.url(&[R::TYPE_STR, id]);
self.read_generic(url).await
self.read_generic(url, None).await
}

/// Read a specific version of a specific FHIR resource.
Expand All @@ -72,7 +77,7 @@ where
version_id: &str,
) -> Result<Option<R>, Error> {
let url = self.url(&[R::TYPE_STR, id, "_history", version_id]);
self.read_generic(url).await
self.read_generic(url, None).await
}

/// Read the resource that is targeted in the reference.
Expand All @@ -93,7 +98,7 @@ where
};

let resource: V::Resource = self
.read_generic(url.clone())
.read_generic(url.clone(), None)
.await?
.ok_or_else(|| Error::ResourceNotFound(url.to_string()))?;
if let Some(resource_type) = reference.r#type() {
Expand All @@ -114,19 +119,26 @@ where
R: AnyResource<V> + TryFrom<V::Resource, Error = WrongResourceType> + 'static,
for<'a> &'a R: TryFrom<&'a V::Resource>,
{
let correlation_id = make_uuid_header_value();

let url = {
if let Some(id) = id {
self.url(&[R::TYPE_STR, id, "_history"])
} else {
self.url(&[R::TYPE_STR, "_history"])
}
};
let request = self.0.client.get(url).header(header::ACCEPT, V::MIME_TYPE);
let request = self
.0
.client
.get(url)
.header(header::ACCEPT, V::MIME_TYPE)
.header("X-Correlation-Id", correlation_id.clone());

let response = self.run_request(request).await?;
if response.status().is_success() {
let bundle: V::Bundle = response.json().await?;
Ok(Page::new(self.clone(), bundle))
Ok(Page::new(self.clone(), bundle, correlation_id))
} else {
Err(Error::from_response::<V>(response).await)
}
Expand Down Expand Up @@ -233,18 +245,21 @@ where
) -> Result<Page<V, V::Resource>, Error> {
// TODO: Use POST for long queries?

let correlation_id = make_uuid_header_value();

let url = self.url(&[]);
let request = self
.0
.client
.get(url)
.query(&queries.into_queries())
.header(header::ACCEPT, V::MIME_TYPE);
.header(header::ACCEPT, V::MIME_TYPE)
.header("X-Correlation-Id", correlation_id.clone());

let response = self.run_request(request).await?;
if response.status().is_success() {
let bundle: V::Bundle = response.json().await?;
Ok(Page::new(self.clone(), bundle))
Ok(Page::new(self.clone(), bundle, correlation_id))
} else {
Err(Error::from_response::<V>(response).await)
}
Expand All @@ -258,18 +273,21 @@ where
{
// TODO: Use POST for long queries?

let correlation_id = make_uuid_header_value();

let url = self.url(&[R::TYPE_STR]);
let request = self
.0
.client
.get(url)
.query(&queries.into_queries())
.header(header::ACCEPT, V::MIME_TYPE);
.header(header::ACCEPT, V::MIME_TYPE)
.header("X-Correlation-Id", correlation_id.clone());

let response = self.run_request(request).await?;
if response.status().is_success() {
let bundle: V::Bundle = response.json().await?;
Ok(Page::new(self.clone(), bundle))
Ok(Page::new(self.clone(), bundle, correlation_id))
} else {
Err(Error::from_response::<V>(response).await)
}
Expand All @@ -295,10 +313,15 @@ where
R: TryFrom<V::Resource> + Send + Sync + 'static,
for<'a> &'a R: TryFrom<&'a V::Resource>,
{
let response = self.send_custom_request(make_request).await?;
let request = (make_request)(&self.0.client);
let (mut request, correlation_id) = extract_header(request, "X-Correlation-Id")?;
let correlation_id = correlation_id.unwrap_or_else(make_uuid_header_value);
request = request.header("X-Correlation-Id", correlation_id.clone());

let response = self.run_request(request).await?;
if response.status().is_success() {
let bundle: V::Bundle = response.json().await?;
Ok(Page::new(self.clone(), bundle))
Ok(Page::new(self.clone(), bundle, correlation_id))
} else {
Err(Error::from_response::<V>(response).await)
}
Expand Down
35 changes: 26 additions & 9 deletions crates/fhir-sdk/src/client/fhir/paging.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
use std::{any::type_name, fmt::Debug, marker::PhantomData};

use futures::{stream, Stream, StreamExt, TryStreamExt};
use reqwest::{StatusCode, Url};
use reqwest::{header::HeaderValue, StatusCode, Url};

use super::{Client, Error};
use crate::{
Expand All @@ -21,6 +21,8 @@ pub struct Page<V: FhirVersion, R> {
client: Client<V>,
/// The inner Bundle result.
bundle: V::Bundle,
/// The correlation ID to send when fetching all further pages.
correlation_id: HeaderValue,

/// The resource type to return in matches.
_resource_type: PhantomData<R>,
Expand All @@ -33,8 +35,12 @@ where
for<'a> &'a R: TryFrom<&'a V::Resource>,
{
/// Create a new `Page` result from a `Bundle` and client.
pub(crate) const fn new(client: Client<V>, bundle: V::Bundle) -> Self {
Self { client, bundle, _resource_type: PhantomData }
pub(crate) const fn new(
client: Client<V>,
bundle: V::Bundle,
correlation_id: HeaderValue,
) -> Self {
Self { client, bundle, correlation_id, _resource_type: PhantomData }
}

/// Get the next page URL, if there is one.
Expand All @@ -51,13 +57,17 @@ where
};

tracing::debug!("Fetching next page from URL: {next_page_url}");
let next_bundle = match self.client.read_generic::<V::Bundle>(url).await {
let next_bundle = match self
.client
.read_generic::<V::Bundle>(url, Some(self.correlation_id.clone()))
.await
{
Ok(Some(bundle)) => bundle,
Ok(None) => return Some(Err(Error::ResourceNotFound(next_page_url.clone()))),
Err(err) => return Some(Err(err)),
};

Some(Ok(Self::new(self.client.clone(), next_bundle)))
Some(Ok(Self::new(self.client.clone(), next_bundle, self.correlation_id.clone())))
}

/// Get the `total` field, indicating the total number of results.
Expand Down Expand Up @@ -104,8 +114,10 @@ where
&mut self,
) -> impl Stream<Item = Result<V::Resource, Error>> + Send + 'static {
let client = self.client.clone();
stream::iter(self.take_entries().into_iter().flatten())
.filter_map(move |entry| resolve_bundle_entry(entry, client.clone()))
let correlation_id = self.correlation_id.clone();
stream::iter(self.take_entries().into_iter().flatten()).filter_map(move |entry| {
resolve_bundle_entry(entry, client.clone(), correlation_id.clone())
})
}

/// Get the matches of this page, where the `fullUrl` is automatically resolved whenever there
Expand All @@ -116,13 +128,16 @@ where
/// Consumes the entries, leaving the page empty.
pub fn matches_owned(&mut self) -> impl Stream<Item = Result<R, Error>> + Send + 'static {
let client = self.client.clone();
let correlation_id = self.correlation_id.clone();
stream::iter(
self.take_entries()
.into_iter()
.flatten()
.filter(|entry| entry.search_mode().is_some_and(SearchEntryModeExt::is_match)),
)
.filter_map(move |entry| resolve_bundle_entry(entry, client.clone()))
.filter_map(move |entry| {
resolve_bundle_entry(entry, client.clone(), correlation_id.clone())
})
.try_filter_map(|resource| std::future::ready(Ok(resource.try_into().ok())))
}

Expand Down Expand Up @@ -163,6 +178,7 @@ where
async fn resolve_bundle_entry<V: FhirVersion>(
entry: BundleEntry<V>,
client: Client<V>,
correlation_id: HeaderValue,
) -> Option<Result<V::Resource, Error>>
where
(StatusCode, V::OperationOutcome): Into<Error>,
Expand All @@ -184,7 +200,7 @@ where
};

let result = client
.read_generic::<V::Resource>(url)
.read_generic::<V::Resource>(url, Some(correlation_id))
.await
.and_then(|opt| opt.ok_or_else(|| Error::ResourceNotFound(full_url.clone())));
Some(result)
Expand All @@ -195,6 +211,7 @@ impl<V: FhirVersion, R> Clone for Page<V, R> {
Self {
client: self.client.clone(),
bundle: self.bundle.clone(),
correlation_id: self.correlation_id.clone(),
_resource_type: self._resource_type,
}
}
Expand Down
24 changes: 23 additions & 1 deletion crates/fhir-sdk/src/client/misc.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
//! Miscellaneous helpers.
use reqwest::header::{self, HeaderMap};
use ::reqwest::{header::AsHeaderName, RequestBuilder};
use ::uuid::Uuid;
use reqwest::header::{self, HeaderMap, HeaderValue};

use super::Error;

Expand Down Expand Up @@ -62,6 +64,26 @@ pub fn escape_search_value(value: &str) -> String {
value.replace('\\', "\\\\").replace('|', "\\|").replace('$', "\\$").replace(',', "\\,")
}

/// Make a [HeaderValue] containing a new UUID.
pub fn make_uuid_header_value() -> HeaderValue {
#[allow(clippy::expect_used)] // Will not fail.
HeaderValue::from_str(&Uuid::new_v4().to_string()).expect("UUIDs are valid header values")
}

/// Get a cloned header value from a request builder without cloning the whole request.
pub fn extract_header<K>(
request_builder: RequestBuilder,
header: K,
) -> Result<(RequestBuilder, Option<HeaderValue>), reqwest::Error>
where
K: AsHeaderName,
{
let (client, request_result) = request_builder.build_split();
let request = request_result?;
let value = request.headers().get(header).cloned();
Ok((RequestBuilder::from_parts(client, request), value))
}

#[cfg(test)]
mod tests {
#![allow(clippy::expect_used)] // Allowed for tests
Expand Down
31 changes: 14 additions & 17 deletions crates/fhir-sdk/src/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,14 @@ mod search;
use std::{marker::PhantomData, sync::Arc};

use ::std::any::type_name;
use ::uuid::Uuid;
use misc::parse_major_fhir_version;
use reqwest::{header, StatusCode, Url};

use self::auth::AuthCallback;
pub use self::{
aliases::*, auth::LoginManager, builder::ClientBuilder, error::Error, fhir::*,
request::RequestSettings, search::SearchParameters,
};
use self::{auth::AuthCallback, misc::make_uuid_header_value};
use crate::version::{DefaultVersion, FhirR4B, FhirR5, FhirStu3, FhirVersion};

/// FHIR REST Client.
Expand Down Expand Up @@ -141,7 +140,10 @@ impl<V: FhirVersion> Client<V> {
&self,
mut request: reqwest::RequestBuilder,
) -> Result<reqwest::Response, Error> {
let info_request = request.try_clone().ok_or(Error::RequestNotClone)?.build()?;
let (client, info_request_result) = request.build_split();
let info_request = info_request_result?;
let req_method = info_request.method().clone();
let req_url = info_request.url().clone();

// Check the URL origin if configured to ensure equality.
if self.0.error_on_origin_mismatch {
Expand All @@ -153,24 +155,19 @@ impl<V: FhirVersion> Client<V> {

// Generate a new correlation ID for this request/transaction across login, if there was
// none.
let x_correlation_id = if let Some(value) = info_request.headers().get("X-Correlation-Id") {
value.to_str().ok().map(ToOwned::to_owned)
} else {
let id_str = Uuid::new_v4().to_string();
#[allow(clippy::expect_used)] // Will not fail.
let id_value = header::HeaderValue::from_str(&id_str).expect("UUIDs are valid header values");
request = request.header("X-Correlation-Id", id_value);
Some(id_str)
};
let correlation_id = info_request
.headers()
.get("X-Correlation-Id")
.cloned()
.unwrap_or_else(make_uuid_header_value);
let x_correlation_id = correlation_id.to_str().ok().map(ToOwned::to_owned);
request = reqwest::RequestBuilder::from_parts(client, info_request)
.header("X-Correlation-Id", correlation_id);
tracing::Span::current().record("x_correlation_id", x_correlation_id);

// Try running the request
let mut request_settings = self.request_settings();
tracing::info!(
"Sending {} request to {} (potentially with retries)",
info_request.method(),
info_request.url()
);
tracing::info!("Sending {req_method} request to {req_url} (potentially with retries)");
let mut response = request_settings
.make_request(request.try_clone().ok_or(Error::RequestNotClone)?)
.await?;
Expand Down
7 changes: 2 additions & 5 deletions crates/fhir-sdk/src/client/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,13 @@
use std::time::Duration;

use ::uuid::Uuid;
use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
use tokio_retry::{
strategy::{ExponentialBackoff, FixedInterval},
RetryIf,
};

use super::error::Error;
use super::{error::Error, misc::make_uuid_header_value};

/// Settings for the HTTP Requests.
///
Expand Down Expand Up @@ -105,9 +104,7 @@ impl RequestSettings {
*request.headers_mut() = headers;

// Add `X-Request-Id` and `X-Correlation-Id` header if not already set.
#[allow(clippy::expect_used)] // Will not fail.
let id_value = HeaderValue::from_str(&Uuid::new_v4().to_string())
.expect("UUIDs are valid header values");
let id_value = make_uuid_header_value();
request.headers_mut().entry("X-Correlation-Id").or_insert_with(|| id_value.clone());
request.headers_mut().entry("X-Request-Id").or_insert(id_value);

Expand Down

0 comments on commit cf37e8a

Please sign in to comment.