diff --git a/Cargo.lock b/Cargo.lock index 3713119..918d0d6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -616,6 +616,12 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2acce4a10f12dc2fb14a218589d4f1f62ef011b2d0cc4b3cb1bba8e94da14649" +[[package]] +name = "fastrand" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "25cbce373ec4653f1a01a31e8a5e5ec0c622dc27ff9c4e6606eefef5cbbed4a5" + [[package]] name = "fd-lock" version = "4.0.0" @@ -976,6 +982,7 @@ dependencies = [ "serde", "serde_json", "serde_yaml", + "tempfile", "thiserror", "tokio", "tokio-stream", @@ -1635,6 +1642,19 @@ version = "0.12.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "14c39fd04924ca3a864207c66fc2cd7d22d7c016007f9ce846cbb9326331930a" +[[package]] +name = "tempfile" +version = "3.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7ef1adac450ad7f4b3c28589471ade84f25f731a7a0fe30d71dfa9f60fd808e5" +dependencies = [ + "cfg-if", + "fastrand", + "redox_syscall", + "rustix", + "windows-sys", +] + [[package]] name = "thiserror" version = "1.0.50" diff --git a/src/kiwi/Cargo.toml b/src/kiwi/Cargo.toml index fd56c63..630b0c5 100644 --- a/src/kiwi/Cargo.toml +++ b/src/kiwi/Cargo.toml @@ -27,3 +27,6 @@ tracing = "0.1.40" tracing-subscriber = "0.3.18" wasmtime = { version = "14.0.4", features = ["component-model"] } wasmtime-wasi = "14.0.4" + +[dev-dependencies] +tempfile = "3" diff --git a/src/kiwi/src/protocol.rs b/src/kiwi/src/protocol.rs index 3a066dc..86ab887 100644 --- a/src/kiwi/src/protocol.rs +++ b/src/kiwi/src/protocol.rs @@ -3,7 +3,7 @@ use thiserror::Error; use crate::source::SourceId; -#[derive(Debug, Clone, Deserialize, PartialEq, Eq)] +#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)] #[serde(tag = "type")] #[serde(rename_all = "SCREAMING_SNAKE_CASE")] /// A request that is sent from a client to the server @@ -14,7 +14,7 @@ pub enum Command { Unsubscribe { source_id: SourceId }, } -#[derive(Debug, Clone, Serialize)] +#[derive(Debug, Clone, Deserialize, Serialize)] #[serde(tag = "type")] #[serde(rename_all = "SCREAMING_SNAKE_CASE")] pub enum CommandResponse { @@ -28,7 +28,7 @@ pub enum CommandResponse { UnsubscribeError { source_id: SourceId, error: String }, } -#[derive(Debug, Clone, Serialize)] +#[derive(Debug, Clone, Deserialize, Serialize)] #[serde(tag = "type")] #[serde(rename_all = "SCREAMING_SNAKE_CASE")] /// An info or error message that may be pushed to a client. A notice, in many @@ -37,7 +37,7 @@ pub enum Notice { Lag { source: SourceId, count: u64 }, } -#[derive(Debug, Clone, Serialize)] +#[derive(Debug, Clone, Deserialize, Serialize)] #[serde(tag = "type", content = "data")] #[serde(rename_all = "SCREAMING_SNAKE_CASE")] /// An outbound message that is sent from the server to a client diff --git a/src/kiwi/src/ws.rs b/src/kiwi/src/ws.rs index a5404bf..5346f93 100644 --- a/src/kiwi/src/ws.rs +++ b/src/kiwi/src/ws.rs @@ -25,13 +25,13 @@ use crate::source::Source; use tokio_tungstenite::tungstenite::protocol::frame::coding::CloseCode; use tokio_tungstenite::tungstenite::protocol::{CloseFrame, Message as ProtocolMessage}; -#[derive(Debug, Clone, serde::Serialize)] +#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)] pub struct MessageData { - payload: Option, - topic: String, - timestamp: Option, - partition: i32, - offset: i64, + pub payload: Option, + pub topic: String, + pub timestamp: Option, + pub partition: i32, + pub offset: i64, } impl From for MessageData { diff --git a/src/kiwi/tests/kafka.rs b/src/kiwi/tests/kafka.rs new file mode 100644 index 0000000..4e1cb62 --- /dev/null +++ b/src/kiwi/tests/kafka.rs @@ -0,0 +1,130 @@ +use base64::Engine; +use futures::{SinkExt, StreamExt}; +use rdkafka::producer::{FutureProducer, FutureRecord}; +use std::{io::Write, process, time::Duration}; +use tokio_tungstenite::{connect_async, tungstenite}; + +use kiwi::{ + protocol::{Command, CommandResponse, Message}, + ws::MessageData, +}; + +use tempfile::NamedTempFile; + +// Helper function to start the kiwi process. +fn start_kiwi(config: &str) -> anyhow::Result { + // Expects `kiwi` to be in the PATH + let mut cmd = process::Command::new("kiwi"); + + let mut config_file = NamedTempFile::new()?; + config_file + .as_file_mut() + .write_all(config.as_bytes()) + .unwrap(); + + cmd.args(&[ + "--config", + config_file.path().to_str().expect("path should be valid"), + "--log-level", + "debug", + ]); + let child = cmd.spawn().unwrap(); + + Ok(child) +} + +#[ignore = "todo"] +#[tokio::test] +async fn test_kafka_source() -> anyhow::Result<()> { + let config = r#" + sources: + kafka: + group_prefix: '' + bootstrap_servers: + - 'localhost:9092' + topics: + - name: topic1 + + server: + address: '127.0.0.1:8000' + "#; + + let mut proc = start_kiwi(config)?; + + let (ws_stream, _) = connect_async("http://127.0.0.1:8000") + .await + .expect("Failed to connect"); + + let (mut write, mut read) = ws_stream.split(); + + let cmd = Command::Subscribe { + source_id: "topic1".into(), + }; + + write + .send(tungstenite::protocol::Message::Text( + serde_json::to_string(&cmd).unwrap(), + )) + .await?; + + let resp = read.next().await.expect("Expected response")?; + + let resp: Message<()> = serde_json::from_str(&resp.to_text().unwrap())?; + + match resp { + Message::CommandResponse(CommandResponse::SubscribeOk { source_id }) => { + assert_eq!(source_id, "topic1".to_string()); + } + _ => panic!("Expected subscribe ok"), + } + + let producer = tokio::spawn(async { + let producer: FutureProducer = rdkafka::config::ClientConfig::new() + .set("bootstrap.servers", "localhost:9092") + .set("message.timeout.ms", "5000") + .create() + .expect("Producer creation error"); + + for i in 0..1000 { + let payload = format!("Message {}", i); + let key = format!("Key {}", i); + let record = FutureRecord::to("topic1").payload(&payload).key(&key); + + producer + .send(record, Duration::from_secs(0)) + .await + .expect("Failed to enqueue"); + } + }); + + let reader = tokio::spawn(async move { + let mut count = 0; + while let Some(msg) = read.next().await { + let msg = msg.unwrap(); + let msg: Message = serde_json::from_str(&msg.to_text().unwrap()).unwrap(); + + match msg { + Message::Result(msg) => { + assert_eq!(msg.topic.as_ref(), "topic1".to_string()); + let msg = base64::engine::general_purpose::STANDARD + .decode(msg.payload.as_ref().unwrap()) + .unwrap(); + assert_eq!( + std::str::from_utf8(&msg).unwrap(), + format!("Message {}", count) + ); + count += 1; + } + _ => panic!("Expected message"), + } + } + + assert_eq!(count, 1000, "failed to receive all messages"); + }); + + let _ = futures::join!(producer, reader); + + let _ = proc.kill(); + + Ok(()) +}