Skip to content

Commit

Permalink
adding functional tests and cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
carlmontanari committed Feb 1, 2020
1 parent 96ae96f commit 1f1cfcb
Show file tree
Hide file tree
Showing 28 changed files with 708 additions and 68 deletions.
60 changes: 58 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,13 +1,69 @@
lint:
python -m isort -rc -w 100 -y .
python -m isort -rc -y .
python -m black .
python -m pylama .
python -m pydocstyle .
find nssh -type f \( -iname "*.py" ! -iname "ptyprocess.py" \) | xargs darglint

cov:
python -m pytest \
--cov=nssh \
--cov-report html \
--cov-report term \
tests/

cov_unit:
python -m pytest \
--cov=nssh \
--cov-report html \
--cov-report term \
tests/unit/
tests/unit/

cov_functional:
python -m pytest \
--cov=nssh \
--cov-report html \
--cov-report term \
tests/functional/

test:
python -m pytest tests/

test_unit:
python -m pytest tests/unit/

test_functional:
python -m pytest tests/functional/

test_iosxe:
python -m pytest -v \
tests/unit \
tests/functional/cisco_iosxe

test_nxos:
python -m pytest -v \
tests/unit \
tests/functional/cisco_nxos

test_iosxr:
python -m pytest -v \
tests/unit \
tests/functional/cisco_iosxr

test_eos:
python -m pytest -v \
tests/unit \
tests/functional/arista_eos

test_junos:
python -m pytest -v \
tests/unit \
tests/functional/juniper_junos

