Skip to content

Commit

Permalink
Perf changes (#357)
Browse files Browse the repository at this point in the history
* Performance improvements for reading output in all clients.
* Output reading for all clients has been changed to be less prone to race conditions.
* Parallel clients now read a common private key only once, reusing it for all clients it applies to, to improve performance.
* Updated changelog.
* Added test case for joining on parallel clients without ever running run_command. Updated join so that it does not raise exception in that case.
  • Loading branch information
pkittenis authored Aug 20, 2022
1 parent 1b44e9a commit d812ff3
Show file tree
Hide file tree
Showing 10 changed files with 198 additions and 138 deletions.
20 changes: 20 additions & 0 deletions Changelog.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,26 @@
Change Log
============

2.12.0
+++++++

Changes
--------

* Added ``alias`` optional parameter to ``SSHClient`` and ``HostConfig`` for passing through from parallel clients.
Used to set an SSH host name alias, for cases where the real host name is the same and there is a need to
differentiate output from otherwise identical host names - #355. Thank you @simonfelding.
* Parallel clients now read a common private key only once, reusing it for all clients it applies to,
to improve performance.
* Performance improvements for all clients when reading output.
* Output reading for all clients has been changed to be less prone to race conditions.

Fixes
------

* Calling ``ParallelSSHClient.join`` without ever running ``run_command`` would raise exception. Is now a no-op.


2.11.1
+++++++

Expand Down
61 changes: 36 additions & 25 deletions pssh/clients/base/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from gevent import joinall, spawn, Timeout as GTimeout
from gevent.hub import Hub

from ..common import _validate_pkey_path
from ..common import _validate_pkey_path, _validate_pkey
from ...config import HostConfig
from ...constants import DEFAULT_RETRIES, RETRY_DELAY
from ...exceptions import HostArgumentError, Timeout, ShellError, HostConfigError
Expand All @@ -39,7 +39,7 @@ class BaseParallelSSHClient(object):
def __init__(self, hosts, user=None, password=None, port=None, pkey=None,
allow_agent=True,
num_retries=DEFAULT_RETRIES,
timeout=120, pool_size=10,
timeout=120, pool_size=100,
host_config=None, retry_delay=RETRY_DELAY,
identity_auth=True,
ipv6_only=False,
Expand All @@ -64,7 +64,8 @@ def __init__(self, hosts, user=None, password=None, port=None, pkey=None,
self.user = user
self.password = password
self.port = port
self.pkey = pkey
self.pkey = _validate_pkey(pkey)
self.__pkey_data = self._load_pkey_data(pkey) if pkey is not None else None
self.num_retries = num_retries
self.timeout = timeout
self._host_clients = {}
Expand Down Expand Up @@ -113,9 +114,26 @@ def hosts(self, _hosts):
self._host_clients.pop((i, host), None)
self._hosts = _hosts

def __del__(self):
self.disconnect()

def disconnect(self):
if not hasattr(self, '_host_clients'):
return
for s_client in self._host_clients.values():
try:
s_client.disconnect()
except Exception as ex:
logger.debug("Client disconnect failed with %s", ex)
pass
del s_client

def _check_host_config(self):
if self.host_config is None:
return
if not isinstance(self.host_config, list):
raise HostConfigError("Host configuration of type %s is invalid - valid types are List[HostConfig]",
type(self.host_config))
host_len = len(self.hosts)
if host_len != len(self.host_config):
raise ValueError(
Expand Down Expand Up @@ -231,7 +249,7 @@ def _get_output_from_cmds(self, cmds, raise_error=False):

def _get_output_from_greenlet(self, cmd_i, cmd, raise_error=False):
host = self.hosts[cmd_i]
alias = self._get_host_config(cmd_i, host).alias
alias = self._get_host_config(cmd_i).alias
try:
host_out = cmd.get()
return host_out
Expand All @@ -256,7 +274,7 @@ def get_last_output(self, cmds=None):
return self._get_output_from_cmds(
cmds, raise_error=False)

def _get_host_config(self, host_i, host):
def _get_host_config(self, host_i):
if self.host_config is None:
config = HostConfig(
user=self.user, port=self.port, password=self.password, private_key=self.pkey,
Expand All @@ -275,17 +293,13 @@ def _get_host_config(self, host_i, host):
alias=None,
)
return config
elif not isinstance(self.host_config, list):
raise HostConfigError("Host configuration of type %s is invalid - valid types are list[HostConfig]",
type(self.host_config))
config = self.host_config[host_i]
return config

def _run_command(self, host_i, host, command, sudo=False, user=None,
shell=None, use_pty=False,
encoding='utf-8', read_timeout=None):
"""Make SSHClient if needed, run command on host"""
logger.debug("_run_command with read timeout %s", read_timeout)
try:
_client = self._get_ssh_client(host_i, host)
host_out = _client.run_command(
Expand All @@ -311,13 +325,13 @@ def connect_auth(self):
:returns: list of greenlets to ``joinall`` with.
:rtype: list(:py:mod:`gevent.greenlet.Greenlet`)
"""
cmds = [spawn(self._get_ssh_client, i, host) for i, host in enumerate(self.hosts)]
cmds = [self.pool.spawn(self._get_ssh_client, i, host) for i, host in enumerate(self.hosts)]
return cmds

def _consume_output(self, stdout, stderr):
for line in stdout:
for _ in stdout:
pass
for line in stderr:
for _ in stderr:
pass

def join(self, output=None, consume_output=False, timeout=None):
Expand Down Expand Up @@ -346,6 +360,9 @@ def join(self, output=None, consume_output=False, timeout=None):
:rtype: ``None``"""
if output is None:
output = self.get_last_output()
if output is None:
logger.info("No last output to join on - run_command has never been run.")
return
elif not isinstance(output, list):
raise ValueError("Unexpected output object type")
cmds = [self.pool.spawn(self._join, host_out, timeout=timeout,
Expand Down Expand Up @@ -544,32 +561,26 @@ def _copy_remote_file(self, host_i, host, remote_file, local_file, recurse,
return client.copy_remote_file(
remote_file, local_file, recurse=recurse, **kwargs)

def _handle_greenlet_exc(self, func, host, *args, **kwargs):
try:
return func(*args, **kwargs)
except Exception as ex:
raise ex

def _get_ssh_client(self, host_i, host):
logger.debug("Make client request for host %s, (host_i, host) in clients: %s",
host, (host_i, host) in self._host_clients)
_client = self._host_clients.get((host_i, host))
if _client is not None:
return _client
cfg = self._get_host_config(host_i, host)
cfg = self._get_host_config(host_i)
_pkey = self.pkey if cfg.private_key is None else cfg.private_key
_pkey_data = self._load_pkey_data(_pkey)
_client = self._make_ssh_client(host, cfg, _pkey_data)
self._host_clients[(host_i, host)] = _client
return _client

def _load_pkey_data(self, _pkey):
if isinstance(_pkey, str):
_validate_pkey_path(_pkey)
with open(_pkey, 'rb') as fh:
_pkey_data = fh.read()
return _pkey_data
return _pkey
if not isinstance(_pkey, str):
return _pkey
_pkey = _validate_pkey_path(_pkey)
with open(_pkey, 'rb') as fh:
_pkey_data = fh.read()
return _pkey_data

def _make_ssh_client(self, host, cfg, _pkey_data):
raise NotImplementedError
29 changes: 17 additions & 12 deletions pssh/clients/base/single.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,17 @@
from gevent import sleep, socket, Timeout as GTimeout
from gevent.hub import Hub
from gevent.select import poll, POLLIN, POLLOUT

from ssh2.utils import find_eol
from ssh2.exceptions import AgentConnectionError, AgentListIdentitiesError, \
AgentAuthenticationError, AgentGetIdentityError
from ssh2.utils import find_eol

from ..common import _validate_pkey
from ...constants import DEFAULT_RETRIES, RETRY_DELAY
from ..reader import ConcurrentRWBuffer
from ...constants import DEFAULT_RETRIES, RETRY_DELAY
from ...exceptions import UnknownHostError, AuthenticationError, \
ConnectionError, Timeout, NoIPv6AddressFoundError
from ...output import HostOutput, HostOutputBuffers, BufferData


Hub.NOT_ERROR = (Exception,)
host_logger = logging.getLogger('pssh.host_logger')
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -287,15 +285,15 @@ def _connect(self, host, port, retries=1):
raise unknown_ex from ex
for i, (family, _type, proto, _, sock_addr) in enumerate(addr_info):
try:
return self._connect_socket(family, _type, proto, sock_addr, host, port, retries)
return self._connect_socket(family, _type, sock_addr, host, port, retries)
except ConnectionRefusedError as ex:
if i+1 == len(addr_info):
logger.error("No available addresses from %s", [addr[4] for addr in addr_info])
ex.args += (host, port)
raise
continue

def _connect_socket(self, family, _type, proto, sock_addr, host, port, retries):
def _connect_socket(self, family, _type, sock_addr, host, port, retries):
self.sock = socket.socket(family, _type)
if self.timeout:
self.sock.settimeout(self.timeout)
Expand Down Expand Up @@ -428,6 +426,8 @@ def read_stderr(self, stderr_buffer, timeout=None):
:param stderr_buffer: Buffer to read from.
:type stderr_buffer: :py:class:`pssh.clients.reader.ConcurrentRWBuffer`
:param timeout: Timeout in seconds - defaults to no timeout.
:type timeout: int or float
:rtype: generator
"""
logger.debug("Reading from stderr buffer, timeout=%s", timeout)
Expand All @@ -439,6 +439,8 @@ def read_output(self, stdout_buffer, timeout=None):
:param stdout_buffer: Buffer to read from.
:type stdout_buffer: :py:class:`pssh.clients.reader.ConcurrentRWBuffer`
:param timeout: Timeout in seconds - defaults to no timeout.
:type timeout: int or float
:rtype: generator
"""
logger.debug("Reading from stdout buffer, timeout=%s", timeout)
Expand Down Expand Up @@ -492,14 +494,16 @@ def read_output_buffer(self, output_buffer, prefix=None,
encoding='utf-8'):
"""Read from output buffers and log to ``host_logger``.
:param output_buffer: Iterator containing buffer
:param output_buffer: Iterator containing buffer.
:type output_buffer: iterator
:param prefix: String to prefix log output to ``host_logger`` with
:param prefix: String to prefix log output to ``host_logger`` with.
:type prefix: str
:param callback: Function to call back once buffer is depleted:
:param callback: Function to call back once buffer is depleted.
:type callback: function
:param callback_args: Arguments for call back function
:param callback_args: Arguments for call back function.
:type callback_args: tuple
:param encoding: Encoding for output.
:type encoding: str
"""
prefix = '' if prefix is None else prefix
for line in output_buffer:
Expand Down Expand Up @@ -553,7 +557,7 @@ def run_command(self, command, sudo=False, user=None,
host_out = self._make_host_output(channel, encoding, _timeout)
return host_out

def _eagain_write_errcode(self, write_func, data, eagain, timeout=None):
def _eagain_write_errcode(self, write_func, data, eagain):
data_len = len(data)
total_written = 0
while total_written < data_len:
Expand All @@ -570,9 +574,10 @@ def _eagain_errcode(self, func, eagain, *args, **kwargs):
while ret == eagain:
self.poll()
ret = func(*args, **kwargs)
sleep()
return ret

def _eagain_write(self, write_func, data, timeout=None):
def _eagain_write(self, write_func, data):
raise NotImplementedError

def _eagain(self, func, *args, **kwargs):
Expand Down
31 changes: 10 additions & 21 deletions pssh/clients/native/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,6 @@ def __init__(self, hosts, user=None, password=None, port=22, pkey=None,
identity_auth=identity_auth,
ipv6_only=ipv6_only,
)
self.pkey = _validate_pkey(pkey)
self.proxy_host = proxy_host
self.proxy_port = proxy_port
self.proxy_pkey = _validate_pkey(proxy_pkey)
Expand Down Expand Up @@ -216,17 +215,6 @@ def run_command(self, command, sudo=False, user=None, stop_on_errors=True,
read_timeout=read_timeout,
)

def __del__(self):
if not hasattr(self, '_host_clients'):
return
for s_client in self._host_clients.values():
try:
s_client.disconnect()
except Exception as ex:
logger.debug("Client disconnect failed with %s", ex)
pass
del s_client

def _make_ssh_client(self, host, cfg, _pkey_data):
_client = SSHClient(
host, user=cfg.user or self.user, password=cfg.password or self.password, port=cfg.port or self.port,
Expand Down Expand Up @@ -371,16 +359,12 @@ def copy_remote_file(self, remote_file, local_file, recurse=False,
encoding=encoding)

def _scp_send(self, host_i, host, local_file, remote_file, recurse=False):
self._get_ssh_client(host_i, host)
return self._handle_greenlet_exc(
self._host_clients[(host_i, host)].scp_send, host,
local_file, remote_file, recurse=recurse)
_client = self._get_ssh_client(host_i, host)
return _client.scp_send(local_file, remote_file, recurse=recurse)

def _scp_recv(self, host_i, host, remote_file, local_file, recurse=False):
self._get_ssh_client(host_i, host)
return self._handle_greenlet_exc(
self._host_clients[(host_i, host)].scp_recv, host,
remote_file, local_file, recurse=recurse)
_client = self._get_ssh_client(host_i, host)
return _client.scp_recv(remote_file, local_file, recurse=recurse)

def scp_send(self, local_file, remote_file, recurse=False, copy_args=None):
"""Copy local file to remote file in parallel via SCP.
Expand All @@ -405,6 +389,11 @@ def scp_send(self, local_file, remote_file, recurse=False, copy_args=None):
:type local_file: str
:param remote_file: Remote filepath on remote host to copy file to
:type remote_file: str
:param copy_args: (Optional) format local_file and remote_file strings
with per-host arguments in ``copy_args``. ``copy_args`` length must
equal length of host list -
:py:class:`pssh.exceptions.HostArgumentError` is raised otherwise
:type copy_args: tuple or list
:param recurse: Whether or not to descend into directories recursively.
:type recurse: bool
Expand All @@ -416,7 +405,7 @@ def scp_send(self, local_file, remote_file, recurse=False, copy_args=None):
"""
copy_args = [{'local_file': local_file,
'remote_file': remote_file}
for i, host in enumerate(self.hosts)] \
for _ in self.hosts] \
if copy_args is None else copy_args
local_file = "%(local_file)s"
remote_file = "%(remote_file)s"
Expand Down
Loading

0 comments on commit d812ff3

Please sign in to comment.