Skip to content

Commit

Permalink
Add unsubscribe handle
Browse files Browse the repository at this point in the history
With that client can unsubscribe from updates of a value.
Main change in API concerns the fact that now we store subscriptions
not based on SocketAddr, but on client uuid, that a) abstracts us from
TCP as a transport layer and b) gives easier flexibility to trace clients

Also this commit introduces sending Hello message to a client when it
subscribes to a value, it removes any need for sleeps in any tests making
async programming correct way

Signed-off-by: Pavel Abramov <[email protected]>
  • Loading branch information
uncleDecart committed Oct 2, 2024
1 parent c72b5ad commit 3d2678b
Show file tree
Hide file tree
Showing 7 changed files with 201 additions and 58 deletions.
1 change: 1 addition & 0 deletions benches/nkv_bench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ fn bench_server(c: &mut Criterion) {
let (del_resp_tx, mut del_resp_rx) = mpsc::channel(1);
let _ = del_tx.send(BaseMsg {
key: "key1".to_string(),
uuid: "0".to_string(),
resp_tx: del_resp_tx,
});
let result = del_resp_rx.recv().await.unwrap();
Expand Down
10 changes: 10 additions & 0 deletions src/client/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,16 @@ async fn main() {
println!("SUBSCRIBE requires a key");
}
}
"UNSUBSCRIBE" => {
if let Some(key) = parts.get(1) {
let start = Instant::now();
let resp = client.unsubscribe(key.to_string()).await.unwrap();
let elapsed = start.elapsed();
println!("Request took: {:.2?}\n{}", elapsed, resp);
} else {
println!("SUBSCRIBE requires a key");
}
}
"QUIT" => {
break;
}
Expand Down
16 changes: 15 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ impl std::error::Error for NkvClientError {}

pub struct NkvClient {
addr: String,
uuid: String,
subscriptions: HashMap<String, bool>,
}

