diff --git a/lib/core/src/sync/client.rs b/lib/core/src/sync/client.rs index 3ab203cf5..698b0908b 100644 --- a/lib/core/src/sync/client.rs +++ b/lib/core/src/sync/client.rs @@ -1,8 +1,16 @@ -use anyhow::{anyhow, Result}; +use std::time::Duration; + +use anyhow::{anyhow, Error, Result}; use async_trait::async_trait; use log::debug; use tokio::sync::Mutex; +use tonic::{ + metadata::{errors::InvalidMetadataValue, Ascii, MetadataValue}, + service::{interceptor::InterceptedService, Interceptor}, + transport::{Channel, ClientTlsConfig, Endpoint}, + Request, Status, +}; use super::model::sync::{ syncer_client::SyncerClient as ProtoSyncerClient, ListChangesReply, ListChangesRequest, @@ -18,58 +26,97 @@ pub(crate) trait SyncerClient: Send + Sync { } pub(crate) struct BreezSyncerClient { - inner: Mutex>>, + grpc_channel: Mutex>, api_key: Option, } impl BreezSyncerClient { pub(crate) fn new(api_key: Option) -> Self { Self { - inner: Default::default(), + grpc_channel: Mutex::new(None), api_key, } } + + fn create_endpoint(server_url: &str) -> Result { + Ok(Endpoint::from_shared(server_url.to_string())? + .http2_keep_alive_interval(Duration::new(5, 0)) + .tcp_keepalive(Some(Duration::from_secs(5))) + .keep_alive_timeout(Duration::from_secs(5)) + .keep_alive_while_idle(true) + .tls_config(ClientTlsConfig::new().with_enabled_roots())?) + } + + fn api_key_metadata(&self) -> Result>, Error> { + match &self.api_key { + Some(key) => Ok(Some(format!("Bearer {key}").parse().map_err( + |e: InvalidMetadataValue| { + anyhow!(format!( + "(Breez: {:?}) Failed parse API key: {e}", + self.api_key + )) + }, + )?)), + _ => Ok(None), + } + } } impl BreezSyncerClient { - fn set_api_key(&self, req: T) -> Result> { - let mut req = tonic::Request::new(req); - if let Some(api_key) = &self.api_key { - let metadata = req.metadata_mut(); - metadata.insert("authorization", format!("Bearer {}", api_key).parse()?); - } - Ok(req) + async fn get_client( + &self, + ) -> Result>, Error> { + let Some(channel) = self.grpc_channel.lock().await.clone() else { + return Err(anyhow!("Cannot get sync client: not connected")); + }; + let api_key_metadata = self.api_key_metadata()?; + Ok(ProtoSyncerClient::with_interceptor( + channel, + ApiKeyInterceptor { api_key_metadata }, + )) } } #[async_trait] impl SyncerClient for BreezSyncerClient { async fn connect(&self, connect_url: String) -> Result<()> { - let mut client = self.inner.lock().await; - *client = Some(ProtoSyncerClient::connect(connect_url.clone()).await?); + let mut grpc_channel = self.grpc_channel.lock().await; + *grpc_channel = Some(Self::create_endpoint(&connect_url)?.connect_lazy()); debug!("Successfully connected to {connect_url}"); Ok(()) } async fn push(&self, req: SetRecordRequest) -> Result { - let Some(mut client) = self.inner.lock().await.clone() else { - return Err(anyhow!("Cannot run `set_record`: client not connected")); - }; - let req = self.set_api_key(req)?; - Ok(client.set_record(req).await?.into_inner()) + Ok(self.get_client().await?.set_record(req).await?.into_inner()) } async fn pull(&self, req: ListChangesRequest) -> Result { - let Some(mut client) = self.inner.lock().await.clone() else { - return Err(anyhow!("Cannot run `list_changes`: client not connected")); - }; - let req = self.set_api_key(req)?; - Ok(client.list_changes(req).await?.into_inner()) + Ok(self + .get_client() + .await? + .list_changes(req) + .await? + .into_inner()) } async fn disconnect(&self) -> Result<()> { - let mut client = self.inner.lock().await; - *client = None; + let mut channel = self.grpc_channel.lock().await; + *channel = None; Ok(()) } } + +#[derive(Clone)] +pub struct ApiKeyInterceptor { + api_key_metadata: Option>, +} + +impl Interceptor for ApiKeyInterceptor { + fn call(&mut self, mut req: Request<()>) -> Result, Status> { + if self.api_key_metadata.clone().is_some() { + req.metadata_mut() + .insert("authorization", self.api_key_metadata.clone().unwrap()); + } + Ok(req) + } +}