Skip to content

Commit

Permalink
feat: roughly contains input (driver) option, add eager input option …
Browse files Browse the repository at this point in the history
…to config methods
  • Loading branch information
carlmontanari committed Jan 18, 2024
1 parent ab20115 commit c472d65
Show file tree
Hide file tree
Showing 23 changed files with 293 additions and 9 deletions.
8 changes: 7 additions & 1 deletion docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,13 @@ Changelog

- Expand `arista_eos` prompt pattern to handle super long config sections (things like qos queues and such). Thanks
to @MarkRudenko over in scrapli_cfg repo for finding this and providing the fix!

- Add `comms_roughly_match_inputs` option -- this uses a "rough" match when looking for inputs (commands/configs you
send) in output printed back on the channel. Basically, if all input characters show up in the output in the correct
order, then we assume the input was found. Of course this could be less "exacting" but it also *probably* is ok 99%
of the time :)
- Added an `eager_input` option to send operations -- this option completely skips checking for inputs being echoed back
on the channel. With the addition of the `comms_roughly_match_inputs` option this is *probably* unnecessary, but
could be useful for some corner cases.

## 2023.07.30

Expand Down
14 changes: 11 additions & 3 deletions scrapli/channel/async_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from scrapli.channel.base_channel import BaseChannel, BaseChannelArgs
from scrapli.decorators import timeout_wrapper
from scrapli.exceptions import ScrapliAuthenticationFailed, ScrapliTimeout
from scrapli.helper import output_roughly_contains_input
from scrapli.transport.base import AsyncTransport


Expand Down Expand Up @@ -104,9 +105,16 @@ async def _read_until_input(self, channel_input: bytes) -> bytes:
while True:
buf += await self.read()

# replace any backspace chars (particular problem w/ junos), and remove any added spaces
# this is just for comparison of the inputs to what was read from channel
if processed_channel_input in b"".join(buf.lower().replace(b"\x08", b"").split()):
if not self._base_channel_args.comms_roughly_match_inputs:
# replace any backspace chars (particular problem w/ junos), and remove any added
# spaces this is just for comparison of the inputs to what was read from channel
# note (2024) this would be worked around by using the roughly contains search,
# *but* that is slower (probably immaterially for most people but... ya know...)
processed_buf = b"".join(buf.lower().replace(b"\x08", b"").split())

if processed_channel_input in processed_buf:
return buf
elif output_roughly_contains_input(input_=processed_channel_input, output=buf):
return buf

async def _read_until_prompt(self, buf: bytes = b"") -> bytes:
Expand Down
7 changes: 7 additions & 0 deletions scrapli/channel/base_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,12 @@ class BaseChannelArgs:
comms_prompt_search_depth: depth of the buffer to search in for searching for the prompt
in "read_until_prompt"; smaller number here will generally be faster, though may be less
reliable; default value is 1000
comms_roughly_match_inputs: indicates if the channel should "roughly" match inputs sent
to the device. If False (default) inputs are strictly checked, as in any input
*must* be read back exactly on the channel. When set to True all input chars *must*
be read back in order in the output and all chars must be present, but the *exact*
input string does not need to be seen. This can be useful if a device echoes back
extra characters or rewrites the terminal during command input.
timeout_ops: timeout_ops to assign to the channel, see above
channel_log: log "channel" output -- this would be the output you would normally see on a
terminal. If `True` logs to `scrapli_channel.log`, if a string is provided, logs to
Expand All @@ -61,6 +67,7 @@ class BaseChannelArgs:
comms_prompt_pattern: str = r"^[a-z0-9.\-@()/:]{1,32}[#>$]$"
comms_return_char: str = "\n"
comms_prompt_search_depth: int = 1000
comms_roughly_match_inputs: bool = False
timeout_ops: float = 30.0
channel_log: Union[str, bool, BytesIO] = False
channel_log_mode: str = "write"
Expand Down
14 changes: 11 additions & 3 deletions scrapli/channel/sync_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from scrapli.channel.base_channel import BaseChannel, BaseChannelArgs
from scrapli.decorators import timeout_wrapper
from scrapli.exceptions import ScrapliAuthenticationFailed, ScrapliConnectionError, ScrapliTimeout
from scrapli.helper import output_roughly_contains_input
from scrapli.transport.base import Transport


