From 46d1092ae6249a2bb790cb43ee939d957e3d9669 Mon Sep 17 00:00:00 2001 From: Vasyl Gello Date: Thu, 18 Jul 2024 11:48:42 +0300 Subject: [PATCH] Try to fix #4 ... by catching TCP RST packets in WritePackets and sending them during the next WritePackets call where no RST packet is being sent Signed-off-by: Vasyl Gello --- src/netstack/yggdrasil.go | 51 +++++++++++++++++++++++++++++++-------- 1 file changed, 41 insertions(+), 10 deletions(-) diff --git a/src/netstack/yggdrasil.go b/src/netstack/yggdrasil.go index 3a582aa..fb282b4 100644 --- a/src/netstack/yggdrasil.go +++ b/src/netstack/yggdrasil.go @@ -1,6 +1,7 @@ package netstack import ( + "container/list" "log" "net" @@ -12,6 +13,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" ) type YggdrasilNIC struct { @@ -20,15 +22,17 @@ type YggdrasilNIC struct { dispatcher stack.NetworkDispatcher readBuf []byte writeBuf []byte + rstPackets *list.List } func (s *YggdrasilNetstack) NewYggdrasilNIC(ygg *core.Core) tcpip.Error { rwc := ipv6rwc.NewReadWriteCloser(ygg) mtu := rwc.MTU() nic := &YggdrasilNIC{ - ipv6rwc: rwc, - readBuf: make([]byte, mtu), - writeBuf: make([]byte, mtu), + ipv6rwc: rwc, + readBuf: make([]byte, mtu), + writeBuf: make([]byte, mtu), + rstPackets: list.New(), } if err := s.stack.CreateNIC(1, nic); err != nil { return err @@ -93,24 +97,51 @@ func (*YggdrasilNIC) LinkAddress() tcpip.LinkAddress { return "" } func (*YggdrasilNIC) Wait() {} +func (e *YggdrasilNIC) writePacket( + pkt *stack.PacketBuffer, +) tcpip.Error { + vv := pkt.ToView() + n, err := vv.Read(e.writeBuf) + if err != nil { + return &tcpip.ErrAborted{} + } + _, err = e.ipv6rwc.Write(e.writeBuf[:n]) + if err != nil { + return &tcpip.ErrAborted{} + } + return nil +} + func (e *YggdrasilNIC) WritePackets( list stack.PacketBufferList, ) (int, tcpip.Error) { var i int = 0 + var err tcpip.Error = nil + var rstCaught = false for i, pkt := range list.AsSlice() { - vv := pkt.ToView() - n, err := vv.Read(e.writeBuf) - if err != nil { - log.Println(err) - return i - 1, &tcpip.ErrAborted{} + if pkt.Network().TransportProtocol() == tcp.ProtocolNumber { + tcpHeader := header.TCP(pkt.TransportHeader().Slice()) + if (tcpHeader.Flags() & header.TCPFlagRst) == header.TCPFlagRst { + e.rstPackets.PushFront(pkt) + rstCaught = true + continue + } } - _, err = e.ipv6rwc.Write(e.writeBuf[:n]) + err = e.writePacket(pkt) if err != nil { log.Println(err) - return i - 1, &tcpip.ErrAborted{} + return i - 1, err } } + if !rstCaught { + for rstPkt := e.rstPackets.Front(); rstPkt != nil; rstPkt = rstPkt.Next() { + _ = e.writePacket(rstPkt.Value.(*stack.PacketBuffer)) + } + + _ = e.rstPackets.Init() + } + return i, nil }