diff --git a/sockets/sockets.go b/sockets/sockets.go index b0eae239..202ced41 100644 --- a/sockets/sockets.go +++ b/sockets/sockets.go @@ -2,13 +2,19 @@ package sockets import ( + "context" "errors" + "fmt" "net" "net/http" + "syscall" "time" ) -const defaultTimeout = 10 * time.Second +const ( + defaultTimeout = 10 * time.Second + maxUnixSocketPathSize = len(syscall.RawSockaddrUnix{}.Path) +) // ErrProtocolNotAvailable is returned when a given transport protocol is not provided by the operating system. var ErrProtocolNotAvailable = errors.New("protocol not available") @@ -35,3 +41,18 @@ func ConfigureTransport(tr *http.Transport, proto, addr string) error { } return nil } + +func configureUnixTransport(tr *http.Transport, proto, addr string) error { + if len(addr) > maxUnixSocketPathSize { + return fmt.Errorf("Unix socket path %q is too long", addr) + } + // No need for compression in local communications. + tr.DisableCompression = true + dialer := &net.Dialer{ + Timeout: defaultTimeout, + } + tr.DialContext = func(ctx context.Context, _, _ string) (net.Conn, error) { + return dialer.DialContext(ctx, proto, addr) + } + return nil +} diff --git a/sockets/sockets_unix.go b/sockets/sockets_unix.go index 78a34a98..80d3dfac 100644 --- a/sockets/sockets_unix.go +++ b/sockets/sockets_unix.go @@ -1,33 +1,14 @@ -//go:build !windows +//go:build unix package sockets import ( - "context" - "fmt" "net" "net/http" "syscall" "time" ) -const maxUnixSocketPathSize = len(syscall.RawSockaddrUnix{}.Path) - -func configureUnixTransport(tr *http.Transport, proto, addr string) error { - if len(addr) > maxUnixSocketPathSize { - return fmt.Errorf("unix socket path %q is too long", addr) - } - // No need for compression in local communications. - tr.DisableCompression = true - dialer := &net.Dialer{ - Timeout: defaultTimeout, - } - tr.DialContext = func(ctx context.Context, _, _ string) (net.Conn, error) { - return dialer.DialContext(ctx, proto, addr) - } - return nil -} - func configureNpipeTransport(tr *http.Transport, proto, addr string) error { return ErrProtocolNotAvailable } diff --git a/sockets/sockets_windows.go b/sockets/sockets_windows.go index 7acafc5a..d4f2e788 100644 --- a/sockets/sockets_windows.go +++ b/sockets/sockets_windows.go @@ -9,10 +9,6 @@ import ( "github.com/Microsoft/go-winio" ) -func configureUnixTransport(tr *http.Transport, proto, addr string) error { - return ErrProtocolNotAvailable -} - func configureNpipeTransport(tr *http.Transport, proto, addr string) error { // No need for compression in local communications. tr.DisableCompression = true diff --git a/sockets/unix_socket.go b/sockets/unix_socket.go index b9233521..6a1c934d 100644 --- a/sockets/unix_socket.go +++ b/sockets/unix_socket.go @@ -1,5 +1,3 @@ -//go:build !windows - /* Package sockets is a simple unix domain socket wrapper. @@ -105,7 +103,7 @@ func NewUnixSocketWithOpts(path string, opts ...SockOption) (net.Listener, error // an option, and skip changing umask if default permissions are used. origUmask := syscall.Umask(0o777) l, err := net.Listen("unix", path) - syscall.Umask(origUmask) + umask(origUmask) if err != nil { return nil, err } diff --git a/sockets/unix_socket_test.go b/sockets/unix_socket_test.go index e4ae0e37..6563452b 100644 --- a/sockets/unix_socket_test.go +++ b/sockets/unix_socket_test.go @@ -1,12 +1,10 @@ -//go:build !windows - package sockets import ( "fmt" + "io/ioutil" "net" "os" - "syscall" "testing" ) @@ -52,26 +50,15 @@ func TestNewUnixSocket(t *testing.T) { } func TestUnixSocketWithOpts(t *testing.T) { - uid, gid := os.Getuid(), os.Getgid() - perms := os.FileMode(0o660) - path := "/tmp/test.sock" - echoStr := "hello" - l, err := NewUnixSocketWithOpts(path, WithChown(uid, gid), WithChmod(perms)) + socketFile, err := ioutil.TempFile("", "test*.sock") if err != nil { t.Fatal(err) } + defer socketFile.Close() + + l := createTestUnixSocket(t, socketFile.Name()) defer l.Close() - p, err := os.Stat(path) - if err != nil { - t.Fatal(err) - } - if p.Mode().Perm() != perms { - t.Fatalf("unexpected file permissions: expected: %#o, got: %#o", perms, p.Mode().Perm()) - } - if stat, ok := p.Sys().(*syscall.Stat_t); ok { - if stat.Uid != uint32(uid) || stat.Gid != uint32(gid) { - t.Fatalf("unexpected file ownership: expected: %d:%d, got: %d:%d", uid, gid, stat.Uid, stat.Gid) - } - } - runTest(t, path, l, echoStr) + + echoStr := "hello" + runTest(t, socketFile.Name(), l, echoStr) } diff --git a/sockets/unix_socket_test_unix.go b/sockets/unix_socket_test_unix.go new file mode 100644 index 00000000..aa1c6e67 --- /dev/null +++ b/sockets/unix_socket_test_unix.go @@ -0,0 +1,32 @@ +//go:build unix + +package sockets + +import ( + "net" + "os" + "syscall" + "testing" +) + +func createTestUnixSocket(t *testing.T, path string) (listener net.Listener) { + uid, gid := os.Getuid(), os.Getgid() + perms := os.FileMode(0660) + l, err := NewUnixSocketWithOpts(path, WithChown(uid, gid), WithChmod(perms)) + if err != nil { + t.Fatal(err) + } + p, err := os.Stat(path) + if err != nil { + t.Fatal(err) + } + if p.Mode().Perm() != perms { + t.Fatalf("unexpected file permissions: expected: %#o, got: %#o", perms, p.Mode().Perm()) + } + if stat, ok := p.Sys().(*syscall.Stat_t); ok { + if stat.Uid != uint32(uid) || stat.Gid != uint32(gid) { + t.Fatalf("unexpected file ownership: expected: %d:%d, got: %d:%d", uid, gid, stat.Uid, stat.Gid) + } + } + return l +} diff --git a/sockets/unix_socket_test_windows.go b/sockets/unix_socket_test_windows.go new file mode 100644 index 00000000..e68aca0b --- /dev/null +++ b/sockets/unix_socket_test_windows.go @@ -0,0 +1,14 @@ +package sockets + +import ( + "net" + "testing" +) + +func createTestUnixSocket(t *testing.T, path string) (listener net.Listener) { + l, err := NewUnixSocketWithOpts(path) + if err != nil { + t.Fatal(err) + } + return l +} diff --git a/sockets/unix_socket_unix.go b/sockets/unix_socket_unix.go new file mode 100644 index 00000000..6bf0ae8d --- /dev/null +++ b/sockets/unix_socket_unix.go @@ -0,0 +1,9 @@ +//go:build unix + +package sockets + +import "syscall" + +func umask(newmask int) (oldmask int) { + return syscall.Umask(0777) +} diff --git a/sockets/unix_socket_windows.go b/sockets/unix_socket_windows.go new file mode 100644 index 00000000..e8b295a6 --- /dev/null +++ b/sockets/unix_socket_windows.go @@ -0,0 +1,5 @@ +package sockets + +func umask(newmask int) (oldmask int) { + return newmask +}