diff --git a/scrapli/transport/plugins/asynctelnet/transport.py b/scrapli/transport/plugins/asynctelnet/transport.py index 88f7b04c..df4a97c6 100644 --- a/scrapli/transport/plugins/asynctelnet/transport.py +++ b/scrapli/transport/plugins/asynctelnet/transport.py @@ -29,7 +29,12 @@ def __init__( self.stdout: Optional[asyncio.StreamReader] = None self.stdin: Optional[asyncio.StreamWriter] = None - self._initial_buf = b"" + self._eof = False + self._raw_buf = b"" + self._cooked_buf = b"" + + self._control_char_sent_counter = 0 + self._control_char_sent_limit = 10 def _handle_control_chars_response(self, control_buf: bytes, c: bytes) -> bytes: """ @@ -56,7 +61,7 @@ def _handle_control_chars_response(self, control_buf: bytes, c: bytes) -> bytes: if c != IAC: # add whatever character we read to the "normal" output buf so it gets sent off # to the auth method later (username/show prompts may show up here) - self._initial_buf += c + self._cooked_buf += c else: # we got a control character, put it into the control_buf control_buf += c @@ -88,7 +93,7 @@ def _handle_control_chars_response(self, control_buf: bytes, c: bytes) -> bytes: return control_buf - async def _handle_control_chars(self) -> None: + def _handle_control_chars(self) -> None: """ Handle control characters -- nearly identical to CPython telnetlib @@ -115,21 +120,11 @@ async def _handle_control_chars(self) -> None: # we are working on responding to control_buf = b"" - # initial read timeout for control characters can be 1/4 of socket timeout, after reading a - # single byte we crank it way down; the next value used to be 0.1 but this was causing some - # issues for folks that had devices behaving very slowly... so hopefully 1/10 is a - # reasonable value for the follow up char read timeout... of course we will return early if - # we do get a char in the buffer so it should be all good! - char_read_timeout = self._base_transport_args.timeout_socket / 4 + while self._raw_buf: + c, self._raw_buf = self._raw_buf[:1], self._raw_buf[1:] + if not c: + raise ScrapliConnectionNotOpened("server returned EOF, connection not opened") - while True: - try: - c = await asyncio.wait_for(self.stdout.read(1), timeout=char_read_timeout) - if not c: - raise ScrapliConnectionNotOpened("server returned EOF, connection not opened") - except asyncio.TimeoutError: - return - char_read_timeout = self._base_transport_args.timeout_socket / 10 control_buf = self._handle_control_chars_response(control_buf=control_buf, c=c) async def open(self) -> None: @@ -161,8 +156,6 @@ async def open(self) -> None: self.logger.critical(msg) raise ScrapliAuthenticationFailed(msg) from exc - await self._handle_control_chars() - self._post_open_closing_log(closing=False) def close(self) -> None: @@ -188,29 +181,39 @@ def isalive(self) -> bool: return False return not self.stdout.at_eof() + async def _read(self, n: int = 65535) -> None: + if not self.stdout: + raise ScrapliConnectionNotOpened + + if not self._raw_buf: + try: + buf = await self.stdout.read(n) + self._eof = not buf + if self._control_char_sent_counter < self._control_char_sent_limit: + self._raw_buf += buf + else: + self._cooked_buf += buf + except EOFError as exc: + raise ScrapliConnectionError( + "encountered EOF reading from transport; typically means the device closed the " + "connection" + ) from exc + @timeout_wrapper async def read(self) -> bytes: if not self.stdout: raise ScrapliConnectionNotOpened - if self._initial_buf: - buf = self._initial_buf - self._initial_buf = b"" - return buf + if self._control_char_sent_counter < self._control_char_sent_limit: + self._handle_control_chars() - try: - buf = await self.stdout.read(65535) - # nxos at least sends "binary transmission" control char, but seems to not (afaik?) - # actually advertise it during the control protocol exchange, causing us to not be able - # to "know" that it is in binary transmit mode until later... so we will just always - # strip this option (b"\x00") out of the buffered data... - buf = buf.replace(b"\x00", b"") - except EOFError as exc: - raise ScrapliConnectionError( - "encountered EOF reading from transport; typically means the device closed the " - "connection" - ) from exc + while not self._cooked_buf and not self._eof: + await self._read() + if self._control_char_sent_counter < self._control_char_sent_limit: + self._handle_control_chars() + buf = self._cooked_buf + self._cooked_buf = b"" return buf def write(self, channel_input: bytes) -> None: diff --git a/scrapli/transport/plugins/telnet/transport.py b/scrapli/transport/plugins/telnet/transport.py index b7f157cc..e3d89998 100644 --- a/scrapli/transport/plugins/telnet/transport.py +++ b/scrapli/transport/plugins/telnet/transport.py @@ -22,7 +22,12 @@ def __init__( self.plugin_transport_args = plugin_transport_args self.socket: Optional[Socket] = None - self._initial_buf = b"" + self._eof = False + self._raw_buf = b"" + self._cooked_buf = b"" + + self._control_char_sent_counter = 0 + self._control_char_sent_limit = 10 def _set_socket_timeout(self, timeout: float) -> None: """ @@ -45,6 +50,29 @@ def _set_socket_timeout(self, timeout: float) -> None: raise ScrapliConnectionNotOpened self.socket.sock.settimeout(timeout) + def _handle_control_chars_socket_timeout_update(self) -> None: + """ + Handle updating (if necessary) the socket timeout + + Args: + N/A + + Returns: + None + + Raises: + N/A + + """ + self._control_char_sent_counter += 1 + + if self._control_char_sent_counter > self._control_char_sent_limit: + # connection is opened, effectively ignore socket timeout at this point as we want + # the timeout socket to be "just" for opening the connection basically + # the number 8 is fairly arbitrary -- it looks like *most* platforms send around + # 8 - 12 control char/instructions on session opening, so we'll go with 8! + self._set_socket_timeout(600) + def _handle_control_chars_response(self, control_buf: bytes, c: bytes) -> bytes: """ Handle the actual response to control characters @@ -70,7 +98,7 @@ def _handle_control_chars_response(self, control_buf: bytes, c: bytes) -> bytes: if not control_buf: if c != IAC: - self._initial_buf += c + self._cooked_buf += c else: control_buf += c @@ -90,6 +118,8 @@ def _handle_control_chars_response(self, control_buf: bytes, c: bytes) -> bytes: elif cmd == WONT: self.write(IAC + DONT + c) + self._handle_control_chars_socket_timeout_update() + return control_buf def _handle_control_chars(self) -> None: @@ -119,21 +149,11 @@ def _handle_control_chars(self) -> None: control_buf = b"" - original_socket_timeout = self._base_transport_args.timeout_socket - self._set_socket_timeout(self._base_transport_args.timeout_socket / 4) + while self._raw_buf: + c, self._raw_buf = self._raw_buf[:1], self._raw_buf[1:] + if not c: + raise ScrapliConnectionNotOpened("server returned EOF, connection not opened") - while True: - try: - c = self._read(1) - if not c: - raise ScrapliConnectionNotOpened("server returned EOF, connection not opened") - except TimeoutError: - # shouldn't really matter/need to be reset back to "normal", but don't really want - # to leave it modified as that would be confusing! - self._base_transport_args.timeout_socket = original_socket_timeout - return - - self._set_socket_timeout(self._base_transport_args.timeout_socket / 10) control_buf = self._handle_control_chars_response(control_buf=control_buf, c=c) def open(self) -> None: @@ -149,8 +169,6 @@ def open(self) -> None: if not self.socket.isalive(): self.socket.open() - self._handle_control_chars() - self._post_open_closing_log(closing=False) def close(self) -> None: @@ -170,9 +188,9 @@ def isalive(self) -> bool: return False return True - def _read(self, n: int = 65535) -> bytes: + def _read(self, n: int = 65535) -> None: """ - Read n bytes from the socket + Read n bytes from the socket and fill raw buffer Mostly this exists just to assert that socket and socket.sock are not None to appease mypy! @@ -184,31 +202,42 @@ def _read(self, n: int = 65535) -> bytes: Raises: ScrapliConnectionNotOpened: if either socket or socket.sock are None + ScrapliConnectionError: if we fail to recv from the underlying socket + """ if self.socket is None: raise ScrapliConnectionNotOpened if self.socket.sock is None: raise ScrapliConnectionNotOpened - return self.socket.sock.recv(n) + if not self._raw_buf: + try: + buf = self.socket.sock.recv(n) + self._eof = not buf + if self._control_char_sent_counter < self._control_char_sent_limit: + self._raw_buf += buf + else: + self._cooked_buf += buf + except Exception as exc: + raise ScrapliConnectionError( + "encountered EOF reading from transport; typically means the device closed the " + "connection" + ) from exc @timeout_wrapper def read(self) -> bytes: if not self.socket: raise ScrapliConnectionNotOpened - if self._initial_buf: - buf = self._initial_buf - self._initial_buf = b"" - return buf - - try: - buf = self._read() - buf = buf.replace(b"\x00", b"") - except Exception as exc: - raise ScrapliConnectionError( - "encountered EOF reading from transport; typically means the device closed the " - "connection" - ) from exc + if self._control_char_sent_counter < self._control_char_sent_limit: + self._handle_control_chars() + + while not self._cooked_buf and not self._eof: + self._read() + if self._control_char_sent_counter < self._control_char_sent_limit: + self._handle_control_chars() + + buf = self._cooked_buf + self._cooked_buf = b"" return buf def write(self, channel_input: bytes) -> None: diff --git a/tests/functional/conftest.py b/tests/functional/conftest.py index a49d1954..92a657a2 100644 --- a/tests/functional/conftest.py +++ b/tests/functional/conftest.py @@ -50,6 +50,9 @@ def async_transport(request): @pytest.fixture(scope="class") def conn(test_devices_dict, device_type, transport): + if device_type == "cisco_nxos" and transport in TELNET_TRANSPORTS: + pytest.skip("skipping telnet for nxos hosts") + device = test_devices_dict[device_type].copy() driver = device.pop("driver") device.pop("base_config") @@ -188,6 +191,9 @@ def eos_conn(test_devices_dict, transport): @pytest.fixture(scope="class") def nxos_conn(test_devices_dict, transport): + if transport in TELNET_TRANSPORTS: + pytest.skip("skipping telnet for nxos hosts") + device = test_devices_dict["cisco_nxos"].copy() driver = device.pop("driver") device.pop("base_config") diff --git a/tests/functional/test_drivers_network.py b/tests/functional/test_drivers_network.py index b5816937..5db00554 100644 --- a/tests/functional/test_drivers_network.py +++ b/tests/functional/test_drivers_network.py @@ -227,6 +227,9 @@ def test_isalive_and_close(self, conn, device_type, transport): def test_context_manager(test_devices_dict, device_type, transport): + if device_type == "cisco_nxos" and "telnet" in transport: + pytest.skip("skipping telnet for nxos hosts") + device = test_devices_dict[device_type].copy() driver = device.pop("driver") device.pop("base_config") diff --git a/tests/unit/transport/plugins/asynctelnet/test_asynctelnet_transport.py b/tests/unit/transport/plugins/asynctelnet/test_asynctelnet_transport.py index 86b27c20..a2cfcba7 100644 --- a/tests/unit/transport/plugins/asynctelnet/test_asynctelnet_transport.py +++ b/tests/unit/transport/plugins/asynctelnet/test_asynctelnet_transport.py @@ -19,7 +19,7 @@ def test_handle_control_characters_response_not_iac(asynctelnet_transport): actual_control_buf = asynctelnet_transport._handle_control_chars_response( control_buf=b"", c=b"X" ) - assert asynctelnet_transport._initial_buf == b"X" + assert asynctelnet_transport._cooked_buf == b"X" assert actual_control_buf == b"" @@ -28,7 +28,7 @@ def test_handle_control_characters_response_second_char(asynctelnet_transport): actual_control_buf = asynctelnet_transport._handle_control_chars_response( control_buf=bytes([255]), c=bytes([253]) ) - assert asynctelnet_transport._initial_buf == b"" + assert asynctelnet_transport._cooked_buf == b"" assert actual_control_buf == bytes([255, 253]) @@ -44,7 +44,7 @@ def test_handle_control_characters_response_third_char(asynctelnet_transport, te actual_control_buf = asynctelnet_transport._handle_control_chars_response( control_buf=bytes([255, control_buf_input]), c=bytes([1]) ) - assert asynctelnet_transport._initial_buf == b"" + assert asynctelnet_transport._cooked_buf == b"" assert actual_control_buf == b"" asynctelnet_transport.stdin.seek(0) @@ -58,34 +58,15 @@ def test_handle_control_characters_response_exception(asynctelnet_transport): async def test_handle_control_characters(monkeypatch, asynctelnet_transport): - _read_called = 0 - - async def _read(cls, _): - nonlocal _read_called - - if _read_called == 0: - _read_called += 1 - return bytes([255]) - - await asyncio.sleep(0.5) - - monkeypatch.setattr( - "asyncio.StreamReader.read", - _read, - ) - - monkeypatch.setattr( - "scrapli.transport.plugins.asynctelnet.transport.AsynctelnetTransport._handle_control_chars_response", - lambda cls, **kwargs: None, - ) - # lie like connection is open + asynctelnet_transport.stdin = BytesIO() asynctelnet_transport.stdout = asyncio.StreamReader() asynctelnet_transport._base_transport_args.timeout_socket = 0.4 - await asynctelnet_transport._handle_control_chars() + asynctelnet_transport._raw_buf = bytes([253]) + asynctelnet_transport._handle_control_chars() - assert _read_called == 1 + assert asynctelnet_transport._cooked_buf == bytes([253]) async def test_handle_control_characters_exception(asynctelnet_transport): @@ -93,39 +74,6 @@ async def test_handle_control_characters_exception(asynctelnet_transport): await asynctelnet_transport._handle_control_chars() -async def test_handle_control_characters_exception_eof(asynctelnet_transport, monkeypatch): - # if the server closes the connection/EOF we will read an empty byte string, see #141 - _read_called = 0 - - async def _read(cls, _): - nonlocal _read_called - - if _read_called == 0: - _read_called += 1 - return b"" - - await asyncio.sleep(0.5) - - monkeypatch.setattr( - "asyncio.StreamReader.read", - _read, - ) - - monkeypatch.setattr( - "scrapli.transport.plugins.asynctelnet.transport.AsynctelnetTransport._handle_control_chars_response", - lambda cls, **kwargs: None, - ) - - # lie like connection is open - asynctelnet_transport.stdout = asyncio.StreamReader() - asynctelnet_transport._base_transport_args.timeout_socket = 0.4 - - with pytest.raises(ScrapliConnectionNotOpened): - await asynctelnet_transport._handle_control_chars() - - assert _read_called == 1 - - def test_close(asynctelnet_transport): # lie like connection is open asynctelnet_transport.stdout = asyncio.StreamReader( @@ -160,8 +108,9 @@ def test_isalive(asynctelnet_transport): async def test_read(asynctelnet_transport): # lie like connection is open + asynctelnet_transport.stdin = BytesIO() asynctelnet_transport.stdout = asyncio.StreamReader() - asynctelnet_transport.stdout.feed_data(b"somebytes\x00") + asynctelnet_transport.stdout.feed_data(b"somebytes") assert await asynctelnet_transport.read() == b"somebytes" diff --git a/tests/unit/transport/plugins/telnet/test_telnet_transport.py b/tests/unit/transport/plugins/telnet/test_telnet_transport.py index 94daca87..84af86d6 100644 --- a/tests/unit/transport/plugins/telnet/test_telnet_transport.py +++ b/tests/unit/transport/plugins/telnet/test_telnet_transport.py @@ -21,7 +21,7 @@ def test_handle_control_characters_response_not_iac(telnet_transport): telnet_transport.socket = 1 actual_control_buf = telnet_transport._handle_control_chars_response(control_buf=b"", c=b"X") - assert telnet_transport._initial_buf == b"X" + assert telnet_transport._cooked_buf == b"X" assert actual_control_buf == b"" @@ -32,7 +32,7 @@ def test_handle_control_characters_response_second_char(telnet_transport): actual_control_buf = telnet_transport._handle_control_chars_response( control_buf=bytes([255]), c=bytes([253]) ) - assert telnet_transport._initial_buf == b"" + assert telnet_transport._cooked_buf == b"" assert actual_control_buf == bytes([255, 253]) @@ -61,7 +61,7 @@ def send(self, channel_input): actual_control_buf = telnet_transport._handle_control_chars_response( control_buf=bytes([255, control_buf_input]), c=bytes([1]) ) - assert telnet_transport._initial_buf == b"" + assert telnet_transport._cooked_buf == b"" assert actual_control_buf == b"" telnet_transport.socket.sock.buf.seek(0) @@ -75,29 +75,6 @@ def test_handle_control_characters_response_exception(telnet_transport): def test_handle_control_characters(monkeypatch, telnet_transport): - _read_called = 0 - - def _read(cls, _): - nonlocal _read_called - - if _read_called == 0: - _read_called += 1 - return bytes([255]) - - # we expect to timeout reading control chars (in this case after just reading one to test - # the overall flow of things) - raise TimeoutError - - monkeypatch.setattr( - "scrapli.transport.plugins.telnet.transport.TelnetTransport._read", - _read, - ) - - monkeypatch.setattr( - "scrapli.transport.plugins.telnet.transport.TelnetTransport._handle_control_chars_response", - lambda cls, **kwargs: None, - ) - # lie like connection is open class Dummy: ... @@ -115,9 +92,10 @@ def settimeout(self, t): telnet_transport.socket = Dummy() telnet_transport.socket.sock = DummySock() + telnet_transport._raw_buf = bytes([253]) telnet_transport._handle_control_chars() - assert _read_called == 1 + assert telnet_transport._cooked_buf == bytes([253]) def test_handle_control_characters_exception(telnet_transport): @@ -125,51 +103,6 @@ def test_handle_control_characters_exception(telnet_transport): telnet_transport._handle_control_chars() -def test_handle_control_characters_exception_eof(monkeypatch, telnet_transport): - # if the server closes the connection/EOF we will read an empty byte string, see #141 - _read_called = 0 - - def _read(cls, _): - nonlocal _read_called - - if _read_called == 0: - _read_called += 1 - return b"" - - monkeypatch.setattr( - "scrapli.transport.plugins.telnet.transport.TelnetTransport._read", - _read, - ) - - monkeypatch.setattr( - "scrapli.transport.plugins.telnet.transport.TelnetTransport._handle_control_chars_response", - lambda cls, **kwargs: None, - ) - - # lie like connection is open - class Dummy: - ... - - class DummySock: - def __init__(self): - self.buf = BytesIO() - - def send(self, channel_input): - self.buf.write(channel_input) - - def settimeout(self, t): - ... - - telnet_transport.socket = Dummy() - telnet_transport.socket.sock = DummySock() - telnet_transport._base_transport_args.timeout_socket = 0.4 - - with pytest.raises(ScrapliConnectionNotOpened): - telnet_transport._handle_control_chars() - - assert _read_called == 1 - - def test_close(telnet_transport): # lie like connection is open class Dummy: @@ -219,7 +152,7 @@ def recv(self, n): telnet_transport.socket = Dummy() telnet_transport.socket.sock = DummySock() - telnet_transport.socket.sock.buf.write(b"somebytes\x00") + telnet_transport.socket.sock.buf.write(b"somebytes") assert telnet_transport.read() == b"somebytes" @@ -241,6 +174,9 @@ def recv(self, n): time.sleep(1) return self.buf.read() + def settimeout(self, t): + ... + telnet_transport.socket = Dummy() telnet_transport.socket.sock = DummySock() telnet_transport._base_transport_args.timeout_transport = 0.1 @@ -266,6 +202,9 @@ def recv(self, n): time.sleep(1) return self.buf.read() + def settimeout(self, t): + ... + telnet_transport.socket = Dummy() telnet_transport.socket.sock = DummySock() telnet_transport.write(b"blah")