Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Read Streams to Crypto API #158

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions examples/crypto/main.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::fs;

use tokio::fs::File;
use tokio::io::AsyncReadExt;
use tokio::time::sleep;

use dapr::client::ReaderStream;
Expand Down Expand Up @@ -28,7 +29,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
.await
.unwrap();

let decrypted = client
let mut decrypted = client
.decrypt(
encrypted,
dapr::client::DecryptRequestOptions {
Expand All @@ -39,7 +40,11 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
.await
.unwrap();

assert_eq!(String::from_utf8(decrypted).unwrap().as_str(), "Test");
let mut value = String::new();

decrypted.read_to_string(&mut value).await.unwrap();

assert_eq!(value.as_str(), "Test");

println!("Successfully Decrypted String");

Expand All @@ -60,7 +65,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
.await
.unwrap();

let decrypted = client
let mut decrypted = client
.decrypt(
encrypted,
dapr::client::DecryptRequestOptions {
Expand All @@ -73,7 +78,11 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {

let image = fs::read("./image.png").unwrap();

assert_eq!(decrypted, image);
let mut buf = bytes::BytesMut::with_capacity(image.len());

decrypted.read_buf(&mut buf).await.unwrap();

assert_eq!(buf.to_vec(), image);

println!("Successfully Decrypted Image");

Expand Down
139 changes: 93 additions & 46 deletions src/client.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
use std::collections::HashMap;
use std::pin::Pin;
use std::task::{Context, Poll};

use async_trait::async_trait;
use futures::StreamExt;
use prost_types::Any;
use serde::{Deserialize, Serialize};
use tokio::io::AsyncRead;
use tokio::io::{AsyncRead, ReadBuf};
use tonic::codegen::tokio_stream;
use tonic::{transport::Channel as TonicChannel, Request};
use tonic::{Status, Streaming};
Expand Down Expand Up @@ -394,7 +396,7 @@ impl<T: DaprInterface> Client<T> {
&mut self,
payload: ReaderStream<R>,
request_options: EncryptRequestOptions,
) -> Result<Vec<StreamPayload>, Status>
) -> Result<ResponseStream<EncryptResponse>, Status>
where
R: AsyncRead + Send,
{
Expand Down Expand Up @@ -433,26 +435,27 @@ impl<T: DaprInterface> Client<T> {
/// * `options` - Decryption request options.
pub async fn decrypt(
&mut self,
encrypted: Vec<StreamPayload>,
mut encrypted_stream: ResponseStream<EncryptResponse>,
options: DecryptRequestOptions,
) -> Result<Vec<u8>, Status> {
let requested_items: Vec<DecryptRequest> = encrypted
.iter()
.enumerate()
.map(|(i, item)| {
if i == 0 {
DecryptRequest {
options: Some(options.clone()),
payload: Some(item.clone()),
}
} else {
DecryptRequest {
options: None,
payload: Some(item.clone()),
) -> Result<ResponseStream<DecryptResponse>, Status> {
let mut requested_items = vec![];
while let Some(resp_result) = encrypted_stream.stream.next().await {
if let Ok(resp) = resp_result {
if let Some(payload) = resp.payload {
if requested_items.len() == 0 {
requested_items.push(DecryptRequest {
options: Some(options.clone()),
payload: Some(payload),
})
} else {
requested_items.push(DecryptRequest {
options: None,
payload: Some(payload),
})
}
}
})
.collect();
}
}
self.0.decrypt(requested_items).await
}
}
Expand Down Expand Up @@ -497,10 +500,15 @@ pub trait DaprInterface: Sized {
request: UnsubscribeConfigurationRequest,
) -> Result<UnsubscribeConfigurationResponse, Error>;

async fn encrypt(&mut self, payload: Vec<EncryptRequest>)
-> Result<Vec<StreamPayload>, Status>;
async fn encrypt(
&mut self,
payload: Vec<EncryptRequest>,
) -> Result<ResponseStream<EncryptResponse>, Status>;

async fn decrypt(&mut self, payload: Vec<DecryptRequest>) -> Result<Vec<u8>, Status>;
async fn decrypt(
&mut self,
payload: Vec<DecryptRequest>,
) -> Result<ResponseStream<DecryptResponse>, Status>;
}

#[async_trait]
Expand Down Expand Up @@ -626,19 +634,10 @@ impl DaprInterface for dapr_v1::dapr_client::DaprClient<TonicChannel> {
async fn encrypt(
&mut self,
request: Vec<EncryptRequest>,
) -> Result<Vec<StreamPayload>, Status> {
) -> Result<ResponseStream<EncryptResponse>, Status> {
let request = Request::new(tokio_stream::iter(request));
let stream = self.encrypt_alpha1(request).await?;
let mut stream = stream.into_inner();
let mut return_data = vec![];
while let Some(resp) = stream.next().await {
if let Ok(resp) = resp {
if let Some(data) = resp.payload {
return_data.push(data)
}
}
}
Ok(return_data)
let stream = self.encrypt_alpha1(request).await?.into_inner();
Ok(ResponseStream { stream })
}

/// Decrypt binary data using Dapr. returns Vec<u8>.
Expand All @@ -647,19 +646,13 @@ impl DaprInterface for dapr_v1::dapr_client::DaprClient<TonicChannel> {
///
/// * `encrypted` - Encrypted data usually returned from encrypted, Vec<StreamPayload>
/// * `options` - Decryption request options.
async fn decrypt(&mut self, request: Vec<DecryptRequest>) -> Result<Vec<u8>, Status> {
async fn decrypt(
&mut self,
request: Vec<DecryptRequest>,
) -> Result<ResponseStream<DecryptResponse>, Status> {
let request = Request::new(tokio_stream::iter(request));
let stream = self.decrypt_alpha1(request).await?;
let mut stream = stream.into_inner();
let mut data = vec![];
while let Some(resp) = stream.next().await {
if let Ok(resp) = resp {
if let Some(mut payload) = resp.payload {
data.append(payload.data.as_mut())
}
}
}
Ok(data)
let stream = self.decrypt_alpha1(request).await?.into_inner();
Ok(ResponseStream { stream })
}
}

Expand Down Expand Up @@ -752,6 +745,10 @@ pub type EncryptRequestOptions = crate::dapr::dapr::proto::runtime::v1::EncryptR
/// Decryption request options
pub type DecryptRequestOptions = crate::dapr::dapr::proto::runtime::v1::DecryptRequestOptions;

pub type EncryptResponse = crate::dapr::dapr::proto::runtime::v1::EncryptResponse;

pub type DecryptResponse = crate::dapr::dapr::proto::runtime::v1::DecryptResponse;

type StreamPayload = crate::dapr::dapr::proto::common::v1::StreamPayload;
impl<K> From<(K, Vec<u8>)> for common_v1::StateItem
where
Expand All @@ -773,3 +770,53 @@ impl<T: AsyncRead> ReaderStream<T> {
ReaderStream(tokio_util::io::ReaderStream::new(data))
}
}

pub struct ResponseStream<T> {
stream: Streaming<T>,
}

impl AsyncRead for ResponseStream<EncryptResponse> {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
match self.stream.poll_next_unpin(cx) {
Poll::Ready(Some(Ok(resp))) => {
if let Some(payload) = resp.payload {
buf.put_slice(&payload.data);
}
Poll::Ready(Ok(()))
}
Poll::Ready(Some(Err(e))) => Poll::Ready(Err(std::io::Error::new(
std::io::ErrorKind::Other,
format!("{:?}", e),
))),
Poll::Ready(None) => Poll::Ready(Ok(())),
Poll::Pending => Poll::Pending,
}
}
}

impl AsyncRead for ResponseStream<DecryptResponse> {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
match self.stream.poll_next_unpin(cx) {
Poll::Ready(Some(Ok(resp))) => {
if let Some(payload) = resp.payload {
buf.put_slice(&payload.data);
}
Poll::Ready(Ok(()))
}
Poll::Ready(Some(Err(e))) => Poll::Ready(Err(std::io::Error::new(
std::io::ErrorKind::Other,
format!("{:?}", e),
))),
Poll::Ready(None) => Poll::Ready(Ok(())),
Poll::Pending => Poll::Pending,
}
}
}
Loading