Skip to content

Commit

Permalink
Added timeout functionality to join and output reading for native cli…
Browse files Browse the repository at this point in the history
…ents.
  • Loading branch information
pkittenis committed Jan 30, 2018
1 parent 9e79ebe commit 088b1f7
Show file tree
Hide file tree
Showing 8 changed files with 482 additions and 354 deletions.
13 changes: 13 additions & 0 deletions Changelog.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,19 @@
Change Log
============

1.3.1
++++++

Changes
--------

* Added ``timeout`` optional parameter to ``join`` and ``run_command``, for reading output, on native clients.

Fixes
------

* From source builds when Cython is installed with recent versions of ``ssh2-python``.

1.3.0
++++++

Expand Down
724 changes: 393 additions & 331 deletions pssh/native/_ssh2.c

Large diffs are not rendered by default.

6 changes: 4 additions & 2 deletions pssh/native/_ssh2.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ from ..exceptions import SessionError
cdef bytes LINESEP = b'\n'


def _read_output(Session session, read_func):
def _read_output(Session session, read_func, timeout=None):
cdef Py_ssize_t _size
cdef bytes _data
cdef bytes remainder = b""
Expand All @@ -51,8 +51,10 @@ def _read_output(Session session, read_func):
_size, _data = read_func()
while _size == LIBSSH2_ERROR_EAGAIN or _size > 0:
if _size == LIBSSH2_ERROR_EAGAIN:
_wait_select(_sock, _session, None)
_wait_select(_sock, _session, timeout)
_size, _data = read_func()
if timeout is not None and _size == LIBSSH2_ERROR_EAGAIN:
break
while _size > 0:
while _pos < _size:
linesep = _data[:_size].find(LINESEP, _pos)
Expand Down
3 changes: 3 additions & 0 deletions pssh/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,6 @@ def __repr__(self):
stdout=self.stdout, stdin=self.stdin, stderr=self.stderr,
exception=self.exception, linesep=linesep,
exit_code=self.exit_code)

def __str__(self):
return self.__repr__()
30 changes: 21 additions & 9 deletions pssh/pssh2_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def __init__(self, hosts, user=None, password=None, port=None, pkey=None,

def run_command(self, command, sudo=False, user=None, stop_on_errors=True,
use_pty=False, host_args=None, shell=None,
encoding='utf-8'):
encoding='utf-8', timeout=None):
"""Run command on all hosts in parallel, honoring self.pool_size,
and return output dictionary.
Expand Down Expand Up @@ -152,6 +152,10 @@ def run_command(self, command, sudo=False, user=None, stop_on_errors=True,
:param encoding: Encoding to use for output. Must be valid
`Python codec <https://docs.python.org/2.7/library/codecs.html>`_
:type encoding: str
:param timeout: (Optional) Timeout in seconds for reading from stdout
or stderr. Defaults to no timeout. Reading from stdout/stderr will
timeout after this many seconds if remote output is not ready.
:type timeout: int
:rtype: Dictionary with host as key and
:py:class:`pssh.output.HostOutput` as value as per
Expand All @@ -169,24 +173,24 @@ def run_command(self, command, sudo=False, user=None, stop_on_errors=True,
string format
:raises: :py:class:`KeyError` on no host argument key in arguments
dict for cmd string format
:raises: :py:class:`pssh.exceptions.ProxyErrors` on errors connecting
:raises: :py:class:`pssh.exceptions.ProxyError` on errors connecting
to proxy if a proxy host has been set.
"""
return BaseParallelSSHClient.run_command(
self, command, stop_on_errors=stop_on_errors, host_args=host_args,
user=user, shell=shell, sudo=sudo,
encoding=encoding, use_pty=use_pty)
encoding=encoding, use_pty=use_pty, timeout=timeout)

def _run_command(self, host, command, sudo=False, user=None,
shell=None, use_pty=False,
encoding='utf-8'):
encoding='utf-8', timeout=None):
"""Make SSHClient if needed, run command on host"""
self._make_ssh_client(host)
return self.host_clients[host].run_command(
command, sudo=sudo, user=user, shell=shell,
use_pty=use_pty, encoding=encoding)
use_pty=use_pty, encoding=encoding, timeout=timeout)

def join(self, output, consume_output=False):
def join(self, output, consume_output=False, timeout=None):
"""Wait until all remote commands in output have finished
and retrieve exit codes. Does *not* block other commands from
running in parallel.
Expand All @@ -198,11 +202,17 @@ def join(self, output, consume_output=False):
buffers. Output buffers will be empty after ``join`` if set
to ``True``. Must be set to ``True`` to allow host logger to log
output on call to ``join`` when host logger has been enabled.
:type consume_output: bool"""
:type consume_output: bool
:param timeout: Timeout in seconds if remote command is not yet
finished.
:type timeout: int
:rtype: ``None``"""
for host in output:
if host not in self.host_clients or self.host_clients[host] is None:
continue
self.host_clients[host].wait_finished(output[host].channel)
self.host_clients[host].wait_finished(output[host].channel,
timeout=timeout)
if consume_output:
for line in output[host].stdout:
pass
Expand All @@ -217,6 +227,8 @@ def _get_exit_code(self, channel):
return channel.get_exit_status()

def _start_tunnel(self, host):
if host in self._tunnels:
return self._tunnels[host]
tunnel = Tunnel(
self.proxy_host, host, self.port, user=self.proxy_user,
password=self.proxy_password, port=self.proxy_port,
Expand Down Expand Up @@ -284,7 +296,7 @@ def copy_file(self, local_file, remote_file, recurse=False):
.. note ::
Remote directories in `remote_file` that do not exist will be
Remote directories in ``remote_file`` that do not exist will be
created as long as permissions allow.
"""
Expand Down
29 changes: 19 additions & 10 deletions pssh/ssh2_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,35 +253,43 @@ def execute(self, cmd, use_pty=False, channel=None):
self._eagain(channel.execute, cmd)
return channel

