Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix tcp/udp send failed when receivers not ready #10

Merged
merged 1 commit into from
Aug 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 15 additions & 15 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,19 @@ Currently, it works on most targets, but mainly tested the popular platforms whi
// let device = tun2::create_as_async(&cfg)?;
// let framed = device.into_framed();

// let mut builder = StackBuilder::default();
// let (runner, udp_socket, tcp_listener, stack) = builder.build();
// tokio::task::spawn(runner);
let (udp_socket, tcp_listener, stack) = StackBuilder::default().run();
let (stack, runner, udp_socket, tcp_listener) = netstack_smoltcp::StackBuilder::default()
.stack_buffer_size(512)
.tcp_buffer_size(4096)
.enable_udp(true)
.enable_tcp(true)
.enable_icmp(true)
.build()
.unwrap();
let mut udp_socket = udp_socket.unwrap(); // udp enabled
let mut tcp_listener = tcp_listener.unwrap(); // tcp/icmp enabled
if let Some(runner) = runner {
tokio::spawn(runner);
}

let (mut stack_sink, mut stack_stream) = stack.split();
let (mut tun_sink, mut tun_stream) = framed.split();
Expand Down Expand Up @@ -105,18 +114,9 @@ shall be dual licensed as above, without any additional terms or conditions.

## Inspired By

