Skip to content

Commit

Permalink
Merge pull request #172 from Carter12s/fix-connection-header-shenanigans
Browse files Browse the repository at this point in the history
Fix a lot of header shenanigans
  • Loading branch information
Carter12s authored Jul 5, 2024
2 parents 03d0e57 + ac79ab8 commit 99d27ab
Show file tree
Hide file tree
Showing 7 changed files with 110 additions and 89 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added

### Fixed
- Bug with ros1 native publishers not parsing connection headers correctly

### Changed

Expand Down
6 changes: 3 additions & 3 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

104 changes: 53 additions & 51 deletions roslibrust/src/ros1/publisher.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use crate::ros1::{names::Name, tcpros::ConnectionHeader};
use crate::ros1::{
names::Name,
tcpros::{self, ConnectionHeader},
};
use abort_on_drop::ChildTask;
use roslibrust_codegen::RosMessageType;
use std::{
Expand Down Expand Up @@ -91,65 +94,64 @@ impl Publication {
log::info!(
"Received connection from subscriber at {peer_addr} for topic {topic_name}"
);
let mut connection_header = Vec::with_capacity(16 * 1024);
if let Ok(bytes) = stream.read_buf(&mut connection_header).await {
if let Ok(connection_header) =
ConnectionHeader::from_bytes(&connection_header[..bytes])
{
log::debug!(
"Received subscribe request for {:?} with md5sum {:?}",
connection_header.topic,
connection_header.md5sum
);
// I can't find documentation for this anywhere, but when using
// `rostopic hz` with one of our publishers I discovered that the rospy code sent "*" as the md5sum
// To indicate a "generic subscription"...
// I also discovered that `rostopic echo` does not send a md5sum (even thou ros documentation says its required)
if let Some(connection_md5sum) = connection_header.md5sum {
if connection_md5sum != "*" {
if let Some(local_md5sum) = &responding_conn_header.md5sum {
if connection_md5sum != *local_md5sum {
log::warn!(

// Read the connection header:
let connection_header = match tcpros::recieve_header(&mut stream).await {
Ok(header) => header,
Err(e) => {
log::error!("Failed to read connection header: {e:?}");
stream
.shutdown()
.await
.expect("Unable to shutdown tcpstream");
continue;
}
};

log::debug!(
"Received subscribe request for {:?} with md5sum {:?}",
connection_header.topic,
connection_header.md5sum
);
// I can't find documentation for this anywhere, but when using
// `rostopic hz` with one of our publishers I discovered that the rospy code sent "*" as the md5sum
// To indicate a "generic subscription"...
// I also discovered that `rostopic echo` does not send a md5sum (even thou ros documentation says its required)
if let Some(connection_md5sum) = connection_header.md5sum {
if connection_md5sum != "*" {
if let Some(local_md5sum) = &responding_conn_header.md5sum {
if connection_md5sum != *local_md5sum {
log::warn!(
"Got subscribe request for {}, but md5sums do not match. Expected {:?}, received {:?}",
topic_name,
local_md5sum,
connection_md5sum,
);
// Close the TCP connection
stream
.shutdown()
.await
.expect("Unable to shutdown tcpstream");
continue;
}
}
// Close the TCP connection
stream
.shutdown()
.await
.expect("Unable to shutdown tcpstream");
continue;
}
}
// Write our own connection header in response
let response_header_bytes = responding_conn_header
.to_bytes(false)
.expect("Couldn't serialize connection header");
stream
.write(&response_header_bytes[..])
.await
.expect("Unable to respond on tcpstream");
let mut wlock = subscriber_streams.write().await;
wlock.push(stream);
log::debug!(
"Added stream for topic {:?} to subscriber {}",
connection_header.topic,
peer_addr
);
} else {
let header_str = connection_header[..bytes]
.into_iter()
.map(|ch| if *ch < 128 { *ch as char } else { '.' })
.collect::<String>();
log::error!(
"Failed to parse connection header: ({bytes} bytes) {header_str}",
)
}
}
// Write our own connection header in response
let response_header_bytes = responding_conn_header
.to_bytes(false)
.expect("Couldn't serialize connection header");
stream
.write(&response_header_bytes[..])
.await
.expect("Unable to respond on tcpstream");
let mut wlock = subscriber_streams.write().await;
wlock.push(stream);
log::debug!(
"Added stream for topic {:?} to subscriber {}",
connection_header.topic,
peer_addr
);
}
}
});
Expand Down
19 changes: 2 additions & 17 deletions roslibrust/src/ros1/service_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use abort_on_drop::ChildTask;
use log::*;
use tokio::io::{AsyncReadExt, AsyncWriteExt};

use crate::ros1::tcpros::ConnectionHeader;
use crate::ros1::tcpros::{self, ConnectionHeader};

use super::{names::Name, NodeHandle};

Expand Down Expand Up @@ -170,22 +170,7 @@ impl ServiceServerLink {
// Probably it is better to try to send an error back?
debug!("Received service_request connection from {peer_addr} for {service_name}");

// Get the header from the stream:
let mut header_len_bytes = [0u8; 4];
if let Err(e) = stream.read_exact(&mut header_len_bytes).await {
warn!("Communication error while handling service request connection for {service_name}, could not get header length: {e:?}");
// TODO returning here simply closes the socket? Should we respond with an error instead?
return;
}
let header_len = u32::from_le_bytes(header_len_bytes) as usize;

let mut connection_header = vec![0u8; header_len];
if let Err(e) = stream.read_exact(&mut connection_header).await {
warn!("Communication error while handling service request connection for {service_name}, could not get header body: {e:?}");
// TODO returning here simply closes the socket? Should we respond with an error instead?
return;
}
let connection_header = match ConnectionHeader::from_bytes(&connection_header) {
let connection_header = match tcpros::recieve_header(&mut stream).await {
Ok(header) => header,
Err(e) => {
warn!("Communication error while handling service request connection for {service_name}, could not parse header: {e:?}");
Expand Down
10 changes: 3 additions & 7 deletions roslibrust/src/ros1/subscriber.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ use tokio::{
},
};

use super::tcpros;

pub struct Subscriber<T> {
receiver: broadcast::Receiver<Vec<u8>>,
_phantom: PhantomData<T>,
Expand Down Expand Up @@ -153,13 +155,7 @@ async fn establish_publisher_connection(
let conn_header_bytes = conn_header.to_bytes(true)?;
stream.write_all(&conn_header_bytes[..]).await?;

let mut header_len_bytes = [0u8; 4];
let _header_bytes = stream.read_exact(&mut header_len_bytes).await?;
let header_len = u32::from_le_bytes(header_len_bytes) as usize;

let mut responded_header_bytes = vec![0u8; header_len];
let bytes = stream.read_exact(&mut responded_header_bytes).await?;
if let Ok(responded_header) = ConnectionHeader::from_bytes(&responded_header_bytes[..bytes]) {
if let Ok(responded_header) = tcpros::recieve_header(&mut stream).await {
if conn_header.md5sum == responded_header.md5sum {
log::debug!(
"Established connection with publisher for {:?}",
Expand Down
57 changes: 47 additions & 10 deletions roslibrust/src/ros1/tcpros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ impl ConnectionHeader {
let mut field = vec![0u8; field_length];
cursor.read_exact(&mut field)?;
let field = String::from_utf8(field).map_err(|e| {
warn!("Failed to parse field in connection header as valid utf8: {e:#?}");
warn!("Failed to parse field in connection header as valid utf8: {e:#?}, Full header: {header_data:#?}");
std::io::ErrorKind::InvalidData
})?;
let equals_pos = match field.find('=') {
Expand Down Expand Up @@ -87,9 +87,9 @@ impl ConnectionHeader {
// If you do `rosservice call /my_service` and hit TAB you'll see this field in the connection header
// we can ignore it
} else if field.starts_with("error=") {
log::error!("Error reported in TCPROS connection header: {field}");
log::error!("Error reported in TCPROS connection header: {field}, full header: {header_data:#?}");
} else {
log::warn!("Encountered unhandled field in connection header: {field}");
log::warn!("Encountered unhandled field in connection header: {field}, full header: {header_data:#?}");
}
}

Expand Down Expand Up @@ -194,16 +194,13 @@ pub async fn establish_connection(
},
)?;

// Write our own connection header to the stream
let conn_header_bytes = conn_header.to_bytes(true)?;
stream.write_all(&conn_header_bytes[..]).await?;

let mut header_len_bytes = [0u8; 4];
let _header_bytes = stream.read_exact(&mut header_len_bytes).await?;
let header_len = u32::from_le_bytes(header_len_bytes) as usize;

let mut responded_header_bytes = Vec::with_capacity(header_len);
let bytes = stream.read_buf(&mut responded_header_bytes).await?;
if let Ok(_responded_header) = ConnectionHeader::from_bytes(&responded_header_bytes[..bytes]) {
// Recieve the header from the server
let responded_header = recieve_header(&mut stream).await;
if let Ok(_responded_header) = responded_header {
// TODO we should really examine this md5sum logic...
// according to the ROS documentation, the service isn't required to respond
// with anything other than caller_id
Expand All @@ -228,6 +225,22 @@ pub async fn establish_connection(
.map_err(std::io::Error::from)
}

// Reads a complete ROS connection header from the given stream
pub async fn recieve_header(stream: &mut TcpStream) -> Result<ConnectionHeader, std::io::Error> {
// Bring trait def into scope
use tokio::io::AsyncReadExt;
// Recieve the header length
let mut header_len_bytes = [0u8; 4];
let _num_bytes_read = stream.read_exact(&mut header_len_bytes).await?;
// This is the length of the header itself
let header_len = u32::from_le_bytes(header_len_bytes) as usize;

// Initialize a buffer to hold the header
let mut header_bytes = vec![0u8; header_len];
let _num_bytes_read = stream.read_exact(&mut header_bytes).await?;
ConnectionHeader::from_bytes(&header_bytes)
}

#[cfg(test)]
mod test {
use super::ConnectionHeader;
Expand Down Expand Up @@ -258,4 +271,28 @@ mod test {
assert_eq!(header.topic, Some("/chatter".to_owned()));
assert_eq!(header.topic_type, "std_msgs/String");
}

#[test_log::test]
fn example_from_testing() {
// example taken from `rostopic echo` with our ros1_talker example
let bytes: Vec<u8> = vec![
37, 0, 0, 0, 99, 97, 108, 108, 101, 114, 105, 100, 61, 47, 114, 111, 115, 116, 111,
112, 105, 99, 95, 49, 49, 54, 56, 95, 49, 55, 50, 48, 50, 49, 53, 56, 51, 56, 57, 48,
50, 39, 0, 0, 0, 109, 100, 53, 115, 117, 109, 61, 57, 57, 50, 99, 101, 56, 97, 49, 54,
56, 55, 99, 101, 99, 56, 99, 56, 98, 100, 56, 56, 51, 101, 99, 55, 51, 99, 97, 52, 49,
100, 49, 31, 0, 0, 0, 109, 101, 115, 115, 97, 103, 101, 95, 100, 101, 102, 105, 110,
105, 116, 105, 111, 110, 61, 115, 116, 114, 105, 110, 103, 32, 100, 97, 116, 97, 10,
13, 0, 0, 0, 116, 99, 112, 95, 110, 111, 100, 101, 108, 97, 121, 61, 48, 14, 0, 0, 0,
116, 111, 112, 105, 99, 61, 47, 99, 104, 97, 116, 116, 101, 114, 20, 0, 0, 0, 116, 121,
112, 101, 61, 115, 116, 100, 95, 109, 115, 103, 115, 47, 83, 116, 114, 105, 110, 103,
];

let header = ConnectionHeader::from_bytes(&bytes).unwrap();
assert_eq!(header.caller_id, "/rostopic_1168_1720215838902");
assert_eq!(header.topic_type, "std_msgs/String");
assert_eq!(
header.md5sum,
Some("992ce8a1687cec8c8bd883ec73ca41d1".to_string())
);
}
}
2 changes: 1 addition & 1 deletion roslibrust_genmsg/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ itertools = "0.12"
lazy_static = "1.4"
log = "0.4"
minijinja = "2.0"
roslibrust_codegen = { path = "../roslibrust_codegen", version = "0.9.0" }
roslibrust_codegen = { path = "../roslibrust_codegen", version = "0.10.0" }
serde = { version = "1", features = ["derive"] }
serde_json = "1"

Expand Down

0 comments on commit 99d27ab

Please sign in to comment.