Skip to content

Commit

Permalink
Implement AF_UNIX sockets on Windows
Browse files Browse the repository at this point in the history
See moby/moby#36442

Signed-off-by: Marat Radchenko <[email protected]>
  • Loading branch information
slonopotamus committed Jan 24, 2022
1 parent 58542c7 commit c430aa9
Show file tree
Hide file tree
Showing 9 changed files with 103 additions and 52 deletions.
25 changes: 25 additions & 0 deletions sockets/sockets.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,13 @@
package sockets

import (
"context"
"errors"
"fmt"
"net"
"net/http"
"syscall"
"time"
)

// ErrProtocolNotAvailable is returned when a given transport protocol is not provided by the operating system.
Expand All @@ -24,3 +29,23 @@ func ConfigureTransport(tr *http.Transport, proto, addr string) error {
}
return nil
}

const (
defaultTimeout = 10 * time.Second
maxUnixSocketPathSize = len(syscall.RawSockaddrUnix{}.Path)
)

func configureUnixTransport(tr *http.Transport, proto, addr string) error {
if len(addr) > maxUnixSocketPathSize {
return fmt.Errorf("Unix socket path %q is too long", addr)
}
// No need for compression in local communications.
tr.DisableCompression = true
dialer := &net.Dialer{
Timeout: defaultTimeout,
}
tr.DialContext = func(ctx context.Context, _, _ string) (net.Conn, error) {
return dialer.DialContext(ctx, proto, addr)
}
return nil
}
23 changes: 1 addition & 22 deletions sockets/sockets_unix.go
Original file line number Diff line number Diff line change
@@ -1,36 +1,15 @@
//go:build !windows
// +build !windows

package sockets

import (
"context"
"fmt"
"net"
"net/http"
"syscall"
"time"
)

const (
defaultTimeout = 10 * time.Second
maxUnixSocketPathSize = len(syscall.RawSockaddrUnix{}.Path)
)

func configureUnixTransport(tr *http.Transport, proto, addr string) error {
if len(addr) > maxUnixSocketPathSize {
return fmt.Errorf("Unix socket path %q is too long", addr)
}
// No need for compression in local communications.
tr.DisableCompression = true
dialer := &net.Dialer{
Timeout: defaultTimeout,
}
tr.DialContext = func(ctx context.Context, _, _ string) (net.Conn, error) {
return dialer.DialContext(ctx, proto, addr)
}
return nil
}

func configureNpipeTransport(tr *http.Transport, proto, addr string) error {
return ErrProtocolNotAvailable
}
Expand Down
4 changes: 0 additions & 4 deletions sockets/sockets_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,6 @@ import (
"github.com/Microsoft/go-winio"
)

func configureUnixTransport(tr *http.Transport, proto, addr string) error {
return ErrProtocolNotAvailable
}

func configureNpipeTransport(tr *http.Transport, proto, addr string) error {
// No need for compression in local communications.
tr.DisableCompression = true
Expand Down
6 changes: 2 additions & 4 deletions sockets/unix_socket.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
// +build !windows

/*
Package sockets is a simple unix domain socket wrapper.
Expand Down Expand Up @@ -103,9 +101,9 @@ func NewUnixSocketWithOpts(path string, opts ...SockOption) (net.Listener, error
// We don't use "defer" here, to reset the umask to its original value as soon
// as possible. Ideally we'd be able to detect if WithChmod() was passed as
// an option, and skip changing umask if default permissions are used.
origUmask := syscall.Umask(0777)
origUmask := umask(0777)
l, err := net.Listen("unix", path)
syscall.Umask(origUmask)
umask(origUmask)
if err != nil {
return nil, err
}
Expand Down
23 changes: 1 addition & 22 deletions sockets/unix_socket_test.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
// +build !windows

package sockets

import (
"fmt"
"net"
"os"
"syscall"
"testing"
)

Expand Down Expand Up @@ -52,26 +49,8 @@ func TestNewUnixSocket(t *testing.T) {
}

func TestUnixSocketWithOpts(t *testing.T) {
uid, gid := os.Getuid(), os.Getgid()
perms := os.FileMode(0660)
path := "/tmp/test.sock"
echoStr := "hello"
l, err := NewUnixSocketWithOpts(path, WithChown(uid, gid), WithChmod(perms))
if err != nil {
t.Fatal(err)
}
l, path := createTestUnixSocket(t)
defer l.Close()
p, err := os.Stat(path)
if err != nil {
t.Fatal(err)
}
if p.Mode().Perm() != perms {
t.Fatalf("unexpected file permissions: expected: %#o, got: %#o", perms, p.Mode().Perm())
}
if stat, ok := p.Sys().(*syscall.Stat_t); ok {
if stat.Uid != uint32(uid) || stat.Gid != uint32(gid) {
t.Fatalf("unexpected file ownership: expected: %d:%d, got: %d:%d", uid, gid, stat.Uid, stat.Gid)
}
}
runTest(t, path, l, echoStr)
}
34 changes: 34 additions & 0 deletions sockets/unix_socket_test_unix.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
//go:build !windows
// +build !windows

package sockets

import (
"os"
"syscall"
"testing"
)

func createTestUnixSocket(t *testing.T) (listener net.Listener, path string) {
uid, gid := os.Getuid(), os.Getgid()
perms := os.FileMode(0660)
path = "/tmp/test.sock"
echoStr := "hello"
l, err := NewUnixSocketWithOpts(path, WithChown(uid, gid), WithChmod(perms))
if err != nil {
t.Fatal(err)
}
p, err := os.Stat(path)
if err != nil {
t.Fatal(err)
}
if p.Mode().Perm() != perms {
t.Fatalf("unexpected file permissions: expected: %#o, got: %#o", perms, p.Mode().Perm())
}
if stat, ok := p.Sys().(*syscall.Stat_t); ok {
if stat.Uid != uint32(uid) || stat.Gid != uint32(gid) {
t.Fatalf("unexpected file ownership: expected: %d:%d, got: %d:%d", uid, gid, stat.Uid, stat.Gid)
}
}
return l, path
}
25 changes: 25 additions & 0 deletions sockets/unix_socket_test_windows.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package sockets

import (
"io/ioutil"
"net"
"testing"
)

func testSocketPath() string {
file, err := ioutil.TempFile("", "test*.sock")
if err != nil {
panic(err)
}
defer file.Close()
return file.Name()
}

func createTestUnixSocket(t *testing.T) (listener net.Listener, path string) {
path = testSocketPath()
l, err := NewUnixSocketWithOpts(path)
if err != nil {
t.Fatal(err)
}
return l, path
}
10 changes: 10 additions & 0 deletions sockets/unix_socket_unix.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
//go:build !windows
// +build !windows

package sockets

import "syscall"

func umask(newmask int) (oldmask int) {
return syscall.Umask(0777)
}
5 changes: 5 additions & 0 deletions sockets/unix_socket_windows.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package sockets

func umask(newmask int) (oldmask int) {
return newmask
}

0 comments on commit c430aa9

Please sign in to comment.