Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix remote port forwarding reporting wrong hostname #48831

Merged
merged 1 commit into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions lib/srv/regular/sshserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -2194,8 +2194,18 @@ func (s *Server) handleTCPIPForwardRequest(ctx context.Context, ccx *sshutils.Co
return trace.Wrap(err)
}

// Set the src addr again since it may have been updated with a new port.
scx.SrcAddr = listener.Addr().String()
// If the client didn't request a specific port, the chosen port needs to
// be reported back.
srcHost, _, err := sshutils.SplitHostPort(scx.SrcAddr)
if err != nil {
return trace.Wrap(err)
}
_, listenPort, err := sshutils.SplitHostPort(listener.Addr().String())
if err != nil {
return trace.Wrap(err)
}
scx.SrcAddr = sshutils.JoinHostPort(srcHost, listenPort)

event := scx.GetPortForwardEvent()
if err := s.EmitAuditEvent(ctx, &event); err != nil {
s.Logger.WithError(err).Warn("Failed to emit audit event.")
Expand Down
76 changes: 49 additions & 27 deletions lib/srv/regular/sshserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
package regular

import (
"bytes"
"context"
"crypto/tls"
"encoding/json"
Expand Down Expand Up @@ -709,40 +708,63 @@ func TestDirectTCPIP(t *testing.T) {
// "tcpip-forward" request and do remote port forwarding.
func TestTCPIPForward(t *testing.T) {
t.Parallel()
f := newFixtureWithoutDiskBasedLogging(t)

// Request a listener from the server.
listener, err := f.ssh.clt.Listen("tcp", "127.0.0.1:0")
hostname, err := os.Hostname()
require.NoError(t, err)
tests := []struct {
name string
listenAddr string
}{
{
name: "localhost",
listenAddr: "localhost:0",
},
{
name: "ip address",
listenAddr: "127.0.0.1:0",
},
{
name: "hostname",
listenAddr: hostname + ":0",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
f := newFixtureWithoutDiskBasedLogging(t)

// Start up a test server that uses the port forwarded listener.
ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "hello, world")
}))
t.Cleanup(ts.Close)
ts.Listener = listener
ts.Start()
// Request a listener from the server.
listener, err := f.ssh.clt.Listen("tcp", tc.listenAddr)
require.NoError(t, err)

// Dial the test server over the SSH connection.
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
t.Cleanup(cancel)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, ts.URL, &bytes.Buffer{})
require.NoError(t, err)
client := &http.Client{}
resp, err := client.Do(req)
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, resp.Body.Close())
})
// Start up a test server that uses the port forwarded listener.
ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "hello, world")
}))
t.Cleanup(ts.Close)
ts.Listener = listener
ts.Start()

// Dial the test server over the SSH connection.
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
t.Cleanup(cancel)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, ts.URL, nil)
require.NoError(t, err)
resp, err := ts.Client().Do(req)
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, resp.Body.Close())
})

// Make sure the response is what was expected.
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, []byte("hello, world\n"), body)
// Make sure the response is what was expected.
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, []byte("hello, world\n"), body)
})
}

t.Run("SessionJoinPrincipal cannot use tcpip-forward", func(t *testing.T) {
// Ensure that ssh client using SessionJoinPrincipal as Login, cannot
// connect using "tcpip-forward".
f := newFixtureWithoutDiskBasedLogging(t)
ctx := context.Background()
cliUsingSessionJoin := f.newSSHClient(ctx, t, &user.User{Username: teleport.SSHSessionJoinPrincipal})
_, err := cliUsingSessionJoin.Listen("tcp", "127.0.0.1:0")
Expand Down
3 changes: 3 additions & 0 deletions lib/sshutils/tcpip.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ func StartRemoteListener(ctx context.Context, sshConn channelOpener, srcAddr str

dstHost, dstPort, err := SplitHostPort(conn.RemoteAddr().String())
if err != nil {
conn.Close()
logger.WithError(err).Warn("failed to parse addr")
return
}
Expand All @@ -112,13 +113,15 @@ func StartRemoteListener(ctx context.Context, sshConn channelOpener, srcAddr str
OrigPort: dstPort,
}
if err := req.CheckAndSetDefaults(); err != nil {
conn.Close()
logger.WithError(err).Warn("failed to create forwarded tcpip request")
return
}
reqBytes := ssh.Marshal(req)

ch, rch, err := sshConn.OpenChannel(teleport.ChanForwardedTCPIP, reqBytes)
if err != nil {
conn.Close()
logger.WithError(err).Warn("failed to open channel")
continue
}
Expand Down
Loading