diff --git a/inside.go b/inside.go index 9250b5e5a..ff9e80b6e 100644 --- a/inside.go +++ b/inside.go @@ -83,6 +83,10 @@ func (f *Interface) rejectInside(packet []byte, out []byte, q int) { } out = iputil.CreateRejectPacket(packet, out) + if len(out) == 0 { + return + } + _, err := f.readers[q].Write(out) if err != nil { f.l.WithError(err).Error("Failed to write to tun") @@ -94,12 +98,22 @@ func (f *Interface) rejectOutside(packet []byte, ci *ConnectionState, hostinfo * return } - // Use some out buffer space to build the packet before encryption - // Need 40 bytes for the reject packet (20 byte ipv4 header, 20 byte tcp rst packet) - // Leave 100 bytes for the encrypted packet (60 byte Nebula header, 40 byte reject packet) - out = out[:140] - outPacket := iputil.CreateRejectPacket(packet, out[100:]) - f.sendNoMetrics(header.Message, 0, ci, hostinfo, nil, outPacket, nb, out, q) + out = iputil.CreateRejectPacket(packet, out) + if len(out) == 0 { + return + } + + if len(out) > iputil.MaxRejectPacketSize { + if f.l.GetLevel() >= logrus.InfoLevel { + f.l. + WithField("packet", packet). + WithField("outPacket", out). + Info("rejectOutside: packet too big, not sending") + } + return + } + + f.sendNoMetrics(header.Message, 0, ci, hostinfo, nil, out, nb, packet, q) } func (f *Interface) Handshake(vpnIp iputil.VpnIp) { diff --git a/iputil/packet.go b/iputil/packet.go index 74ae37f09..b18e52447 100644 --- a/iputil/packet.go +++ b/iputil/packet.go @@ -6,8 +6,19 @@ import ( "golang.org/x/net/ipv4" ) +const ( + // Need 96 bytes for the largest reject packet: + // - 20 byte ipv4 header + // - 8 byte icmpv4 header + // - 68 byte body (60 byte max orig ipv4 header + 8 byte orig icmpv4 header) + MaxRejectPacketSize = ipv4.HeaderLen + 8 + 60 + 8 +) + func CreateRejectPacket(packet []byte, out []byte) []byte { - // TODO ipv4 only, need to fix when inside supports ipv6 + if len(packet) < ipv4.HeaderLen || int(packet[0]>>4) != ipv4.Version { + return nil + } + switch packet[9] { case 6: // tcp return ipv4CreateRejectTCPPacket(packet, out) @@ -19,20 +30,28 @@ func CreateRejectPacket(packet []byte, out []byte) []byte { func ipv4CreateRejectICMPPacket(packet []byte, out []byte) []byte { ihl := int(packet[0]&0x0f) << 2 - // ICMP reply includes header and first 8 bytes of the packet + if len(packet) < ihl { + // We need at least this many bytes for this to be a valid packet + return nil + } + + // ICMP reply includes original header and first 8 bytes of the packet packetLen := len(packet) if packetLen > ihl+8 { packetLen = ihl + 8 } outLen := ipv4.HeaderLen + 8 + packetLen + if outLen > cap(out) { + return nil + } - out = out[:(outLen)] + out = out[:outLen] ipHdr := out[0:ipv4.HeaderLen] - ipHdr[0] = ipv4.Version<<4 | (ipv4.HeaderLen >> 2) // version, ihl - ipHdr[1] = 0 // DSCP, ECN - binary.BigEndian.PutUint16(ipHdr[2:], uint16(ipv4.HeaderLen+8+packetLen)) // Total Length + ipHdr[0] = ipv4.Version<<4 | (ipv4.HeaderLen >> 2) // version, ihl + ipHdr[1] = 0 // DSCP, ECN + binary.BigEndian.PutUint16(ipHdr[2:], uint16(outLen)) // Total Length ipHdr[4] = 0 // id ipHdr[5] = 0 // . @@ -76,7 +95,15 @@ func ipv4CreateRejectTCPPacket(packet []byte, out []byte) []byte { ihl := int(packet[0]&0x0f) << 2 outLen := ipv4.HeaderLen + tcpLen - out = out[:(outLen)] + if len(packet) < ihl+tcpLen { + // We need at least this many bytes for this to be a valid packet + return nil + } + if outLen > cap(out) { + return nil + } + + out = out[:outLen] ipHdr := out[0:ipv4.HeaderLen] ipHdr[0] = ipv4.Version<<4 | (ipv4.HeaderLen >> 2) // version, ihl diff --git a/iputil/packet_test.go b/iputil/packet_test.go new file mode 100644 index 000000000..e1d0d95d8 --- /dev/null +++ b/iputil/packet_test.go @@ -0,0 +1,73 @@ +package iputil + +import ( + "net" + "testing" + + "github.com/stretchr/testify/assert" + "golang.org/x/net/ipv4" +) + +func Test_CreateRejectPacket(t *testing.T) { + h := ipv4.Header{ + Len: 20, + Src: net.IPv4(10, 0, 0, 1), + Dst: net.IPv4(10, 0, 0, 2), + Protocol: 1, // ICMP + } + + b, err := h.Marshal() + if err != nil { + t.Fatalf("h.Marhshal: %v", err) + } + b = append(b, []byte{0, 3, 0, 4}...) + + expectedLen := ipv4.HeaderLen + 8 + h.Len + 4 + out := make([]byte, expectedLen) + rejectPacket := CreateRejectPacket(b, out) + assert.NotNil(t, rejectPacket) + assert.Len(t, rejectPacket, expectedLen) + + // ICMP with max header len + h = ipv4.Header{ + Len: 60, + Src: net.IPv4(10, 0, 0, 1), + Dst: net.IPv4(10, 0, 0, 2), + Protocol: 1, // ICMP + Options: make([]byte, 40), + } + + b, err = h.Marshal() + if err != nil { + t.Fatalf("h.Marhshal: %v", err) + } + b = append(b, []byte{0, 3, 0, 4, 0, 0, 0, 0}...) + + expectedLen = MaxRejectPacketSize + out = make([]byte, MaxRejectPacketSize) + rejectPacket = CreateRejectPacket(b, out) + assert.NotNil(t, rejectPacket) + assert.Len(t, rejectPacket, expectedLen) + + // TCP with max header len + h = ipv4.Header{ + Len: 60, + Src: net.IPv4(10, 0, 0, 1), + Dst: net.IPv4(10, 0, 0, 2), + Protocol: 6, // TCP + Options: make([]byte, 40), + } + + b, err = h.Marshal() + if err != nil { + t.Fatalf("h.Marhshal: %v", err) + } + b = append(b, []byte{0, 3, 0, 4}...) + b = append(b, make([]byte, 16)...) + + expectedLen = ipv4.HeaderLen + 20 + out = make([]byte, expectedLen) + rejectPacket = CreateRejectPacket(b, out) + assert.NotNil(t, rejectPacket) + assert.Len(t, rejectPacket, expectedLen) +} diff --git a/outside.go b/outside.go index 4139830b2..29189110c 100644 --- a/outside.go +++ b/outside.go @@ -406,7 +406,9 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out dropReason := f.firewall.Drop(out, *fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache) if dropReason != nil { - f.rejectOutside(out, hostinfo.ConnectionState, hostinfo, nb, out, q) + // NOTE: We give `packet` as the `out` here since we already decrypted from it and we don't need it anymore + // This gives us a buffer to build the reject packet in + f.rejectOutside(out, hostinfo.ConnectionState, hostinfo, nb, packet, q) if f.l.Level >= logrus.DebugLevel { hostinfo.logger(f.l).WithField("fwPacket", fwPacket). WithField("reason", dropReason).