Expand Down Expand Up @@ -104,9 +105,16 @@ def _read_until_input(self, channel_input: bytes) -> bytes:
while True:
buf += self.read()

# replace any backspace chars (particular problem w/ junos), and remove any added spaces
# this is just for comparison of the inputs to what was read from channel
if processed_channel_input in b"".join(buf.lower().replace(b"\x08", b"").split()):
if not self._base_channel_args.comms_roughly_match_inputs:
# replace any backspace chars (particular problem w/ junos), and remove any added
# spaces this is just for comparison of the inputs to what was read from channel
# note (2024) this would be worked around by using the roughly contains search,
# *but* that is slower (probably immaterially for most people but... ya know...)
processed_buf = b"".join(buf.lower().replace(b"\x08", b"").split())

if processed_channel_input in processed_buf:
return buf
elif output_roughly_contains_input(input_=processed_channel_input, output=buf):
return buf

def _read_until_prompt(self, buf: bytes = b"") -> bytes:
Expand Down
89 changes: 88 additions & 1 deletion scrapli/driver/base/base_driver.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""scrapli.driver.base.base_driver"""
"""scrapli.driver.base.base_driver""" # noqa: C0302
import importlib
from dataclasses import fields
from io import BytesIO
Expand Down Expand Up @@ -34,6 +34,7 @@ def __init__(
timeout_ops: float = 30.0,
comms_prompt_pattern: str = r"^[a-z0-9.\-@()/:]{1,48}[#>$]\s*$",
comms_return_char: str = "\n",
comms_roughly_match_inputs: bool = False,
ssh_config_file: Union[str, bool] = False,
ssh_known_hosts_file: Union[str, bool] = False,
on_init: Optional[Callable[..., Any]] = None,
Expand Down Expand Up @@ -83,6 +84,12 @@ def __init__(
should be mostly sorted for you if using network drivers (i.e. `IOSXEDriver`).
Lastly, the case insensitive is just a convenience factor so i can be lazy.
comms_return_char: character to use to send returns to host
comms_roughly_match_inputs: indicates if the channel should "roughly" match inputs sent
to the device. If False (default) inputs are strictly checked, as in any input
*must* be read back exactly on the channel. When set to True all input chars *must*
be read back in order in the output and all chars must be present, but the *exact*
input string does not need to be seen. This can be useful if a device echoes back
extra characters or rewrites the terminal during command input.
ssh_config_file: string to path for ssh config file, True to use default ssh config file
or False to ignore default ssh config file
ssh_known_hosts_file: string to path for ssh known hosts file, True to use default known
Expand Down Expand Up @@ -149,6 +156,7 @@ class that extends the driver, instead allowing the community platforms to simpl
auth_passphrase_pattern=auth_passphrase_pattern,
comms_prompt_pattern=comms_prompt_pattern,
comms_return_char=comms_return_char,
comms_roughly_match_inputs=comms_roughly_match_inputs,
timeout_ops=timeout_ops,
channel_log=channel_log,
channel_log_mode=channel_log_mode,
Expand Down Expand Up @@ -249,6 +257,7 @@ def __repr__(self) -> str:
f"timeout_ops={self._base_channel_args.timeout_ops!r}, "
f"comms_prompt_pattern={self._base_channel_args.comms_prompt_pattern!r}, "
f"comms_return_char={self._base_channel_args.comms_return_char!r}, "
f"comms_roughly_match_inputs={self._base_channel_args.comms_roughly_match_inputs!r}, "
f"ssh_config_file={self.ssh_config_file!r}, "
f"ssh_known_hosts_file={self.ssh_known_hosts_file!r}, "
f"on_init={self.on_init!r}, "
Expand Down Expand Up @@ -738,6 +747,84 @@ def comms_return_char(self, value: str) -> None:

self._base_channel_args.comms_return_char = value

@property
def comms_prompt_search_depth(self) -> int:
"""
Getter for `comms_prompt_search_depth` attribute
Args:
N/A
Returns:
int: comms_prompt_search_depth int
Raises:
N/A
"""
return self._base_channel_args.comms_prompt_search_depth

Check warning on line 765 in scrapli/driver/base/base_driver.py

View check run for this annotation

Codecov / codecov/patch

scrapli/driver/base/base_driver.py#L765

Added line #L765 was not covered by tests

@comms_prompt_search_depth.setter
def comms_prompt_search_depth(self, value: int) -> None:
"""
Setter for `comms_prompt_search_depth` attribute
Args:
value: int value for comms_prompt_search_depth
Returns:
None
Raises:
ScrapliTypeError: if value is not of type int
"""
self.logger.debug(f"setting 'comms_prompt_search_depth' value to {value!r}")

Check warning on line 782 in scrapli/driver/base/base_driver.py

View check run for this annotation

Codecov / codecov/patch

scrapli/driver/base/base_driver.py#L782

Added line #L782 was not covered by tests

if not isinstance(value, int):
raise ScrapliTypeError

Check warning on line 785 in scrapli/driver/base/base_driver.py

View check run for this annotation

Codecov / codecov/patch

scrapli/driver/base/base_driver.py#L784-L785

Added lines #L784 - L785 were not covered by tests

self._base_channel_args.comms_prompt_search_depth = value

Check warning on line 787 in scrapli/driver/base/base_driver.py

View check run for this annotation

Codecov / codecov/patch

scrapli/driver/base/base_driver.py#L787

Added line #L787 was not covered by tests

@property
def comms_roughly_match_inputs(self) -> bool:
"""
Getter for `comms_roughly_match_inputs` attribute
Args:
N/A
Returns:
bool: comms_roughly_match_inputs bool
Raises:
N/A
"""
return self._base_channel_args.comms_roughly_match_inputs

Check warning on line 804 in scrapli/driver/base/base_driver.py

View check run for this annotation

Codecov / codecov/patch

scrapli/driver/base/base_driver.py#L804

Added line #L804 was not covered by tests

@comms_roughly_match_inputs.setter
def comms_roughly_match_inputs(self, value: bool) -> None:
"""
Setter for `comms_roughly_match_inputs` attribute
Args:
value: int value for comms_roughly_match_inputs
Returns:
None
Raises:
ScrapliTypeError: if value is not of type bool
"""
self.logger.debug(f"setting 'comms_roughly_match_inputs' value to {value!r}")

Check warning on line 821 in scrapli/driver/base/base_driver.py

View check run for this annotation

Codecov / codecov/patch

scrapli/driver/base/base_driver.py#L821

Added line #L821 was not covered by tests

if not isinstance(value, bool):
raise ScrapliTypeError

Check warning on line 824 in scrapli/driver/base/base_driver.py

View check run for this annotation

Codecov / codecov/patch

scrapli/driver/base/base_driver.py#L823-L824

Added lines #L823 - L824 were not covered by tests

self._base_channel_args.comms_roughly_match_inputs = value

Check warning on line 826 in scrapli/driver/base/base_driver.py

View check run for this annotation

Codecov / codecov/patch

scrapli/driver/base/base_driver.py#L826

Added line #L826 was not covered by tests

@property
def timeout_socket(self) -> float:
"""
Expand Down
2 changes: 2 additions & 0 deletions scrapli/driver/core/arista_eos/async_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def __init__(
timeout_transport: float = 30.0,
timeout_ops: float = 30.0,
comms_return_char: str = "\n",
comms_roughly_match_inputs: bool = False,
ssh_config_file: Union[str, bool] = False,
ssh_known_hosts_file: Union[str, bool] = False,
on_init: Optional[Callable[..., Any]] = None,
Expand Down Expand Up @@ -146,6 +147,7 @@ def __init__(
timeout_transport=timeout_transport,
timeout_ops=timeout_ops,
comms_return_char=comms_return_char,
comms_roughly_match_inputs=comms_roughly_match_inputs,
ssh_config_file=ssh_config_file,
ssh_known_hosts_file=ssh_known_hosts_file,
on_init=on_init,
Expand Down
2 changes: 2 additions & 0 deletions scrapli/driver/core/arista_eos/sync_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def __init__(
timeout_transport: float = 30.0,
timeout_ops: float = 30.0,
comms_return_char: str = "\n",
comms_roughly_match_inputs: bool = False,
ssh_config_file: Union[str, bool] = False,
ssh_known_hosts_file: Union[str, bool] = False,
on_init: Optional[Callable[..., Any]] = None,
Expand Down Expand Up @@ -146,6 +147,7 @@ def __init__(
timeout_transport=timeout_transport,
timeout_ops=timeout_ops,
comms_return_char=comms_return_char,
comms_roughly_match_inputs=comms_roughly_match_inputs,
ssh_config_file=ssh_config_file,
ssh_known_hosts_file=ssh_known_hosts_file,
on_init=on_init,
Expand Down
2 changes: 2 additions & 0 deletions scrapli/driver/core/cisco_iosxe/async_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def __init__(
timeout_transport: float = 30.0,
timeout_ops: float = 30.0,
comms_return_char: str = "\n",
comms_roughly_match_inputs: bool = False,
ssh_config_file: Union[str, bool] = False,
ssh_known_hosts_file: Union[str, bool] = False,
on_init: Optional[Callable[..., Any]] = None,
Expand Down Expand Up @@ -143,6 +144,7 @@ def __init__(
timeout_transport=timeout_transport,
timeout_ops=timeout_ops,
comms_return_char=comms_return_char,
comms_roughly_match_inputs=comms_roughly_match_inputs,
ssh_config_file=ssh_config_file,
ssh_known_hosts_file=ssh_known_hosts_file,
on_init=on_init,
Expand Down
2 changes: 2 additions & 0 deletions scrapli/driver/core/cisco_iosxe/sync_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def __init__(
timeout_transport: float = 30.0,
timeout_ops: float = 30.0,
comms_return_char: str = "\n",
comms_roughly_match_inputs: bool = False,
ssh_config_file: Union[str, bool] = False,
ssh_known_hosts_file: Union[str, bool] = False,
on_init: Optional[Callable[..., Any]] = None,
Expand Down Expand Up @@ -143,6 +144,7 @@ def __init__(
timeout_transport=timeout_transport,
timeout_ops=timeout_ops,
comms_return_char=comms_return_char,
comms_roughly_match_inputs=comms_roughly_match_inputs,
ssh_config_file=ssh_config_file,
ssh_known_hosts_file=ssh_known_hosts_file,
on_init=on_init,
Expand Down
2 changes: 2 additions & 0 deletions scrapli/driver/core/cisco_iosxr/async_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def __init__(
timeout_transport: float = 30.0,
timeout_ops: float = 30.0,
comms_return_char: str = "\n",
comms_roughly_match_inputs: bool = False,
ssh_config_file: Union[str, bool] = False,
ssh_known_hosts_file: Union[str, bool] = False,
on_init: Optional[Callable[..., Any]] = None,
Expand Down Expand Up @@ -142,6 +143,7 @@ def __init__(
timeout_transport=timeout_transport,
timeout_ops=timeout_ops,
comms_return_char=comms_return_char,
comms_roughly_match_inputs=comms_roughly_match_inputs,
ssh_config_file=ssh_config_file,
ssh_known_hosts_file=ssh_known_hosts_file,
on_init=on_init,
Expand Down
2 changes: 2 additions & 0 deletions scrapli/driver/core/cisco_iosxr/sync_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def __init__(
timeout_transport: float = 30.0,
timeout_ops: float = 30.0,
comms_return_char: str = "\n",
comms_roughly_match_inputs: bool = False,
ssh_config_file: Union[str, bool] = False,
ssh_known_hosts_file: Union[str, bool] = False,
on_init: Optional[Callable[..., Any]] = None,
Expand Down Expand Up @@ -145,6 +146,7 @@ def __init__(
timeout_transport=timeout_transport,
timeout_ops=timeout_ops,
comms_return_char=comms_return_char,
comms_roughly_match_inputs=comms_roughly_match_inputs,
ssh_config_file=ssh_config_file,
ssh_known_hosts_file=ssh_known_hosts_file,
on_init=on_init,
Expand Down
2 changes: 2 additions & 0 deletions scrapli/driver/core/cisco_nxos/async_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def __init__(
timeout_transport: float = 30.0,
timeout_ops: float = 30.0,
comms_return_char: str = "\n",
comms_roughly_match_inputs: bool = False,
ssh_config_file: Union[str, bool] = False,
ssh_known_hosts_file: Union[str, bool] = False,
on_init: Optional[Callable[..., Any]] = None,
Expand Down Expand Up @@ -146,6 +147,7 @@ def __init__(
timeout_transport=timeout_transport,
timeout_ops=timeout_ops,
comms_return_char=comms_return_char,
comms_roughly_match_inputs=comms_roughly_match_inputs,
ssh_config_file=ssh_config_file,
ssh_known_hosts_file=ssh_known_hosts_file,
on_init=on_init,
Expand Down
2 changes: 2 additions & 0 deletions scrapli/driver/core/cisco_nxos/sync_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def __init__(
timeout_transport: float = 30.0,
timeout_ops: float = 30.0,
comms_return_char: str = "\n",
comms_roughly_match_inputs: bool = False,
ssh_config_file: Union[str, bool] = False,
ssh_known_hosts_file: Union[str, bool] = False,
on_init: Optional[Callable[..., Any]] = None,
Expand Down Expand Up @@ -146,6 +147,7 @@ def __init__(
timeout_transport=timeout_transport,
timeout_ops=timeout_ops,
comms_return_char=comms_return_char,
comms_roughly_match_inputs=comms_roughly_match_inputs,
ssh_config_file=ssh_config_file,
ssh_known_hosts_file=ssh_known_hosts_file,
on_init=on_init,
Expand Down
2 changes: 2 additions & 0 deletions scrapli/driver/core/juniper_junos/async_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def __init__(
timeout_transport: float = 30.0,
timeout_ops: float = 30.0,
comms_return_char: str = "\n",
comms_roughly_match_inputs: bool = False,
ssh_config_file: Union[str, bool] = False,
ssh_known_hosts_file: Union[str, bool] = False,
on_init: Optional[Callable[..., Any]] = None,
Expand Down Expand Up @@ -144,6 +145,7 @@ def __init__(
timeout_transport=timeout_transport,
timeout_ops=timeout_ops,
comms_return_char=comms_return_char,
comms_roughly_match_inputs=comms_roughly_match_inputs,
ssh_config_file=ssh_config_file,
ssh_known_hosts_file=ssh_known_hosts_file,
on_init=on_init,
Expand Down
2 changes: 2 additions & 0 deletions scrapli/driver/core/juniper_junos/sync_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def __init__(
timeout_transport: float = 30.0,
timeout_ops: float = 30.0,
comms_return_char: str = "\n",
comms_roughly_match_inputs: bool = False,
ssh_config_file: Union[str, bool] = False,
ssh_known_hosts_file: Union[str, bool] = False,
on_init: Optional[Callable[..., Any]] = None,
Expand Down Expand Up @@ -144,6 +145,7 @@ def __init__(
timeout_transport=timeout_transport,
timeout_ops=timeout_ops,
comms_return_char=comms_return_char,
comms_roughly_match_inputs=comms_roughly_match_inputs,
ssh_config_file=ssh_config_file,
ssh_known_hosts_file=ssh_known_hosts_file,
on_init=on_init,
Expand Down
2 changes: 2 additions & 0 deletions scrapli/driver/generic/async_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def __init__(
timeout_ops: float = 30.0,
comms_prompt_pattern: str = r"^\S{0,48}[#>$~@:\]]\s*$",
comms_return_char: str = "\n",
comms_roughly_match_inputs: bool = False,
ssh_config_file: Union[str, bool] = False,
ssh_known_hosts_file: Union[str, bool] = False,
on_init: Optional[Callable[..., Any]] = None,
Expand Down Expand Up @@ -64,6 +65,7 @@ def __init__(
timeout_ops=timeout_ops,
comms_prompt_pattern=comms_prompt_pattern,
comms_return_char=comms_return_char,
comms_roughly_match_inputs=comms_roughly_match_inputs,
ssh_config_file=ssh_config_file,
ssh_known_hosts_file=ssh_known_hosts_file,
on_init=on_init,
Expand Down
Loading

0 comments on commit c472d65

Please sign in to comment.