diff --git a/benches/nkv_bench.rs b/benches/nkv_bench.rs index 52fdbd1..6c38b39 100644 --- a/benches/nkv_bench.rs +++ b/benches/nkv_bench.rs @@ -125,9 +125,10 @@ fn bench_server(c: &mut Criterion) { // not used with channels let url = "127.0.0.1:8091"; - let srv = Server::new(url.to_string(), temp_dir.path().to_path_buf()) - .await - .unwrap(); + let (srv, _cancel) = + Server::new(url.to_string(), temp_dir.path().to_path_buf()) + .await + .unwrap(); let put_tx = srv.put_tx(); @@ -160,9 +161,10 @@ fn bench_server(c: &mut Criterion) { // not used with channels let url = "127.0.0.1:8091"; - let srv = Server::new(url.to_string(), temp_dir.path().to_path_buf()) - .await - .unwrap(); + let (srv, _cancel) = + Server::new(url.to_string(), temp_dir.path().to_path_buf()) + .await + .unwrap(); let put_tx = srv.put_tx(); let get_tx = srv.get_tx(); @@ -202,9 +204,10 @@ fn bench_server(c: &mut Criterion) { // not used with channels let url = "127.0.0.1:8091"; - let srv = Server::new(url.to_string(), temp_dir.path().to_path_buf()) - .await - .unwrap(); + let (srv, _cancel) = + Server::new(url.to_string(), temp_dir.path().to_path_buf()) + .await + .unwrap(); let put_tx = srv.put_tx(); let del_tx = srv.del_tx(); diff --git a/src/server/main.rs b/src/server/main.rs index 2975abe..228e5c5 100644 --- a/src/server/main.rs +++ b/src/server/main.rs @@ -18,7 +18,7 @@ async fn main() { let temp_dir = TempDir::new().expect("Failed to create temporary directory"); // creates a task where it waits to serve threads - let srv = srv::Server::new(url.to_string(), temp_dir.path().to_path_buf()) + let (mut srv, _cancel) = srv::Server::new(url.to_string(), temp_dir.path().to_path_buf()) .await .unwrap(); diff --git a/src/srv.rs b/src/srv.rs index 33349c1..1b13b81 100644 --- a/src/srv.rs +++ b/src/srv.rs @@ -10,7 +10,7 @@ use std::net::SocketAddr; use std::sync::Arc; use tokio::io::{split, AsyncBufReadExt, AsyncWriteExt, BufReader, BufWriter}; use tokio::net::TcpListener; -use tokio::sync::mpsc; +use tokio::sync::{mpsc, oneshot}; use crate::nkv; use crate::notifier::WriteStream; @@ -51,15 +51,20 @@ pub struct Server { get_tx: mpsc::UnboundedSender, del_tx: mpsc::UnboundedSender, sub_tx: mpsc::UnboundedSender, + cancel_rx: oneshot::Receiver<()>, } impl Server { - pub async fn new(addr: String, path: std::path::PathBuf) -> std::io::Result { + pub async fn new( + addr: String, + path: std::path::PathBuf, + ) -> std::io::Result<(Self, oneshot::Sender<()>)> { let (put_tx, mut put_rx) = mpsc::unbounded_channel::(); let (get_tx, mut get_rx) = mpsc::unbounded_channel::(); let (del_tx, mut del_rx) = mpsc::unbounded_channel::(); let (sub_tx, mut sub_rx) = mpsc::unbounded_channel::(); - let (_cancel_tx, mut cancel_rx) = mpsc::unbounded_channel::(); + let (cancel_tx, cancel_rx) = oneshot::channel(); + let (usr_cancel_tx, mut usr_cancel_rx) = oneshot::channel(); let mut nkv = nkv::NotifyKeyValue::new(path)?; let addr: SocketAddr = addr.parse().expect("Unable to parse addr"); @@ -70,13 +75,13 @@ impl Server { get_tx, del_tx, sub_tx, + cancel_rx, }; // Spawn task to handle Asynchronous access to notify key value // storage via channels tokio::spawn(async move { - let mut cancelled = false; - while !cancelled { + loop { tokio::select! { Some(req) = put_rx.recv() => { nkv.put(&req.key, req.value).await; @@ -109,16 +114,21 @@ impl Server { } let _ = req.resp_tx.send(err).await; } - Some(_) = cancel_rx.recv() => { cancelled = true } - else => { break; } + + _ = &mut usr_cancel_rx => { + _ = cancel_tx.send(()); + return; + } + + else => { return; } } } }); - Ok(srv) + Ok((srv, usr_cancel_tx)) } - pub async fn serve(&self) { + pub async fn serve(&mut self) { let listener = TcpListener::bind(self.addr).await.unwrap(); loop { let put_tx = self.put_tx(); @@ -126,44 +136,51 @@ impl Server { let del_tx = self.del_tx(); let sub_tx = self.sub_tx(); - let (stream, addr) = listener.accept().await.unwrap(); - let (read_half, write_half) = split(stream); - let mut reader = BufReader::new(read_half); - let writer = BufWriter::new(write_half); - - tokio::spawn(async move { - let mut buffer = String::new(); - match reader.read_line(&mut buffer).await { - Ok(0) => { - // Connection was closed - return; - } - Ok(_) => match serde_json::from_str::(&buffer.trim()) { - Ok(request) => { - match request { - ServerRequest::Put(PutMessage { .. }) => { - Self::handle_put(writer, put_tx.clone(), request).await - } - ServerRequest::Get(BaseMessage { .. }) => { - Self::handle_get(writer, get_tx.clone(), request).await - } - ServerRequest::Delete(BaseMessage { .. }) => { - Self::handle_delete(writer, del_tx.clone(), request).await + tokio::select! { + Ok((stream, addr)) = listener.accept() => { + let (read_half, write_half) = split(stream); + let mut reader = BufReader::new(read_half); + let writer = BufWriter::new(write_half); + + tokio::spawn(async move { + let mut buffer = String::new(); + match reader.read_line(&mut buffer).await { + Ok(0) => { + // Connection was closed + return; + } + Ok(_) => match serde_json::from_str::(&buffer.trim()) { + Ok(request) => { + match request { + ServerRequest::Put(PutMessage { .. }) => { + Self::handle_put(writer, put_tx.clone(), request).await + } + ServerRequest::Get(BaseMessage { .. }) => { + Self::handle_get(writer, get_tx.clone(), request).await + } + ServerRequest::Delete(BaseMessage { .. }) => { + Self::handle_delete(writer, del_tx.clone(), request).await + } + ServerRequest::Subscribe(BaseMessage { .. }) => { + Self::handle_sub(writer, sub_tx.clone(), request, addr).await + } + }; } - ServerRequest::Subscribe(BaseMessage { .. }) => { - Self::handle_sub(writer, sub_tx.clone(), request, addr).await + Err(e) => { + eprintln!("Failed to parse JSON: {}", e); } - }; + }, + Err(_) => { + eprintln!("Failed to match request"); + } } - Err(e) => { - eprintln!("Failed to parse JSON: {}", e); - } - }, - Err(_) => { - eprintln!("Failed to match request"); - } + }); } - }); + + _ = &mut self.cancel_rx => { + return; + } + } } } @@ -331,7 +348,7 @@ mod tests { let temp_dir = TempDir::new().expect("Failed to create temporary directory"); let url = "127.0.0.1:8091"; - let srv = Server::new(url.to_string(), temp_dir.path().to_path_buf()) + let (mut srv, _cancel) = Server::new(url.to_string(), temp_dir.path().to_path_buf()) .await .unwrap(); @@ -401,7 +418,7 @@ mod tests { let temp_dir = TempDir::new().expect("Failed to create temporary directory"); let url = "127.0.0.1:8092"; - let srv = Server::new(url.to_string(), temp_dir.path().to_path_buf()) + let (mut srv, _cancel) = Server::new(url.to_string(), temp_dir.path().to_path_buf()) .await .unwrap(); tokio::spawn(async move {