diff --git a/lib/srv/regular/sshserver.go b/lib/srv/regular/sshserver.go index b11427118f89d..ba649ca03d2a5 100644 --- a/lib/srv/regular/sshserver.go +++ b/lib/srv/regular/sshserver.go @@ -2194,8 +2194,18 @@ func (s *Server) handleTCPIPForwardRequest(ctx context.Context, ccx *sshutils.Co return trace.Wrap(err) } - // Set the src addr again since it may have been updated with a new port. - scx.SrcAddr = listener.Addr().String() + // If the client didn't request a specific port, the chosen port needs to + // be reported back. + srcHost, _, err := sshutils.SplitHostPort(scx.SrcAddr) + if err != nil { + return trace.Wrap(err) + } + _, listenPort, err := sshutils.SplitHostPort(listener.Addr().String()) + if err != nil { + return trace.Wrap(err) + } + scx.SrcAddr = sshutils.JoinHostPort(srcHost, listenPort) + event := scx.GetPortForwardEvent() if err := s.EmitAuditEvent(ctx, &event); err != nil { s.Logger.WithError(err).Warn("Failed to emit audit event.") diff --git a/lib/srv/regular/sshserver_test.go b/lib/srv/regular/sshserver_test.go index 8d6a64154ee9b..9ffba1bebe480 100644 --- a/lib/srv/regular/sshserver_test.go +++ b/lib/srv/regular/sshserver_test.go @@ -19,7 +19,6 @@ package regular import ( - "bytes" "context" "crypto/tls" "encoding/json" @@ -709,40 +708,63 @@ func TestDirectTCPIP(t *testing.T) { // "tcpip-forward" request and do remote port forwarding. func TestTCPIPForward(t *testing.T) { t.Parallel() - f := newFixtureWithoutDiskBasedLogging(t) - - // Request a listener from the server. - listener, err := f.ssh.clt.Listen("tcp", "127.0.0.1:0") + hostname, err := os.Hostname() require.NoError(t, err) + tests := []struct { + name string + listenAddr string + }{ + { + name: "localhost", + listenAddr: "localhost:0", + }, + { + name: "ip address", + listenAddr: "127.0.0.1:0", + }, + { + name: "hostname", + listenAddr: hostname + ":0", + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + f := newFixtureWithoutDiskBasedLogging(t) - // Start up a test server that uses the port forwarded listener. - ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - fmt.Fprintln(w, "hello, world") - })) - t.Cleanup(ts.Close) - ts.Listener = listener - ts.Start() + // Request a listener from the server. + listener, err := f.ssh.clt.Listen("tcp", tc.listenAddr) + require.NoError(t, err) - // Dial the test server over the SSH connection. - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - t.Cleanup(cancel) - req, err := http.NewRequestWithContext(ctx, http.MethodGet, ts.URL, &bytes.Buffer{}) - require.NoError(t, err) - client := &http.Client{} - resp, err := client.Do(req) - require.NoError(t, err) - t.Cleanup(func() { - require.NoError(t, resp.Body.Close()) - }) + // Start up a test server that uses the port forwarded listener. + ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, "hello, world") + })) + t.Cleanup(ts.Close) + ts.Listener = listener + ts.Start() + + // Dial the test server over the SSH connection. + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + t.Cleanup(cancel) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, ts.URL, nil) + require.NoError(t, err) + resp, err := ts.Client().Do(req) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, resp.Body.Close()) + }) - // Make sure the response is what was expected. - body, err := io.ReadAll(resp.Body) - require.NoError(t, err) - require.Equal(t, []byte("hello, world\n"), body) + // Make sure the response is what was expected. + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, []byte("hello, world\n"), body) + }) + } t.Run("SessionJoinPrincipal cannot use tcpip-forward", func(t *testing.T) { // Ensure that ssh client using SessionJoinPrincipal as Login, cannot // connect using "tcpip-forward". + f := newFixtureWithoutDiskBasedLogging(t) ctx := context.Background() cliUsingSessionJoin := f.newSSHClient(ctx, t, &user.User{Username: teleport.SSHSessionJoinPrincipal}) _, err := cliUsingSessionJoin.Listen("tcp", "127.0.0.1:0") diff --git a/lib/sshutils/tcpip.go b/lib/sshutils/tcpip.go index 4b92efb0177ce..55ac4ea981cf3 100644 --- a/lib/sshutils/tcpip.go +++ b/lib/sshutils/tcpip.go @@ -101,6 +101,7 @@ func StartRemoteListener(ctx context.Context, sshConn channelOpener, srcAddr str dstHost, dstPort, err := SplitHostPort(conn.RemoteAddr().String()) if err != nil { + conn.Close() logger.WithError(err).Warn("failed to parse addr") return } @@ -112,6 +113,7 @@ func StartRemoteListener(ctx context.Context, sshConn channelOpener, srcAddr str OrigPort: dstPort, } if err := req.CheckAndSetDefaults(); err != nil { + conn.Close() logger.WithError(err).Warn("failed to create forwarded tcpip request") return } @@ -119,6 +121,7 @@ func StartRemoteListener(ctx context.Context, sshConn channelOpener, srcAddr str ch, rch, err := sshConn.OpenChannel(teleport.ChanForwardedTCPIP, reqBytes) if err != nil { + conn.Close() logger.WithError(err).Warn("failed to open channel") continue }