diff --git a/docs/changelog.md b/docs/changelog.md index b0a104cd..16cc7719 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -5,7 +5,13 @@ Changelog - Expand `arista_eos` prompt pattern to handle super long config sections (things like qos queues and such). Thanks to @MarkRudenko over in scrapli_cfg repo for finding this and providing the fix! - +- Add `comms_roughly_match_inputs` option -- this uses a "rough" match when looking for inputs (commands/configs you + send) in output printed back on the channel. Basically, if all input characters show up in the output in the correct + order, then we assume the input was found. Of course this could be less "exacting" but it also *probably* is ok 99% + of the time :) +- Added an `eager_input` option to send operations -- this option completely skips checking for inputs being echoed back + on the channel. With the addition of the `comms_roughly_match_inputs` option this is *probably* unnecessary, but + could be useful for some corner cases. ## 2023.07.30 diff --git a/scrapli/channel/async_channel.py b/scrapli/channel/async_channel.py index cb63c9c9..9985a467 100644 --- a/scrapli/channel/async_channel.py +++ b/scrapli/channel/async_channel.py @@ -10,6 +10,7 @@ from scrapli.channel.base_channel import BaseChannel, BaseChannelArgs from scrapli.decorators import timeout_wrapper from scrapli.exceptions import ScrapliAuthenticationFailed, ScrapliTimeout +from scrapli.helper import output_roughly_contains_input from scrapli.transport.base import AsyncTransport @@ -104,9 +105,16 @@ async def _read_until_input(self, channel_input: bytes) -> bytes: while True: buf += await self.read() - # replace any backspace chars (particular problem w/ junos), and remove any added spaces - # this is just for comparison of the inputs to what was read from channel - if processed_channel_input in b"".join(buf.lower().replace(b"\x08", b"").split()): + if not self._base_channel_args.comms_roughly_match_inputs: + # replace any backspace chars (particular problem w/ junos), and remove any added + # spaces this is just for comparison of the inputs to what was read from channel + # note (2024) this would be worked around by using the roughly contains search, + # *but* that is slower (probably immaterially for most people but... ya know...) + processed_buf = b"".join(buf.lower().replace(b"\x08", b"").split()) + + if processed_channel_input in processed_buf: + return buf + elif output_roughly_contains_input(input_=processed_channel_input, output=buf): return buf async def _read_until_prompt(self, buf: bytes = b"") -> bytes: diff --git a/scrapli/channel/base_channel.py b/scrapli/channel/base_channel.py index 8c62f453..b6b6d6b0 100644 --- a/scrapli/channel/base_channel.py +++ b/scrapli/channel/base_channel.py @@ -39,6 +39,12 @@ class BaseChannelArgs: comms_prompt_search_depth: depth of the buffer to search in for searching for the prompt in "read_until_prompt"; smaller number here will generally be faster, though may be less reliable; default value is 1000 + comms_roughly_match_inputs: indicates if the channel should "roughly" match inputs sent + to the device. If False (default) inputs are strictly checked, as in any input + *must* be read back exactly on the channel. When set to True all input chars *must* + be read back in order in the output and all chars must be present, but the *exact* + input string does not need to be seen. This can be useful if a device echoes back + extra characters or rewrites the terminal during command input. timeout_ops: timeout_ops to assign to the channel, see above channel_log: log "channel" output -- this would be the output you would normally see on a terminal. If `True` logs to `scrapli_channel.log`, if a string is provided, logs to @@ -61,6 +67,7 @@ class BaseChannelArgs: comms_prompt_pattern: str = r"^[a-z0-9.\-@()/:]{1,32}[#>$]$" comms_return_char: str = "\n" comms_prompt_search_depth: int = 1000 + comms_roughly_match_inputs: bool = False timeout_ops: float = 30.0 channel_log: Union[str, bool, BytesIO] = False channel_log_mode: str = "write" diff --git a/scrapli/channel/sync_channel.py b/scrapli/channel/sync_channel.py index e4289eb8..871dbb7f 100644 --- a/scrapli/channel/sync_channel.py +++ b/scrapli/channel/sync_channel.py @@ -10,6 +10,7 @@ from scrapli.channel.base_channel import BaseChannel, BaseChannelArgs from scrapli.decorators import timeout_wrapper from scrapli.exceptions import ScrapliAuthenticationFailed, ScrapliConnectionError, ScrapliTimeout +from scrapli.helper import output_roughly_contains_input from scrapli.transport.base import Transport @@ -104,9 +105,16 @@ def _read_until_input(self, channel_input: bytes) -> bytes: while True: buf += self.read() - # replace any backspace chars (particular problem w/ junos), and remove any added spaces - # this is just for comparison of the inputs to what was read from channel - if processed_channel_input in b"".join(buf.lower().replace(b"\x08", b"").split()): + if not self._base_channel_args.comms_roughly_match_inputs: + # replace any backspace chars (particular problem w/ junos), and remove any added + # spaces this is just for comparison of the inputs to what was read from channel + # note (2024) this would be worked around by using the roughly contains search, + # *but* that is slower (probably immaterially for most people but... ya know...) + processed_buf = b"".join(buf.lower().replace(b"\x08", b"").split()) + + if processed_channel_input in processed_buf: + return buf + elif output_roughly_contains_input(input_=processed_channel_input, output=buf): return buf def _read_until_prompt(self, buf: bytes = b"") -> bytes: diff --git a/scrapli/driver/base/base_driver.py b/scrapli/driver/base/base_driver.py index 1da5e673..3cb74544 100644 --- a/scrapli/driver/base/base_driver.py +++ b/scrapli/driver/base/base_driver.py @@ -1,4 +1,4 @@ -"""scrapli.driver.base.base_driver""" +"""scrapli.driver.base.base_driver""" # noqa: C0302 import importlib from dataclasses import fields from io import BytesIO @@ -34,6 +34,7 @@ def __init__( timeout_ops: float = 30.0, comms_prompt_pattern: str = r"^[a-z0-9.\-@()/:]{1,48}[#>$]\s*$", comms_return_char: str = "\n", + comms_roughly_match_inputs: bool = False, ssh_config_file: Union[str, bool] = False, ssh_known_hosts_file: Union[str, bool] = False, on_init: Optional[Callable[..., Any]] = None, @@ -83,6 +84,12 @@ def __init__( should be mostly sorted for you if using network drivers (i.e. `IOSXEDriver`). Lastly, the case insensitive is just a convenience factor so i can be lazy. comms_return_char: character to use to send returns to host + comms_roughly_match_inputs: indicates if the channel should "roughly" match inputs sent + to the device. If False (default) inputs are strictly checked, as in any input + *must* be read back exactly on the channel. When set to True all input chars *must* + be read back in order in the output and all chars must be present, but the *exact* + input string does not need to be seen. This can be useful if a device echoes back + extra characters or rewrites the terminal during command input. ssh_config_file: string to path for ssh config file, True to use default ssh config file or False to ignore default ssh config file ssh_known_hosts_file: string to path for ssh known hosts file, True to use default known @@ -149,6 +156,7 @@ class that extends the driver, instead allowing the community platforms to simpl auth_passphrase_pattern=auth_passphrase_pattern, comms_prompt_pattern=comms_prompt_pattern, comms_return_char=comms_return_char, + comms_roughly_match_inputs=comms_roughly_match_inputs, timeout_ops=timeout_ops, channel_log=channel_log, channel_log_mode=channel_log_mode, @@ -249,6 +257,7 @@ def __repr__(self) -> str: f"timeout_ops={self._base_channel_args.timeout_ops!r}, " f"comms_prompt_pattern={self._base_channel_args.comms_prompt_pattern!r}, " f"comms_return_char={self._base_channel_args.comms_return_char!r}, " + f"comms_roughly_match_inputs={self._base_channel_args.comms_roughly_match_inputs!r}, " f"ssh_config_file={self.ssh_config_file!r}, " f"ssh_known_hosts_file={self.ssh_known_hosts_file!r}, " f"on_init={self.on_init!r}, " @@ -738,6 +747,84 @@ def comms_return_char(self, value: str) -> None: self._base_channel_args.comms_return_char = value + @property + def comms_prompt_search_depth(self) -> int: + """ + Getter for `comms_prompt_search_depth` attribute + + Args: + N/A + + Returns: + int: comms_prompt_search_depth int + + Raises: + N/A + + """ + return self._base_channel_args.comms_prompt_search_depth + + @comms_prompt_search_depth.setter + def comms_prompt_search_depth(self, value: int) -> None: + """ + Setter for `comms_prompt_search_depth` attribute + + Args: + value: int value for comms_prompt_search_depth + + Returns: + None + + Raises: + ScrapliTypeError: if value is not of type int + + """ + self.logger.debug(f"setting 'comms_prompt_search_depth' value to {value!r}") + + if not isinstance(value, int): + raise ScrapliTypeError + + self._base_channel_args.comms_prompt_search_depth = value + + @property + def comms_roughly_match_inputs(self) -> bool: + """ + Getter for `comms_roughly_match_inputs` attribute + + Args: + N/A + + Returns: + bool: comms_roughly_match_inputs bool + + Raises: + N/A + + """ + return self._base_channel_args.comms_roughly_match_inputs + + @comms_roughly_match_inputs.setter + def comms_roughly_match_inputs(self, value: bool) -> None: + """ + Setter for `comms_roughly_match_inputs` attribute + + Args: + value: int value for comms_roughly_match_inputs + + Returns: + None + + Raises: + ScrapliTypeError: if value is not of type bool + + """ + self.logger.debug(f"setting 'comms_roughly_match_inputs' value to {value!r}") + + if not isinstance(value, bool): + raise ScrapliTypeError + + self._base_channel_args.comms_roughly_match_inputs = value + @property def timeout_socket(self) -> float: """ diff --git a/scrapli/driver/core/arista_eos/async_driver.py b/scrapli/driver/core/arista_eos/async_driver.py index de250f52..7182f7f1 100644 --- a/scrapli/driver/core/arista_eos/async_driver.py +++ b/scrapli/driver/core/arista_eos/async_driver.py @@ -66,6 +66,7 @@ def __init__( timeout_transport: float = 30.0, timeout_ops: float = 30.0, comms_return_char: str = "\n", + comms_roughly_match_inputs: bool = False, ssh_config_file: Union[str, bool] = False, ssh_known_hosts_file: Union[str, bool] = False, on_init: Optional[Callable[..., Any]] = None, @@ -146,6 +147,7 @@ def __init__( timeout_transport=timeout_transport, timeout_ops=timeout_ops, comms_return_char=comms_return_char, + comms_roughly_match_inputs=comms_roughly_match_inputs, ssh_config_file=ssh_config_file, ssh_known_hosts_file=ssh_known_hosts_file, on_init=on_init, diff --git a/scrapli/driver/core/arista_eos/sync_driver.py b/scrapli/driver/core/arista_eos/sync_driver.py index 094f810b..b9f42ef1 100644 --- a/scrapli/driver/core/arista_eos/sync_driver.py +++ b/scrapli/driver/core/arista_eos/sync_driver.py @@ -66,6 +66,7 @@ def __init__( timeout_transport: float = 30.0, timeout_ops: float = 30.0, comms_return_char: str = "\n", + comms_roughly_match_inputs: bool = False, ssh_config_file: Union[str, bool] = False, ssh_known_hosts_file: Union[str, bool] = False, on_init: Optional[Callable[..., Any]] = None, @@ -146,6 +147,7 @@ def __init__( timeout_transport=timeout_transport, timeout_ops=timeout_ops, comms_return_char=comms_return_char, + comms_roughly_match_inputs=comms_roughly_match_inputs, ssh_config_file=ssh_config_file, ssh_known_hosts_file=ssh_known_hosts_file, on_init=on_init, diff --git a/scrapli/driver/core/cisco_iosxe/async_driver.py b/scrapli/driver/core/cisco_iosxe/async_driver.py index 5b85c78d..3193bf0f 100644 --- a/scrapli/driver/core/cisco_iosxe/async_driver.py +++ b/scrapli/driver/core/cisco_iosxe/async_driver.py @@ -66,6 +66,7 @@ def __init__( timeout_transport: float = 30.0, timeout_ops: float = 30.0, comms_return_char: str = "\n", + comms_roughly_match_inputs: bool = False, ssh_config_file: Union[str, bool] = False, ssh_known_hosts_file: Union[str, bool] = False, on_init: Optional[Callable[..., Any]] = None, @@ -143,6 +144,7 @@ def __init__( timeout_transport=timeout_transport, timeout_ops=timeout_ops, comms_return_char=comms_return_char, + comms_roughly_match_inputs=comms_roughly_match_inputs, ssh_config_file=ssh_config_file, ssh_known_hosts_file=ssh_known_hosts_file, on_init=on_init, diff --git a/scrapli/driver/core/cisco_iosxe/sync_driver.py b/scrapli/driver/core/cisco_iosxe/sync_driver.py index 441b20ec..6d21e153 100644 --- a/scrapli/driver/core/cisco_iosxe/sync_driver.py +++ b/scrapli/driver/core/cisco_iosxe/sync_driver.py @@ -66,6 +66,7 @@ def __init__( timeout_transport: float = 30.0, timeout_ops: float = 30.0, comms_return_char: str = "\n", + comms_roughly_match_inputs: bool = False, ssh_config_file: Union[str, bool] = False, ssh_known_hosts_file: Union[str, bool] = False, on_init: Optional[Callable[..., Any]] = None, @@ -143,6 +144,7 @@ def __init__( timeout_transport=timeout_transport, timeout_ops=timeout_ops, comms_return_char=comms_return_char, + comms_roughly_match_inputs=comms_roughly_match_inputs, ssh_config_file=ssh_config_file, ssh_known_hosts_file=ssh_known_hosts_file, on_init=on_init, diff --git a/scrapli/driver/core/cisco_iosxr/async_driver.py b/scrapli/driver/core/cisco_iosxr/async_driver.py index 4d1faec0..f2cd150d 100644 --- a/scrapli/driver/core/cisco_iosxr/async_driver.py +++ b/scrapli/driver/core/cisco_iosxr/async_driver.py @@ -65,6 +65,7 @@ def __init__( timeout_transport: float = 30.0, timeout_ops: float = 30.0, comms_return_char: str = "\n", + comms_roughly_match_inputs: bool = False, ssh_config_file: Union[str, bool] = False, ssh_known_hosts_file: Union[str, bool] = False, on_init: Optional[Callable[..., Any]] = None, @@ -142,6 +143,7 @@ def __init__( timeout_transport=timeout_transport, timeout_ops=timeout_ops, comms_return_char=comms_return_char, + comms_roughly_match_inputs=comms_roughly_match_inputs, ssh_config_file=ssh_config_file, ssh_known_hosts_file=ssh_known_hosts_file, on_init=on_init, diff --git a/scrapli/driver/core/cisco_iosxr/sync_driver.py b/scrapli/driver/core/cisco_iosxr/sync_driver.py index 284bd8c3..95f94495 100644 --- a/scrapli/driver/core/cisco_iosxr/sync_driver.py +++ b/scrapli/driver/core/cisco_iosxr/sync_driver.py @@ -68,6 +68,7 @@ def __init__( timeout_transport: float = 30.0, timeout_ops: float = 30.0, comms_return_char: str = "\n", + comms_roughly_match_inputs: bool = False, ssh_config_file: Union[str, bool] = False, ssh_known_hosts_file: Union[str, bool] = False, on_init: Optional[Callable[..., Any]] = None, @@ -145,6 +146,7 @@ def __init__( timeout_transport=timeout_transport, timeout_ops=timeout_ops, comms_return_char=comms_return_char, + comms_roughly_match_inputs=comms_roughly_match_inputs, ssh_config_file=ssh_config_file, ssh_known_hosts_file=ssh_known_hosts_file, on_init=on_init, diff --git a/scrapli/driver/core/cisco_nxos/async_driver.py b/scrapli/driver/core/cisco_nxos/async_driver.py index b3751f14..be7df6d9 100644 --- a/scrapli/driver/core/cisco_nxos/async_driver.py +++ b/scrapli/driver/core/cisco_nxos/async_driver.py @@ -66,6 +66,7 @@ def __init__( timeout_transport: float = 30.0, timeout_ops: float = 30.0, comms_return_char: str = "\n", + comms_roughly_match_inputs: bool = False, ssh_config_file: Union[str, bool] = False, ssh_known_hosts_file: Union[str, bool] = False, on_init: Optional[Callable[..., Any]] = None, @@ -146,6 +147,7 @@ def __init__( timeout_transport=timeout_transport, timeout_ops=timeout_ops, comms_return_char=comms_return_char, + comms_roughly_match_inputs=comms_roughly_match_inputs, ssh_config_file=ssh_config_file, ssh_known_hosts_file=ssh_known_hosts_file, on_init=on_init, diff --git a/scrapli/driver/core/cisco_nxos/sync_driver.py b/scrapli/driver/core/cisco_nxos/sync_driver.py index ba831a77..69ccd3a9 100644 --- a/scrapli/driver/core/cisco_nxos/sync_driver.py +++ b/scrapli/driver/core/cisco_nxos/sync_driver.py @@ -66,6 +66,7 @@ def __init__( timeout_transport: float = 30.0, timeout_ops: float = 30.0, comms_return_char: str = "\n", + comms_roughly_match_inputs: bool = False, ssh_config_file: Union[str, bool] = False, ssh_known_hosts_file: Union[str, bool] = False, on_init: Optional[Callable[..., Any]] = None, @@ -146,6 +147,7 @@ def __init__( timeout_transport=timeout_transport, timeout_ops=timeout_ops, comms_return_char=comms_return_char, + comms_roughly_match_inputs=comms_roughly_match_inputs, ssh_config_file=ssh_config_file, ssh_known_hosts_file=ssh_known_hosts_file, on_init=on_init, diff --git a/scrapli/driver/core/juniper_junos/async_driver.py b/scrapli/driver/core/juniper_junos/async_driver.py index c208319e..db98aea9 100644 --- a/scrapli/driver/core/juniper_junos/async_driver.py +++ b/scrapli/driver/core/juniper_junos/async_driver.py @@ -67,6 +67,7 @@ def __init__( timeout_transport: float = 30.0, timeout_ops: float = 30.0, comms_return_char: str = "\n", + comms_roughly_match_inputs: bool = False, ssh_config_file: Union[str, bool] = False, ssh_known_hosts_file: Union[str, bool] = False, on_init: Optional[Callable[..., Any]] = None, @@ -144,6 +145,7 @@ def __init__( timeout_transport=timeout_transport, timeout_ops=timeout_ops, comms_return_char=comms_return_char, + comms_roughly_match_inputs=comms_roughly_match_inputs, ssh_config_file=ssh_config_file, ssh_known_hosts_file=ssh_known_hosts_file, on_init=on_init, diff --git a/scrapli/driver/core/juniper_junos/sync_driver.py b/scrapli/driver/core/juniper_junos/sync_driver.py index 16549430..80351d2c 100644 --- a/scrapli/driver/core/juniper_junos/sync_driver.py +++ b/scrapli/driver/core/juniper_junos/sync_driver.py @@ -67,6 +67,7 @@ def __init__( timeout_transport: float = 30.0, timeout_ops: float = 30.0, comms_return_char: str = "\n", + comms_roughly_match_inputs: bool = False, ssh_config_file: Union[str, bool] = False, ssh_known_hosts_file: Union[str, bool] = False, on_init: Optional[Callable[..., Any]] = None, @@ -144,6 +145,7 @@ def __init__( timeout_transport=timeout_transport, timeout_ops=timeout_ops, comms_return_char=comms_return_char, + comms_roughly_match_inputs=comms_roughly_match_inputs, ssh_config_file=ssh_config_file, ssh_known_hosts_file=ssh_known_hosts_file, on_init=on_init, diff --git a/scrapli/driver/generic/async_driver.py b/scrapli/driver/generic/async_driver.py index 3a015ad0..ffb6bff4 100644 --- a/scrapli/driver/generic/async_driver.py +++ b/scrapli/driver/generic/async_driver.py @@ -35,6 +35,7 @@ def __init__( timeout_ops: float = 30.0, comms_prompt_pattern: str = r"^\S{0,48}[#>$~@:\]]\s*$", comms_return_char: str = "\n", + comms_roughly_match_inputs: bool = False, ssh_config_file: Union[str, bool] = False, ssh_known_hosts_file: Union[str, bool] = False, on_init: Optional[Callable[..., Any]] = None, @@ -64,6 +65,7 @@ def __init__( timeout_ops=timeout_ops, comms_prompt_pattern=comms_prompt_pattern, comms_return_char=comms_return_char, + comms_roughly_match_inputs=comms_roughly_match_inputs, ssh_config_file=ssh_config_file, ssh_known_hosts_file=ssh_known_hosts_file, on_init=on_init, diff --git a/scrapli/driver/generic/sync_driver.py b/scrapli/driver/generic/sync_driver.py index 4adfe2cb..54dcbd21 100644 --- a/scrapli/driver/generic/sync_driver.py +++ b/scrapli/driver/generic/sync_driver.py @@ -35,6 +35,7 @@ def __init__( timeout_ops: float = 30.0, comms_prompt_pattern: str = r"^\S{0,48}[#>$~@:\]]\s*$", comms_return_char: str = "\n", + comms_roughly_match_inputs: bool = False, ssh_config_file: Union[str, bool] = False, ssh_known_hosts_file: Union[str, bool] = False, on_init: Optional[Callable[..., Any]] = None, @@ -64,6 +65,7 @@ def __init__( timeout_ops=timeout_ops, comms_prompt_pattern=comms_prompt_pattern, comms_return_char=comms_return_char, + comms_roughly_match_inputs=comms_roughly_match_inputs, ssh_config_file=ssh_config_file, ssh_known_hosts_file=ssh_known_hosts_file, on_init=on_init, diff --git a/scrapli/driver/network/async_driver.py b/scrapli/driver/network/async_driver.py index b23bfd1b..d68493f8 100644 --- a/scrapli/driver/network/async_driver.py +++ b/scrapli/driver/network/async_driver.py @@ -29,6 +29,7 @@ def __init__( timeout_transport: float = 30.0, timeout_ops: float = 30.0, comms_return_char: str = "\n", + comms_roughly_match_inputs: bool = False, ssh_config_file: Union[str, bool] = False, ssh_known_hosts_file: Union[str, bool] = False, on_init: Optional[Callable[..., Any]] = None, @@ -64,6 +65,7 @@ def __init__( timeout_transport=timeout_transport, timeout_ops=timeout_ops, comms_return_char=comms_return_char, + comms_roughly_match_inputs=comms_roughly_match_inputs, ssh_config_file=ssh_config_file, ssh_known_hosts_file=ssh_known_hosts_file, on_init=on_init, @@ -489,6 +491,7 @@ async def send_configs( stop_on_failed: bool = False, privilege_level: str = "", eager: bool = False, + eager_input: bool = False, timeout_ops: Optional[float] = None, ) -> MultiResponse: """ @@ -510,6 +513,8 @@ async def send_configs( eager: if eager is True we do not read until prompt is seen at each command sent to the channel. Do *not* use this unless you know what you are doing as it is possible that it can make scrapli less reliable! + eager_input: when true does *not* try to read our input off the channel -- generally + this should be left alone unless you know what you are doing! timeout_ops: timeout ops value for this operation; only sets the timeout_ops value for the duration of the operation, value is reset to initial value after operation is completed. Note that this is the timeout value PER CONFIG sent, not for the total @@ -537,6 +542,7 @@ async def send_configs( failed_when_contains=failed_when_contains, stop_on_failed=stop_on_failed, eager=eager, + eager_input=eager_input, timeout_ops=timeout_ops, ) @@ -554,6 +560,7 @@ async def send_config( stop_on_failed: bool = False, privilege_level: str = "", eager: bool = False, + eager_input: bool = False, timeout_ops: Optional[float] = None, ) -> Response: """ @@ -575,6 +582,8 @@ async def send_config( eager: if eager is True we do not read until prompt is seen at each command sent to the channel. Do *not* use this unless you know what you are doing as it is possible that it can make scrapli less reliable! + eager_input: when true does *not* try to read our input off the channel -- generally + this should be left alone unless you know what you are doing! timeout_ops: timeout ops value for this operation; only sets the timeout_ops value for the duration of the operation, value is reset to initial value after operation is completed. Note that this is the timeout value PER CONFIG sent, not for the total @@ -597,6 +606,7 @@ async def send_config( stop_on_failed=stop_on_failed, privilege_level=privilege_level, eager=eager, + eager_input=eager_input, timeout_ops=timeout_ops, ) return self._post_send_config(config=config, multi_response=multi_response) @@ -610,6 +620,7 @@ async def send_configs_from_file( stop_on_failed: bool = False, privilege_level: str = "", eager: bool = False, + eager_input: bool = False, timeout_ops: Optional[float] = None, ) -> MultiResponse: """ @@ -631,6 +642,8 @@ async def send_configs_from_file( eager: if eager is True we do not read until prompt is seen at each command sent to the channel. Do *not* use this unless you know what you are doing as it is possible that it can make scrapli less reliable! + eager_input: when true does *not* try to read our input off the channel -- generally + this should be left alone unless you know what you are doing! timeout_ops: timeout ops value for this operation; only sets the timeout_ops value for the duration of the operation, value is reset to initial value after operation is completed. Note that this is the timeout value PER CONFIG sent, not for the total @@ -652,5 +665,6 @@ async def send_configs_from_file( stop_on_failed=stop_on_failed, privilege_level=privilege_level, eager=eager, + eager_input=eager_input, timeout_ops=timeout_ops, ) diff --git a/scrapli/driver/network/sync_driver.py b/scrapli/driver/network/sync_driver.py index a2187ffd..19457e34 100644 --- a/scrapli/driver/network/sync_driver.py +++ b/scrapli/driver/network/sync_driver.py @@ -29,6 +29,7 @@ def __init__( timeout_transport: float = 30.0, timeout_ops: float = 30.0, comms_return_char: str = "\n", + comms_roughly_match_inputs: bool = False, ssh_config_file: Union[str, bool] = False, ssh_known_hosts_file: Union[str, bool] = False, on_init: Optional[Callable[..., Any]] = None, @@ -64,6 +65,7 @@ def __init__( timeout_transport=timeout_transport, timeout_ops=timeout_ops, comms_return_char=comms_return_char, + comms_roughly_match_inputs=comms_roughly_match_inputs, ssh_config_file=ssh_config_file, ssh_known_hosts_file=ssh_known_hosts_file, on_init=on_init, @@ -489,6 +491,7 @@ def send_configs( stop_on_failed: bool = False, privilege_level: str = "", eager: bool = False, + eager_input: bool = False, timeout_ops: Optional[float] = None, ) -> MultiResponse: """ @@ -510,6 +513,8 @@ def send_configs( eager: if eager is True we do not read until prompt is seen at each command sent to the channel. Do *not* use this unless you know what you are doing as it is possible that it can make scrapli less reliable! + eager_input: when true does *not* try to read our input off the channel -- generally + this should be left alone unless you know what you are doing! timeout_ops: timeout ops value for this operation; only sets the timeout_ops value for the duration of the operation, value is reset to initial value after operation is completed. Note that this is the timeout value PER CONFIG sent, not for the total @@ -537,6 +542,7 @@ def send_configs( failed_when_contains=failed_when_contains, stop_on_failed=stop_on_failed, eager=eager, + eager_input=eager_input, timeout_ops=timeout_ops, ) @@ -554,6 +560,7 @@ def send_config( stop_on_failed: bool = False, privilege_level: str = "", eager: bool = False, + eager_input: bool = False, timeout_ops: Optional[float] = None, ) -> Response: """ @@ -575,6 +582,8 @@ def send_config( eager: if eager is True we do not read until prompt is seen at each command sent to the channel. Do *not* use this unless you know what you are doing as it is possible that it can make scrapli less reliable! + eager_input: when true does *not* try to read our input off the channel -- generally + this should be left alone unless you know what you are doing! timeout_ops: timeout ops value for this operation; only sets the timeout_ops value for the duration of the operation, value is reset to initial value after operation is completed. Note that this is the timeout value PER CONFIG sent, not for the total @@ -597,6 +606,7 @@ def send_config( stop_on_failed=stop_on_failed, privilege_level=privilege_level, eager=eager, + eager_input=eager_input, timeout_ops=timeout_ops, ) return self._post_send_config(config=config, multi_response=multi_response) @@ -610,6 +620,7 @@ def send_configs_from_file( stop_on_failed: bool = False, privilege_level: str = "", eager: bool = False, + eager_input: bool = False, timeout_ops: Optional[float] = None, ) -> MultiResponse: """ @@ -631,6 +642,8 @@ def send_configs_from_file( eager: if eager is True we do not read until prompt is seen at each command sent to the channel. Do *not* use this unless you know what you are doing as it is possible that it can make scrapli less reliable! + eager_input: when true does *not* try to read our input off the channel -- generally + this should be left alone unless you know what you are doing! timeout_ops: timeout ops value for this operation; only sets the timeout_ops value for the duration of the operation, value is reset to initial value after operation is completed. Note that this is the timeout value PER CONFIG sent, not for the total @@ -652,5 +665,6 @@ def send_configs_from_file( stop_on_failed=stop_on_failed, privilege_level=privilege_level, eager=eager, + eager_input=eager_input, timeout_ops=timeout_ops, ) diff --git a/scrapli/factory.py b/scrapli/factory.py index d3a2ca35..990c9a41 100644 --- a/scrapli/factory.py +++ b/scrapli/factory.py @@ -44,6 +44,7 @@ def _build_provided_kwargs_dict( # pylint: disable=R0914 timeout_transport: Optional[float], timeout_ops: Optional[float], comms_return_char: Optional[str], + comms_roughly_match_inputs: Optional[bool], ssh_config_file: Optional[Union[str, bool]], ssh_known_hosts_file: Optional[Union[str, bool]], on_init: Optional[Callable[..., Any]], @@ -97,6 +98,7 @@ def _build_provided_kwargs_dict( # pylint: disable=R0914 "timeout_transport": timeout_transport, "timeout_ops": timeout_ops, "comms_return_char": comms_return_char, + "comms_roughly_match_inputs": comms_roughly_match_inputs, "ssh_config_file": ssh_config_file, "ssh_known_hosts_file": ssh_known_hosts_file, "on_init": on_init, @@ -351,6 +353,7 @@ def __new__( # pylint: disable=R0914 timeout_transport: Optional[float] = None, timeout_ops: Optional[float] = None, comms_return_char: Optional[str] = None, + comms_roughly_match_inputs: Optional[bool] = None, ssh_config_file: Optional[Union[str, bool]] = None, ssh_known_hosts_file: Optional[Union[str, bool]] = None, on_init: Optional[Callable[..., Any]] = None, @@ -388,6 +391,12 @@ def __new__( # pylint: disable=R0914 timeout_transport: timeout for ssh|telnet transport in seconds timeout_ops: timeout for ssh channel operations comms_return_char: character to use to send returns to host + comms_roughly_match_inputs: indicates if the channel should "roughly" match inputs sent + to the device. If False (default) inputs are strictly checked, as in any input + *must* be read back exactly on the channel. When set to True all input chars *must* + be read back in order in the output and all chars must be present, but the *exact* + input string does not need to be seen. This can be useful if a device echoes back + extra characters or rewrites the terminal during command input. ssh_config_file: string to path for ssh config file, True to use default ssh config file or False to ignore default ssh config file ssh_known_hosts_file: string to path for ssh known hosts file, True to use default known @@ -474,6 +483,7 @@ class that extends the driver, instead allowing the community platforms to simpl timeout_transport=timeout_transport, timeout_ops=timeout_ops, comms_return_char=comms_return_char, + comms_roughly_match_inputs=comms_roughly_match_inputs, ssh_config_file=ssh_config_file, ssh_known_hosts_file=ssh_known_hosts_file, on_init=on_init, @@ -642,6 +652,7 @@ def __new__( # pylint: disable=R0914 timeout_transport: Optional[float] = None, timeout_ops: Optional[float] = None, comms_return_char: Optional[str] = None, + comms_roughly_match_inputs: Optional[bool] = None, ssh_config_file: Optional[Union[str, bool]] = None, ssh_known_hosts_file: Optional[Union[str, bool]] = None, on_init: Optional[Callable[..., Any]] = None, @@ -679,6 +690,12 @@ def __new__( # pylint: disable=R0914 timeout_transport: timeout for ssh|telnet transport in seconds timeout_ops: timeout for ssh channel operations comms_return_char: character to use to send returns to host + comms_roughly_match_inputs: indicates if the channel should "roughly" match inputs sent + to the device. If False (default) inputs are strictly checked, as in any input + *must* be read back exactly on the channel. When set to True all input chars *must* + be read back in order in the output and all chars must be present, but the *exact* + input string does not need to be seen. This can be useful if a device echoes back + extra characters or rewrites the terminal during command input. ssh_config_file: string to path for ssh config file, True to use default ssh config file or False to ignore default ssh config file ssh_known_hosts_file: string to path for ssh known hosts file, True to use default known @@ -765,6 +782,7 @@ class that extends the driver, instead allowing the community platforms to simpl timeout_transport=timeout_transport, timeout_ops=timeout_ops, comms_return_char=comms_return_char, + comms_roughly_match_inputs=comms_roughly_match_inputs, ssh_config_file=ssh_config_file, ssh_known_hosts_file=ssh_known_hosts_file, on_init=on_init, diff --git a/scrapli/helper.py b/scrapli/helper.py index 1465b773..dd15f2a4 100644 --- a/scrapli/helper.py +++ b/scrapli/helper.py @@ -6,7 +6,7 @@ from io import BytesIO, TextIOWrapper from pathlib import Path from shutil import get_terminal_size -from typing import Any, Dict, List, Optional, TextIO, Union +from typing import Any, Dict, List, Optional, TextIO, Tuple, Union from warnings import warn from scrapli.exceptions import ScrapliValueError @@ -306,3 +306,57 @@ def user_warning(title: str, message: str) -> None: if Settings.SUPPRESS_USER_WARNINGS is False: warn(warning_message) + + +def output_roughly_contains_input(input_: bytes, output: bytes) -> True: + """ + Return True if all characters in input are contained in order in the given output. + + Args: + input_: the input presented to a device + output: the output echoed on the channel + + Returns: + bool: True if the input is "roughly" contained in the output, otherwise False + + Raises: + N/A + + """ + if output in input_: + return True + + if len(output) < len(input_): + return False + + for char in input_: + should_continue, output = _roughly_contains_input_iter_output_for_input_char(char, output) + + if not should_continue: + return False + + return True + + +def _roughly_contains_input_iter_output_for_input_char( + char: int, output: bytes +) -> Tuple[bool, bytes]: + """ + Iterates over chars in the output to find input, returns remaining output bytes if input found. + + Args: + char: input char to find in output + output: the output echoed on the channel + + Returns: + output: bool indicating char was found, and remaining output chars to continue searching in + + Raises: + N/A + + """ + for index, output_char in enumerate(output): + if char == output_char: + return True, output[index + 1 :] + + return False, b"" diff --git a/tests/unit/channel/test_async_channel.py b/tests/unit/channel/test_async_channel.py index fc3264b7..d648c277 100644 --- a/tests/unit/channel/test_async_channel.py +++ b/tests/unit/channel/test_async_channel.py @@ -96,6 +96,28 @@ async def _read(cls): assert actual_read_output == expected_read_output +async def test_channel_read_until_input_roughly(monkeypatch, async_channel): + async_channel._base_channel_args.comms_roughly_match_inputs = True + + expected_read_output = b"read_data\nthis foo is bar my baz input" + _read_counter = 0 + + async def _read(cls): + nonlocal _read_counter + + if _read_counter == 0: + _read_counter += 1 + return b"read_data\x1b[0;0m\n" + + return b"this foo is bar my baz input" + + monkeypatch.setattr("scrapli.transport.base.async_transport.AsyncTransport.read", _read) + + actual_read_output = await async_channel._read_until_input(channel_input=b"thisismyinput") + + assert actual_read_output == expected_read_output + + async def test_channel_read_until_input_no_input(async_channel): assert await async_channel._read_until_input(channel_input=b"") == b"" diff --git a/tests/unit/channel/test_sync_channel.py b/tests/unit/channel/test_sync_channel.py index 8c373f49..8b69c5ab 100644 --- a/tests/unit/channel/test_sync_channel.py +++ b/tests/unit/channel/test_sync_channel.py @@ -88,6 +88,28 @@ def _read(cls): assert actual_read_output == expected_read_output +def test_channel_read_until_input_roughly(monkeypatch, sync_channel): + sync_channel._base_channel_args.comms_roughly_match_inputs = True + + expected_read_output = b"read_data\nthis foo is bar my baz input" + _read_counter = 0 + + def _read(cls): + nonlocal _read_counter + + if _read_counter == 0: + _read_counter += 1 + return b"read_data\x1b[0;0m\n" + + return b"this foo is bar my baz input" + + monkeypatch.setattr("scrapli.transport.base.sync_transport.Transport.read", _read) + + actual_read_output = sync_channel._read_until_input(channel_input=b"thisismyinput") + + assert actual_read_output == expected_read_output + + def test_channel_read_until_input_no_input(sync_channel): assert sync_channel._read_until_input(channel_input=b"") == b""