diff --git a/Changelog.rst b/Changelog.rst index 8f1b93dc..826ead55 100644 --- a/Changelog.rst +++ b/Changelog.rst @@ -1,6 +1,26 @@ Change Log ============ +2.12.0 ++++++++ + +Changes +-------- + +* Added ``alias`` optional parameter to ``SSHClient`` and ``HostConfig`` for passing through from parallel clients. + Used to set an SSH host name alias, for cases where the real host name is the same and there is a need to + differentiate output from otherwise identical host names - #355. Thank you @simonfelding. +* Parallel clients now read a common private key only once, reusing it for all clients it applies to, + to improve performance. +* Performance improvements for all clients when reading output. +* Output reading for all clients has been changed to be less prone to race conditions. + +Fixes +------ + +* Calling ``ParallelSSHClient.join`` without ever running ``run_command`` would raise exception. Is now a no-op. + + 2.11.1 +++++++ diff --git a/pssh/clients/base/parallel.py b/pssh/clients/base/parallel.py index f2b83df7..b216d8f9 100644 --- a/pssh/clients/base/parallel.py +++ b/pssh/clients/base/parallel.py @@ -23,7 +23,7 @@ from gevent import joinall, spawn, Timeout as GTimeout from gevent.hub import Hub -from ..common import _validate_pkey_path +from ..common import _validate_pkey_path, _validate_pkey from ...config import HostConfig from ...constants import DEFAULT_RETRIES, RETRY_DELAY from ...exceptions import HostArgumentError, Timeout, ShellError, HostConfigError @@ -39,7 +39,7 @@ class BaseParallelSSHClient(object): def __init__(self, hosts, user=None, password=None, port=None, pkey=None, allow_agent=True, num_retries=DEFAULT_RETRIES, - timeout=120, pool_size=10, + timeout=120, pool_size=100, host_config=None, retry_delay=RETRY_DELAY, identity_auth=True, ipv6_only=False, @@ -64,7 +64,8 @@ def __init__(self, hosts, user=None, password=None, port=None, pkey=None, self.user = user self.password = password self.port = port - self.pkey = pkey + self.pkey = _validate_pkey(pkey) + self.__pkey_data = self._load_pkey_data(pkey) if pkey is not None else None self.num_retries = num_retries self.timeout = timeout self._host_clients = {} @@ -113,9 +114,26 @@ def hosts(self, _hosts): self._host_clients.pop((i, host), None) self._hosts = _hosts + def __del__(self): + self.disconnect() + + def disconnect(self): + if not hasattr(self, '_host_clients'): + return + for s_client in self._host_clients.values(): + try: + s_client.disconnect() + except Exception as ex: + logger.debug("Client disconnect failed with %s", ex) + pass + del s_client + def _check_host_config(self): if self.host_config is None: return + if not isinstance(self.host_config, list): + raise HostConfigError("Host configuration of type %s is invalid - valid types are List[HostConfig]", + type(self.host_config)) host_len = len(self.hosts) if host_len != len(self.host_config): raise ValueError( @@ -231,7 +249,7 @@ def _get_output_from_cmds(self, cmds, raise_error=False): def _get_output_from_greenlet(self, cmd_i, cmd, raise_error=False): host = self.hosts[cmd_i] - alias = self._get_host_config(cmd_i, host).alias + alias = self._get_host_config(cmd_i).alias try: host_out = cmd.get() return host_out @@ -256,7 +274,7 @@ def get_last_output(self, cmds=None): return self._get_output_from_cmds( cmds, raise_error=False) - def _get_host_config(self, host_i, host): + def _get_host_config(self, host_i): if self.host_config is None: config = HostConfig( user=self.user, port=self.port, password=self.password, private_key=self.pkey, @@ -275,9 +293,6 @@ def _get_host_config(self, host_i, host): alias=None, ) return config - elif not isinstance(self.host_config, list): - raise HostConfigError("Host configuration of type %s is invalid - valid types are list[HostConfig]", - type(self.host_config)) config = self.host_config[host_i] return config @@ -285,7 +300,6 @@ def _run_command(self, host_i, host, command, sudo=False, user=None, shell=None, use_pty=False, encoding='utf-8', read_timeout=None): """Make SSHClient if needed, run command on host""" - logger.debug("_run_command with read timeout %s", read_timeout) try: _client = self._get_ssh_client(host_i, host) host_out = _client.run_command( @@ -311,13 +325,13 @@ def connect_auth(self): :returns: list of greenlets to ``joinall`` with. :rtype: list(:py:mod:`gevent.greenlet.Greenlet`) """ - cmds = [spawn(self._get_ssh_client, i, host) for i, host in enumerate(self.hosts)] + cmds = [self.pool.spawn(self._get_ssh_client, i, host) for i, host in enumerate(self.hosts)] return cmds def _consume_output(self, stdout, stderr): - for line in stdout: + for _ in stdout: pass - for line in stderr: + for _ in stderr: pass def join(self, output=None, consume_output=False, timeout=None): @@ -346,6 +360,9 @@ def join(self, output=None, consume_output=False, timeout=None): :rtype: ``None``""" if output is None: output = self.get_last_output() + if output is None: + logger.info("No last output to join on - run_command has never been run.") + return elif not isinstance(output, list): raise ValueError("Unexpected output object type") cmds = [self.pool.spawn(self._join, host_out, timeout=timeout, @@ -544,19 +561,13 @@ def _copy_remote_file(self, host_i, host, remote_file, local_file, recurse, return client.copy_remote_file( remote_file, local_file, recurse=recurse, **kwargs) - def _handle_greenlet_exc(self, func, host, *args, **kwargs): - try: - return func(*args, **kwargs) - except Exception as ex: - raise ex - def _get_ssh_client(self, host_i, host): logger.debug("Make client request for host %s, (host_i, host) in clients: %s", host, (host_i, host) in self._host_clients) _client = self._host_clients.get((host_i, host)) if _client is not None: return _client - cfg = self._get_host_config(host_i, host) + cfg = self._get_host_config(host_i) _pkey = self.pkey if cfg.private_key is None else cfg.private_key _pkey_data = self._load_pkey_data(_pkey) _client = self._make_ssh_client(host, cfg, _pkey_data) @@ -564,12 +575,12 @@ def _get_ssh_client(self, host_i, host): return _client def _load_pkey_data(self, _pkey): - if isinstance(_pkey, str): - _validate_pkey_path(_pkey) - with open(_pkey, 'rb') as fh: - _pkey_data = fh.read() - return _pkey_data - return _pkey + if not isinstance(_pkey, str): + return _pkey + _pkey = _validate_pkey_path(_pkey) + with open(_pkey, 'rb') as fh: + _pkey_data = fh.read() + return _pkey_data def _make_ssh_client(self, host, cfg, _pkey_data): raise NotImplementedError diff --git a/pssh/clients/base/single.py b/pssh/clients/base/single.py index 17db66d4..6229e0f4 100644 --- a/pssh/clients/base/single.py +++ b/pssh/clients/base/single.py @@ -23,19 +23,17 @@ from gevent import sleep, socket, Timeout as GTimeout from gevent.hub import Hub from gevent.select import poll, POLLIN, POLLOUT - -from ssh2.utils import find_eol from ssh2.exceptions import AgentConnectionError, AgentListIdentitiesError, \ AgentAuthenticationError, AgentGetIdentityError +from ssh2.utils import find_eol from ..common import _validate_pkey -from ...constants import DEFAULT_RETRIES, RETRY_DELAY from ..reader import ConcurrentRWBuffer +from ...constants import DEFAULT_RETRIES, RETRY_DELAY from ...exceptions import UnknownHostError, AuthenticationError, \ ConnectionError, Timeout, NoIPv6AddressFoundError from ...output import HostOutput, HostOutputBuffers, BufferData - Hub.NOT_ERROR = (Exception,) host_logger = logging.getLogger('pssh.host_logger') logger = logging.getLogger(__name__) @@ -287,7 +285,7 @@ def _connect(self, host, port, retries=1): raise unknown_ex from ex for i, (family, _type, proto, _, sock_addr) in enumerate(addr_info): try: - return self._connect_socket(family, _type, proto, sock_addr, host, port, retries) + return self._connect_socket(family, _type, sock_addr, host, port, retries) except ConnectionRefusedError as ex: if i+1 == len(addr_info): logger.error("No available addresses from %s", [addr[4] for addr in addr_info]) @@ -295,7 +293,7 @@ def _connect(self, host, port, retries=1): raise continue - def _connect_socket(self, family, _type, proto, sock_addr, host, port, retries): + def _connect_socket(self, family, _type, sock_addr, host, port, retries): self.sock = socket.socket(family, _type) if self.timeout: self.sock.settimeout(self.timeout) @@ -428,6 +426,8 @@ def read_stderr(self, stderr_buffer, timeout=None): :param stderr_buffer: Buffer to read from. :type stderr_buffer: :py:class:`pssh.clients.reader.ConcurrentRWBuffer` + :param timeout: Timeout in seconds - defaults to no timeout. + :type timeout: int or float :rtype: generator """ logger.debug("Reading from stderr buffer, timeout=%s", timeout) @@ -439,6 +439,8 @@ def read_output(self, stdout_buffer, timeout=None): :param stdout_buffer: Buffer to read from. :type stdout_buffer: :py:class:`pssh.clients.reader.ConcurrentRWBuffer` + :param timeout: Timeout in seconds - defaults to no timeout. + :type timeout: int or float :rtype: generator """ logger.debug("Reading from stdout buffer, timeout=%s", timeout) @@ -492,14 +494,16 @@ def read_output_buffer(self, output_buffer, prefix=None, encoding='utf-8'): """Read from output buffers and log to ``host_logger``. - :param output_buffer: Iterator containing buffer + :param output_buffer: Iterator containing buffer. :type output_buffer: iterator - :param prefix: String to prefix log output to ``host_logger`` with + :param prefix: String to prefix log output to ``host_logger`` with. :type prefix: str - :param callback: Function to call back once buffer is depleted: + :param callback: Function to call back once buffer is depleted. :type callback: function - :param callback_args: Arguments for call back function + :param callback_args: Arguments for call back function. :type callback_args: tuple + :param encoding: Encoding for output. + :type encoding: str """ prefix = '' if prefix is None else prefix for line in output_buffer: @@ -553,7 +557,7 @@ def run_command(self, command, sudo=False, user=None, host_out = self._make_host_output(channel, encoding, _timeout) return host_out - def _eagain_write_errcode(self, write_func, data, eagain, timeout=None): + def _eagain_write_errcode(self, write_func, data, eagain): data_len = len(data) total_written = 0 while total_written < data_len: @@ -570,9 +574,10 @@ def _eagain_errcode(self, func, eagain, *args, **kwargs): while ret == eagain: self.poll() ret = func(*args, **kwargs) + sleep() return ret - def _eagain_write(self, write_func, data, timeout=None): + def _eagain_write(self, write_func, data): raise NotImplementedError def _eagain(self, func, *args, **kwargs): diff --git a/pssh/clients/native/parallel.py b/pssh/clients/native/parallel.py index 9cb94995..c8c879c2 100644 --- a/pssh/clients/native/parallel.py +++ b/pssh/clients/native/parallel.py @@ -127,7 +127,6 @@ def __init__(self, hosts, user=None, password=None, port=22, pkey=None, identity_auth=identity_auth, ipv6_only=ipv6_only, ) - self.pkey = _validate_pkey(pkey) self.proxy_host = proxy_host self.proxy_port = proxy_port self.proxy_pkey = _validate_pkey(proxy_pkey) @@ -216,17 +215,6 @@ def run_command(self, command, sudo=False, user=None, stop_on_errors=True, read_timeout=read_timeout, ) - def __del__(self): - if not hasattr(self, '_host_clients'): - return - for s_client in self._host_clients.values(): - try: - s_client.disconnect() - except Exception as ex: - logger.debug("Client disconnect failed with %s", ex) - pass - del s_client - def _make_ssh_client(self, host, cfg, _pkey_data): _client = SSHClient( host, user=cfg.user or self.user, password=cfg.password or self.password, port=cfg.port or self.port, @@ -371,16 +359,12 @@ def copy_remote_file(self, remote_file, local_file, recurse=False, encoding=encoding) def _scp_send(self, host_i, host, local_file, remote_file, recurse=False): - self._get_ssh_client(host_i, host) - return self._handle_greenlet_exc( - self._host_clients[(host_i, host)].scp_send, host, - local_file, remote_file, recurse=recurse) + _client = self._get_ssh_client(host_i, host) + return _client.scp_send(local_file, remote_file, recurse=recurse) def _scp_recv(self, host_i, host, remote_file, local_file, recurse=False): - self._get_ssh_client(host_i, host) - return self._handle_greenlet_exc( - self._host_clients[(host_i, host)].scp_recv, host, - remote_file, local_file, recurse=recurse) + _client = self._get_ssh_client(host_i, host) + return _client.scp_recv(remote_file, local_file, recurse=recurse) def scp_send(self, local_file, remote_file, recurse=False, copy_args=None): """Copy local file to remote file in parallel via SCP. @@ -405,6 +389,11 @@ def scp_send(self, local_file, remote_file, recurse=False, copy_args=None): :type local_file: str :param remote_file: Remote filepath on remote host to copy file to :type remote_file: str + :param copy_args: (Optional) format local_file and remote_file strings + with per-host arguments in ``copy_args``. ``copy_args`` length must + equal length of host list - + :py:class:`pssh.exceptions.HostArgumentError` is raised otherwise + :type copy_args: tuple or list :param recurse: Whether or not to descend into directories recursively. :type recurse: bool @@ -416,7 +405,7 @@ def scp_send(self, local_file, remote_file, recurse=False, copy_args=None): """ copy_args = [{'local_file': local_file, 'remote_file': remote_file} - for i, host in enumerate(self.hosts)] \ + for _ in self.hosts] \ if copy_args is None else copy_args local_file = "%(local_file)s" remote_file = "%(remote_file)s" diff --git a/pssh/clients/native/single.py b/pssh/clients/native/single.py index 335dfa56..1256ae4d 100644 --- a/pssh/clients/native/single.py +++ b/pssh/clients/native/single.py @@ -18,7 +18,6 @@ import logging import os from collections import deque -from warnings import warn from gevent import sleep, spawn, get_hub from gevent.lock import RLock @@ -33,11 +32,10 @@ from .tunnel import FORWARDER from ..base.single import BaseSSHClient -from ...output import HostOutput +from ...constants import DEFAULT_RETRIES, RETRY_DELAY from ...exceptions import SessionError, SFTPError, \ SFTPIOError, Timeout, SCPError, ProxyError -from ...constants import DEFAULT_RETRIES, RETRY_DELAY - +from ...output import HostOutput logger = logging.getLogger(__name__) THREAD_POOL = get_hub().threadpool @@ -64,7 +62,8 @@ def __init__(self, host, identity_auth=True, ipv6_only=False, ): - """:param host: Host name or IP to connect to. + """ + :param host: Host name or IP to connect to. :type host: str :param user: User to connect as. Defaults to logged in user. :type user: str @@ -134,7 +133,8 @@ def __init__(self, host, identity_auth=identity_auth, ) proxy_host = '127.0.0.1' - self._chan_lock = RLock() + self._chan_stdout_lock = RLock() + self._chan_stderr_lock = RLock() super(SSHClient, self).__init__( host, user=user, password=password, alias=alias, port=port, pkey=pkey, num_retries=num_retries, retry_delay=retry_delay, @@ -231,6 +231,8 @@ def _init_session(self, retries=1): return self._connect_init_session_retry(retries=retries+1) msg = "Error connecting to host %s:%s - %s" logger.error(msg, self.host, self.port, ex) + if not self.sock.closed: + self.sock.close() if isinstance(ex, SSH2Timeout): raise Timeout(msg, self.host, self.port, ex) raise @@ -269,11 +271,7 @@ def open_session(self): chan = self._open_session() except Exception as ex: raise SessionError(ex) - if self.forward_ssh_agent and not self._forward_requested: - if not hasattr(chan, 'request_auth_agent'): - warn("Requested SSH Agent forwarding but libssh2 version used " - "does not support it - ignoring") - return chan + # if self.forward_ssh_agent and not self._forward_requested: # self._eagain(chan.request_auth_agent) # self._forward_requested = True return chan @@ -303,18 +301,19 @@ def execute(self, cmd, use_pty=False, channel=None): self._eagain(channel.execute, cmd) return channel - def _read_output_to_buffer(self, read_func, _buffer): + def _read_output_to_buffer(self, read_func, _buffer, is_stderr=False): + _lock = self._chan_stderr_lock if is_stderr else self._chan_stdout_lock try: while True: - with self._chan_lock: + with _lock: size, data = read_func() - while size == LIBSSH2_ERROR_EAGAIN: + if size == LIBSSH2_ERROR_EAGAIN: self.poll() - with self._chan_lock: - size, data = read_func() + continue if size <= 0: break _buffer.write(data) + sleep() finally: _buffer.eof.set() @@ -342,7 +341,7 @@ def wait_finished(self, host_output, timeout=None): self.close_channel(channel) def close_channel(self, channel): - with self._chan_lock: + with self._chan_stdout_lock, self._chan_stderr_lock: logger.debug("Closing channel") self._eagain(channel.close) @@ -353,7 +352,6 @@ def _make_sftp_eagain(self): return self._eagain(self.session.sftp_init) def _make_sftp(self): - """Make SFTP client from open transport""" try: sftp = self._make_sftp_eagain() except Exception as ex: @@ -361,7 +359,7 @@ def _make_sftp(self): return sftp def _mkdir(self, sftp, directory): - """Make directory via SFTP channel + """Make directory via SFTP channel. :param sftp: SFTP client object :type sftp: :py:class:`ssh2.sftp.SFTP` @@ -431,10 +429,22 @@ def _sftp_put(self, remote_fh, local_file): data = local_fh.read(self._BUF_SIZE) def sftp_put(self, sftp, local_file, remote_file): + """Perform an SFTP put - copy local file path to remote via SFTP. + + :param sftp: SFTP client object. + :type sftp: :py:class:`ssh2.sftp.SFTP` + :param local_file: Local filepath to copy to remote host. + :type local_file: str + :param remote_file: Remote filepath on remote host to copy file to. + :type remote_file: str + + :raises: :py:class:`pssh.exceptions.SFTPIOError` on I/O errors writing + via SFTP. + """ mode = LIBSSH2_SFTP_S_IRUSR | \ - LIBSSH2_SFTP_S_IWUSR | \ - LIBSSH2_SFTP_S_IRGRP | \ - LIBSSH2_SFTP_S_IROTH + LIBSSH2_SFTP_S_IWUSR | \ + LIBSSH2_SFTP_S_IRGRP | \ + LIBSSH2_SFTP_S_IROTH f_flags = LIBSSH2_FXF_CREAT | LIBSSH2_FXF_WRITE | LIBSSH2_FXF_TRUNC with self._sftp_openfh( sftp.open, remote_file, f_flags, mode) as remote_fh: @@ -560,6 +570,9 @@ def scp_recv(self, remote_file, local_file, recurse=False, sftp=None, :type local_file: str :param recurse: Whether or not to recursively copy directories :type recurse: bool + :param sftp: The SFTP channel to use instead of creating a new one. + Only used when ``recurse`` is ``True``. + :type sftp: :py:class:`ssh2.sftp.SFTP` :param encoding: Encoding to use for file paths when recursion is enabled. :type encoding: str @@ -617,6 +630,9 @@ def scp_send(self, local_file, remote_file, recurse=False, sftp=None): :type local_file: str :param remote_file: Remote filepath on remote host to copy file to :type remote_file: str + :param sftp: The SFTP channel to use instead of creating a new one. + Only used when ``recurse`` is ``True``. + :type sftp: :py:class:`ssh2.sftp.SFTP` :param recurse: Whether or not to descend into directories recursively. :type recurse: bool @@ -737,12 +753,12 @@ def poll(self, timeout=None): LIBSSH2_SESSION_BLOCK_OUTBOUND, ) - def _eagain_write(self, write_func, data, timeout=None): + def _eagain_write(self, write_func, data): """Write data with given write_func for an ssh2-python session while handling EAGAIN and resuming writes from last written byte on each call to write_func. """ - return self._eagain_write_errcode(write_func, data, LIBSSH2_ERROR_EAGAIN, timeout=timeout) + return self._eagain_write_errcode(write_func, data, LIBSSH2_ERROR_EAGAIN) - def eagain_write(self, write_func, data, timeout=None): - return self._eagain_write(write_func, data, timeout=timeout) + def eagain_write(self, write_func, data): + return self._eagain_write(write_func, data) diff --git a/pssh/clients/native/tunnel.py b/pssh/clients/native/tunnel.py index 5748a3c2..30c9ca01 100644 --- a/pssh/clients/native/tunnel.py +++ b/pssh/clients/native/tunnel.py @@ -18,10 +18,7 @@ import logging from threading import Thread, Event -try: - from queue import Queue -except ImportError: - from Queue import Queue +from queue import Queue from gevent import spawn, joinall, get_hub, sleep from gevent.server import StreamServer @@ -193,7 +190,7 @@ def _read_forward_sock(self, forward_sock, channel): sleep(.01) continue try: - self._client._eagain_write(channel.write, data) + self._client.eagain_write(channel.write, data) except Exception as ex: logger.error("Error writing data to channel - %s", ex) raise diff --git a/pssh/clients/reader.py b/pssh/clients/reader.py index 2fb19094..c6b69b2a 100644 --- a/pssh/clients/reader.py +++ b/pssh/clients/reader.py @@ -17,31 +17,41 @@ from io import BytesIO -from gevent import sleep from gevent.event import Event from gevent.lock import RLock +class _Eof(Event): + def __init__(self, unread_data): + self._unread_data = unread_data + Event.__init__(self) + + def set(self): + self._unread_data.set() + Event.set(self) + + class ConcurrentRWBuffer(object): """Concurrent reader/writer of bytes for use from multiple greenlets. Supports both concurrent reading and writing. - Iterate on buffer object to read data, yielding greenlet if no data exists + Iterate on buffer object to read data, yielding event loop if no data exists until self.eof has been set. - Writers should ``eof.set()`` when finished writing data via ``write``. + Writers should call ``ConcurrentRWBuffer.eof.set()`` when finished writing data via ``write``. Readers can use ``read()`` to get any available data or ``None``. """ - __slots__ = ('_buffer', '_read_pos', '_write_pos', 'eof', '_lock') + __slots__ = ('_buffer', '_read_pos', '_write_pos', 'eof', '_lock', '_unread_data') def __init__(self): self._buffer = BytesIO() self._read_pos = 0 self._write_pos = 0 - self.eof = Event() self._lock = RLock() + self._unread_data = Event() + self.eof = _Eof(self._unread_data) def write(self, data): """Write data to buffer. @@ -53,14 +63,17 @@ def write(self, data): if not self._buffer.tell() == self._write_pos: self._buffer.seek(self._write_pos) self._write_pos += self._buffer.write(data) + if not self._unread_data.is_set() and self._read_pos < self._write_pos: + self._unread_data.set() def read(self): - """Read available data, or return None + """Read available data, or return None. :rtype: bytes """ with self._lock: if self._write_pos == 0 or self._read_pos == self._write_pos: + self._unread_data.clear() return elif not self._buffer.tell() == self._read_pos: self._buffer.seek(self._read_pos) @@ -73,5 +86,5 @@ def __iter__(self): data = self.read() if data: yield data - elif self._read_pos == self._write_pos: - sleep(.1) + else: + self._unread_data.wait() diff --git a/pssh/clients/ssh/single.py b/pssh/clients/ssh/single.py index 9eee161e..ee3db7e0 100644 --- a/pssh/clients/ssh/single.py +++ b/pssh/clients/ssh/single.py @@ -19,18 +19,17 @@ from gevent import sleep, spawn, Timeout as GTimeout, joinall from ssh import options -from ssh.session import Session, SSH_READ_PENDING, SSH_WRITE_PENDING -from ssh.key import import_privkey_file, import_cert_file, copy_cert_to_privkey,\ - import_privkey_base64 -from ssh.exceptions import EOF from ssh.error_codes import SSH_AGAIN +from ssh.exceptions import EOF +from ssh.key import import_privkey_file, import_cert_file, copy_cert_to_privkey, \ + import_privkey_base64 +from ssh.session import Session, SSH_READ_PENDING, SSH_WRITE_PENDING from ..base.single import BaseSSHClient from ..common import _validate_pkey_path -from ...output import HostOutput -from ...exceptions import SessionError, Timeout from ...constants import DEFAULT_RETRIES, RETRY_DELAY - +from ...exceptions import SessionError, Timeout +from ...output import HostOutput logger = logging.getLogger(__name__) @@ -240,25 +239,22 @@ def execute(self, cmd, use_pty=False, channel=None): if use_pty: self._eagain(channel.request_pty, timeout=self.timeout) logger.debug("Executing command '%s'", cmd) - self._eagain(channel.request_exec, cmd, timeout=self.timeout) + self._eagain(channel.request_exec, cmd) return channel def _read_output_to_buffer(self, channel, _buffer, is_stderr=False): - while True: - self.poll() - try: - size, data = channel.read_nonblocking(is_stderr=is_stderr) - except EOF: - _buffer.eof.set() - sleep(.1) - return - if size > 0: - _buffer.write(data) - else: - # Yield event loop to other greenlets if we have no data to - # send back, meaning the generator does not yield and can there - # for block other generators/greenlets from running. - sleep(.1) + try: + while True: + self.poll() + try: + size, data = channel.read_nonblocking(is_stderr=is_stderr) + except EOF: + return + if size > 0: + _buffer.write(data) + sleep() + finally: + _buffer.eof.set() def wait_finished(self, host_output, timeout=None): """Wait for EOF from channel and close channel. @@ -315,7 +311,7 @@ def close_channel(self, channel): :type channel: :py:class:`ssh.channel.Channel` """ logger.debug("Closing channel") - self._eagain(channel.close, timeout=self.timeout) + self._eagain(channel.close) def poll(self, timeout=None): """ssh-python based co-operative gevent poll on session socket. @@ -331,5 +327,5 @@ def _eagain(self, func, *args, **kwargs): """Run function given and handle EAGAIN for an ssh-python session""" return self._eagain_errcode(func, SSH_AGAIN, *args, **kwargs) - def _eagain_write(self, write_func, data, timeout=None): - return self._eagain_write_errcode(write_func, data, SSH_AGAIN, timeout=timeout) + def _eagain_write(self, write_func, data): + return self._eagain_write_errcode(write_func, data, SSH_AGAIN) diff --git a/pssh/output.py b/pssh/output.py index 01d709b2..52297bac 100644 --- a/pssh/output.py +++ b/pssh/output.py @@ -44,7 +44,7 @@ def __init__(self, reader, rw_buffer): """ :param reader: Greenlet reading data from channel and writing to rw_buffer :type reader: :py:class:`gevent.Greenlet` - :param rw_bufffer: Read/write buffer + :param rw_buffer: Read/write buffer :type rw_buffer: :py:class:`pssh.clients.reader.ConcurrentRWBuffer` """ self.reader = reader diff --git a/tests/native/test_parallel_client.py b/tests/native/test_parallel_client.py index 9250a4df..72ea2cae 100644 --- a/tests/native/test_parallel_client.py +++ b/tests/native/test_parallel_client.py @@ -1307,42 +1307,53 @@ def test_read_timeout(self): self.assertRaises(Timeout, list, host_out.stdout) self.assertFalse(client.finished(output)) client.join(output) - # import ipdb; ipdb.set_trace() for host_out in output: - stdout = list(output[0].stdout) + stdout = list(host_out.stdout) self.assertEqual(len(stdout), 3) self.assertTrue(client.finished(output)) + def test_finished_no_run_command(self): + client = ParallelSSHClient([self.host], port=self.port, + pkey=self.user_key, num_retries=1) + client.join() + self.assertTrue(client.finished()) + def test_partial_read_timeout_close_cmd(self): - self.assertTrue(self.client.finished()) - output = self.client.run_command('while true; do echo a line; sleep .1; done', - use_pty=True, read_timeout=.15) + client = ParallelSSHClient([self.host], port=self.port, + pkey=self.user_key, num_retries=1) + self.assertTrue(client.finished()) + output = client.run_command('while true; do echo a line; sleep .01; done', + use_pty=True, read_timeout=.2) stdout = [] try: - with GTimeout(seconds=.25): + with GTimeout(seconds=.3): for line in output[0].stdout: stdout.append(line) except Timeout: pass self.assertTrue(len(stdout) > 0) + # Allow some more output to be generated + sleep(.1) output[0].client.close_channel(output[0].channel) - self.client.join(output) + client.join(output) # Should not timeout with GTimeout(seconds=.5): stdout = list(output[0].stdout) self.assertTrue(len(stdout) > 0) def test_partial_read_timeout_join_no_output(self): - self.assertTrue(self.client.finished()) - self.client.run_command('while true; do echo a line; sleep .1; done') + client = ParallelSSHClient([self.host], port=self.port, + pkey=self.user_key, num_retries=1) + self.assertTrue(client.finished()) + client.run_command('while true; do echo a line; sleep .01; done') try: with GTimeout(seconds=.1): - self.client.join() + client.join() except GTimeout: pass else: raise Exception("Should have timed out") - output = self.client.get_last_output() + output = client.get_last_output() stdout = [] try: with GTimeout(seconds=.1): @@ -1353,7 +1364,7 @@ def test_partial_read_timeout_join_no_output(self): else: raise Exception("Should have timed out") self.assertTrue(len(stdout) > 0) - self.assertRaises(Timeout, self.client.join, timeout=.1) + self.assertRaises(Timeout, client.join, timeout=.1) stdout = [] try: with GTimeout(seconds=.2): @@ -1375,9 +1386,11 @@ def test_partial_read_timeout_join_no_output(self): else: raise Exception("Should have timed out") self.assertTrue(len(stdout) > 0) + # Allow some more output to be generated + sleep(.1) output[0].client.close_channel(output[0].channel) - self.client.join() - self.assertTrue(self.client.finished()) + client.join() + self.assertTrue(client.finished()) stdout = list(output[0].stdout) self.assertTrue(len(stdout) > 0)