diff --git a/examples/forward.rs b/examples/forward.rs index 1c35f04..347707d 100644 --- a/examples/forward.rs +++ b/examples/forward.rs @@ -92,7 +92,7 @@ async fn main_exec(opt: Opt) { if fd >= 0 { cfg.raw_fd(fd); } else { - cfg.tun_name("utun8") + cfg.tun_name(&opt.interface) .address("10.10.10.2") .destination("10.10.10.1") .mtu(tun2::DEFAULT_MTU); diff --git a/src/runner.rs b/src/runner.rs index 5eb49c5..63b9675 100644 --- a/src/runner.rs +++ b/src/runner.rs @@ -38,4 +38,4 @@ impl Future for BoxFuture<'_, T> { } } -pub type Runner = BoxFuture<'static, ()>; +pub type Runner = BoxFuture<'static, std::io::Result<()>>; diff --git a/src/stack.rs b/src/stack.rs index aca3b6c..cc8a326 100644 --- a/src/stack.rs +++ b/src/stack.rs @@ -119,10 +119,8 @@ impl StackBuilder { // ICMP is handled by TCP's Interface. // smoltcp's interface will always send replies to EchoRequest if self.enable_icmp && !self.enable_tcp { - return Err(std::io::Error::new( - std::io::ErrorKind::InvalidInput, - "Enabling icmp requires enabling tcp", - )); + use std::io::{Error, ErrorKind::InvalidInput}; + return Err(Error::new(InvalidInput, "ICMP requires TCP")); } let icmp_tx = if self.enable_icmp { tcp_tx.clone() @@ -133,7 +131,7 @@ impl StackBuilder { let udp_socket = udp_rx.map(|udp_rx| UdpSocket::new(udp_rx, stack_tx.clone())); let (tcp_runner, tcp_listener) = if let Some(tcp_rx) = tcp_rx { - let (tcp_runner, tcp_listener) = TcpListener::new(tcp_rx, stack_tx); + let (tcp_runner, tcp_listener) = TcpListener::new(tcp_rx, stack_tx)?; (Some(tcp_runner), Some(tcp_listener)) } else { (None, None) diff --git a/src/tcp.rs b/src/tcp.rs index 3e35652..a3ef591 100644 --- a/src/tcp.rs +++ b/src/tcp.rs @@ -78,11 +78,13 @@ impl TcpListenerRunner { Runner::new(async move { let notify = Arc::new(Notify::new()); let (socket_tx, socket_rx) = unbounded_channel::(); - tokio::select! { - _ = Self::handle_packet(notify.clone(), iface_ingress_tx, iface_ingress_tx_avail.clone(), tcp_rx, stream_tx, socket_tx) => {} - _ = Self::handle_socket(notify, device, iface, iface_ingress_tx_avail, sockets, socket_rx) => {} - } + let res = tokio::select! { + v = Self::handle_packet(notify.clone(), iface_ingress_tx, iface_ingress_tx_avail.clone(), tcp_rx, stream_tx, socket_tx) => v, + v = Self::handle_socket(notify, device, iface, iface_ingress_tx_avail, sockets, socket_rx) => v, + }; + res?; trace!("VirtDevice::poll thread exited"); + Ok(()) }) } @@ -93,7 +95,7 @@ impl TcpListenerRunner { mut tcp_rx: Receiver, stream_tx: UnboundedSender, socket_tx: UnboundedSender, - ) { + ) -> std::io::Result<()> { while let Some(frame) = tcp_rx.recv().await { let packet = match IpPacket::new_checked(frame.as_slice()) { Ok(p) => p, @@ -107,7 +109,7 @@ impl TcpListenerRunner { if matches!(packet.protocol(), IpProtocol::Icmp | IpProtocol::Icmpv6) { iface_ingress_tx .send(frame) - .expect("channel already closed"); + .map_err(|e| std::io::Error::new(std::io::ErrorKind::BrokenPipe, e))?; iface_ingress_tx_avail.store(true, Ordering::Release); notify.notify_one(); continue; @@ -165,19 +167,20 @@ impl TcpListenerRunner { notify: notify.clone(), control: control.clone(), }) - .expect("channel already closed"); + .map_err(|e| std::io::Error::new(std::io::ErrorKind::BrokenPipe, e))?; socket_tx .send(TcpSocketCreation { control, socket }) - .expect("channel already closed"); + .map_err(|e| std::io::Error::new(std::io::ErrorKind::BrokenPipe, e))?; } // Pipeline tcp stream packet iface_ingress_tx .send(frame) - .expect("channel already closed"); + .map_err(|e| std::io::Error::new(std::io::ErrorKind::BrokenPipe, e))?; iface_ingress_tx_avail.store(true, Ordering::Release); notify.notify_one(); } + Ok(()) } async fn handle_socket( @@ -187,7 +190,7 @@ impl TcpListenerRunner { iface_ingress_tx_avail: Arc, mut sockets: HashMap, mut socket_rx: UnboundedReceiver, - ) { + ) -> std::io::Result<()> { let mut socket_set = SocketSet::new(vec![]); loop { while let Ok(TcpSocketCreation { control, socket }) = socket_rx.try_recv() { @@ -354,9 +357,9 @@ impl TcpListener { pub(super) fn new( tcp_rx: Receiver, stack_tx: Sender, - ) -> (Runner, Self) { + ) -> std::io::Result<(Runner, Self)> { let (mut device, iface_ingress_tx, iface_ingress_tx_avail) = VirtualDevice::new(stack_tx); - let iface = Self::create_interface(&mut device); + let iface = Self::create_interface(&mut device)?; let (stream_tx, stream_rx) = unbounded_channel(); @@ -370,10 +373,10 @@ impl TcpListener { HashMap::new(), ); - (runner, Self { stream_rx }) + Ok((runner, Self { stream_rx })) } - fn create_interface(device: &mut D) -> Interface + fn create_interface(device: &mut D) -> std::io::Result where D: Device + ?Sized, { @@ -391,13 +394,13 @@ impl TcpListener { iface .routes_mut() .add_default_ipv4_route(Ipv4Address::new(0, 0, 0, 1)) - .expect("IPv4 default route"); + .map_err(|e| std::io::Error::new(std::io::ErrorKind::AddrNotAvailable, e))?; iface .routes_mut() .add_default_ipv6_route(Ipv6Address::new(0, 0, 0, 0, 0, 0, 0, 1)) - .expect("IPv6 default route"); + .map_err(|e| std::io::Error::new(std::io::ErrorKind::AddrNotAvailable, e))?; iface.set_any_ip(true); - iface + Ok(iface) } }