Expand All @@ -48,6 +49,7 @@ impl NkvClient {
pub fn new(addr: &str) -> Self {
Self {
addr: addr.to_string(),
uuid: Self::uuid(),
subscriptions: HashMap::new(),
}
}
Expand All @@ -59,6 +61,7 @@ impl NkvClient {
pub async fn get(&mut self, key: String) -> tokio::io::Result<ServerResponse> {
let req = ServerRequest::Get(BaseMessage {
id: Self::uuid(),
client_uuid: self.uuid.clone(),
key,
});
self.send_request(&req).await
Expand All @@ -68,6 +71,7 @@ impl NkvClient {
let req = ServerRequest::Put(PutMessage {
base: BaseMessage {
id: Self::uuid(),
client_uuid: self.uuid.clone(),
key,
},
value: val,
Expand All @@ -78,6 +82,16 @@ impl NkvClient {
pub async fn delete(&mut self, key: String) -> tokio::io::Result<ServerResponse> {
let req = ServerRequest::Delete(BaseMessage {
id: Self::uuid(),
client_uuid: self.uuid.clone(),
key,
});
self.send_request(&req).await
}

pub async fn unsubscribe(&mut self, key: String) -> tokio::io::Result<ServerResponse> {
let req = ServerRequest::Unsubscribe(BaseMessage {
id: Self::uuid(),
client_uuid: self.uuid.clone(),
key,
});
self.send_request(&req).await
Expand All @@ -97,7 +111,7 @@ impl NkvClient {
}));
}

let (mut subscriber, mut rx) = Subscriber::new(&self.addr, &key);
let (mut subscriber, mut rx) = Subscriber::new(&self.addr, &key, &self.uuid);

tokio::spawn(async move {
// TODO: stop when cancleed
Expand Down
28 changes: 14 additions & 14 deletions src/nkv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ use crate::notifier::{Notifier, NotifierError, WriteStream};
use crate::persist_value::PersistValue;
use crate::trie::{Trie, TrieNode};
use std::fmt;
use std::net::SocketAddr;
use tokio::sync::Mutex;

#[derive(Debug)]
Expand Down Expand Up @@ -252,33 +251,34 @@ impl NotifyKeyValue {
pub async fn subscribe(
&mut self,
key: &str,
addr: SocketAddr,
uuid: String,
stream: WriteStream,
) -> Result<(), NotifierError> {
// println!("DEBUG: SUBSCRIBE STATE {:?}", self.state);
if let Some(val) = self.state.get_mut(key, None).await {
val.notifier.lock().await.subscribe(addr, stream).await;
val.notifier.lock().await.subscribe(uuid, stream).await?;
} else {
// Client can subscribe to a non-existent value
let val = self.create_value(key, Box::new([]));
val.notifier.lock().await.subscribe(addr, stream).await;
val.notifier.lock().await.subscribe(uuid, stream).await?;
self.state.insert(key, val);
}
Ok(())
}

pub async fn unsubscribe(&mut self, key: &str, addr: &SocketAddr) {
pub async fn unsubscribe(
&mut self,
key: &str,
uuid: String,
) -> Result<(), NotifyKeyValueError> {
if let Some(val) = self.state.get_mut(key, None).await {
match val
.notifier
val.notifier
.lock()
.await
.unsubscribe(key.to_string(), addr)
.await
{
Ok(_) => (),
Err(e) => eprintln!("Failed to unsubscribe {}", e),
}
.unsubscribe(key.to_string(), uuid)
.await?;
Ok(())
} else {
Err(NotifyKeyValueError::NotFound)
}
}
}
Expand Down
63 changes: 39 additions & 24 deletions src/notifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ use core::fmt;
use serde::{Deserialize, Serialize};
use serde_json::to_vec;
use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::Arc;

use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader, BufWriter};
Expand Down Expand Up @@ -112,7 +111,7 @@ impl<T> StateBuf<T> {

#[derive(Debug)]
pub struct Notifier {
clients: Arc<Mutex<HashMap<SocketAddr, WriteStream>>>,
clients: Arc<Mutex<HashMap<String, WriteStream>>>,
// use a buffer to guarantee latest state on consumers
// for detailed information see DESIGN_DECISIONS.md
msg_buf: Arc<Mutex<StateBuf<Message>>>,
Expand Down Expand Up @@ -149,7 +148,7 @@ impl Notifier {
}

async fn send_notifications(
clients: Arc<Mutex<HashMap<SocketAddr, WriteStream>>>,
clients: Arc<Mutex<HashMap<String, WriteStream>>>,
msg_buf: Arc<Mutex<StateBuf<Message>>>,
) {
let buf_val = {
Expand All @@ -164,28 +163,32 @@ impl Notifier {
}
}

pub async fn subscribe(&self, addr: SocketAddr, stream: WriteStream) {
pub async fn subscribe(&self, uuid: String, stream: WriteStream) -> Result<(), NotifierError> {
let mut subscribers = self.clients.lock().await;
subscribers.insert(addr, stream);
subscribers.insert(uuid.clone(), stream);
match subscribers.get_mut(&uuid) {
Some(stream) => Notifier::send_bytes(&to_vec(&Message::Hello).unwrap(), stream).await,
None => return Err(NotifierError::SubscribtionNotFound),
}
}

pub async fn unsubscribe(&self, key: String, addr: &SocketAddr) -> Result<(), NotifierError> {
pub async fn unsubscribe(&self, key: String, uuid: String) -> Result<(), NotifierError> {
let mut clients = self.clients.lock().await;
Self::unsubscribe_impl(key, &mut clients, addr).await
Self::unsubscribe_impl(key, &mut clients, uuid).await
}

async fn unsubscribe_impl(
key: String,
clients: &mut HashMap<SocketAddr, WriteStream>,
addr: &SocketAddr,
clients: &mut HashMap<String, WriteStream>,
uuid: String,
) -> Result<(), NotifierError> {
match clients.get_mut(&addr) {
match clients.get_mut(&uuid) {
Some(stream) => {
Notifier::send_bytes(&to_vec(&Message::Close { key }).unwrap(), stream).await?
}
None => return Err(NotifierError::SubscribtionNotFound),
}
clients.remove(addr);
clients.remove(&uuid);
Ok(())
}

Expand Down Expand Up @@ -217,30 +220,30 @@ impl Notifier {
}

async fn broadcast_message(
clients: Arc<Mutex<HashMap<SocketAddr, WriteStream>>>,
clients: Arc<Mutex<HashMap<String, WriteStream>>>,
message: &Message,
) {
let json_bytes = to_vec(&message).unwrap();

let keys: Vec<std::net::SocketAddr> = {
let keys: Vec<String> = {
let client_guard = clients.lock().await;
client_guard.keys().cloned().collect()
};

let mut failed_addrs = Vec::new();
for addr in keys.iter() {
if let Some(stream) = clients.lock().await.get_mut(addr) {
for uuid in keys.iter() {
if let Some(stream) = clients.lock().await.get_mut(uuid) {
if let Err(e) = Notifier::send_bytes(&json_bytes, stream).await {
eprintln!("broadcast message: {}", e);
failed_addrs.push(addr.clone());
failed_addrs.push(uuid.clone());
continue;
}
}
}

let mut clients = clients.lock().await;
for addr in failed_addrs {
match Self::unsubscribe_impl("failed addr".to_string(), &mut clients, &addr).await {
for uuid in failed_addrs {
match Self::unsubscribe_impl("failed addr".to_string(), &mut clients, uuid).await {
Ok(_) => {}
Err(e) => eprintln!("Failed to unsubscribe: {}", e),
}
Expand Down Expand Up @@ -269,16 +272,18 @@ impl Notifier {
pub struct Subscriber {
addr: String,
key: String,
uuid: String,
tx: watch::Sender<Message>,
}

impl Subscriber {
pub fn new(addr: &str, key: &str) -> (Self, watch::Receiver<Message>) {
pub fn new(addr: &str, key: &str, uuid: &str) -> (Self, watch::Receiver<Message>) {
let (tx, rx) = watch::channel(Message::Hello);
(
Self {
addr: addr.to_string(),
key: key.to_string(),
uuid: uuid.to_string(),
tx,
},
rx,
Expand All @@ -304,12 +309,14 @@ impl Subscriber {
let req = ServerRequest::Subscribe(BaseMessage {
id: "0".to_string(),
key: self.key.to_string(),
client_uuid: self.uuid.to_string(),
});
let req = serde_json::to_string(&req)?;
writer.write_all(req.as_bytes()).await?;
writer.write_all(b"\n").await?;
writer.flush().await?;

// TODO: we need to talk about max buff
let mut buffer = [0; 1024];
loop {
let n = reader.read(&mut buffer).await?;
Expand All @@ -335,6 +342,7 @@ impl Subscriber {
#[cfg(test)]
mod tests {
use super::*;
use std::net::SocketAddr;
use tokio::io::{split, AsyncBufReadExt};
use tokio::net::TcpListener;

Expand All @@ -344,9 +352,10 @@ mod tests {
.parse()
.expect("Unable to parse socket address");
let key = "AWESOME_KEY".to_string();
let uuid = "AWESOME_UUID".to_string();

let mut notifier = Notifier::new();
let (mut subscriber, mut rx) = Subscriber::new(srv_addr.to_string().as_str(), &key);
let (mut subscriber, mut rx) = Subscriber::new(srv_addr.to_string().as_str(), &key, &uuid);
let listener = TcpListener::bind(srv_addr).await.unwrap();
let val: Box<[u8]> = "Bazinga".to_string().into_bytes().into_boxed_slice();
let vc = val.clone();
Expand All @@ -355,23 +364,29 @@ mod tests {
subscriber.start().await;
});

let handle = tokio::spawn(async move {
let (stream, addr) = listener.accept().await.unwrap();
let _handle = tokio::spawn(async move {
let (stream, _addr) = listener.accept().await.unwrap();
let (read_half, write_half) = split(stream);
let mut reader = tokio::io::BufReader::new(read_half);
let writer = BufWriter::new(write_half);

let mut buffer = String::new();
let _ = reader.read_line(&mut buffer).await;

notifier.subscribe(addr, writer).await;
let _ = notifier.subscribe(uuid, writer).await;
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
notifier
.send_update("AWESOME_KEY".to_string(), vc.clone())
.await;
sleep(Duration::from_secs(1)).await;
});

handle.await.unwrap();
// handle.await.unwrap();
assert_eq!(true, rx.changed().await.is_ok());
{
let msg = rx.borrow();
assert_eq!(*msg, Message::Hello);
}

assert_eq!(true, rx.changed().await.is_ok());
let msg = rx.borrow();
Expand Down
2 changes: 2 additions & 0 deletions src/request_msg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use std::fmt;
pub struct BaseMessage {
pub id: String,
pub key: String,
pub client_uuid: String,
}

#[derive(Debug, Deserialize, Serialize)]
Expand All @@ -27,6 +28,7 @@ pub enum ServerRequest {
Get(BaseMessage),
Delete(BaseMessage),
Subscribe(BaseMessage),
Unsubscribe(BaseMessage),
}

#[derive(Debug, serde::Deserialize, serde::Serialize)]
Expand Down
Loading

0 comments on commit 3d2678b

Please sign in to comment.