Special thanks to these amazing projects that inspired netstack-smoltcp:
Special thanks to these amazing projects that inspired netstack-smoltcp (in no particular order):
- [shadowsocks-rust](https://github.com/shadowsocks/shadowsocks-rust/)
- [netstack-lwip](https://github.com/eycorsican/netstack-lwip/)
- [rust-tun-active](https://github.com/tun2proxy/rust-tun)
- [rust-tun](https://github.com/meh/rust-tun/)

## Star History

<a href="https://star-history.com/#automesh-network/netstack-smoltcp&Date">
<picture>
<source media="(prefers-color-scheme: dark)" srcset="https://api.star-history.com/svg?repos=automesh-network/netstack-smoltcp&type=Date&theme=dark" />
<source media="(prefers-color-scheme: light)" srcset="https://api.star-history.com/svg?repos=automesh-network/netstack-smoltcp&type=Date" />
<img alt="Star History Chart" src="https://api.star-history.com/svg?repos=automesh-network/netstack-smoltcp&type=Date" />
</picture>
</a>
- [smoltcp](https://github.com/smoltcp-rs/smoltcp)
21 changes: 12 additions & 9 deletions examples/forward.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,28 +96,31 @@ async fn main_exec(opt: Opt) {
.address("10.10.10.2")
.destination("10.10.10.1")
.mtu(tun::DEFAULT_MTU);
#[cfg(not(any(
target_arch = "mips",
target_arch = "mips64",
target_arch = "mipsel",
target_arch = "mipsel64",
)))]
#[cfg(not(any(target_arch = "mips", target_arch = "mips64",)))]
{
cfg.netmask("255.255.255.0");
}
cfg.up();
}

let device = tun::create_as_async(&cfg).unwrap();
let mut builder = StackBuilder::default();
let mut builder = StackBuilder::default()
.enable_tcp(true)
.enable_udp(true)
.enable_icmp(true);
if let Some(device_broadcast) = get_device_broadcast(&device) {
builder = builder
// .add_ip_filter(Box::new(move |src, dst| *src != device_broadcast && *dst != device_broadcast));
.add_ip_filter_fn(move |src, dst| *src != device_broadcast && *dst != device_broadcast);
}

let (runner, udp_socket, tcp_listener, stack) = builder.build();
tokio_spawn!(runner);
let (stack, runner, udp_socket, tcp_listener) = builder.build().unwrap();
let udp_socket = udp_socket.unwrap(); // udp enabled
let tcp_listener = tcp_listener.unwrap(); // tcp enabled or icmp enabled

if let Some(runner) = runner {
tokio_spawn!(runner);
}

let framed = device.into_framed();
let (mut tun_sink, mut tun_stream) = framed.split();
Expand Down
205 changes: 142 additions & 63 deletions src/stack.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ use crate::{
};

pub struct StackBuilder {
enable_udp: bool,
enable_tcp: bool,
enable_icmp: bool,
stack_buffer_size: usize,
udp_buffer_size: usize,
tcp_buffer_size: usize,
Expand All @@ -28,6 +31,9 @@ pub struct StackBuilder {
impl Default for StackBuilder {
fn default() -> Self {
Self {
enable_udp: false,
enable_tcp: false,
enable_icmp: false,
stack_buffer_size: 1024,
udp_buffer_size: 512,
tcp_buffer_size: 512,
Expand All @@ -38,6 +44,21 @@ impl Default for StackBuilder {

#[allow(unused)]
impl StackBuilder {
pub fn enable_udp(mut self, enable: bool) -> Self {
self.enable_udp = enable;
self
}

pub fn enable_tcp(mut self, enable: bool) -> Self {
self.enable_tcp = enable;
self
}

pub fn enable_icmp(mut self, enable: bool) -> Self {
self.enable_icmp = enable;
self
}

pub fn stack_buffer_size(mut self, size: usize) -> Self {
self.stack_buffer_size = size;
self
Expand Down Expand Up @@ -71,45 +92,119 @@ impl StackBuilder {
self
}

pub fn build(self) -> (Runner, UdpSocket, TcpListener, Stack) {
pub fn build(
self,
) -> std::io::Result<(
Stack,
Option<Runner>,
Option<UdpSocket>,
Option<TcpListener>,
)> {
let (stack_tx, stack_rx) = channel(self.stack_buffer_size);
let (udp_tx, udp_rx) = channel(self.udp_buffer_size);
let (tcp_tx, tcp_rx) = channel(self.tcp_buffer_size);

let udp_socket = UdpSocket::new(udp_rx, stack_tx.clone());
let (tcp_runner, tcp_listener) = TcpListener::new(tcp_rx, stack_tx);
let (udp_tx, udp_rx) = if self.enable_udp {
let (udp_tx, udp_rx) = channel(self.udp_buffer_size);
(Some(udp_tx), Some(udp_rx))
} else {
(None, None)
};

let (tcp_tx, tcp_rx) = if self.enable_tcp {
let (tcp_tx, tcp_rx) = channel(self.tcp_buffer_size);
(Some(tcp_tx), Some(tcp_rx))
} else {
(None, None)
};

// ICMP is handled by TCP's Interface.
// smoltcp's interface will always send replies to EchoRequest
if self.enable_icmp && !self.enable_tcp {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"Enabling icmp requires enabling tcp",
));
}
let icmp_tx = if self.enable_icmp {
if let Some(ref tcp_tx) = tcp_tx {
Some(tcp_tx.clone())
} else {
None
}
} else {
None
};

let udp_socket = if let Some(udp_rx) = udp_rx {
Some(UdpSocket::new(udp_rx, stack_tx.clone()))
} else {
None
};

let (tcp_runner, tcp_listener) = if let Some(tcp_rx) = tcp_rx {
let (tcp_runner, tcp_listener) = TcpListener::new(tcp_rx, stack_tx);
(Some(tcp_runner), Some(tcp_listener))
} else {
(None, None)
};

let stack = Stack {
ip_filters: self.ip_filters,
sink_buf: None,
stack_rx,
sink_buf: None,
udp_tx,
tcp_tx,
icmp_tx,
};

(tcp_runner, udp_socket, tcp_listener, stack)
}

pub fn run(self) -> (UdpSocket, TcpListener, Stack) {
let (tcp_runner, udp_socket, tcp_listener, stack) = self.build();
tokio::task::spawn(tcp_runner);
(udp_socket, tcp_listener, stack)
}

pub fn run_local(self) -> (UdpSocket, TcpListener, Stack) {
let (tcp_runner, udp_socket, tcp_listener, stack) = self.build();
tokio::task::spawn_local(tcp_runner);
(udp_socket, tcp_listener, stack)
Ok((stack, tcp_runner, udp_socket, tcp_listener))
}
}

pub struct Stack {
ip_filters: IpFilters<'static>,
sink_buf: Option<AnyIpPktFrame>,
udp_tx: Sender<AnyIpPktFrame>,
tcp_tx: Sender<AnyIpPktFrame>,
sink_buf: Option<(AnyIpPktFrame, IpProtocol)>,
udp_tx: Option<Sender<AnyIpPktFrame>>,
tcp_tx: Option<Sender<AnyIpPktFrame>>,
icmp_tx: Option<Sender<AnyIpPktFrame>>,
stack_rx: Receiver<AnyIpPktFrame>,
}

impl Stack {
fn poll_send(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
let (item, proto) = match self.sink_buf.take() {
Some(val) => val,
None => return Poll::Ready(Ok(())),
};

let ready_res = match proto {
IpProtocol::Tcp => self.tcp_tx.as_mut().map(|tx| tx.try_reserve()),
IpProtocol::Udp => self.udp_tx.as_mut().map(|tx| tx.try_reserve()),
IpProtocol::Icmp | IpProtocol::Icmpv6 => {
self.icmp_tx.as_mut().map(|tx| tx.try_reserve())
}
_ => unreachable!(),
};

let Some(ready_res) = ready_res else {
return Poll::Ready(Ok(()));
};

let permit = match ready_res {
Ok(permit) => permit,
Err(tokio::sync::mpsc::error::TrySendError::Full(_)) => {
self.sink_buf.replace((item, proto));
return Poll::Pending;
}
Err(tokio::sync::mpsc::error::TrySendError::Closed(_)) => {
return Poll::Ready(Err(channel_closed_err("channel is closed")));
}
};

permit.send(item);
Poll::Ready(Ok(()))
}
}

// Recv from stack.
impl Stream for Stack {
type Item = io::Result<AnyIpPktFrame>;
Expand All @@ -127,29 +222,17 @@ impl Stream for Stack {
impl Sink<AnyIpPktFrame> for Stack {
type Error = io::Error;

fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
if self.sink_buf.is_none() {
Poll::Ready(Ok(()))
} else {
self.poll_flush(cx)
Poll::Pending
}
}

fn start_send(mut self: Pin<&mut Self>, item: AnyIpPktFrame) -> Result<(), Self::Error> {
self.sink_buf.replace(item);
Ok(())
}

fn poll_flush(
mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<Result<(), Self::Error>> {
let Some(item) = self.sink_buf.take() else {
return Poll::Ready(Ok(()));
};

if item.is_empty() {
return Poll::Ready(Ok(()));
return Ok(());
}

let packet = IpPacket::new_checked(item.as_slice()).map_err(|err| {
Expand All @@ -170,35 +253,24 @@ impl Sink<AnyIpPktFrame> for Stack {
dst_ip,
addr_allowed,
);
return Poll::Ready(Ok(()));
return Ok(());
}

match packet.protocol() {
IpProtocol::Tcp => {
self.tcp_tx
.try_send(item)
.map_err(|err| io::Error::new(io::ErrorKind::Other, err))?;
Poll::Ready(Ok(()))
}
IpProtocol::Udp => {
self.udp_tx
.try_send(item)
.map_err(|err| io::Error::new(io::ErrorKind::Other, err))?;
Poll::Ready(Ok(()))
}
IpProtocol::Icmp | IpProtocol::Icmpv6 => {
// ICMP is handled by TCP's Interface.
// smoltcp's interface will always send replies to EchoRequest
self.tcp_tx
.try_send(item)
.map_err(|err| io::Error::new(io::ErrorKind::Other, err))?;
Poll::Ready(Ok(()))
}
protocol => {
debug!("tun IP packet ignored (protocol: {:?})", protocol);
Poll::Ready(Ok(()))
}
let protocol = packet.protocol();
if matches!(
protocol,
IpProtocol::Tcp | IpProtocol::Udp | IpProtocol::Icmp | IpProtocol::Icmpv6
) {
self.sink_buf.replace((item, protocol));
} else {
debug!("tun IP packet ignored (protocol: {:?})", protocol);
}

Ok(())
}

fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.poll_send(cx)
}

fn poll_close(
Expand All @@ -209,3 +281,10 @@ impl Sink<AnyIpPktFrame> for Stack {
Poll::Ready(Ok(()))
}
}

fn channel_closed_err<E>(err: E) -> std::io::Error
where
E: Into<Box<dyn std::error::Error + Send + Sync>>,
{
std::io::Error::new(std::io::ErrorKind::BrokenPipe, err)
}