Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[v14] Support double dash delimiter in tsh ssh #47495

Merged
merged 3 commits into from
Nov 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions tool/tsh/common/tsh.go
Original file line number Diff line number Diff line change
Expand Up @@ -3492,6 +3492,11 @@ func onSSH(cf *CLIConf) error {

tc.AllowHeadless = true

// Support calling `tsh ssh -- <command>` (with a double dash before the command)
if len(cf.RemoteCommand) > 0 && strings.TrimSpace(cf.RemoteCommand[0]) == "--" {
cf.RemoteCommand = cf.RemoteCommand[1:]
}

tc.Stdin = os.Stdin
err = retryWithAccessRequest(cf, tc, func() error {
err = client.RetryWithRelogin(cf.Context, tc, func() error {
Expand Down
177 changes: 177 additions & 0 deletions tool/tsh/common/tsh_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2040,6 +2040,183 @@ func TestAccessRequestOnLeaf(t *testing.T) {
require.NoError(t, err)
}

// TestSSHCommand tests that a user can access a single SSH node and run commands.
func TestSSHCommands(t *testing.T) {
modules.SetTestModules(t, &modules.TestModules{TestBuildType: modules.BuildEnterprise})
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

accessRoleName := "access"
sshHostname := "test-ssh-server"

accessUser, err := types.NewUser(accessRoleName)
require.NoError(t, err)
accessUser.SetRoles([]string{accessRoleName})

user, err := user.Current()
require.NoError(t, err)
accessUser.SetLogins([]string{user.Username})

traits := map[string][]string{
constants.TraitLogins: {user.Username},
}
accessUser.SetTraits(traits)

connector := mockConnector(t)
rootServerOpts := []testserver.TestServerOptFunc{
testserver.WithBootstrap(connector, accessUser),
testserver.WithHostname(sshHostname),
testserver.WithClusterName(t, "root"),
testserver.WithSSHLabel(accessRoleName, "true"),
testserver.WithSSHPublicAddrs("127.0.0.1:0"),
testserver.WithConfig(func(cfg *servicecfg.Config) {
cfg.SSH.Enabled = true
cfg.SSH.PublicAddrs = []utils.NetAddr{cfg.SSH.Addr}
cfg.SSH.DisableCreateHostUser = true
}),
}
rootServer := testserver.MakeTestServer(t, rootServerOpts...)

rootProxyAddr, err := rootServer.ProxyWebAddr()
require.NoError(t, err)

require.EventuallyWithT(t, func(t *assert.CollectT) {
rootNodes, err := rootServer.GetAuthServer().GetNodes(ctx, apidefaults.Namespace)
if !assert.NoError(t, err) || !assert.Len(t, rootNodes, 1) {
return
}
}, 10*time.Second, 100*time.Millisecond)

tmpHomePath := t.TempDir()
rootAuth := rootServer.GetAuthServer()

err = Run(ctx, []string{
"login",
"--insecure",
"--proxy", rootProxyAddr.String(),
"--user", user.Username,
}, setHomePath(tmpHomePath), setMockSSOLogin(rootAuth, accessUser, connector.GetName()))
require.NoError(t, err)

tests := []struct {
name string
args []string
expected string
shouldErr bool
}{
{
// Test that a simple echo works.
name: "ssh simple command",
expected: "this is a test message",
args: []string{
fmt.Sprintf("%s@%s", user.Username, sshHostname),
"echo",
"this is a test message",
},
shouldErr: false,
},
{
// Test that commands can be prefixed with a double dash.
name: "ssh command with double dash",
expected: "this is a test message",
args: []string{
fmt.Sprintf("%s@%s", user.Username, sshHostname),
"--",
"echo",
"this is a test message",
},
shouldErr: false,
},
{
// Test that a double dash is not removed from the middle of a command.
name: "ssh command with double dash in the middle",
expected: "-- this is a test message",
args: []string{
fmt.Sprintf("%s@%s", user.Username, sshHostname),
"echo",
"--",
"this is a test message",
},
shouldErr: false,
},
{
// Test that quoted commands work (e.g. `tsh ssh 'echo test'`)
name: "ssh command literal",
expected: "this is a test message",
args: []string{
fmt.Sprintf("%s@%s", user.Username, sshHostname),
"echo this is a test message",
},
shouldErr: false,
},
{
// Test that a double dash is passed as-is in a quoted command (which should fail).
name: "ssh command literal with double dash err",
expected: "",
args: []string{
fmt.Sprintf("%s@%s", user.Username, sshHostname),
"-- echo this is a test message",
},
shouldErr: true,
},
{
// Test that a double dash is not removed from the middle of a quoted command.
name: "ssh command literal with double dash in the middle",
expected: "-- this is a test message",
args: []string{
fmt.Sprintf("%s@%s", user.Username, sshHostname),
"echo", "-- this is a test message",
},
shouldErr: false,
},
{
// Test tsh ssh -- hostname command
name: "delimiter before host and command",
expected: "this is a test message",
args: []string{
"--", sshHostname, "echo", "this is a test message",
},
shouldErr: false,
},
}

for _, test := range tests {
test := test
ctx := context.Background()
t.Run(test.name, func(t *testing.T) {
t.Parallel()

stdout := &output{buf: bytes.Buffer{}}
stderr := &output{buf: bytes.Buffer{}}
args := append(
[]string{
"ssh",
"--insecure",
"--proxy", rootProxyAddr.String(),
},
test.args...,
)

err := Run(ctx, args, setHomePath(tmpHomePath),
func(conf *CLIConf) error {
conf.overrideStdin = &bytes.Buffer{}
conf.OverrideStdout = stdout
conf.overrideStderr = stderr
return nil
},
)

if test.shouldErr {
require.Error(t, err)
} else {
require.NoError(t, err)
require.Equal(t, test.expected, strings.TrimSpace(stdout.String()))
require.Empty(t, stderr.String())
}
})
}
}

// tryCreateTrustedCluster performs several attempts to create a trusted cluster,
// retries on connection problems and access denied errors to let caches
// propagate and services to start
Expand Down
Loading