diff --git a/README.md b/README.md
index 0901118..e3149aa 100644
--- a/README.md
+++ b/README.md
@@ -48,10 +48,19 @@ Currently, it works on most targets, but mainly tested the popular platforms whi
// let device = tun2::create_as_async(&cfg)?;
// let framed = device.into_framed();
-// let mut builder = StackBuilder::default();
-// let (runner, udp_socket, tcp_listener, stack) = builder.build();
-// tokio::task::spawn(runner);
-let (udp_socket, tcp_listener, stack) = StackBuilder::default().run();
+let (stack, runner, udp_socket, tcp_listener) = netstack_smoltcp::StackBuilder::default()
+ .stack_buffer_size(512)
+ .tcp_buffer_size(4096)
+ .enable_udp(true)
+ .enable_tcp(true)
+ .enable_icmp(true)
+ .build()
+ .unwrap();
+let mut udp_socket = udp_socket.unwrap(); // udp enabled
+let mut tcp_listener = tcp_listener.unwrap(); // tcp/icmp enabled
+if let Some(runner) = runner {
+ tokio::spawn(runner);
+}
let (mut stack_sink, mut stack_stream) = stack.split();
let (mut tun_sink, mut tun_stream) = framed.split();
@@ -105,18 +114,9 @@ shall be dual licensed as above, without any additional terms or conditions.
## Inspired By
-Special thanks to these amazing projects that inspired netstack-smoltcp:
+Special thanks to these amazing projects that inspired netstack-smoltcp (in no particular order):
- [shadowsocks-rust](https://github.com/shadowsocks/shadowsocks-rust/)
- [netstack-lwip](https://github.com/eycorsican/netstack-lwip/)
- [rust-tun-active](https://github.com/tun2proxy/rust-tun)
- [rust-tun](https://github.com/meh/rust-tun/)
-
-## Star History
-
-
-
-
+- [smoltcp](https://github.com/smoltcp-rs/smoltcp)
diff --git a/examples/forward.rs b/examples/forward.rs
index 20bf9cd..fa5b257 100644
--- a/examples/forward.rs
+++ b/examples/forward.rs
@@ -96,12 +96,7 @@ async fn main_exec(opt: Opt) {
.address("10.10.10.2")
.destination("10.10.10.1")
.mtu(tun::DEFAULT_MTU);
- #[cfg(not(any(
- target_arch = "mips",
- target_arch = "mips64",
- target_arch = "mipsel",
- target_arch = "mipsel64",
- )))]
+ #[cfg(not(any(target_arch = "mips", target_arch = "mips64",)))]
{
cfg.netmask("255.255.255.0");
}
@@ -109,15 +104,23 @@ async fn main_exec(opt: Opt) {
}
let device = tun::create_as_async(&cfg).unwrap();
- let mut builder = StackBuilder::default();
+ let mut builder = StackBuilder::default()
+ .enable_tcp(true)
+ .enable_udp(true)
+ .enable_icmp(true);
if let Some(device_broadcast) = get_device_broadcast(&device) {
builder = builder
// .add_ip_filter(Box::new(move |src, dst| *src != device_broadcast && *dst != device_broadcast));
.add_ip_filter_fn(move |src, dst| *src != device_broadcast && *dst != device_broadcast);
}
- let (runner, udp_socket, tcp_listener, stack) = builder.build();
- tokio_spawn!(runner);
+ let (stack, runner, udp_socket, tcp_listener) = builder.build().unwrap();
+ let udp_socket = udp_socket.unwrap(); // udp enabled
+ let tcp_listener = tcp_listener.unwrap(); // tcp enabled or icmp enabled
+
+ if let Some(runner) = runner {
+ tokio_spawn!(runner);
+ }
let framed = device.into_framed();
let (mut tun_sink, mut tun_stream) = framed.split();
diff --git a/src/stack.rs b/src/stack.rs
index 1d608a8..ad689a1 100644
--- a/src/stack.rs
+++ b/src/stack.rs
@@ -19,6 +19,9 @@ use crate::{
};
pub struct StackBuilder {
+ enable_udp: bool,
+ enable_tcp: bool,
+ enable_icmp: bool,
stack_buffer_size: usize,
udp_buffer_size: usize,
tcp_buffer_size: usize,
@@ -28,6 +31,9 @@ pub struct StackBuilder {
impl Default for StackBuilder {
fn default() -> Self {
Self {
+ enable_udp: false,
+ enable_tcp: false,
+ enable_icmp: false,
stack_buffer_size: 1024,
udp_buffer_size: 512,
tcp_buffer_size: 512,
@@ -38,6 +44,21 @@ impl Default for StackBuilder {
#[allow(unused)]
impl StackBuilder {
+ pub fn enable_udp(mut self, enable: bool) -> Self {
+ self.enable_udp = enable;
+ self
+ }
+
+ pub fn enable_tcp(mut self, enable: bool) -> Self {
+ self.enable_tcp = enable;
+ self
+ }
+
+ pub fn enable_icmp(mut self, enable: bool) -> Self {
+ self.enable_icmp = enable;
+ self
+ }
+
pub fn stack_buffer_size(mut self, size: usize) -> Self {
self.stack_buffer_size = size;
self
@@ -71,45 +92,119 @@ impl StackBuilder {
self
}
- pub fn build(self) -> (Runner, UdpSocket, TcpListener, Stack) {
+ pub fn build(
+ self,
+ ) -> std::io::Result<(
+ Stack,
+ Option,
+ Option,
+ Option,
+ )> {
let (stack_tx, stack_rx) = channel(self.stack_buffer_size);
- let (udp_tx, udp_rx) = channel(self.udp_buffer_size);
- let (tcp_tx, tcp_rx) = channel(self.tcp_buffer_size);
- let udp_socket = UdpSocket::new(udp_rx, stack_tx.clone());
- let (tcp_runner, tcp_listener) = TcpListener::new(tcp_rx, stack_tx);
+ let (udp_tx, udp_rx) = if self.enable_udp {
+ let (udp_tx, udp_rx) = channel(self.udp_buffer_size);
+ (Some(udp_tx), Some(udp_rx))
+ } else {
+ (None, None)
+ };
+
+ let (tcp_tx, tcp_rx) = if self.enable_tcp {
+ let (tcp_tx, tcp_rx) = channel(self.tcp_buffer_size);
+ (Some(tcp_tx), Some(tcp_rx))
+ } else {
+ (None, None)
+ };
+
+ // ICMP is handled by TCP's Interface.
+ // smoltcp's interface will always send replies to EchoRequest
+ if self.enable_icmp && !self.enable_tcp {
+ return Err(std::io::Error::new(
+ std::io::ErrorKind::InvalidInput,
+ "Enabling icmp requires enabling tcp",
+ ));
+ }
+ let icmp_tx = if self.enable_icmp {
+ if let Some(ref tcp_tx) = tcp_tx {
+ Some(tcp_tx.clone())
+ } else {
+ None
+ }
+ } else {
+ None
+ };
+
+ let udp_socket = if let Some(udp_rx) = udp_rx {
+ Some(UdpSocket::new(udp_rx, stack_tx.clone()))
+ } else {
+ None
+ };
+
+ let (tcp_runner, tcp_listener) = if let Some(tcp_rx) = tcp_rx {
+ let (tcp_runner, tcp_listener) = TcpListener::new(tcp_rx, stack_tx);
+ (Some(tcp_runner), Some(tcp_listener))
+ } else {
+ (None, None)
+ };
+
let stack = Stack {
ip_filters: self.ip_filters,
- sink_buf: None,
stack_rx,
+ sink_buf: None,
udp_tx,
tcp_tx,
+ icmp_tx,
};
- (tcp_runner, udp_socket, tcp_listener, stack)
- }
-
- pub fn run(self) -> (UdpSocket, TcpListener, Stack) {
- let (tcp_runner, udp_socket, tcp_listener, stack) = self.build();
- tokio::task::spawn(tcp_runner);
- (udp_socket, tcp_listener, stack)
- }
-
- pub fn run_local(self) -> (UdpSocket, TcpListener, Stack) {
- let (tcp_runner, udp_socket, tcp_listener, stack) = self.build();
- tokio::task::spawn_local(tcp_runner);
- (udp_socket, tcp_listener, stack)
+ Ok((stack, tcp_runner, udp_socket, tcp_listener))
}
}
pub struct Stack {
ip_filters: IpFilters<'static>,
- sink_buf: Option,
- udp_tx: Sender,
- tcp_tx: Sender,
+ sink_buf: Option<(AnyIpPktFrame, IpProtocol)>,
+ udp_tx: Option>,
+ tcp_tx: Option>,
+ icmp_tx: Option>,
stack_rx: Receiver,
}
+impl Stack {
+ fn poll_send(&mut self, _cx: &mut Context<'_>) -> Poll> {
+ let (item, proto) = match self.sink_buf.take() {
+ Some(val) => val,
+ None => return Poll::Ready(Ok(())),
+ };
+
+ let ready_res = match proto {
+ IpProtocol::Tcp => self.tcp_tx.as_mut().map(|tx| tx.try_reserve()),
+ IpProtocol::Udp => self.udp_tx.as_mut().map(|tx| tx.try_reserve()),
+ IpProtocol::Icmp | IpProtocol::Icmpv6 => {
+ self.icmp_tx.as_mut().map(|tx| tx.try_reserve())
+ }
+ _ => unreachable!(),
+ };
+
+ let Some(ready_res) = ready_res else {
+ return Poll::Ready(Ok(()));
+ };
+
+ let permit = match ready_res {
+ Ok(permit) => permit,
+ Err(tokio::sync::mpsc::error::TrySendError::Full(_)) => {
+ self.sink_buf.replace((item, proto));
+ return Poll::Pending;
+ }
+ Err(tokio::sync::mpsc::error::TrySendError::Closed(_)) => {
+ return Poll::Ready(Err(channel_closed_err("channel is closed")));
+ }
+ };
+
+ permit.send(item);
+ Poll::Ready(Ok(()))
+ }
+}
+
// Recv from stack.
impl Stream for Stack {
type Item = io::Result;
@@ -127,29 +222,17 @@ impl Stream for Stack {
impl Sink for Stack {
type Error = io::Error;
- fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> {
+ fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> {
if self.sink_buf.is_none() {
Poll::Ready(Ok(()))
} else {
- self.poll_flush(cx)
+ Poll::Pending
}
}
fn start_send(mut self: Pin<&mut Self>, item: AnyIpPktFrame) -> Result<(), Self::Error> {
- self.sink_buf.replace(item);
- Ok(())
- }
-
- fn poll_flush(
- mut self: Pin<&mut Self>,
- _cx: &mut Context<'_>,
- ) -> Poll> {
- let Some(item) = self.sink_buf.take() else {
- return Poll::Ready(Ok(()));
- };
-
if item.is_empty() {
- return Poll::Ready(Ok(()));
+ return Ok(());
}
let packet = IpPacket::new_checked(item.as_slice()).map_err(|err| {
@@ -170,35 +253,24 @@ impl Sink for Stack {
dst_ip,
addr_allowed,
);
- return Poll::Ready(Ok(()));
+ return Ok(());
}
- match packet.protocol() {
- IpProtocol::Tcp => {
- self.tcp_tx
- .try_send(item)
- .map_err(|err| io::Error::new(io::ErrorKind::Other, err))?;
- Poll::Ready(Ok(()))
- }
- IpProtocol::Udp => {
- self.udp_tx
- .try_send(item)
- .map_err(|err| io::Error::new(io::ErrorKind::Other, err))?;
- Poll::Ready(Ok(()))
- }
- IpProtocol::Icmp | IpProtocol::Icmpv6 => {
- // ICMP is handled by TCP's Interface.
- // smoltcp's interface will always send replies to EchoRequest
- self.tcp_tx
- .try_send(item)
- .map_err(|err| io::Error::new(io::ErrorKind::Other, err))?;
- Poll::Ready(Ok(()))
- }
- protocol => {
- debug!("tun IP packet ignored (protocol: {:?})", protocol);
- Poll::Ready(Ok(()))
- }
+ let protocol = packet.protocol();
+ if matches!(
+ protocol,
+ IpProtocol::Tcp | IpProtocol::Udp | IpProtocol::Icmp | IpProtocol::Icmpv6
+ ) {
+ self.sink_buf.replace((item, protocol));
+ } else {
+ debug!("tun IP packet ignored (protocol: {:?})", protocol);
}
+
+ Ok(())
+ }
+
+ fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> {
+ self.poll_send(cx)
}
fn poll_close(
@@ -209,3 +281,10 @@ impl Sink for Stack {
Poll::Ready(Ok(()))
}
}
+
+fn channel_closed_err(err: E) -> std::io::Error
+where
+ E: Into>,
+{
+ std::io::Error::new(std::io::ErrorKind::BrokenPipe, err)
+}