diff --git a/printrun/device.py b/printrun/device.py new file mode 100644 index 000000000..1657d0c36 --- /dev/null +++ b/printrun/device.py @@ -0,0 +1,416 @@ +# This file is part of the Printrun suite. +# +# Printrun is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Printrun is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Printrun. If not, see . + +# Standard libraries: +import os +import platform +import re +import selectors +import socket +import time + +# Third-party libraries +import serial + +READ_EMPTY = b'' +"""Constant to represent empty or no data""" + +READ_EOF = None +"""Constant to represent an end-of-file""" + + +class Device(): + """Handler for serial and web socket connections. + + Provides the same functions for both so it abstracts what kind of + connection is being used. + + Parameters + ---------- + port : str, optional + Either a device name, such as '/dev/ttyUSB0' or 'COM3', or an URL with + port, such as '192.168.0.10:80' or 'http://www.example.com:8080'. + baudrate : int, optional + Communication speed in bit/s, such as 9600, 115200 or 250000. + (Default is 9600) + force_dtr : bool or None, optional + On serial connections, force the DTR bit to a specific logic level + (1 or 0) after a successful connection. Not all OS/drivers support + this functionality. By default it is set to "None" to let the system + handle it automatically. + parity_workaround : bool, optional + On serial connections, enable/disable a workaround on parity + checking. Not all platforms need to do this parity workaround, and + some drivers don't support it. By default it is disabled. + + Attributes + ---------- + is_connected + has_flow_control + + """ + + def __init__(self, port=None, baudrate=9600, force_dtr=None, + parity_workaround=False): + self.port = port + self.baudrate = baudrate + self.force_dtr = force_dtr + self.parity_workaround = parity_workaround + + # Private + self._device = None + self._is_connected = False + self._hostname = None + self._socketfile = None + self._port_number = None + self._read_buffer = [] + self._selector = None + self._timeout = 0.25 + self._type = None + + if port is not None: + self._parse_type() + + def connect(self, port=None, baudrate=None): + """Establishes the connection to the device. + + Parameters + ---------- + port : str, optional + See `port` attribute. Only required if it was not provided + already. + baudrate : int, optional + See `baudrate` attribute. Only required if it was not provided + already. + + Raises + ------ + DeviceError + If an error occurred when attempting to connect. + + """ + if port is not None: + self.port = port + if baudrate is not None: + self.baudrate = baudrate + + if self.port is not None: + self._parse_type() + getattr(self, "_connect_" + self._type)() + else: + raise DeviceError("No port or URL specified") + + def disconnect(self): + """Terminates the connection to the device.""" + if self._device is not None: + getattr(self, "_disconnect_" + self._type)() + + @property + def is_connected(self): + """True if connection to peer is alive. + + Warnings + -------- + Current implementation for socket connections only tracks status of + the connection but does not actually check it. So, if it is used to + check the connection before sending data, it might fail to prevent an + error being raised due to a lost connection. + + """ + if self._device is not None: + return getattr(self, "_is_connected_" + self._type)() + return False + + @property + def has_flow_control(self): + """True if the device has flow control mechanics.""" + if self._type == 'socket': + return True + return False + + def readline(self) -> bytes: + """Read one line from the device stream. + + Returns + ------- + bytes + Array containing the feedback received from the + device. `READ_EMPTY` will be returned if no data was + available. `READ_EOF` is returned if connection was terminated at + the other end. + + Raises + ------ + DeviceError + If connected peer is unreachable. + + """ + # TODO: silent fail on no device? return timeout? + if self._device is not None: + return getattr(self, "_readline_" + self._type)() + raise DeviceError("Attempted to read when disconnected") + + def reset(self): + """Attempt to reset the connection to the device. + + Warnings + -------- + Current implementation has no effect on socket connections. + + """ + if self._device is not None: + if self._type == 'serial': + getattr(self, "_reset_" + self._type)() + + def write(self, data: bytes): + """Write data to the connected peer. + + Parameters + ---------- + data: bytes + The bytes data to be written. This should be of type `bytes` (or + compatible such as `bytearray` or `memoryview`). Unicode strings + must be encoded. + + Raises + ------ + DeviceError + If connected peer is unreachable. + TypeError + If `data` is not of 'bytes' type. + + """ + if self._device is not None: + getattr(self, "_write_" + self._type)(data) + else: + raise DeviceError("Attempted to write when disconnected") + + def _parse_type(self): + # Guess which type of connection is being used + if self._is_url(self.port): + self._type = 'socket' + else: + self._type = 'serial' + + def _is_url(self, text): + # TODO: Rearrange to avoid long line + host_regexp = re.compile("^(([0-9]|[1-9][0-9]|1[0-9]{2}|2[0-4][0-9]|25[0-5])\.){3}([0-9]|[1-9][0-9]|1[0-9]{2}|2[0-4][0-9]|25[0-5])$|^(([a-zA-Z0-9]|[a-zA-Z0-9][a-zA-Z0-9\-]*[a-zA-Z0-9])\.)*([A-Za-z0-9]|[A-Za-z0-9][A-Za-z0-9\-]*[A-Za-z0-9])$") + if ':' in text: + bits = text.split(":") + if len(bits) == 2: + self._hostname = bits[0] + try: + self._port_number = int(bits[1]) + if (host_regexp.match(self._hostname) and + 1 <= self._port_number <= 65535): + return True + except: + # TODO: avoid catch-all clauses + pass + return False + + # ------------------------------------------------------------------------ + # Serial Functions + # ------------------------------------------------------------------------ + def _connect_serial(self): + # Disable HUPCL + # TODO: Check if still required + self._disable_ttyhup() + + try: + # TODO: Check if this trick is still needed + if self.parity_workaround: + self._device = serial.Serial(port=self.port, + baudrate=self.baudrate, + timeout=0.25, + parity=serial.PARITY_ODD) + self._device.close() + self._device.parity = serial.PARITY_NONE + else: + self._device = serial.Serial(baudrate=self.baudrate, + timeout=0.25, + parity=serial.PARITY_NONE) + self._device.port = self.port + + # TODO: Check if this is still required + if self.force_dtr is not None: + self._device.dtr = self.force_dtr + + self._device.open() + + except (serial.SerialException, IOError) as e: + msg = "Could not connect to serial port '{}'".format(self.port) + raise DeviceError(msg, e) from e + + def _is_connected_serial(self): + return self._device.is_open + + def _disconnect_serial(self): + try: + self._device.close() + except serial.SerialException as e: + msg = "Error on serial disconnection" + raise DeviceError(msg, e) from e + + def _readline_serial(self): + try: + # Serial.readline() returns b'' (aka `READ_EMPTY`) on timeout + return self._device.readline() + except (serial.SerialException, OSError) as e: + msg = f"Unable to read from serial port '{self.port}'" + raise DeviceError(msg, e) from e + + def _reset_serial(self): + self._device.dtr = True + time.sleep(0.2) + self._device.dtr = False + + def _write_serial(self, data): + try: + self._device.write(data) + except serial.SerialException as e: + msg = "Unable to write to serial port '{self.port}'" + raise DeviceError(msg, e) from e + + def _disable_ttyhup(self): + if platform.system() == "Linux": + os.system("stty -F %s -hup" % self.port) + + # ------------------------------------------------------------------------ + # Socket Functions + # ------------------------------------------------------------------------ + def _connect_socket(self): + self._device = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self._device.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + self._timeout = 0.25 + self._device.settimeout(1.0) + + try: + self._device.connect((self._hostname, self._port_number)) + # A single read timeout raises OSError for all later reads + # probably since python 3.5 use non blocking instead + self._device.settimeout(0) + self._socketfile = self._device.makefile('rwb', buffering=0) + self._selector = selectors.DefaultSelector() + self._selector.register(self._device, selectors.EVENT_READ) + self._is_connected = True + + except OSError as e: + self._disconnect_socket() + msg = "Could not connect to {}:{}".format(self._hostname, + self._port_number) + raise DeviceError(msg, e) from e + + def _is_connected_socket(self): + # TODO: current implementation tracks status of connection but + # does not actually check it. Ref. is_connected() + return self._is_connected + + def _disconnect_socket(self): + self._is_connected = False + try: + if self._socketfile is not None: + self._socketfile.close() + if self._selector is not None: + self._selector.unregister(self._device) + self._selector.close() + self._selector = None + self._device.close() + except OSError as e: + msg = "Error on socket disconnection" + raise DeviceError(msg, e) from e + + def _readline_socket(self): + SYS_AGAIN = None # python's marker for timeout/no data + # SYS_EOF = b'' # python's marker for EOF + try: + line = self._readline_buf() + if line: + return line + chunk_size = 256 + while True: + chunk = self._socketfile.read(chunk_size) + if (chunk is SYS_AGAIN and + self._selector.select(self._timeout)): + chunk = self._socketfile.read(chunk_size) + if chunk: + self._read_buffer.append(chunk) + line = self._readline_buf() + if line: + return line + elif chunk is SYS_AGAIN: + return READ_EMPTY + else: # chunk is SYS_EOF + line = b''.join(self._read_buffer) + self._read_buffer = [] + if line: + return line + self._is_connected = False + return READ_EOF + except OSError as e: + self._is_connected = False + msg = ("Unable to read from {}:{}. Connection lost" + ).format(self._hostname, self._port_number) + raise DeviceError(msg, e) from e + + def _readline_buf(self): + # Try to readline from buffer + if self._read_buffer: + chunk = self._read_buffer[-1] + eol = chunk.find(b'\n') + if eol >= 0: + line = b''.join(self._read_buffer[:-1]) + chunk[:(eol+1)] + self._read_buffer = [] + if eol + 1 < len(chunk): + self._read_buffer.append(chunk[(eol+1):]) + return line + return READ_EMPTY + + def _write_socket(self, data): + try: + self._socketfile.write(data) + try: + self._socketfile.flush() + except socket.timeout: + pass + except (OSError, RuntimeError) as e: + self._is_connected = False + msg = ("Unable to write to {}:{}. Connection lost" + ).format(self._hostname, self._port_number) + raise DeviceError(msg, e) from e + + +class DeviceError(Exception): + """Raised on any connection error. + + One exception groups all connection errors regardless of the underlying + connection or error type. + + Parameters + ---------- + msg : str + Error message. + cause : Exception, optional + Underlying error. + + Attributes + ---------- + cause + + """ + + def __init__(self, msg, cause=None): + super().__init__(msg) + self.cause = cause diff --git a/printrun/printcore.py b/printrun/printcore.py index ebb62afcc..22ad85381 100644 --- a/printrun/printcore.py +++ b/printrun/printcore.py @@ -20,22 +20,15 @@ print("You need to run this on Python 3") sys.exit(-1) -import serial -from select import error as SelectError import threading from queue import Queue, Empty as QueueEmpty import time -import platform -import os import logging import traceback -import errno -import socket -import re -import selectors from functools import wraps, reduce from collections import deque from printrun import gcoder +from printrun import device from .utils import set_utf8_locale, install_locale, decode_utf8 try: set_utf8_locale() @@ -52,19 +45,6 @@ def inner(*args, **kw): inner.lock = threading.Lock() return inner -def control_ttyhup(port, disable_hup): - """Controls the HUPCL""" - if platform.system() == "Linux": - if disable_hup: - os.system("stty -F %s -hup" % port) - else: - os.system("stty -F %s hup" % port) - -def enable_hup(port): - control_ttyhup(port, False) - -def disable_hup(port): - control_ttyhup(port, True) PR_EOF = None #printrun's marker for EOF PR_AGAIN = b'' #printrun's marker for timeout/no data @@ -167,11 +147,6 @@ def __init__(self, port = None, baud = None, dtr=None): self.readline_buf = [] self.selector = None self.event_handler = PRINTCORE_HANDLER - # Not all platforms need to do this parity workaround, and some drivers - # don't support it. Limit it to platforms that actually require it - # here to avoid doing redundant work elsewhere and potentially breaking - # things. - self.needs_parity_workaround = platform.system() == "linux" and os.path.exists("/etc/debian") for handler in self.event_handler: try: handler.on_init() except: logging.error(traceback.format_exc()) @@ -213,19 +188,9 @@ def disconnect(self): self.print_thread.join() self._stop_sender() try: - if self.selector is not None: - self.selector.unregister(self.printer_tcp) - self.selector.close() - self.selector = None - if self.printer_tcp is not None: - self.printer_tcp.close() - self.printer_tcp = None - self.printer.close() - except socket.error: - logging.error(traceback.format_exc()) - pass - except OSError: - logging.error(traceback.format_exc()) + self.printer.disconnect() + except device.DeviceError: + self.logError(traceback.format_exc()) pass for handler in self.event_handler: try: handler.on_disconnect() @@ -247,76 +212,15 @@ def connect(self, port = None, baud = None, dtr=None): if dtr is not None: self.dtr = dtr if self.port is not None and self.baud is not None: - # Connect to socket if "port" is an IP, device if not - host_regexp = re.compile("^(([0-9]|[1-9][0-9]|1[0-9]{2}|2[0-4][0-9]|25[0-5])\.){3}([0-9]|[1-9][0-9]|1[0-9]{2}|2[0-4][0-9]|25[0-5])$|^(([a-zA-Z0-9]|[a-zA-Z0-9][a-zA-Z0-9\-]*[a-zA-Z0-9])\.)*([A-Za-z0-9]|[A-Za-z0-9][A-Za-z0-9\-]*[A-Za-z0-9])$") - is_serial = True - if ":" in self.port: - bits = self.port.split(":") - if len(bits) == 2: - hostname = bits[0] - try: - port_number = int(bits[1]) - if host_regexp.match(hostname) and 1 <= port_number <= 65535: - is_serial = False - except: - pass self.writefailures = 0 - if not is_serial: - self.printer_tcp = socket.socket(socket.AF_INET, - socket.SOCK_STREAM) - self.printer_tcp.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) - self.timeout = 0.25 - self.printer_tcp.settimeout(1.0) - try: - self.printer_tcp.connect((hostname, port_number)) - #a single read timeout raises OSError for all later reads - #probably since python 3.5 - #use non blocking instead - self.printer_tcp.settimeout(0) - self.printer = self.printer_tcp.makefile('rwb', buffering=0) - self.selector = selectors.DefaultSelector() - self.selector.register(self.printer_tcp, selectors.EVENT_READ) - except socket.error as e: - if(e.strerror is None): e.strerror="" - self.logError(_("Could not connect to %s:%s:") % (hostname, port_number) + - "\n" + _("Socket error %s:") % e.errno + - "\n" + e.strerror) - self.printer = None - self.printer_tcp.close() - self.printer_tcp = None - return - else: - disable_hup(self.port) - self.printer_tcp = None - try: - if self.needs_parity_workaround: - self.printer = serial.Serial(port = self.port, - baudrate = self.baud, - timeout = 0.25, - parity = serial.PARITY_ODD) - self.printer.close() - self.printer.parity = serial.PARITY_NONE - else: - self.printer = serial.Serial(baudrate = self.baud, - timeout = 0.25, - parity = serial.PARITY_NONE) - self.printer.port = self.port - try: #this appears not to work on many platforms, so we're going to call it but not care if it fails - self.printer.dtr = dtr - except: - #self.logError(_("Could not set DTR on this platform")) #not sure whether to output an error message - pass - self.printer.open() - except serial.SerialException as e: - self.logError(_("Could not connect to %s at baudrate %s:") % (self.port, self.baud) + - "\n" + _("Serial error: %s") % e) - self.printer = None - return - except IOError as e: - self.logError(_("Could not connect to %s at baudrate %s:") % (self.port, self.baud) + - "\n" + _("IO error: %s") % e) - self.printer = None - return + self.printer = device.Device() + self.printer.force_dtr = self.dtr + try: + self.printer.connect(self.port, self.baud) + except device.DeviceError as e: + self.logError("Connection error: %s" % e) + self.printer = None + return for handler in self.event_handler: try: handler.on_connect() except: logging.error(traceback.format_exc()) @@ -335,57 +239,15 @@ def reset(self): hardware DTR flow control. It has no effect on socket connections. """ - if self.printer and not self.printer_tcp: - self.printer.dtr = 1 - time.sleep(0.2) - self.printer.dtr = 0 - - def _readline_buf(self): - "Try to readline from buffer" - if len(self.readline_buf): - chunk = self.readline_buf[-1] - eol = chunk.find(b'\n') - if eol >= 0: - line = b''.join(self.readline_buf[:-1]) + chunk[:(eol+1)] - self.readline_buf = [] - if eol + 1 < len(chunk): - self.readline_buf.append(chunk[(eol+1):]) - return line - return PR_AGAIN - - def _readline_nb(self): - "Non blocking readline. Socket based files do not support non blocking or timeouting readline" - if self.printer_tcp: - line = self._readline_buf() - if line: - return line - chunk_size = 256 - while True: - chunk = self.printer.read(chunk_size) - if chunk is SYS_AGAIN and self.selector.select(self.timeout): - chunk = self.printer.read(chunk_size) - #print('_readline_nb chunk', chunk, type(chunk)) - if chunk: - self.readline_buf.append(chunk) - line = self._readline_buf() - if line: - return line - elif chunk is SYS_AGAIN: - return PR_AGAIN - else: - #chunk == b'' means EOF - line = b''.join(self.readline_buf) - self.readline_buf = [] - self.stop_read_thread = True - return line if line else PR_EOF - else: # serial port - return self.printer.readline() + self.printer.reset() def _readline(self): try: - line_bytes = self._readline_nb() - if line_bytes is PR_EOF: - self.logError(_("Can't read from printer (disconnected?). line_bytes is None")) + line_bytes = self.printer.readline() + if line_bytes is device.READ_EOF: + self.logError("Can't read from printer (disconnected?)." + + " line_bytes is None") + self.stop_read_thread = True return PR_EOF line = line_bytes.decode('utf-8') @@ -400,38 +262,20 @@ def _readline(self): if self.loud: logging.info("RECV: %s" % line.rstrip()) return line except UnicodeDecodeError: - self.logError(_("Got rubbish reply from %s at baudrate %s:") % (self.port, self.baud) + - "\n" + _("Maybe a bad baudrate?")) + msg = ("Got rubbish reply from {0} at baudrate {1}:\n" + "Maybe a bad baudrate?").format(self.port, self.baud) + self.logError(msg) return None - except serial.SerialException as e: - self.logError(_("Can't read from printer (disconnected?) (SerialException): {0}").format(decode_utf8(str(e)))) + except device.DeviceError as e: + msg = ("Can't read from printer (disconnected?) {0}" + ).format(decode_utf8(str(e))) + self.logError(msg) return None - except socket.error as e: - self.logError(_("Can't read from printer (disconnected?) (Socket error {0}): {1}").format(e.errno, decode_utf8(e.strerror))) - return None - except (OSError, SelectError) as e: - # OSError and SelectError are the same thing since python 3.3 - if self.printer_tcp: - # SelectError branch, assume select is used only for socket printers - if len(e.args) > 1 and 'Bad file descriptor' in e.args[1]: - self.logError(_("Can't read from printer (disconnected?) (SelectError {0}): {1}").format(e.errno, decode_utf8(e.strerror))) - return None - else: - self.logError(_("SelectError ({0}): {1}").format(e.errno, decode_utf8(e.strerror))) - raise - else: - # OSError branch, serial printers - if e.errno == errno.EAGAIN: # Not a real error, no data was available - return "" - self.logError(_("Can't read from printer (disconnected?) (OS Error {0}): {1}").format(e.errno, e.strerror)) - return None def _listen_can_continue(self): - if self.printer_tcp: - return not self.stop_read_thread and self.printer return (not self.stop_read_thread and self.printer - and self.printer.is_open) + and self.printer.is_connected) def _listen_until_online(self): while not self.online and self._listen_can_continue(): @@ -762,7 +606,7 @@ def _sendnext(self): time.sleep(0.001) # Only wait for oks when using serial connections or when not using tcp # in streaming mode - if not self.printer_tcp or not self.tcp_streaming_mode: + if not self.printer.has_flow_control or not self.tcp_streaming_mode: self.clear = False if not (self.printing and self.printer and self.online): self.clear = True @@ -835,7 +679,7 @@ def _sendnext(self): def _send(self, command, lineno = 0, calcchecksum = False): # Only add checksums if over serial (tcp does the flow control itself) - if calcchecksum and not self.printer_tcp: + if calcchecksum and not self.printer.has_flow_control: prefix = "N" + str(lineno) + " " + command command = prefix + "*" + str(self._checksum(prefix)) if "M110" not in command: @@ -860,23 +704,8 @@ def _send(self, command, lineno = 0, calcchecksum = False): except: self.logError(traceback.format_exc()) try: self.printer.write((command + "\n").encode('ascii')) - if self.printer_tcp: - try: - self.printer.flush() - except socket.timeout: - pass self.writefailures = 0 - except socket.error as e: - if e.errno is None: - self.logError(_("Can't write to printer (disconnected ?):") + - "\n" + traceback.format_exc()) - else: - self.logError(_("Can't write to printer (disconnected?) (Socket error {0}): {1}").format(e.errno, decode_utf8(e.strerror))) - self.writefailures += 1 - except serial.SerialException as e: - self.logError(_("Can't write to printer (disconnected?) (SerialException): {0}").format(decode_utf8(str(e)))) - self.writefailures += 1 - except RuntimeError: - self.logError("Socket connection broken, disconnected.\n" + - traceback.format_exc()) + except device.DeviceError as e: + self.logError("Can't write to printer (disconnected?)" + "{0}".format(e)) self.writefailures += 1 diff --git a/tests/test_device.py b/tests/test_device.py new file mode 100644 index 000000000..ae8a97eec --- /dev/null +++ b/tests/test_device.py @@ -0,0 +1,410 @@ +"""Test suite for `printrun/device.py`""" +# How to run the tests (requires Python 3.11+): +# python3 -m unittest discover tests + +# Standard libraries: +import socket +import unittest +from unittest import mock + +# Third-party libraries: +import serial + +# Custom libraries: +# pylint: disable-next=no-name-in-module +from printrun import device + + +def mock_sttyhup(cls): + """Fake stty control""" + # Needed to avoid error: + # "stty: /mocked/port: No such file or directory" + cls.enterClassContext( + mock.patch("printrun.device.Device._disable_ttyhup")) + + +def patch_serial(function, **kwargs): + """Patch a function of serial.Serial""" + return mock.patch(f"serial.Serial.{function}", **kwargs) + + +def patch_serial_is_open(): + """Patch the serial.Serial class and make `is_open` always True""" + class_mock = mock.create_autospec(serial.Serial) + instance_mock = class_mock.return_value + instance_mock.is_open = True + return mock.patch("serial.Serial", class_mock) + + +def patch_socket(function, **kwargs): + """Patch a function of socket.socket""" + return mock.patch(f"socket.socket.{function}", **kwargs) + + +def patch_socketio(function, **kwargs): + """Patch a function of socket.SocketIO""" + return mock.patch(f"socket.SocketIO.{function}", **kwargs) + + +def setup_serial(test): + """Set up a Device through a mocked serial connection""" + dev = device.Device() + test.addCleanup(dev.disconnect) + mocked_open = test.enterContext(patch_serial("open")) + dev.connect("/mocked/port") + + return dev, mocked_open + + +def setup_socket(test): + """Set up a Device through a mocked socket connection""" + dev = device.Device() + test.addCleanup(dev.disconnect) + mocked_socket = test.enterContext(patch_socket("connect")) + dev.connect("127.0.0.1:80") + + return dev, mocked_socket + + +class TestInit(unittest.TestCase): + """Test Device constructor""" + + def test_type_serial(self): + """Check detecting serial devices""" + dev = device.Device("/any/port") + + with self.subTest("`serial` type is set"): + # pylint: disable-next=protected-access + self.assertEqual(dev._type, "serial") + + with self.subTest("No flow control is set"): + self.assertFalse(dev.has_flow_control) + + def test_type_socket(self): + """Check detecting socket devices""" + dev = device.Device("127.0.0.1:80") + + with self.subTest("Check `socket` type is set"): + # pylint: disable-next=protected-access + self.assertEqual(dev._type, "socket") + + with self.subTest("Check flow control is set"): + self.assertTrue(dev.has_flow_control) + + def test_default_type(self): + """`serial` type is assigned by default when type unknown""" + # If URL cannot be identified, a serial port is assumed + dev = device.Device("/any/port:") + # pylint: disable-next=protected-access + self.assertEqual(dev._type, "serial") + + +class TestDisconnect(unittest.TestCase): + """Test disconnect functionality""" + + @classmethod + def setUpClass(cls): + mock_sttyhup(cls) + + def test_silent_on_no_device(self): + """No error is raised when disconnecting a device not connected""" + dev = device.Device() + dev.disconnect() + + def test_socket_erorr(self): + """DeviceError is raised if socket fails at disconnect""" + dev, _ = setup_socket(self) + with mock.patch('socket.socket.close', side_effect=socket.error): + with self.assertRaises(device.DeviceError): + dev.disconnect() + + def test_serial_erorr(self): + """DeviceError is raised if serial fails at disconnect""" + dev, _ = setup_serial(self) + with patch_serial("close", side_effect=serial.SerialException): + with self.assertRaises(device.DeviceError): + dev.disconnect() + + +class TestConnect(unittest.TestCase): + """Test connect functionality""" + + @classmethod + def setUpClass(cls): + mock_sttyhup(cls) + + def setUp(self): + self.dev = device.Device() + self.addCleanup(self.dev.disconnect) + + def _fake_serial_connect(self, port=None, baudrate=None, **kargs): + # Mock a serial connection with optional keyword arguments + with patch_serial("open", **kargs) as mocked_open: + self.dev.connect(port=port, baudrate=baudrate) + mocked_open.assert_called() + + def _fake_socket_connect(self, port=None, **kargs): + # Mock a socket connection with optional keyword arguments + with patch_socket("connect", **kargs) as mocked_connect: + self.dev.connect(port) + mocked_connect.assert_called_once() + + def test_error_on_no_device(self): + """DeviceError is raised when connecting to no port/URL""" + with self.assertRaises(device.DeviceError): + self.dev.connect() + self.assertFalse(self.dev.is_connected) + + def test_erorr_on_bad_port(self): + """DeviceError is raised when port does not exist""" + # Serial raises a FileNotFoundError + with self.assertRaises(device.DeviceError): + self.dev.connect("/non/existent/port") + self.assertFalse(self.dev.is_connected) + + def test_call_socket_connect(self): + """socket.socket.connect is called and `is_connected` is set""" + self._fake_socket_connect("127.0.0.1:80") + self.assertTrue(self.dev.is_connected) + + def test_call_serial_open(self): + """serial.Serial.open is called and `is_connected` is set""" + with patch_serial_is_open() as mocked_serial: + self.dev.connect("/mocked/port") + mocked_serial.return_value.open.assert_called_once() + self.assertTrue(self.dev.is_connected) + + def test_set_baudrate(self): + """Successful connection sets `port` and `baudrate`""" + self._fake_serial_connect("/mocked/port", 250000) + self.assertTrue(self.dev.port == "/mocked/port") + self.assertTrue(self.dev.baudrate == 250000) + + def test_set_dtr(self): + """Test no error raised on setting DTR on connect""" + self._fake_serial_connect("/mocked/port", dtr=True) + + def test_connect_already_connected(self): + """Test connecting an already connected device""" + self._fake_serial_connect("/mocked/port") + self._fake_serial_connect("/mocked/port2") + self.assertTrue(self.dev.port == "/mocked/port2") + + def test_connect_serial_to_socket(self): + """Test connecting from a port to a socket""" + # pylint: disable=protected-access + self._fake_serial_connect("/mocked/port") + self.assertEqual(self.dev._type, "serial") + self._fake_socket_connect("127.0.0.1:80") + self.assertEqual(self.dev._type, "socket") + + def test_socket_error(self): + """DeviceError is raised on socket.error on connect""" + with self.assertRaises(device.DeviceError): + self._fake_socket_connect("127.0.0.1:80", side_effect=socket.error) + self.assertFalse(self.dev.is_connected) + + +class TestReset(unittest.TestCase): + """Test reset functionality""" + + @classmethod + def setUpClass(cls): + mock_sttyhup(cls) + + def setUp(self): + self.serial_dev, _ = setup_serial(self) + self.socket_dev, _ = setup_socket(self) + + def test_reset_serial(self): + # TODO: this simply tests that no errors are raised + self.serial_dev.reset() + + def test_reset_socket(self): + # TODO: this simply tests that no errors are raised + self.socket_dev.reset() + + def test_reset_disconnected(self): + # TODO: this simply tests that no errors are raised + dev = device.Device("/a/port") + dev.reset() + + +class TestReadSerial(unittest.TestCase): + """Test readline functionality on serial connections""" + + @classmethod + def setUpClass(cls): + mock_sttyhup(cls) + + def setUp(self): + self.dev, _ = setup_serial(self) + + def _fake_read(self, **kargs): + # Allows mocking a serial read operation for different return values + with patch_serial("readline", **kargs) as mocked_read: + data = self.dev.readline() + mocked_read.assert_called_once() + return data + + def test_calls_readline(self): + """serial.Serial.readline is called""" + self._fake_read() + + def test_read_data(self): + """Data returned by serial.Serial.readline is passed as is""" + data = self._fake_read(return_value=b"data\n") + self.assertEqual(data, b"data\n") + + def test_read_serial_exception(self): + """DeviceError is raised on serial error during reading""" + with self.assertRaises(device.DeviceError): + self._fake_read(side_effect=serial.SerialException) + + def test_read_empty(self): + """READ_EMPTY is returned when there's nothing to read""" + # Serial.readline() returns b'' (aka `READ_EMPTY`) on timeout + self.assertEqual(self._fake_read(return_value=b''), device.READ_EMPTY) + + def test_read_disconnected(self): + """DeviceError is raised when reading from a disconnected device""" + dev = device.Device("/a/port") + with self.assertRaises(device.DeviceError): + dev.readline() + + +class TestReadSocket(unittest.TestCase): + """Test readline functionality on socket connections""" + + @classmethod + def setUpClass(cls): + mock_sttyhup(cls) + + def setUp(self): + self.dev, _ = setup_socket(self) + + def _fake_read(self, **kargs): + with patch_socketio("read", **kargs) as mocked_read: + data = self.dev.readline() + mocked_read.assert_called() + return data + + def test_read_empty(self): + """READ_EMPTY is returned when there's nothing to read""" + # If the socket is non-blocking and no bytes are available, + # None is returned by readinto() + # Device remains connected in this scenario + data = self._fake_read(return_value=None) + self.assertEqual(data, device.READ_EMPTY) + self.assertTrue(self.dev.is_connected) + + def test_read_eof(self): + """READ_EOF is returned when connection is terminated""" + # A 0 return value from readinto() indicates that the + # connection was shutdown at the other end + # Device is no longer connected in this scenario + data = self._fake_read(return_value=0) + self.assertEqual(data, device.READ_EOF) + self.assertFalse(self.dev.is_connected) + + def test_read_no_endpoint(self): + """DeviceError is raised when connection is lost""" + # OSError: [Errno 107] Transport endpoint is not connected + # Thrown when trying to read but connection was lost + with self.assertRaises(device.DeviceError): + self.dev.readline() + self.assertFalse(self.dev.is_connected) + + def test_read_data(self): + """Data returned by socket.socket.read is passed as is""" + with mock.patch('socket.SocketIO.read', return_value=b"data\n"): + self.assertEqual(self.dev.readline(), b"data\n") + + +class TestWriteSerial(unittest.TestCase): + """Test write functionality on serial connections""" + + @classmethod + def setUpClass(cls): + mock_sttyhup(cls) + + def _setup_serial_write(self, side_effect=None): + # Set up a mocked serial with optional side effects for the + # serial.Serial.write function + class_mock = mock.create_autospec(serial.Serial) + instance_mock = class_mock.return_value + instance_mock.is_open = True + if side_effect is not None: + instance_mock.write.side_effect = side_effect + mocked_serial = self.enterContext(mock.patch("serial.Serial", + class_mock)) + + dev = device.Device() + self.addCleanup(dev.disconnect) + dev.connect("/mocked/port") + + return dev, mocked_serial + + def test_write_no_device(self): + """DeviceError is raised when device is not connected""" + # This test serves for socket connections as well, this functionality + # is independent of the underlying connection type + empty_dev = device.Device() + with self.assertRaises(device.DeviceError): + empty_dev.write("test") + + def test_calls_serial_write(self): + """serial.Serial.write is called""" + dev, mocked_serial = self._setup_serial_write() + dev.write("test") + mocked_serial.return_value.write.assert_called_once_with("test") + + def test_write_serial_error(self): + """DeviceError is raised on serial error during writing""" + dev, _ = self._setup_serial_write(serial.SerialException) + with self.assertRaises(device.DeviceError): + dev.write("test") + + +class TestWriteSocket(unittest.TestCase): + """Test write functionality on socket connections""" + + @classmethod + def setUpClass(cls): + mock_sttyhup(cls) + + def setUp(self): + self.dev, _ = setup_socket(self) + + def _fake_write(self, data, **kwargs): + # Perform a fake write operation. `kwargs` allows to set different + # return values for the write operation + with patch_socketio("write", **kwargs) as mocked_write: + self.dev.write(data) + mocked_write.assert_called_once_with(data) + + def test_calls_socket_write(self): + """socket.socket.write is called""" + self._fake_write(b"test") + + def test_write_errors(self): + """DeviceError is raised on socket errors during writing""" + # On errors during writing, the function is expected to raise a + # DeviceError and terminate the connection + self.assertTrue(self.dev.is_connected) + for e in [OSError, RuntimeError]: + with self.subTest(error=e): + with self.assertRaises(device.DeviceError): + self._fake_write(b"test", side_effect=e) + self.assertFalse(self.dev.is_connected) + + def test_not_bytes(self): + """TypeError is raised if argument is not of bytes type""" + with self.assertRaises(TypeError): + self.dev.write("string") + + def test_flush_timeout(self): + """Silent on socket timeout during flushing""" + # Current behavior is to silently ignore socket.timeout + with mock.patch('socket.SocketIO.flush', side_effect=socket.timeout): + self._fake_write(b"test") diff --git a/tests/test_printcore.py b/tests/test_printcore.py index 5bce6da19..8d1dea4df 100644 --- a/tests/test_printcore.py +++ b/tests/test_printcore.py @@ -40,7 +40,7 @@ def mock_sttyhup(cls): # Needed to avoid error: # "stty: /mocked/port: No such file or directory" cls.enterClassContext( - mock.patch("printrun.printcore.control_ttyhup")) + mock.patch("printrun.device.Device._disable_ttyhup")) def mock_serial(test, read_function=slow_printer): @@ -261,7 +261,7 @@ def setUp(self): def test_calls_serial_close(self): """Test that serial.Serial.close() is called""" self.core.disconnect() - self.mocked_serial.return_value.close.assert_called_once() + self.mocked_serial.return_value.close.assert_called() def test_calls_socket_close(self): """Test that socket.socket.close() is called"""