Skip to content

Commit

Permalink
src: connection: Fix panic when DNS lookup fails
Browse files Browse the repository at this point in the history
  • Loading branch information
joaoantoniocardoso authored and patrickelectric committed Jan 17, 2024
1 parent d828e6e commit be0bc1e
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 26 deletions.
16 changes: 16 additions & 0 deletions src/connection/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,19 @@ pub fn connect<M: Message>(address: &str) -> io::Result<Box<dyn MavConnection<M>
protocol_err
}
}

/// Returns the socket address for the given address.
pub(crate) fn get_socket_addr<T: std::net::ToSocketAddrs>(
address: T,
) -> Result<std::net::SocketAddr, io::Error> {
let addr = match address.to_socket_addrs()?.next() {
Some(addr) => addr,
None => {
return Err(io::Error::new(
io::ErrorKind::Other,
"Host address lookup failed",
));
}
};
Ok(addr)
}
15 changes: 5 additions & 10 deletions src/connection/tcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ use std::net::{TcpListener, TcpStream};
use std::sync::Mutex;
use std::time::Duration;

use super::get_socket_addr;

/// TCP MAVLink connection
pub fn select_protocol<M: Message>(
Expand All @@ -26,11 +28,8 @@ pub fn select_protocol<M: Message>(
}

pub fn tcpout<T: ToSocketAddrs>(address: T) -> io::Result<TcpConnection> {
let addr = address
.to_socket_addrs()
.unwrap()
.next()
.expect("Host address lookup failed.");
let addr = get_socket_addr(address)?;

let socket = TcpStream::connect(addr)?;
socket.set_read_timeout(Some(Duration::from_millis(100)))?;

Expand All @@ -45,11 +44,7 @@ pub fn tcpout<T: ToSocketAddrs>(address: T) -> io::Result<TcpConnection> {
}

pub fn tcpin<T: ToSocketAddrs>(address: T) -> io::Result<TcpConnection> {
let addr = address
.to_socket_addrs()
.unwrap()
.next()
.expect("Invalid address");
let addr = get_socket_addr(address)?;
let listener = TcpListener::bind(addr)?;

//For now we only accept one incoming stream: this blocks until we get one
Expand Down
22 changes: 6 additions & 16 deletions src/connection/udp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ use std::net::ToSocketAddrs;
use std::net::{SocketAddr, UdpSocket};
use std::sync::Mutex;

use super::get_socket_addr;

/// UDP MAVLink connection
pub fn select_protocol<M: Message>(
Expand All @@ -28,34 +30,22 @@ pub fn select_protocol<M: Message>(
}

pub fn udpbcast<T: ToSocketAddrs>(address: T) -> io::Result<UdpConnection> {
let addr = address
.to_socket_addrs()
.unwrap()
.next()
.expect("Invalid address");
let socket = UdpSocket::bind("0.0.0.0:0").unwrap();
let addr = get_socket_addr(address)?;
let socket = UdpSocket::bind("0.0.0.0:0")?;
socket
.set_broadcast(true)
.expect("Couldn't bind to broadcast address.");
UdpConnection::new(socket, false, Some(addr))
}

pub fn udpout<T: ToSocketAddrs>(address: T) -> io::Result<UdpConnection> {
let addr = address
.to_socket_addrs()
.unwrap()
.next()
.expect("Invalid address");
let addr = get_socket_addr(address)?;
let socket = UdpSocket::bind("0.0.0.0:0")?;
UdpConnection::new(socket, false, Some(addr))
}

pub fn udpin<T: ToSocketAddrs>(address: T) -> io::Result<UdpConnection> {
let addr = address
.to_socket_addrs()
.unwrap()
.next()
.expect("Invalid address");
let addr = get_socket_addr(address)?;
let socket = UdpSocket::bind(addr)?;
UdpConnection::new(socket, true, None)
}
Expand Down

0 comments on commit be0bc1e

Please sign in to comment.