.PHONY: docs
docs:
python -m pdoc \
--html \
--output-dir docs \
nssh \
--force
52 changes: 44 additions & 8 deletions nssh/channel/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,40 @@ def __init__(
self.comms_ansi = comms_ansi
self.timeout_ops = timeout_ops

def __str__(self) -> str:
"""
Magic str method for Channel
Args:
N/A # noqa
Returns:
N/A # noqa
Raises:
N/A # noqa
"""
return "nssh Channel Object"

def __repr__(self) -> str:
"""
Magic repr method for Channel
Args:
N/A # noqa
Returns:
repr: repr for class object
Raises:
N/A # noqa
"""
class_dict = self.__dict__.copy()
class_dict.pop("transport")
return f"nssh Channel {class_dict}"

def _restructure_output(self, output: bytes, strip_prompt: bool = False) -> bytes:
"""
Clean up preceding empty lines, and strip prompt if desired
Expand All @@ -70,8 +104,12 @@ def _restructure_output(self, output: bytes, strip_prompt: bool = False) -> byte
"""
output = normalize_lines(output)
# purge empty rows before actual output
output = b"\n".join([row for row in output.splitlines() if row])

# TODO -- purge empty rows before actual output
# this was used to remove duplicate line feeds in output, but that causes some issues for
# testing where we want to match the normal output we see as users... so i think this
# should be removed -- or optional?
# output = b"\n".join([row for row in output.splitlines() if row])

if not strip_prompt:
return output
Expand Down Expand Up @@ -116,7 +154,6 @@ def _read_until_input(self, channel_input: bytes) -> bytes:
"""
output = b""
# TODO -- make sure the appending works same as += (who knows w/ bytes!)
while channel_input not in output:
output += self._read_chunk()
return output
Expand Down Expand Up @@ -148,7 +185,7 @@ def _read_until_prompt(self, output: bytes = b"", prompt: str = "") -> bytes:
# parsing if a prompt-like thing is at the end of the output
# TODO -- at one point this was bytes -> str w/ `unicode-escape` have not tested
# on many live devices if keeping this all bytes works!!!
output = re.sub(b"\r", b"\n", output.strip())
output = re.sub(b"\r", b"", output.strip())
channel_match = re.search(prompt_pattern, output)
if channel_match:
self.transport.set_blocking(True)
Expand Down Expand Up @@ -209,12 +246,11 @@ def send_inputs(
result = Result(self.transport.host, channel_input)
raw_result, processed_result = self._send_input(channel_input, strip_prompt)
result.raw_result = raw_result.decode()
result.record_result(processed_result.decode())
result.record_result(processed_result.decode().strip())
results.append(result)
return results

# TODO - uncomment!
#@operation_timeout("timeout_ops")
@operation_timeout("timeout_ops")
def _send_input(self, channel_input: str, strip_prompt: bool) -> Tuple[bytes, bytes]:
"""
Send input to device and return results
Expand Down Expand Up @@ -280,7 +316,7 @@ def send_inputs_interact(
channel_input, expectation, response, finale, hidden_response
)
result.raw_result = raw_result.decode()
result.record_result(processed_result.decode())
result.record_result(processed_result.decode().strip())
results.append(result)
return results

Expand Down
3 changes: 2 additions & 1 deletion nssh/driver/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""nssh.driver"""
from nssh.driver.driver import NSSH
from nssh.driver.network_driver import NetworkDriver

__all__ = ("NSSH",)
__all__ = ("NSSH", "NetworkDriver")
4 changes: 4 additions & 0 deletions nssh/driver/core/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
"""nssh.driver.core"""
from nssh.driver.core.cisco_iosxe.driver import IOSXEDriver

__all__ = ("IOSXEDriver",)
7 changes: 4 additions & 3 deletions nssh/driver/core/cisco_iosxe/driver.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""nssh.driver.core.cisco_iosxe.driver"""
from typing import Any, Dict

from nssh.driver.core.driver import NetworkDriver, PrivilegeLevel
from nssh.driver import NetworkDriver
from nssh.driver.network_driver import PrivilegeLevel

PRIVS = {
"exec": (
Expand Down Expand Up @@ -36,7 +37,7 @@
PrivilegeLevel(
r"^[a-z0-9.\-@/:]{1,32}\(config\)#$",
"configuration",
"priv",
"privilege_exec",
"end",
None,
None,
Expand All @@ -50,7 +51,7 @@
PrivilegeLevel(
r"^[a-z0-9.\-@/:]{1,32}\(config[a-z0-9.\-@/:]{1,16}\)#$",
"special_configuration",
"priv",
"privilege_exec",
"end",
None,
None,
Expand Down
52 changes: 48 additions & 4 deletions nssh/driver/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import logging
import os
import re
from typing import Any, Callable, Dict, Tuple, Union
from types import TracebackType
from typing import Any, Callable, Dict, Optional, Tuple, Type, Union

from nssh.channel import CHANNEL_ARGS, Channel
from nssh.helper import get_external_function, validate_external_function
Expand All @@ -27,7 +28,6 @@
"paramiko": MIKO_TRANSPORT_ARGS,
}


LOG = logging.getLogger("nssh_base")


Expand All @@ -39,6 +39,7 @@ def __init__(
auth_username: str = "",
auth_password: str = "",
auth_public_key: str = "",
auth_strict_key: bool = True,
timeout_socket: int = 5,
timeout_ssh: int = 5000,
timeout_ops: int = 10,
Expand All @@ -60,14 +61,17 @@ def __init__(
N/A # noqa
"""
# TODO -- docstring
self.host = host.strip()
if not isinstance(port, int):
raise TypeError(f"port should be int, got {type(port)}")
self.port = port

self.auth_username: str = ""
self.auth_password: str = ""
self.auth_public_key: bytes = b""
if not isinstance(auth_strict_key, bool):
raise TypeError(f"auth_strict_key should be bool, got {type(auth_strict_key)}")
self.auth_strict_key = auth_strict_key
self._setup_auth(auth_username, auth_password, auth_public_key)

self.timeout_socket = int(timeout_socket)
Expand Down Expand Up @@ -189,7 +193,7 @@ def _setup_session(
except TypeError:
self.session_disable_paging = session_disable_paging
else:
self.session_disable_paging = session_disable_paging
self.session_disable_paging = "terminal length 0"

@staticmethod
def _set_session_pre_login_handler(
Expand Down Expand Up @@ -301,3 +305,43 @@ def close(self) -> None:
"""
self.transport.close()

def __enter__(self) -> "NSSH":
"""
Enter method for context manager
Args:
N/A # noqa
Returns:
self: instance of self
Raises:
N/A # noqa
"""
self.open()
return self

def __exit__(
self,
exception_type: Optional[Type[BaseException]],
exception_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
"""
Exit method to cleanup for context manager
Args:
exception_type: exception type being raised
exception_value: message from exception being raised
traceback: traceback from exception being raised
Returns:
N/A # noqa
Raises:
N/A # noqa
"""
self.close()
39 changes: 32 additions & 7 deletions nssh/driver/core/driver.py → nssh/driver/network_driver.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,34 @@
"""nssh.driver.core.driver"""
"""nssh.base"""
import collections
import logging
import re
from io import TextIOWrapper
from typing import Any, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

from nssh import NSSH
from nssh.driver.driver import NSSH
from nssh.exceptions import CouldNotAcquirePrivLevel, UnknownPrivLevel
from nssh.helper import _textfsm_get_template, get_prompt_pattern, textfsm_parse
from nssh.result import Result
from nssh.transport import (
MIKO_TRANSPORT_ARGS,
SSH2_TRANSPORT_ARGS,
SYSTEM_SSH_TRANSPORT_ARGS,
MikoTransport,
SSH2Transport,
SystemSSHTransport,
Transport,
)

TRANSPORT_CLASS: Dict[str, Callable[..., Transport]] = {
"system": SystemSSHTransport,
"ssh2": SSH2Transport,
"paramiko": MikoTransport,
}
TRANSPORT_ARGS: Dict[str, Tuple[str, ...]] = {
"system": SYSTEM_SSH_TRANSPORT_ARGS,
"ssh2": SSH2_TRANSPORT_ARGS,
"paramiko": MIKO_TRANSPORT_ARGS,
}

PrivilegeLevel = collections.namedtuple(
"PrivilegeLevel",
Expand All @@ -25,6 +46,8 @@

PRIVS: Dict[str, PrivilegeLevel] = {}

LOG = logging.getLogger("nssh_base")


class NetworkDriver(NSSH):
def __init__(self, auth_secondary: str = "", **kwargs: Any):
Expand Down Expand Up @@ -55,14 +78,12 @@ def _determine_current_priv(self, current_prompt: str) -> PrivilegeLevel:
current_prompt: string of current prompt
Returns:
priv_level: NamedTuple of current privilege level
PrivilegeLevel: NamedTuple of current privilege level
Raises:
UnknownPrivLevel: if privilege level cannot be determined # noqa
# NOTE: darglint raises DAR401 for some reason hence the noqa...
UnknownPrivLevel: if privilege level cannot be determined
"""
# TODO -- fix above note...
for priv_level in self.privs.values():
prompt_pattern = get_prompt_pattern("", priv_level.pattern)
if re.search(prompt_pattern, current_prompt.encode()):
Expand Down Expand Up @@ -126,6 +147,10 @@ def _deescalate(self) -> None:
current_priv = self._determine_current_priv(self.channel.get_prompt())
if current_priv.deescalate:
next_priv = self.privs.get(current_priv.deescalate_priv, None)
if not next_priv:
raise UnknownPrivLevel(
"NetworkDriver has no default priv levels, set them or use a network driver"
)
self.channel.comms_prompt_pattern = next_priv.pattern
self.channel.send_inputs(current_priv.deescalate)

Expand Down
3 changes: 0 additions & 3 deletions nssh/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,9 +138,6 @@ def _textfsm_get_template(platform: str, command: str) -> Optional[TextIO]:
"""
try:
from textfsm.clitable import CliTable # pylint: disable=C0415

# TODO -- dont think we *need* ntc_templates since we can pass string path to template
import ntc_templates # pylint: disable=C0415,W0611
except ModuleNotFoundError as exc:
err = f"Module '{exc.name}' not installed!"
msg = f"***** {err} {'*' * (80 - len(err))}"
Expand Down
Loading

0 comments on commit 1f1cfcb

Please sign in to comment.