From 1b44e9afa391e7c31a812bf99634ae48c6fbc99f Mon Sep 17 00:00:00 2001 From: simonfelding <45149055+simonfelding@users.noreply.github.com> Date: Sat, 20 Aug 2022 11:59:02 +0200 Subject: [PATCH] Add optional alias parameter to host config (#355) * added optional alias parameter to single clients and HostConfig for configuration from parallel clients. This is useful for weird ssh proxies like cyberark PAM. Without this, it is difficult to identify the source of the output, as they all have the same host name. * Added tests. * Updated docstrings. --- pssh/clients/base/parallel.py | 5 +++-- pssh/clients/base/single.py | 5 +++-- pssh/clients/native/parallel.py | 1 + pssh/clients/native/single.py | 11 +++++++---- pssh/clients/ssh/parallel.py | 1 + pssh/clients/ssh/single.py | 6 ++++-- pssh/config.py | 9 +++++++-- pssh/output.py | 12 ++++++++---- tests/native/test_parallel_client.py | 5 +++++ tests/native/test_single_client.py | 7 +++++++ tests/test_host_config.py | 5 ++++- 11 files changed, 50 insertions(+), 17 deletions(-) diff --git a/pssh/clients/base/parallel.py b/pssh/clients/base/parallel.py index 43f05db1..f2b83df7 100644 --- a/pssh/clients/base/parallel.py +++ b/pssh/clients/base/parallel.py @@ -231,6 +231,7 @@ def _get_output_from_cmds(self, cmds, raise_error=False): def _get_output_from_greenlet(self, cmd_i, cmd, raise_error=False): host = self.hosts[cmd_i] + alias = self._get_host_config(cmd_i, host).alias try: host_out = cmd.get() return host_out @@ -239,8 +240,7 @@ def _get_output_from_greenlet(self, cmd_i, cmd, raise_error=False): ex = Timeout() if raise_error: raise ex - return HostOutput(host, None, None, None, - exception=ex) + return HostOutput(host, None, None, None, exception=ex, alias=alias) def get_last_output(self, cmds=None): """Get output for last commands executed by ``run_command``. @@ -272,6 +272,7 @@ def _get_host_config(self, host_i, host): gssapi_server_identity=self.gssapi_server_identity, gssapi_client_identity=self.gssapi_client_identity, gssapi_delegate_credentials=self.gssapi_delegate_credentials, + alias=None, ) return config elif not isinstance(self.host_config, list): diff --git a/pssh/clients/base/single.py b/pssh/clients/base/single.py index 4255a14f..17db66d4 100644 --- a/pssh/clients/base/single.py +++ b/pssh/clients/base/single.py @@ -159,7 +159,7 @@ class BaseSSHClient(object): def __init__(self, host, user=None, password=None, port=None, - pkey=None, + pkey=None, alias=None, num_retries=DEFAULT_RETRIES, retry_delay=RETRY_DELAY, allow_agent=True, timeout=None, @@ -171,6 +171,7 @@ def __init__(self, host, ): self._auth_thread_pool = _auth_thread_pool self.host = host + self.alias = alias self.user = user if user else getuser() self.password = password self.port = port if port else 22 @@ -409,7 +410,7 @@ def _make_host_output(self, channel, encoding, read_timeout): stdout=BufferData(rw_buffer=_stdout_buffer, reader=_stdout_reader), stderr=BufferData(rw_buffer=_stderr_buffer, reader=_stderr_reader)) host_out = HostOutput( - host=self.host, channel=channel, stdin=Stdin(channel, self), + host=self.host, alias=self.alias, channel=channel, stdin=Stdin(channel, self), client=self, encoding=encoding, read_timeout=read_timeout, buffers=_buffers, ) diff --git a/pssh/clients/native/parallel.py b/pssh/clients/native/parallel.py index 0b95a51e..9cb94995 100644 --- a/pssh/clients/native/parallel.py +++ b/pssh/clients/native/parallel.py @@ -231,6 +231,7 @@ def _make_ssh_client(self, host, cfg, _pkey_data): _client = SSHClient( host, user=cfg.user or self.user, password=cfg.password or self.password, port=cfg.port or self.port, pkey=_pkey_data, num_retries=cfg.num_retries or self.num_retries, + alias=cfg.alias, timeout=cfg.timeout or self.timeout, allow_agent=cfg.allow_agent or self.allow_agent, retry_delay=cfg.retry_delay or self.retry_delay, proxy_host=cfg.proxy_host or self.proxy_host, diff --git a/pssh/clients/native/single.py b/pssh/clients/native/single.py index 2dc4e7cf..335dfa56 100644 --- a/pssh/clients/native/single.py +++ b/pssh/clients/native/single.py @@ -50,7 +50,7 @@ class SSHClient(BaseSSHClient): def __init__(self, host, user=None, password=None, port=None, - pkey=None, + pkey=None, alias=None, num_retries=DEFAULT_RETRIES, retry_delay=RETRY_DELAY, allow_agent=True, timeout=None, @@ -70,6 +70,8 @@ def __init__(self, host, :type user: str :param password: Password to use for password authentication. :type password: str + :param alias: Use an alias for this host. + :type alias: str :param port: SSH port to connect to. Defaults to SSH default (22) :type port: int :param pkey: Private key file path to use for authentication. Path must @@ -115,6 +117,7 @@ def __init__(self, host, self.keepalive_seconds = keepalive_seconds self._keepalive_greenlet = None self._proxy_client = None + self.alias = alias self.host = host self.port = port if port is not None else 22 if proxy_host is not None: @@ -133,7 +136,7 @@ def __init__(self, host, proxy_host = '127.0.0.1' self._chan_lock = RLock() super(SSHClient, self).__init__( - host, user=user, password=password, port=port, pkey=pkey, + host, user=user, password=password, alias=alias, port=port, pkey=pkey, num_retries=num_retries, retry_delay=retry_delay, allow_agent=allow_agent, _auth_thread_pool=_auth_thread_pool, timeout=timeout, @@ -146,7 +149,7 @@ def _shell(self, channel): return self._eagain(channel.shell) def _connect_proxy(self, proxy_host, proxy_port, proxy_pkey, - user=None, password=None, + user=None, password=None, alias=None, num_retries=DEFAULT_RETRIES, retry_delay=RETRY_DELAY, allow_agent=True, timeout=None, @@ -156,7 +159,7 @@ def _connect_proxy(self, proxy_host, proxy_port, proxy_pkey, assert isinstance(self.port, int) try: self._proxy_client = SSHClient( - proxy_host, port=proxy_port, pkey=proxy_pkey, + proxy_host, port=proxy_port, pkey=proxy_pkey, alias=alias, num_retries=num_retries, user=user, password=password, retry_delay=retry_delay, allow_agent=allow_agent, timeout=timeout, forward_ssh_agent=forward_ssh_agent, diff --git a/pssh/clients/ssh/parallel.py b/pssh/clients/ssh/parallel.py index bd7a11a8..7ab4aafd 100644 --- a/pssh/clients/ssh/parallel.py +++ b/pssh/clients/ssh/parallel.py @@ -217,6 +217,7 @@ def _make_ssh_client(self, host, cfg, _pkey_data): _client = SSHClient( host, user=cfg.user or self.user, password=cfg.password or self.password, port=cfg.port or self.port, pkey=_pkey_data, num_retries=cfg.num_retries or self.num_retries, + alias=cfg.alias, timeout=cfg.timeout or self.timeout, allow_agent=cfg.allow_agent or self.allow_agent, retry_delay=cfg.retry_delay or self.retry_delay, _auth_thread_pool=cfg.auth_thread_pool or self._auth_thread_pool, diff --git a/pssh/clients/ssh/single.py b/pssh/clients/ssh/single.py index 855ada98..9eee161e 100644 --- a/pssh/clients/ssh/single.py +++ b/pssh/clients/ssh/single.py @@ -40,7 +40,7 @@ class SSHClient(BaseSSHClient): def __init__(self, host, user=None, password=None, port=None, - pkey=None, + pkey=None, alias=None, cert_file=None, num_retries=DEFAULT_RETRIES, retry_delay=RETRY_DELAY, @@ -60,6 +60,8 @@ def __init__(self, host, :type password: str :param port: SSH port to connect to. Defaults to SSH default (22) :type port: int + :param alias: Use an alias for this host. + :type alias: str :param pkey: Private key file path to use for authentication. Path must be either absolute path or relative to user home directory like ``~/``. @@ -114,7 +116,7 @@ def __init__(self, host, self.gssapi_client_identity = gssapi_client_identity self.gssapi_delegate_credentials = gssapi_delegate_credentials super(SSHClient, self).__init__( - host, user=user, password=password, port=port, pkey=pkey, + host, user=user, password=password, port=port, pkey=pkey, alias=alias, num_retries=num_retries, retry_delay=retry_delay, allow_agent=allow_agent, _auth_thread_pool=_auth_thread_pool, diff --git a/pssh/config.py b/pssh/config.py index 5c1cc949..65fc3e34 100644 --- a/pssh/config.py +++ b/pssh/config.py @@ -25,7 +25,7 @@ class HostConfig(object): Used to hold individual configuration for each host in ParallelSSHClient host list. """ __slots__ = ('user', 'port', 'password', 'private_key', 'allow_agent', - 'num_retries', 'retry_delay', 'timeout', 'identity_auth', + 'alias', 'num_retries', 'retry_delay', 'timeout', 'identity_auth', 'proxy_host', 'proxy_port', 'proxy_user', 'proxy_password', 'proxy_pkey', 'keepalive_seconds', 'ipv6_only', 'cert_file', 'auth_thread_pool', 'gssapi_auth', 'gssapi_server_identity', 'gssapi_client_identity', 'gssapi_delegate_credentials', @@ -33,7 +33,7 @@ class HostConfig(object): ) def __init__(self, user=None, port=None, password=None, private_key=None, - allow_agent=None, num_retries=None, retry_delay=None, timeout=None, + allow_agent=None, alias=None, num_retries=None, retry_delay=None, timeout=None, identity_auth=None, proxy_host=None, proxy_port=None, proxy_user=None, proxy_password=None, proxy_pkey=None, @@ -58,6 +58,8 @@ def __init__(self, user=None, port=None, password=None, private_key=None, :type private_key: str :param allow_agent: Enable/disable SSH agent authentication. :type allow_agent: bool + :param alias: Use an alias for this host. + :type alias: str or int :param num_retries: Number of retry attempts before giving up on connection and SSH operations. :type num_retries: int @@ -103,6 +105,7 @@ def __init__(self, user=None, port=None, password=None, private_key=None, self.password = password self.private_key = private_key self.allow_agent = allow_agent + self.alias = alias self.num_retries = num_retries self.timeout = timeout self.retry_delay = retry_delay @@ -130,6 +133,8 @@ def _sanity_checks(self): raise ValueError("Port %s is not an integer" % (self.port,)) if self.password is not None and not isinstance(self.password, str): raise ValueError("Password %s is not a string" % (self.password,)) + if self.alias is not None and not isinstance(self.alias, str): + raise ValueError("Alias %s is not a string" % (self.alias,)) if self.private_key is not None and not ( isinstance(self.private_key, str) or isinstance(self.private_key, bytes) ): diff --git a/pssh/output.py b/pssh/output.py index c7e9375e..01d709b2 100644 --- a/pssh/output.py +++ b/pssh/output.py @@ -55,12 +55,12 @@ class HostOutput(object): """Host output""" __slots__ = ('host', 'channel', 'stdin', - 'client', 'exception', 'encoding', 'read_timeout', - 'buffers', + 'client', 'alias', 'exception', + 'encoding', 'read_timeout', 'buffers', ) def __init__(self, host, channel, stdin, - client, exception=None, encoding='utf-8', read_timeout=None, + client, alias=None, exception=None, encoding='utf-8', read_timeout=None, buffers=None): """ :param host: Host name output is for @@ -71,6 +71,8 @@ def __init__(self, host, channel, stdin, :type stdin: :py:func:`file`-like object :param client: `SSHClient` output is coming from. :type client: :py:class:`pssh.clients.base.single.BaseSSHClient` + :param alias: Host alias. + :type alias: str :param exception: Exception from host if any :type exception: :py:class:`Exception` or ``None`` :param read_timeout: Timeout in seconds for reading from buffers. @@ -82,6 +84,7 @@ def __init__(self, host, channel, stdin, self.channel = channel self.stdin = stdin self.client = client + self.alias = alias self.exception = exception self.encoding = encoding self.read_timeout = read_timeout @@ -117,12 +120,13 @@ def exit_code(self): def __repr__(self): return "\thost={host}{linesep}" \ + "\talias={alias}{linesep}" \ "\texit_code={exit_code}{linesep}" \ "\tchannel={channel}{linesep}" \ "\texception={exception}{linesep}" \ "\tencoding={encoding}{linesep}" \ "\tread_timeout={read_timeout}".format( - host=self.host, channel=self.channel, + host=self.host, alias=self.alias, channel=self.channel, exception=self.exception, linesep=linesep, exit_code=self.exit_code, encoding=self.encoding, read_timeout=self.read_timeout, ) diff --git a/tests/native/test_parallel_client.py b/tests/native/test_parallel_client.py index 5b760985..9250a4df 100644 --- a/tests/native/test_parallel_client.py +++ b/tests/native/test_parallel_client.py @@ -930,6 +930,7 @@ def test_host_config(self): servers = [] password = 'overriden_pass' fake_key = 'FAKE KEY' + aliases = [f"alias for host {host_i}" for host_i, _ in enumerate(hosts)] for host_i, (host, port) in enumerate(hosts): server = OpenSSHServer(listen_ip=host, port=port) server.start_server() @@ -937,12 +938,14 @@ def test_host_config(self): host_config[host_i].user = self.user host_config[host_i].password = password host_config[host_i].private_key = self.user_key + host_config[host_i].alias = aliases[host_i] servers.append(server) host_config[1].private_key = fake_key client = ParallelSSHClient([h for h, _ in hosts], host_config=host_config, num_retries=1) output = client.run_command(self.cmd, stop_on_errors=False) + client.join(output) self.assertEqual(len(hosts), len(output)) try: @@ -954,6 +957,8 @@ def test_host_config(self): self.assertEqual(client._host_clients[0, hosts[0][0]].user, self.user) self.assertEqual(client._host_clients[0, hosts[0][0]].password, password) self.assertEqual(client._host_clients[0, hosts[0][0]].pkey, open(os.path.abspath(self.user_key), 'rb').read()) + self.assertEqual(set(aliases), set([client.alias for client in output])) + for server in servers: server.stop() diff --git a/tests/native/test_single_client.py b/tests/native/test_single_client.py index 61df944e..762a6366 100644 --- a/tests/native/test_single_client.py +++ b/tests/native/test_single_client.py @@ -180,6 +180,13 @@ def test_execute(self): exit_code = host_out.channel.get_exit_status() self.assertEqual(host_out.exit_code, 0) self.assertEqual(expected, output) + + def test_alias(self): + client = SSHClient(self.host, port=self.port, + pkey=self.user_key, num_retries=1, + alias='test') + host_out = client.run_command(self.cmd) + self.assertEqual(host_out.alias, 'test') def test_open_session_timeout(self): client = SSHClient(self.host, port=self.port, diff --git a/tests/test_host_config.py b/tests/test_host_config.py index 0bdf863c..7327fb78 100644 --- a/tests/test_host_config.py +++ b/tests/test_host_config.py @@ -26,6 +26,7 @@ def test_host_config_entries(self): user = 'user' port = 22 password = 'password' + alias = 'alias' private_key = 'private key' allow_agent = False num_retries = 1 @@ -43,7 +44,7 @@ def test_host_config_entries(self): gssapi_client_identity = 'some_id' gssapi_delegate_credentials = True cfg = HostConfig( - user=user, port=port, password=password, private_key=private_key, + user=user, port=port, password=password, alias=alias, private_key=private_key, allow_agent=allow_agent, num_retries=num_retries, retry_delay=retry_delay, timeout=timeout, identity_auth=identity_auth, proxy_host=proxy_host, ipv6_only=ipv6_only, @@ -59,6 +60,7 @@ def test_host_config_entries(self): self.assertEqual(cfg.user, user) self.assertEqual(cfg.port, port) self.assertEqual(cfg.password, password) + self.assertEqual(cfg.alias, alias) self.assertEqual(cfg.private_key, private_key) self.assertEqual(cfg.allow_agent, allow_agent) self.assertEqual(cfg.num_retries, num_retries) @@ -79,6 +81,7 @@ def test_host_config_bad_entries(self): self.assertRaises(ValueError, HostConfig, user=22) self.assertRaises(ValueError, HostConfig, password=22) self.assertRaises(ValueError, HostConfig, port='22') + self.assertRaises(ValueError, HostConfig, alias=2) self.assertRaises(ValueError, HostConfig, private_key=1) self.assertRaises(ValueError, HostConfig, allow_agent=1) self.assertRaises(ValueError, HostConfig, num_retries='')