From 33a0a8c3b8621b098ca7554cc763a8c6fd3e2573 Mon Sep 17 00:00:00 2001 From: Vasyl Gello Date: Mon, 15 Jul 2024 13:15:22 +0300 Subject: [PATCH] Fix unix domain socket creation/cleanup logic * If there is another instance still listening on the same Unix domain socket, bail out * If there is a leftover from crashed yggstack etc, clean the socket file and proceed Signed-off-by: Vasyl Gello --- cmd/yggstack/main.go | 94 ++++++++++++++++++++++++++++---------------- 1 file changed, 61 insertions(+), 33 deletions(-) diff --git a/cmd/yggstack/main.go b/cmd/yggstack/main.go index aff7ae0..067f530 100644 --- a/cmd/yggstack/main.go +++ b/cmd/yggstack/main.go @@ -12,6 +12,7 @@ import ( "os" "os/signal" "regexp" + "runtime" "strings" "syscall" @@ -273,44 +274,48 @@ func main() { // Create SOCKS server { - if socks != nil { - if nameserver != nil { - if strings.Contains(*socks, ":") { - logger.Infof("Starting SOCKS server on %s", *socks) - resolver := types.NewNameResolver(s, *nameserver) - socksOptions := []socks5.Option{ - socks5.WithDial(s.DialContext), - socks5.WithResolver(resolver), - } - if logger.GetLevel("debug") { - socksOptions = append(socksOptions, socks5.WithLogger(logger)) - } - server := socks5.NewServer(socksOptions...) - go server.ListenAndServe("tcp", *socks) // nolint:errcheck - } else { - logger.Infof("Starting SOCKS server with socket file %s", *socks) - _, err := os.Stat(*socks) - if os.IsNotExist(err) { - n.socks5Listener, err = net.Listen("unix", *socks) + if socks != nil && *socks != "" { + socksOptions := []socks5.Option{ + socks5.WithDial(s.DialContext), + } + if nameserver != nil && *nameserver != "" { + resolver := types.NewNameResolver(s, *nameserver) + socksOptions = append(socksOptions, socks5.WithResolver(resolver)) + } else { + logger.Warningf("DNS nameserver is not set!") + logger.Warningf("SOCKS server will not be able to resolve hostnames other than .pk.ygg !") + } + if logger.GetLevel("debug") { + socksOptions = append(socksOptions, socks5.WithLogger(logger)) + } + server := socks5.NewServer(socksOptions...) + if strings.Contains(*socks, ":") { + logger.Infof("Starting SOCKS server on %s", *socks) + go server.ListenAndServe("tcp", *socks) // nolint:errcheck + } else { + logger.Infof("Starting SOCKS server with socket file %s", *socks) + n.socks5Listener, err = net.Listen("unix", *socks) + if err != nil { + // If address in use, try connecting to + // the socket to see if other yggstack + // instance is listening on it + + if isErrorAddressAlreadyInUse(err) { + _, err = net.Dial("unix", *socks) if err != nil { - panic(err) - } - resolver := types.NewNameResolver(s, *nameserver) - socksOptions := []socks5.Option{ - socks5.WithDial(s.DialContext), - socks5.WithResolver(resolver), - } - if logger.GetLevel("debug") { - socksOptions = append(socksOptions, socks5.WithLogger(logger)) + // Unlink dead socket if not connected + err = os.RemoveAll(*socks) + if err != nil { + panic(err) + } + } else { + panic(fmt.Errorf("Another yggstack instance is listening on socket '%s'", *socks)) } - server := socks5.NewServer(socksOptions...) - go server.Serve(n.socks5Listener) // nolint:errcheck - } else if err != nil { - logger.Errorf("Cannot create socket file %s: %s", *socks, err) } else { - panic(errors.New(fmt.Sprintf("Socket file %s already exists", *socks))) + panic(err) } } + go server.Serve(n.socks5Listener) // nolint:errcheck } } } @@ -349,11 +354,34 @@ func main() { _ = n.multicast.Stop() if n.socks5Listener != nil { _ = n.socks5Listener.Close() + _ = os.RemoveAll(*socks) logger.Infof("Stopped UNIX socket listener") } n.core.Stop() } +// Helper to detect if socket address is in use +// https://stackoverflow.com/a/52152912 +func isErrorAddressAlreadyInUse(err error) bool { + var eOsSyscall *os.SyscallError + if !errors.As(err, &eOsSyscall) { + return false + } + var errErrno syscall.Errno // doesn't need a "*" (ptr) because it's already a ptr (uintptr) + if !errors.As(eOsSyscall, &errErrno) { + return false + } + if errors.Is(errErrno, syscall.EADDRINUSE) { + return true + } + const WSAEADDRINUSE = 10048 + if runtime.GOOS == "windows" && errErrno == WSAEADDRINUSE { + return true + } + return false +} + +// Helper to set logging level func setLogLevel(loglevel string, logger *log.Logger) { levels := [...]string{"error", "warn", "info", "debug", "trace"} loglevel = strings.ToLower(loglevel)