Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move message-queue to a fully binary representation #454

Merged
merged 4 commits into from
Nov 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion coordinator/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -769,7 +769,7 @@ async fn handle_processor_messages<D: Db, Pro: Processors, P: P2p>(
mut db: D,
key: Zeroizing<<Ristretto as Ciphersuite>::F>,
serai: Arc<Serai>,
mut processors: Pro,
processors: Pro,
p2p: P,
cosign_channel: mpsc::UnboundedSender<CosignedBlock>,
network: NetworkId,
Expand Down
8 changes: 4 additions & 4 deletions coordinator/src/processors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ pub struct Message {
#[async_trait::async_trait]
pub trait Processors: 'static + Send + Sync + Clone {
async fn send(&self, network: NetworkId, msg: impl Send + Into<CoordinatorMessage>);
async fn recv(&mut self, network: NetworkId) -> Message;
async fn ack(&mut self, msg: Message);
async fn recv(&self, network: NetworkId) -> Message;
async fn ack(&self, msg: Message);
}

#[async_trait::async_trait]
Expand All @@ -28,7 +28,7 @@ impl Processors for Arc<MessageQueue> {
let msg = borsh::to_vec(&msg).unwrap();
self.queue(metadata, msg).await;
}
async fn recv(&mut self, network: NetworkId) -> Message {
async fn recv(&self, network: NetworkId) -> Message {
let msg = self.next(Service::Processor(network)).await;
assert_eq!(msg.from, Service::Processor(network));

Expand All @@ -40,7 +40,7 @@ impl Processors for Arc<MessageQueue> {

return Message { id, network, msg };
}
async fn ack(&mut self, msg: Message) {
async fn ack(&self, msg: Message) {
MessageQueue::ack(self, Service::Processor(msg.network), msg.id).await
}
}
4 changes: 2 additions & 2 deletions coordinator/src/tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@ impl Processors for MemProcessors {
let processor = processors.entry(network).or_insert_with(VecDeque::new);
processor.push_back(msg.into());
}
async fn recv(&mut self, _: NetworkId) -> Message {
async fn recv(&self, _: NetworkId) -> Message {
todo!()
}
async fn ack(&mut self, _: Message) {
async fn ack(&self, _: Message) {
todo!()
}
}
Expand Down
11 changes: 3 additions & 8 deletions message-queue/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,10 @@ rustdoc-args = ["--cfg", "docsrs"]
[dependencies]
# Macros
once_cell = { version = "1", default-features = false }
serde = { version = "1", default-features = false, features = ["std", "derive"] }

# Encoders
hex = { version = "0.4", default-features = false, features = ["std"] }
borsh = { version = "1", default-features = false, features = ["std", "derive", "de_strict_order"] }
serde_json = { version = "1", default-features = false, features = ["std"] }

# Libs
zeroize = { version = "1", default-features = false, features = ["std"] }
Expand All @@ -37,16 +35,13 @@ log = { version = "0.4", default-features = false, features = ["std"] }
env_logger = { version = "0.10", default-features = false, features = ["humantime"] }

# Uses a single threaded runtime since this shouldn't ever be CPU-bound
tokio = { version = "1", default-features = false, features = ["rt", "time", "macros"] }
tokio = { version = "1", default-features = false, features = ["rt", "time", "io-util", "net", "macros"] }

serai-db = { path = "../common/db", features = ["rocksdb"], optional = true }

serai-env = { path = "../common/env" }

serai-primitives = { path = "../substrate/primitives", features = ["borsh", "serde"] }

jsonrpsee = { version = "0.16", default-features = false, features = ["server"], optional = true }
simple-request = { path = "../common/request", default-features = false }
serai-primitives = { path = "../substrate/primitives", features = ["borsh"] }

[features]
binaries = ["serai-db", "jsonrpsee"]
binaries = ["serai-db"]
183 changes: 94 additions & 89 deletions message-queue/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,20 @@ use ciphersuite::{
};
use schnorr_signatures::SchnorrSignature;

use serde::{Serialize, Deserialize};

use simple_request::{hyper::Request, Client};
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
net::TcpStream,
};

use serai_env as env;

use crate::{Service, Metadata, QueuedMessage, message_challenge, ack_challenge};
#[rustfmt::skip]
use crate::{Service, Metadata, QueuedMessage, MessageQueueRequest, message_challenge, ack_challenge};

pub struct MessageQueue {
pub service: Service,
priv_key: Zeroizing<<Ristretto as Ciphersuite>::F>,
pub_key: <Ristretto as Ciphersuite>::G,
client: Client,
url: String,
}

Expand All @@ -37,17 +38,8 @@ impl MessageQueue {
if !url.contains(':') {
url += ":2287";
}
if !url.starts_with("http://") {
url = "http://".to_string() + &url;
}

MessageQueue {
service,
pub_key: Ristretto::generator() * priv_key.deref(),
priv_key,
client: Client::with_connection_pool(),
url,
}
MessageQueue { service, pub_key: Ristretto::generator() * priv_key.deref(), priv_key, url }
}

pub fn from_env(service: Service) -> MessageQueue {
Expand All @@ -72,60 +64,14 @@ impl MessageQueue {
Self::new(service, url, priv_key)
}

async fn json_call(&self, method: &'static str, params: serde_json::Value) -> serde_json::Value {
#[derive(Clone, PartialEq, Eq, Debug, Serialize, Deserialize)]
struct JsonRpcRequest {
jsonrpc: &'static str,
method: &'static str,
params: serde_json::Value,
id: u64,
}

let mut res = loop {
// Make the request
match self
.client
.request(
Request::post(&self.url)
.header("Content-Type", "application/json")
.body(
serde_json::to_vec(&JsonRpcRequest {
jsonrpc: "2.0",
method,
params: params.clone(),
id: 0,
})
.unwrap()
.into(),
)
.unwrap(),
)
.await
{
Ok(req) => {
// Get the response
match req.body().await {
Ok(res) => break res,
Err(e) => {
dbg!(e);
}
}
}
Err(e) => {
dbg!(e);
}
}

// Sleep for a second before trying again
tokio::time::sleep(core::time::Duration::from_secs(1)).await;
#[must_use]
async fn send(socket: &mut TcpStream, msg: MessageQueueRequest) -> bool {
let msg = borsh::to_vec(&msg).unwrap();
let Ok(_) = socket.write_all(&u32::try_from(msg.len()).unwrap().to_le_bytes()).await else {
return false;
};

let json: serde_json::Value =
serde_json::from_reader(&mut res).expect("message-queue returned invalid JSON");
if json.get("result").is_none() {
panic!("call failed: {json}");
}
json
let Ok(_) = socket.write_all(&msg).await else { return false };
true
}

pub async fn queue(&self, metadata: Metadata, msg: Vec<u8>) {
Expand All @@ -146,30 +92,76 @@ impl MessageQueue {
)
.serialize();

let json = self.json_call("queue", serde_json::json!([metadata, msg, sig])).await;
if json.get("result") != Some(&serde_json::Value::Bool(true)) {
panic!("failed to queue message: {json}");
let msg = MessageQueueRequest::Queue { meta: metadata, msg, sig };
let mut first = true;
loop {
// Sleep, so we don't hammer re-attempts
if !first {
tokio::time::sleep(core::time::Duration::from_secs(5)).await;
}
first = false;

let Ok(mut socket) = TcpStream::connect(&self.url).await else { continue };
if !Self::send(&mut socket, msg.clone()).await {
continue;
}
if socket.read_u8().await.ok() != Some(1) {
continue;
}
break;
}
}

pub async fn next(&self, from: Service) -> QueuedMessage {
loop {
let json = self.json_call("next", serde_json::json!([from, self.service])).await;

// Convert from a Value to a type via reserialization
let msg: Option<QueuedMessage> = serde_json::from_str(
&serde_json::to_string(
&json.get("result").expect("successful JSON RPC call didn't have result"),
)
.unwrap(),
)
.expect("next didn't return an Option<QueuedMessage>");

// If there wasn't a message, check again in 1s
let Some(msg) = msg else {
tokio::time::sleep(core::time::Duration::from_secs(1)).await;
let msg = MessageQueueRequest::Next { from, to: self.service };
let mut first = true;
'outer: loop {
if !first {
tokio::time::sleep(core::time::Duration::from_secs(5)).await;
continue;
}
first = false;

let Ok(mut socket) = TcpStream::connect(&self.url).await else { continue };

loop {
if !Self::send(&mut socket, msg.clone()).await {
continue 'outer;
}
let Ok(status) = socket.read_u8().await else {
continue 'outer;
};
// If there wasn't a message, check again in 1s
if status == 0 {
tokio::time::sleep(core::time::Duration::from_secs(1)).await;
continue;
}
assert_eq!(status, 1);
break;
}

// Timeout after 5 seconds in case there's an issue with the length handling
let Ok(msg) = tokio::time::timeout(core::time::Duration::from_secs(5), async {
// Read the message length
let Ok(len) = socket.read_u32_le().await else {
return vec![];
};
let mut buf = vec![0; usize::try_from(len).unwrap()];
// Read the message
let Ok(_) = socket.read_exact(&mut buf).await else {
return vec![];
};
buf
})
.await
else {
continue;
};
if msg.is_empty() {
continue;
}

let msg: QueuedMessage = borsh::from_slice(msg.as_slice()).unwrap();

// Verify the message
// Verify the sender is sane
Expand Down Expand Up @@ -202,9 +194,22 @@ impl MessageQueue {
)
.serialize();

let json = self.json_call("ack", serde_json::json!([from, self.service, id, sig])).await;
if json.get("result") != Some(&serde_json::Value::Bool(true)) {
panic!("failed to ack message {id}: {json}");
let msg = MessageQueueRequest::Ack { from, to: self.service, id, sig };
let mut first = true;
loop {
if !first {
tokio::time::sleep(core::time::Duration::from_secs(5)).await;
}
first = false;

let Ok(mut socket) = TcpStream::connect(&self.url).await else { continue };
if !Self::send(&mut socket, msg.clone()).await {
continue;
}
if socket.read_u8().await.ok() != Some(1) {
continue;
}
break;
}
}
}
Loading