diff --git a/README.md b/README.md index 0901118..e3149aa 100644 --- a/README.md +++ b/README.md @@ -48,10 +48,19 @@ Currently, it works on most targets, but mainly tested the popular platforms whi // let device = tun2::create_as_async(&cfg)?; // let framed = device.into_framed(); -// let mut builder = StackBuilder::default(); -// let (runner, udp_socket, tcp_listener, stack) = builder.build(); -// tokio::task::spawn(runner); -let (udp_socket, tcp_listener, stack) = StackBuilder::default().run(); +let (stack, runner, udp_socket, tcp_listener) = netstack_smoltcp::StackBuilder::default() + .stack_buffer_size(512) + .tcp_buffer_size(4096) + .enable_udp(true) + .enable_tcp(true) + .enable_icmp(true) + .build() + .unwrap(); +let mut udp_socket = udp_socket.unwrap(); // udp enabled +let mut tcp_listener = tcp_listener.unwrap(); // tcp/icmp enabled +if let Some(runner) = runner { + tokio::spawn(runner); +} let (mut stack_sink, mut stack_stream) = stack.split(); let (mut tun_sink, mut tun_stream) = framed.split(); @@ -105,18 +114,9 @@ shall be dual licensed as above, without any additional terms or conditions. ## Inspired By -Special thanks to these amazing projects that inspired netstack-smoltcp: +Special thanks to these amazing projects that inspired netstack-smoltcp (in no particular order): - [shadowsocks-rust](https://github.com/shadowsocks/shadowsocks-rust/) - [netstack-lwip](https://github.com/eycorsican/netstack-lwip/) - [rust-tun-active](https://github.com/tun2proxy/rust-tun) - [rust-tun](https://github.com/meh/rust-tun/) - -## Star History - - - - - - Star History Chart - - +- [smoltcp](https://github.com/smoltcp-rs/smoltcp) diff --git a/examples/forward.rs b/examples/forward.rs index 20bf9cd..fa5b257 100644 --- a/examples/forward.rs +++ b/examples/forward.rs @@ -96,12 +96,7 @@ async fn main_exec(opt: Opt) { .address("10.10.10.2") .destination("10.10.10.1") .mtu(tun::DEFAULT_MTU); - #[cfg(not(any( - target_arch = "mips", - target_arch = "mips64", - target_arch = "mipsel", - target_arch = "mipsel64", - )))] + #[cfg(not(any(target_arch = "mips", target_arch = "mips64",)))] { cfg.netmask("255.255.255.0"); } @@ -109,15 +104,23 @@ async fn main_exec(opt: Opt) { } let device = tun::create_as_async(&cfg).unwrap(); - let mut builder = StackBuilder::default(); + let mut builder = StackBuilder::default() + .enable_tcp(true) + .enable_udp(true) + .enable_icmp(true); if let Some(device_broadcast) = get_device_broadcast(&device) { builder = builder // .add_ip_filter(Box::new(move |src, dst| *src != device_broadcast && *dst != device_broadcast)); .add_ip_filter_fn(move |src, dst| *src != device_broadcast && *dst != device_broadcast); } - let (runner, udp_socket, tcp_listener, stack) = builder.build(); - tokio_spawn!(runner); + let (stack, runner, udp_socket, tcp_listener) = builder.build().unwrap(); + let udp_socket = udp_socket.unwrap(); // udp enabled + let tcp_listener = tcp_listener.unwrap(); // tcp enabled or icmp enabled + + if let Some(runner) = runner { + tokio_spawn!(runner); + } let framed = device.into_framed(); let (mut tun_sink, mut tun_stream) = framed.split(); diff --git a/src/stack.rs b/src/stack.rs index 1d608a8..ad689a1 100644 --- a/src/stack.rs +++ b/src/stack.rs @@ -19,6 +19,9 @@ use crate::{ }; pub struct StackBuilder { + enable_udp: bool, + enable_tcp: bool, + enable_icmp: bool, stack_buffer_size: usize, udp_buffer_size: usize, tcp_buffer_size: usize, @@ -28,6 +31,9 @@ pub struct StackBuilder { impl Default for StackBuilder { fn default() -> Self { Self { + enable_udp: false, + enable_tcp: false, + enable_icmp: false, stack_buffer_size: 1024, udp_buffer_size: 512, tcp_buffer_size: 512, @@ -38,6 +44,21 @@ impl Default for StackBuilder { #[allow(unused)] impl StackBuilder { + pub fn enable_udp(mut self, enable: bool) -> Self { + self.enable_udp = enable; + self + } + + pub fn enable_tcp(mut self, enable: bool) -> Self { + self.enable_tcp = enable; + self + } + + pub fn enable_icmp(mut self, enable: bool) -> Self { + self.enable_icmp = enable; + self + } + pub fn stack_buffer_size(mut self, size: usize) -> Self { self.stack_buffer_size = size; self @@ -71,45 +92,119 @@ impl StackBuilder { self } - pub fn build(self) -> (Runner, UdpSocket, TcpListener, Stack) { + pub fn build( + self, + ) -> std::io::Result<( + Stack, + Option, + Option, + Option, + )> { let (stack_tx, stack_rx) = channel(self.stack_buffer_size); - let (udp_tx, udp_rx) = channel(self.udp_buffer_size); - let (tcp_tx, tcp_rx) = channel(self.tcp_buffer_size); - let udp_socket = UdpSocket::new(udp_rx, stack_tx.clone()); - let (tcp_runner, tcp_listener) = TcpListener::new(tcp_rx, stack_tx); + let (udp_tx, udp_rx) = if self.enable_udp { + let (udp_tx, udp_rx) = channel(self.udp_buffer_size); + (Some(udp_tx), Some(udp_rx)) + } else { + (None, None) + }; + + let (tcp_tx, tcp_rx) = if self.enable_tcp { + let (tcp_tx, tcp_rx) = channel(self.tcp_buffer_size); + (Some(tcp_tx), Some(tcp_rx)) + } else { + (None, None) + }; + + // ICMP is handled by TCP's Interface. + // smoltcp's interface will always send replies to EchoRequest + if self.enable_icmp && !self.enable_tcp { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidInput, + "Enabling icmp requires enabling tcp", + )); + } + let icmp_tx = if self.enable_icmp { + if let Some(ref tcp_tx) = tcp_tx { + Some(tcp_tx.clone()) + } else { + None + } + } else { + None + }; + + let udp_socket = if let Some(udp_rx) = udp_rx { + Some(UdpSocket::new(udp_rx, stack_tx.clone())) + } else { + None + }; + + let (tcp_runner, tcp_listener) = if let Some(tcp_rx) = tcp_rx { + let (tcp_runner, tcp_listener) = TcpListener::new(tcp_rx, stack_tx); + (Some(tcp_runner), Some(tcp_listener)) + } else { + (None, None) + }; + let stack = Stack { ip_filters: self.ip_filters, - sink_buf: None, stack_rx, + sink_buf: None, udp_tx, tcp_tx, + icmp_tx, }; - (tcp_runner, udp_socket, tcp_listener, stack) - } - - pub fn run(self) -> (UdpSocket, TcpListener, Stack) { - let (tcp_runner, udp_socket, tcp_listener, stack) = self.build(); - tokio::task::spawn(tcp_runner); - (udp_socket, tcp_listener, stack) - } - - pub fn run_local(self) -> (UdpSocket, TcpListener, Stack) { - let (tcp_runner, udp_socket, tcp_listener, stack) = self.build(); - tokio::task::spawn_local(tcp_runner); - (udp_socket, tcp_listener, stack) + Ok((stack, tcp_runner, udp_socket, tcp_listener)) } } pub struct Stack { ip_filters: IpFilters<'static>, - sink_buf: Option, - udp_tx: Sender, - tcp_tx: Sender, + sink_buf: Option<(AnyIpPktFrame, IpProtocol)>, + udp_tx: Option>, + tcp_tx: Option>, + icmp_tx: Option>, stack_rx: Receiver, } +impl Stack { + fn poll_send(&mut self, _cx: &mut Context<'_>) -> Poll> { + let (item, proto) = match self.sink_buf.take() { + Some(val) => val, + None => return Poll::Ready(Ok(())), + }; + + let ready_res = match proto { + IpProtocol::Tcp => self.tcp_tx.as_mut().map(|tx| tx.try_reserve()), + IpProtocol::Udp => self.udp_tx.as_mut().map(|tx| tx.try_reserve()), + IpProtocol::Icmp | IpProtocol::Icmpv6 => { + self.icmp_tx.as_mut().map(|tx| tx.try_reserve()) + } + _ => unreachable!(), + }; + + let Some(ready_res) = ready_res else { + return Poll::Ready(Ok(())); + }; + + let permit = match ready_res { + Ok(permit) => permit, + Err(tokio::sync::mpsc::error::TrySendError::Full(_)) => { + self.sink_buf.replace((item, proto)); + return Poll::Pending; + } + Err(tokio::sync::mpsc::error::TrySendError::Closed(_)) => { + return Poll::Ready(Err(channel_closed_err("channel is closed"))); + } + }; + + permit.send(item); + Poll::Ready(Ok(())) + } +} + // Recv from stack. impl Stream for Stack { type Item = io::Result; @@ -127,29 +222,17 @@ impl Stream for Stack { impl Sink for Stack { type Error = io::Error; - fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { if self.sink_buf.is_none() { Poll::Ready(Ok(())) } else { - self.poll_flush(cx) + Poll::Pending } } fn start_send(mut self: Pin<&mut Self>, item: AnyIpPktFrame) -> Result<(), Self::Error> { - self.sink_buf.replace(item); - Ok(()) - } - - fn poll_flush( - mut self: Pin<&mut Self>, - _cx: &mut Context<'_>, - ) -> Poll> { - let Some(item) = self.sink_buf.take() else { - return Poll::Ready(Ok(())); - }; - if item.is_empty() { - return Poll::Ready(Ok(())); + return Ok(()); } let packet = IpPacket::new_checked(item.as_slice()).map_err(|err| { @@ -170,35 +253,24 @@ impl Sink for Stack { dst_ip, addr_allowed, ); - return Poll::Ready(Ok(())); + return Ok(()); } - match packet.protocol() { - IpProtocol::Tcp => { - self.tcp_tx - .try_send(item) - .map_err(|err| io::Error::new(io::ErrorKind::Other, err))?; - Poll::Ready(Ok(())) - } - IpProtocol::Udp => { - self.udp_tx - .try_send(item) - .map_err(|err| io::Error::new(io::ErrorKind::Other, err))?; - Poll::Ready(Ok(())) - } - IpProtocol::Icmp | IpProtocol::Icmpv6 => { - // ICMP is handled by TCP's Interface. - // smoltcp's interface will always send replies to EchoRequest - self.tcp_tx - .try_send(item) - .map_err(|err| io::Error::new(io::ErrorKind::Other, err))?; - Poll::Ready(Ok(())) - } - protocol => { - debug!("tun IP packet ignored (protocol: {:?})", protocol); - Poll::Ready(Ok(())) - } + let protocol = packet.protocol(); + if matches!( + protocol, + IpProtocol::Tcp | IpProtocol::Udp | IpProtocol::Icmp | IpProtocol::Icmpv6 + ) { + self.sink_buf.replace((item, protocol)); + } else { + debug!("tun IP packet ignored (protocol: {:?})", protocol); } + + Ok(()) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.poll_send(cx) } fn poll_close( @@ -209,3 +281,10 @@ impl Sink for Stack { Poll::Ready(Ok(())) } } + +fn channel_closed_err(err: E) -> std::io::Error +where + E: Into>, +{ + std::io::Error::new(std::io::ErrorKind::BrokenPipe, err) +}