Skip to content

Commit

Permalink
setup ExposeHostPorts forwards on container start
Browse files Browse the repository at this point in the history
Fixes testcontainers#2811

Previously ExposedHostPorts would start an SSHD container prior to
starting the testcontainer and inject a PostReadies lifecycle hook into
the testcontainer in order to set up remote port forwarding from the
host to the SSHD container so the testcontainer can talk to the host via
the SSHD container

This would be an issue if the testcontainer depends on accessing the
host port on startup ( e.g., a proxy server ) as the forwarding for the
host access isn't set up until all the WiatFor strategies on the
testcontainer have completed.

The fix is to move the forwarding setup to the PreCreates hook on the
testcontainer. Since remote forwarding doesn't establish a connection to
the host port until a connection is made to the remote port, this should
not be an issue even if the host isn't listening yet and ensures the
remote port is available to the testcontainer immediately.
  • Loading branch information
hathvi committed Oct 5, 2024
1 parent 7c53667 commit ad3f698
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 75 deletions.
4 changes: 2 additions & 2 deletions port_forwarding.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,8 @@ func exposeHostPorts(ctx context.Context, req *ContainerRequest, ports ...int) (
// after the container is ready, create the SSH tunnel
// for each exposed port from the host.
sshdConnectHook = ContainerLifecycleHooks{
PostReadies: []ContainerHook{
func(ctx context.Context, c Container) error {
PreCreates: []ContainerRequestHook{
func(ctx context.Context, req ContainerRequest) error {
return sshdContainer.exposeHostPort(ctx, req.HostAccessPorts...)
},
},
Expand Down
124 changes: 51 additions & 73 deletions port_forwarding_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,12 @@ import (
"net/http"
"net/http/httptest"
"testing"
"time"

"github.com/stretchr/testify/require"

"github.com/testcontainers/testcontainers-go"
tcexec "github.com/testcontainers/testcontainers-go/exec"
"github.com/testcontainers/testcontainers-go/network"
"github.com/testcontainers/testcontainers-go/wait"
)

const (
Expand All @@ -23,42 +22,59 @@ const (

func TestExposeHostPorts(t *testing.T) {
tests := []struct {
name string
numberOfPorts int
hasNetwork bool
hasHostAccess bool
name string
numberOfPorts int
hasNetwork bool
bindOnPostStarts bool
}{
{
name: "single port",
numberOfPorts: 1,
hasHostAccess: true,
},
{
name: "single port using a network",
numberOfPorts: 1,
hasNetwork: true,
hasHostAccess: true,
},
{
name: "multiple ports",
numberOfPorts: 3,
hasHostAccess: true,
},
{
name: "single port with cancellation",
numberOfPorts: 1,
hasHostAccess: false,
name: "multiple ports bound on PostStarts",
numberOfPorts: 3,
bindOnPostStarts: true,
},
}

for _, tc := range tests {
t.Run(tc.name, func(tt *testing.T) {
servers := make([]*httptest.Server, tc.numberOfPorts)
freePorts := make([]int, tc.numberOfPorts)
waitStrategies := make([]wait.Strategy, tc.numberOfPorts)
for i := range freePorts {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
server := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, expectedResponse)
}))
freePorts[i] = server.Listener.Addr().(*net.TCPAddr).Port

if !tc.bindOnPostStarts {
server.Start()
}

servers[i] = server
freePort := server.Listener.Addr().(*net.TCPAddr).Port
freePorts[i] = freePort
waitStrategies[i] = wait.
ForExec([]string{"wget", "-q", "-O", "-", fmt.Sprintf("http://%s:%d", testcontainers.HostInternal, freePort)}).
WithExitCodeMatcher(func(code int) bool {
return code == 0
}).
WithResponseMatcher(func(body io.Reader) bool {
bs, err := io.ReadAll(body)
require.NoError(tt, err)
return string(bs) == expectedResponse
})

tt.Cleanup(func() {
server.Close()
})
Expand All @@ -69,7 +85,26 @@ func TestExposeHostPorts(t *testing.T) {
ContainerRequest: testcontainers.ContainerRequest{
Image: "alpine:3.17",
HostAccessPorts: freePorts,
Cmd: []string{"top"},
WaitingFor: wait.ForAll(waitStrategies...),
LifecycleHooks: []testcontainers.ContainerLifecycleHooks{
{
PostStarts: []testcontainers.ContainerHook{
func(ctx context.Context, c testcontainers.Container) error {
if tc.bindOnPostStarts {
for _, server := range servers {
server.Start()
}
}

return nil
},
func(ctx context.Context, c testcontainers.Container) error {
return waitStrategies[0].WaitUntilReady(ctx, c)
},
},
},
},
Cmd: []string{"top"},
},
// }
Started: true,
Expand All @@ -87,66 +122,9 @@ func TestExposeHostPorts(t *testing.T) {
}

ctx := context.Background()
if !tc.hasHostAccess {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, 10*time.Second)
defer cancel()
}

c, err := testcontainers.GenericContainer(ctx, req)
testcontainers.CleanupContainer(t, c)
require.NoError(tt, err)

if tc.hasHostAccess {
// create a container that has host access, which will
// automatically forward the port to the container
assertContainerHasHostAccess(tt, c, freePorts...)
} else {
// force cancellation because of timeout
time.Sleep(11 * time.Second)

assertContainerHasNoHostAccess(tt, c, freePorts...)
}
_ = c.Terminate(ctx)
})
}
}

func httpRequest(t *testing.T, c testcontainers.Container, port int) (int, string) {
// wgetHostInternal {
code, reader, err := c.Exec(
context.Background(),
[]string{"wget", "-q", "-O", "-", fmt.Sprintf("http://%s:%d", testcontainers.HostInternal, port)},
tcexec.Multiplexed(),
)
// }
require.NoError(t, err)

// read the response
bs, err := io.ReadAll(reader)
require.NoError(t, err)

return code, string(bs)
}

func assertContainerHasHostAccess(t *testing.T, c testcontainers.Container, ports ...int) {
for _, port := range ports {
code, response := httpRequest(t, c, port)
if code != 0 {
t.Fatalf("expected status code [%d] but got [%d]", 0, code)
}

if response != expectedResponse {
t.Fatalf("expected [%s] but got [%s]", expectedResponse, response)
}
}
}

func assertContainerHasNoHostAccess(t *testing.T, c testcontainers.Container, ports ...int) {
for _, port := range ports {
_, response := httpRequest(t, c, port)

if response == expectedResponse {
t.Fatalf("expected not to get [%s] but got [%s]", expectedResponse, response)
}
}
}

0 comments on commit ad3f698

Please sign in to comment.