Skip to content

Commit

Permalink
Merge pull request #253 from haccht/fix_sync_telnet_transport
Browse files Browse the repository at this point in the history
Fix sync telnet transport
  • Loading branch information
carlmontanari authored Sep 3, 2022
2 parents 2a1bf05 + 2ffa4a0 commit 2c6de7b
Show file tree
Hide file tree
Showing 6 changed files with 131 additions and 202 deletions.
73 changes: 38 additions & 35 deletions scrapli/transport/plugins/asynctelnet/transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,12 @@ def __init__(
self.stdout: Optional[asyncio.StreamReader] = None
self.stdin: Optional[asyncio.StreamWriter] = None

self._initial_buf = b""
self._eof = False
self._raw_buf = b""
self._cooked_buf = b""

self._control_char_sent_counter = 0
self._control_char_sent_limit = 10

def _handle_control_chars_response(self, control_buf: bytes, c: bytes) -> bytes:
"""
Expand All @@ -56,7 +61,7 @@ def _handle_control_chars_response(self, control_buf: bytes, c: bytes) -> bytes:
if c != IAC:
# add whatever character we read to the "normal" output buf so it gets sent off
# to the auth method later (username/show prompts may show up here)
self._initial_buf += c
self._cooked_buf += c
else:
# we got a control character, put it into the control_buf
control_buf += c
Expand Down Expand Up @@ -88,7 +93,7 @@ def _handle_control_chars_response(self, control_buf: bytes, c: bytes) -> bytes:

return control_buf

async def _handle_control_chars(self) -> None:
def _handle_control_chars(self) -> None:
"""
Handle control characters -- nearly identical to CPython telnetlib
Expand All @@ -115,21 +120,11 @@ async def _handle_control_chars(self) -> None:
# we are working on responding to
control_buf = b""

# initial read timeout for control characters can be 1/4 of socket timeout, after reading a
# single byte we crank it way down; the next value used to be 0.1 but this was causing some
# issues for folks that had devices behaving very slowly... so hopefully 1/10 is a
# reasonable value for the follow up char read timeout... of course we will return early if
# we do get a char in the buffer so it should be all good!
char_read_timeout = self._base_transport_args.timeout_socket / 4
while self._raw_buf:
c, self._raw_buf = self._raw_buf[:1], self._raw_buf[1:]
if not c:
raise ScrapliConnectionNotOpened("server returned EOF, connection not opened")

while True:
try:
c = await asyncio.wait_for(self.stdout.read(1), timeout=char_read_timeout)
if not c:
raise ScrapliConnectionNotOpened("server returned EOF, connection not opened")
except asyncio.TimeoutError:
return
char_read_timeout = self._base_transport_args.timeout_socket / 10
control_buf = self._handle_control_chars_response(control_buf=control_buf, c=c)

async def open(self) -> None:
Expand Down Expand Up @@ -161,8 +156,6 @@ async def open(self) -> None:
self.logger.critical(msg)
raise ScrapliAuthenticationFailed(msg) from exc

await self._handle_control_chars()

self._post_open_closing_log(closing=False)

def close(self) -> None:
Expand All @@ -188,29 +181,39 @@ def isalive(self) -> bool:
return False
return not self.stdout.at_eof()

async def _read(self, n: int = 65535) -> None:
if not self.stdout:
raise ScrapliConnectionNotOpened

if not self._raw_buf:
try:
buf = await self.stdout.read(n)
self._eof = not buf
if self._control_char_sent_counter < self._control_char_sent_limit:
self._raw_buf += buf
else:
self._cooked_buf += buf
except EOFError as exc:
raise ScrapliConnectionError(
"encountered EOF reading from transport; typically means the device closed the "
"connection"
) from exc

@timeout_wrapper
async def read(self) -> bytes:
if not self.stdout:
raise ScrapliConnectionNotOpened

if self._initial_buf:
buf = self._initial_buf
self._initial_buf = b""
return buf
if self._control_char_sent_counter < self._control_char_sent_limit:
self._handle_control_chars()

try:
buf = await self.stdout.read(65535)
# nxos at least sends "binary transmission" control char, but seems to not (afaik?)
# actually advertise it during the control protocol exchange, causing us to not be able
# to "know" that it is in binary transmit mode until later... so we will just always
# strip this option (b"\x00") out of the buffered data...
buf = buf.replace(b"\x00", b"")
except EOFError as exc:
raise ScrapliConnectionError(
"encountered EOF reading from transport; typically means the device closed the "
"connection"
) from exc
while not self._cooked_buf and not self._eof:
await self._read()
if self._control_char_sent_counter < self._control_char_sent_limit:
self._handle_control_chars()

buf = self._cooked_buf
self._cooked_buf = b""
return buf

def write(self, channel_input: bytes) -> None:
Expand Down
97 changes: 63 additions & 34 deletions scrapli/transport/plugins/telnet/transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,12 @@ def __init__(
self.plugin_transport_args = plugin_transport_args

self.socket: Optional[Socket] = None
self._initial_buf = b""
self._eof = False
self._raw_buf = b""
self._cooked_buf = b""

self._control_char_sent_counter = 0
self._control_char_sent_limit = 10

def _set_socket_timeout(self, timeout: float) -> None:
"""
Expand All @@ -45,6 +50,29 @@ def _set_socket_timeout(self, timeout: float) -> None:
raise ScrapliConnectionNotOpened
self.socket.sock.settimeout(timeout)

def _handle_control_chars_socket_timeout_update(self) -> None:
"""
Handle updating (if necessary) the socket timeout
Args:
N/A
Returns:
None
Raises:
N/A
"""
self._control_char_sent_counter += 1

if self._control_char_sent_counter > self._control_char_sent_limit:
# connection is opened, effectively ignore socket timeout at this point as we want
# the timeout socket to be "just" for opening the connection basically
# the number 8 is fairly arbitrary -- it looks like *most* platforms send around
# 8 - 12 control char/instructions on session opening, so we'll go with 8!
self._set_socket_timeout(600)

def _handle_control_chars_response(self, control_buf: bytes, c: bytes) -> bytes:
"""
Handle the actual response to control characters
Expand All @@ -70,7 +98,7 @@ def _handle_control_chars_response(self, control_buf: bytes, c: bytes) -> bytes:

if not control_buf:
if c != IAC:
self._initial_buf += c
self._cooked_buf += c
else:
control_buf += c

Expand All @@ -90,6 +118,8 @@ def _handle_control_chars_response(self, control_buf: bytes, c: bytes) -> bytes:
elif cmd == WONT:
self.write(IAC + DONT + c)

self._handle_control_chars_socket_timeout_update()

return control_buf

def _handle_control_chars(self) -> None:
Expand Down Expand Up @@ -119,21 +149,11 @@ def _handle_control_chars(self) -> None:

control_buf = b""

original_socket_timeout = self._base_transport_args.timeout_socket
self._set_socket_timeout(self._base_transport_args.timeout_socket / 4)
while self._raw_buf:
c, self._raw_buf = self._raw_buf[:1], self._raw_buf[1:]
if not c:
raise ScrapliConnectionNotOpened("server returned EOF, connection not opened")

while True:
try:
c = self._read(1)
if not c:
raise ScrapliConnectionNotOpened("server returned EOF, connection not opened")
except TimeoutError:
# shouldn't really matter/need to be reset back to "normal", but don't really want
# to leave it modified as that would be confusing!
self._base_transport_args.timeout_socket = original_socket_timeout
return

self._set_socket_timeout(self._base_transport_args.timeout_socket / 10)
control_buf = self._handle_control_chars_response(control_buf=control_buf, c=c)

def open(self) -> None:
Expand All @@ -149,8 +169,6 @@ def open(self) -> None:
if not self.socket.isalive():
self.socket.open()

self._handle_control_chars()

self._post_open_closing_log(closing=False)

def close(self) -> None:
Expand All @@ -170,9 +188,9 @@ def isalive(self) -> bool:
return False
return True

def _read(self, n: int = 65535) -> bytes:
def _read(self, n: int = 65535) -> None:
"""
Read n bytes from the socket
Read n bytes from the socket and fill raw buffer
Mostly this exists just to assert that socket and socket.sock are not None to appease mypy!
Expand All @@ -184,31 +202,42 @@ def _read(self, n: int = 65535) -> bytes:
Raises:
ScrapliConnectionNotOpened: if either socket or socket.sock are None
ScrapliConnectionError: if we fail to recv from the underlying socket
"""
if self.socket is None:
raise ScrapliConnectionNotOpened
if self.socket.sock is None:
raise ScrapliConnectionNotOpened
return self.socket.sock.recv(n)
if not self._raw_buf:
try:
buf = self.socket.sock.recv(n)
self._eof = not buf
if self._control_char_sent_counter < self._control_char_sent_limit:
self._raw_buf += buf
else:
self._cooked_buf += buf
except Exception as exc:
raise ScrapliConnectionError(
"encountered EOF reading from transport; typically means the device closed the "
"connection"
) from exc

@timeout_wrapper
def read(self) -> bytes:
if not self.socket:
raise ScrapliConnectionNotOpened

if self._initial_buf:
buf = self._initial_buf
self._initial_buf = b""
return buf

try:
buf = self._read()
buf = buf.replace(b"\x00", b"")
except Exception as exc:
raise ScrapliConnectionError(
"encountered EOF reading from transport; typically means the device closed the "
"connection"
) from exc
if self._control_char_sent_counter < self._control_char_sent_limit:
self._handle_control_chars()

while not self._cooked_buf and not self._eof:
self._read()
if self._control_char_sent_counter < self._control_char_sent_limit:
self._handle_control_chars()

buf = self._cooked_buf
self._cooked_buf = b""
return buf

def write(self, channel_input: bytes) -> None:
Expand Down
6 changes: 6 additions & 0 deletions tests/functional/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ def async_transport(request):

@pytest.fixture(scope="class")
def conn(test_devices_dict, device_type, transport):
if device_type == "cisco_nxos" and transport in TELNET_TRANSPORTS:
pytest.skip("skipping telnet for nxos hosts")

device = test_devices_dict[device_type].copy()
driver = device.pop("driver")
device.pop("base_config")
Expand Down Expand Up @@ -188,6 +191,9 @@ def eos_conn(test_devices_dict, transport):

@pytest.fixture(scope="class")
def nxos_conn(test_devices_dict, transport):
if transport in TELNET_TRANSPORTS:
pytest.skip("skipping telnet for nxos hosts")

device = test_devices_dict["cisco_nxos"].copy()
driver = device.pop("driver")
device.pop("base_config")
Expand Down
3 changes: 3 additions & 0 deletions tests/functional/test_drivers_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,9 @@ def test_isalive_and_close(self, conn, device_type, transport):


def test_context_manager(test_devices_dict, device_type, transport):
if device_type == "cisco_nxos" and "telnet" in transport:
pytest.skip("skipping telnet for nxos hosts")

device = test_devices_dict[device_type].copy()
driver = device.pop("driver")
device.pop("base_config")
Expand Down
Loading

0 comments on commit 2c6de7b

Please sign in to comment.