def read_stderr(self, channel):
def read_stderr(self, channel, timeout=None):
"""Read standard error buffer from channel.
:param channel: Channel to read output from.
:type channel: :py:class:`ssh2.channel.Channel`
"""
return _read_output(self.session, channel.read_stderr)
return _read_output(self.session, channel.read_stderr, timeout=timeout)

def read_output(self, channel):
def read_output(self, channel, timeout=None):
"""Read standard output buffer from channel.
:param channel: Channel to read output from.
:type channel: :py:class:`ssh2.channel.Channel`
"""
return _read_output(self.session, channel.read)
return _read_output(self.session, channel.read, timeout=timeout)

def wait_finished(self, channel):
def wait_finished(self, channel, timeout=None):
"""Wait for EOF from channel, close channel and wait for
close acknowledgement.
Used to wait for remote command completion and be able to gather
exit code.
:param channel: The channel to use
:param channel: The channel to use.
:type channel: :py:class:`ssh2.channel.Channel`
"""
if channel is None:
return
self._eagain(channel.wait_eof)
# If .eof() returns EAGAIN after a select with a timeout, it means
# it reached timeout without EOF and the connection should not be
# closed as the command is still running.
ret = channel.wait_eof()
while ret == LIBSSH2_ERROR_EAGAIN:
wait_select(self.session, timeout=timeout)
ret = channel.wait_eof()
if ret == LIBSSH2_ERROR_EAGAIN and timeout is not None:
return
self._eagain(channel.close)
self._eagain(channel.wait_closed)

Expand Down Expand Up @@ -317,7 +325,7 @@ def read_output_buffer(self, output_buffer, prefix=None,

def run_command(self, command, sudo=False, user=None,
use_pty=False, shell=None,
encoding='utf-8'):
encoding='utf-8', timeout=None):
"""Run remote command.
:param command: Command to run.
Expand Down Expand Up @@ -352,9 +360,10 @@ def run_command(self, command, sudo=False, user=None,
channel = self.execute(_command, use_pty=use_pty)
return channel, self.host, \
self.read_output_buffer(
self.read_output(channel), encoding=encoding), \
self.read_output(channel, timeout=timeout),
encoding=encoding), \
self.read_output_buffer(
self.read_stderr(channel), encoding=encoding,
self.read_stderr(channel, timeout=timeout), encoding=encoding,
prefix='\t[err]'), channel

def _make_sftp(self):
Expand Down
8 changes: 7 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# License along with this library; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA

import os
import platform
from setuptools import setup, find_packages
from platform import python_version
Expand All @@ -36,7 +37,11 @@
'optimize.use_switch': True,
'wraparound': False,
}
cython_args = {'cython_directives': cython_directives} if USING_CYTHON else {}
_embedded_lib = bool(os.environ.get('EMBEDDED_LIB', 1))

cython_args = {'cython_directives': cython_directives,
'cython_compile_time_env': {'EMBEDDED_LIB': _embedded_lib},
} if USING_CYTHON else {}

_libs = ['ssh2'] if platform.system() != 'Windows' else [
# For libssh2 OpenSSL backend on Windows.
Expand All @@ -47,6 +52,7 @@

ext = 'pyx' if USING_CYTHON else 'c'
_comp_args = ["-O3"] if platform.system() != 'Windows' else None

extensions = [
Extension('pssh.native._ssh2',
sources=['pssh/native/_ssh2.%s' % ext],
Expand Down
23 changes: 22 additions & 1 deletion tests/test_pssh_ssh2_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1110,7 +1110,7 @@ def test_run_command_user_sudo(self):
self.client.join(output)
stderr = list(output[self.host].stderr)
self.assertTrue(len(stderr) > 0)
self.assertTrue(output[self.host].exit_code == 1)
self.assertEqual(output[self.host].exit_code, 1)

def test_run_command_shell(self):
output = self.client.run_command(self.cmd, shell="bash -c")
Expand Down Expand Up @@ -1163,6 +1163,27 @@ def test_host_no_client(self):
output = {'blah': None}
self.client.join(output)

def test_join_timeout(self):
client = ParallelSSHClient([self.host], port=self.port,
pkey=self.user_key)
output = client.run_command('sleep 2')
client.join(output, timeout=1)
self.assertFalse(output[self.host].channel.eof())
client.join(output, timeout=2)
self.assertTrue(output[self.host].channel.eof())

def test_read_timeout(self):
client = ParallelSSHClient([self.host], port=self.port,
pkey=self.user_key)
output = client.run_command('sleep 2', timeout=1)
stdout = list(output[self.host].stdout)
self.assertFalse(output[self.host].channel.eof())
self.assertEqual(len(stdout), 0)
list(output[self.host].stdout)
list(output[self.host].stdout)
client.join(output)
self.assertTrue(output[self.host].channel.eof())

## OpenSSHServer needs to run in its own thread for this test to work
## Race conditions otherwise.
#
Expand Down

0 comments on commit 088b1f7

Please sign in to comment.