Skip to content

Commit

Permalink
bring up multi-ip tun on linux
Browse files Browse the repository at this point in the history
  • Loading branch information
JackDoanRivian committed Sep 20, 2024
1 parent 735edd0 commit 7d41f3f
Showing 1 changed file with 97 additions and 63 deletions.
160 changes: 97 additions & 63 deletions overlay/tun_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ type tun struct {
io.ReadWriteCloser
fd int
Device string
cidr netip.Prefix
vpnNetworks []netip.Prefix
MaxMTU int
DefaultMTU int
TXQueueLen int
Expand All @@ -40,18 +40,16 @@ type tun struct {
l *logrus.Logger
}

func (t *tun) Networks() []netip.Prefix {
return t.vpnNetworks
}

type ifReq struct {
Name [16]byte
Flags uint16
pad [8]byte
}

type ifreqAddr struct {
Name [16]byte
Addr unix.RawSockaddrInet4
pad [8]byte
}

type ifreqMTU struct {
Name [16]byte
MTU int32
Expand All @@ -64,10 +62,10 @@ type ifreqQLEN struct {
pad [8]byte
}

func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix) (*tun, error) {
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")

t, err := newTunGeneric(c, l, file, cidr)
t, err := newTunGeneric(c, l, file, vpnNetworks)
if err != nil {
return nil, err
}
Expand All @@ -77,7 +75,7 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix
return t, nil
}

func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, multiqueue bool) (*tun, error) {
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueue bool) (*tun, error) {
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
if err != nil {
// If /dev/net/tun doesn't exist, try to create it (will happen in docker)
Expand Down Expand Up @@ -112,7 +110,7 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, multiqueue bool) (
name := strings.Trim(string(req.Name[:]), "\x00")

file := os.NewFile(uintptr(fd), "/dev/net/tun")
t, err := newTunGeneric(c, l, file, cidr)
t, err := newTunGeneric(c, l, file, vpnNetworks)
if err != nil {
return nil, err
}
Expand All @@ -122,11 +120,11 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, multiqueue bool) (
return t, nil
}

func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, cidr netip.Prefix) (*tun, error) {
func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []netip.Prefix) (*tun, error) {
t := &tun{
ReadWriteCloser: file,
fd: int(file.Fd()),
cidr: cidr,
vpnNetworks: vpnNetworks,
TXQueueLen: c.GetInt("tun.tx_queue", 500),
useSystemRoutes: c.GetBool("tun.use_system_route_table", false),
l: l,
Expand All @@ -148,7 +146,7 @@ func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, cidr netip.Pref
}

func (t *tun) reload(c *config.C, initial bool) error {
routeChange, routes, err := getAllRoutesFromConfig(c, t.cidr, initial)
routeChange, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
if err != nil {
return err
}
Expand Down Expand Up @@ -190,11 +188,13 @@ func (t *tun) reload(c *config.C, initial bool) error {
}

if oldDefaultMTU != newDefaultMTU {
err := t.setDefaultRoute()
if err != nil {
t.l.Warn(err)
} else {
t.l.Infof("Set default MTU to %v was %v", t.DefaultMTU, oldDefaultMTU)
for i := range t.vpnNetworks {
err := t.setDefaultRoute(t.vpnNetworks[i])
if err != nil {
t.l.Warn(err)
} else {
t.l.Infof("Set default MTU to %v was %v", t.DefaultMTU, oldDefaultMTU)
}
}
}

Expand Down Expand Up @@ -265,22 +265,70 @@ func (t *tun) deviceBytes() (o [16]byte) {
return
}

func hasNetlinkAddr(al []*netlink.Addr, x netlink.Addr) bool {
for i := range al {
if al[i].Equal(x) {
return true
}
}
return false
}

// addIPs uses netlink to add all addresses that don't exist, then it removes ones that should not be there
func (t *tun) addIPs(link netlink.Link) error {
newAddrs := make([]*netlink.Addr, len(t.vpnNetworks))
for i := range t.vpnNetworks {
//todo I wish I didn't need to stringify and re-parse
nlAddr, err := netlink.ParseAddr(t.vpnNetworks[i].String())
if err != nil {
return err
}
newAddrs[i] = nlAddr
}

//add all new addresses
for i := range newAddrs {
//todo do we want to stack errors and try as many ops as possible?
//todo AddrReplace should still add new IPs I think
if err := netlink.AddrReplace(link, newAddrs[i]); err != nil {
return err
}
//newAddrs is the same order as vpnNetworks so this is fine
if err := t.setDefaultRoute(t.vpnNetworks[i]); err != nil {
return err
}
}

//iterate over remainder, remove whoever shouldn't be there
al, err := netlink.AddrList(link, netlink.FAMILY_ALL)
if err != nil {
return fmt.Errorf("failed to get tun address list: %s", err)
}

for i := range al {
if hasNetlinkAddr(newAddrs, al[i]) {
continue
}
err = netlink.AddrDel(link, &al[i])
if err != nil {
t.l.WithError(err).Error("failed to remove address from tun address list")
} else {
t.l.WithField("removed", al[i].String()).Info("removed address not listed in cert(s)")
}
}

return nil
}

func (t *tun) Activate() error {
devName := t.deviceBytes()

if t.useSystemRoutes {
t.watchRoutes()
}

var addr, mask [4]byte

//TODO: IPV6-WORK
addr = t.cidr.Addr().As4()
tmask := net.CIDRMask(t.cidr.Bits(), 32)
copy(mask[:], tmask)

s, err := unix.Socket(
unix.AF_INET,
unix.AF_INET, //todo do we ever need INET6 here?
unix.SOCK_DGRAM,
unix.IPPROTO_IP,
)
Expand All @@ -289,31 +337,19 @@ func (t *tun) Activate() error {
}
t.ioctlFd = uintptr(s)

ifra := ifreqAddr{
Name: devName,
Addr: unix.RawSockaddrInet4{
Family: unix.AF_INET,
Addr: addr,
},
}

// Set the device ip address
if err = ioctl(t.ioctlFd, unix.SIOCSIFADDR, uintptr(unsafe.Pointer(&ifra))); err != nil {
return fmt.Errorf("failed to set tun address: %s", err)
}

// Set the device network
ifra.Addr.Addr = mask
if err = ioctl(t.ioctlFd, unix.SIOCSIFNETMASK, uintptr(unsafe.Pointer(&ifra))); err != nil {
return fmt.Errorf("failed to set tun netmask: %s", err)
}

// Set the device name
ifrf := ifReq{Name: devName}
if err = ioctl(t.ioctlFd, unix.SIOCGIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
return fmt.Errorf("failed to set tun device name: %s", err)
}

link, err := netlink.LinkByName(t.Device)
if err != nil {
return fmt.Errorf("failed to get tun device link: %s", err)
}

t.deviceIndex = link.Attrs().Index

// Setup our default MTU
t.setMTU()

Expand All @@ -330,13 +366,7 @@ func (t *tun) Activate() error {
return fmt.Errorf("failed to bring the tun device up: %s", err)
}

link, err := netlink.LinkByName(t.Device)
if err != nil {
return fmt.Errorf("failed to get tun device link: %s", err)
}
t.deviceIndex = link.Attrs().Index

if err = t.setDefaultRoute(); err != nil {
if err = t.addIPs(link); err != nil {
return err
}

Expand All @@ -350,6 +380,7 @@ func (t *tun) Activate() error {
if err = ioctl(t.ioctlFd, unix.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
return fmt.Errorf("failed to run tun device: %s", err)
}
//todo do we want to keep the link-local address?

return nil
}
Expand All @@ -363,12 +394,12 @@ func (t *tun) setMTU() {
}
}

func (t *tun) setDefaultRoute() error {
func (t *tun) setDefaultRoute(cidr netip.Prefix) error {
// Default route

dr := &net.IPNet{
IP: t.cidr.Masked().Addr().AsSlice(),
Mask: net.CIDRMask(t.cidr.Bits(), t.cidr.Addr().BitLen()),
IP: cidr.Masked().Addr().AsSlice(),
Mask: net.CIDRMask(cidr.Bits(), cidr.Addr().BitLen()),
}

nr := netlink.Route{
Expand All @@ -377,7 +408,7 @@ func (t *tun) setDefaultRoute() error {
MTU: t.DefaultMTU,
AdvMSS: t.advMSS(Route{}),
Scope: unix.RT_SCOPE_LINK,
Src: net.IP(t.cidr.Addr().AsSlice()),
Src: net.IP(cidr.Addr().AsSlice()),
Protocol: unix.RTPROT_KERNEL,
Table: unix.RT_TABLE_MAIN,
Type: unix.RTN_UNICAST,
Expand Down Expand Up @@ -463,10 +494,6 @@ func (t *tun) removeRoutes(routes []Route) {
}
}

func (t *tun) Cidr() netip.Prefix {
return t.cidr
}

func (t *tun) Name() string {
return t.Device
}
Expand Down Expand Up @@ -523,9 +550,16 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) {
}

gwAddr = gwAddr.Unmap()
if !t.cidr.Contains(gwAddr) {
withinNetworks := false
for i := range t.vpnNetworks {
if t.vpnNetworks[i].Contains(gwAddr) {
withinNetworks = true
break
}
}
if !withinNetworks {
// Gateway isn't in our overlay network, ignore
t.l.WithField("route", r).Debug("Ignoring route update, not in our network")
t.l.WithField("route", r).Debug("Ignoring route update, not in our networks")
return
}

Expand Down

0 comments on commit 7d41f3f

Please sign in to comment.