diff --git a/trollmoves/movers.py b/trollmoves/movers.py index e3219e52..ab7dccae 100644 --- a/trollmoves/movers.py +++ b/trollmoves/movers.py @@ -115,8 +115,7 @@ def move(self): def get_connection(self, hostname, port, username=None): """Get the connection.""" with self.active_connection_lock: - LOGGER.debug("Destination username and passwd: %s %s", - self._dest_username, self._dest_password) + LOGGER.debug("Destination username: %s", self._dest_username) LOGGER.debug('Getting connection to %s@%s:%s', username, hostname, port) try: @@ -297,38 +296,57 @@ class ScpMover(Mover): def open_connection(self): """Open a connection.""" - from paramiko import SSHClient, SSHException + ssh_connection = self._run_with_retries(self._open_connection, "ssh connect") + if ssh_connection is None: + raise IOError("Failed to ssh connect after 3 attempts") + return ssh_connection + + def _open_connection(self): + from paramiko import SSHException + + try: + ssh_connection = self._create_ssh_connection() + except SSHException as sshe: + LOGGER.exception("Failed to init SSHClient: %s", str(sshe)) + except socket.timeout as sto: + LOGGER.exception("SSH connection timed out: %s", str(sto)) + except Exception as err: + LOGGER.exception("Unknown exception at init SSHClient: %s", str(err)) + else: + return ssh_connection + + return None + + def _create_ssh_connection(self): + from paramiko import SSHClient - retries = 3 ssh_key_filename = self.attrs.get("ssh_key_filename", None) timeout = self.attrs.get("ssh_connection_timeout", None) - while retries > 0: - retries -= 1 - try: - ssh_connection = SSHClient() - ssh_connection.load_system_host_keys() - ssh_connection.connect(self.destination.hostname, - username=self._dest_username, - port=self.destination.port or 22, - key_filename=ssh_key_filename, - timeout=timeout) - LOGGER.debug("Successfully connected to %s:%s as %s", - self.destination.hostname, - self.destination.port or 22, - self._dest_username) - except SSHException as sshe: - LOGGER.exception("Failed to init SSHClient: %s", str(sshe)) - except socket.timeout as sto: - LOGGER.exception("SSH connection timed out: %s", str(sto)) - except Exception as err: - LOGGER.exception("Unknown exception at init SSHClient: %s", str(err)) - else: - return ssh_connection - ssh_connection.close() - time.sleep(2) - LOGGER.debug("Retrying ssh connect ...") - raise IOError("Failed to ssh connect after 3 attempts") + ssh_connection = SSHClient() + ssh_connection.load_system_host_keys() + ssh_connection.connect(self.destination.hostname, + username=self._dest_username, + port=self.destination.port or 22, + key_filename=ssh_key_filename, + timeout=timeout) + LOGGER.debug("Successfully connected to %s:%s as %s", + self.destination.hostname, + self.destination.port or 22, + self._dest_username) + return ssh_connection + + def _run_with_retries(self, func, name): + num_retries = self.attrs.get("num_ssh_retries", 3) + res = None + for _ in range(num_retries): + res = func() + if res: + break + time.sleep(2) + LOGGER.debug(f"Retrying {name} ...") + + return res @staticmethod def is_connected(connection): @@ -357,21 +375,16 @@ def move(self): def copy(self): """Upload the file.""" - from scp import SCPClient + _ = self._run_with_retries(self._copy, "SCP copy") - ssh_connection = self.get_connection(self.destination.hostname, - self.destination.port or 22, - self._dest_username) - - try: - scp = SCPClient(ssh_connection.get_transport()) - except Exception as err: - LOGGER.error("Failed to initiate SCPClient: %s", str(err)) - ssh_connection.close() - raise + def _copy(self): + from scp import SCPException + success = False try: + scp = self._get_scp_client() scp.put(self.origin, self.destination.path) + success = True except OSError as osex: if osex.errno == 2: LOGGER.error("No such file or directory. File not transfered: " @@ -380,6 +393,8 @@ def copy(self): else: LOGGER.error("OSError in scp.put: %s", str(osex)) raise + except SCPException as err: + LOGGER.error("SCP failed: %s", str(err)) except Exception as err: LOGGER.error("Something went wrong with scp: %s", str(err)) LOGGER.error("Exception name %s", type(err).__name__) @@ -388,6 +403,23 @@ def copy(self): finally: scp.close() + return success + + def _get_scp_client(self): + from scp import SCPClient + + ssh_connection = self.get_connection(self.destination.hostname, + self.destination.port or 22, + self._dest_username) + + try: + scp = SCPClient(ssh_connection.get_transport()) + except Exception as err: + LOGGER.error("Failed to initiate SCPClient: %s", str(err)) + ssh_connection.close() + raise + return scp + class SftpMover(Mover): """Move files over sftp.""" diff --git a/trollmoves/server.py b/trollmoves/server.py index e1944d39..a7aa6bee 100644 --- a/trollmoves/server.py +++ b/trollmoves/server.py @@ -579,6 +579,7 @@ def _read_ini_config(filename): _parse_nameserver(res[section], cp_[section]) _parse_addresses(res[section]) _parse_delete(res[section], cp_[section]) + _parse_ssh_retries(res[section], cp_[section]) if not _check_origin_and_listen(res, section): continue if not _check_topic(res, section): @@ -594,6 +595,7 @@ def _set_config_defaults(conf): conf.setdefault("transfer_req_timeout", 10 * DEFAULT_REQ_TIMEOUT) conf.setdefault("ssh_key_filename", None) conf.setdefault("delete", False) + conf.setdefault("num_ssh_retries", 3) def _parse_nameserver(conf, raw_conf): @@ -617,6 +619,12 @@ def _parse_delete(conf, raw_conf): conf["delete"] = val +def _parse_ssh_retries(conf, raw_conf): + val = raw_conf.getint("num_ssh_retries") + if val is not None: + conf["num_ssh_retries"] = val + + def _check_origin_and_listen(res, section): if ("origin" not in res[section]) and ('listen' not in res[section]): LOGGER.warning("Incomplete section %s: add an 'origin' or 'listen' item.", section) diff --git a/trollmoves/tests/test_server.py b/trollmoves/tests/test_server.py index fc3090d9..3fd6bbe9 100644 --- a/trollmoves/tests/test_server.py +++ b/trollmoves/tests/test_server.py @@ -300,3 +300,51 @@ def test_requestmanager_is_delete_set_True(patch_validate_file_pattern): port = 9876 req_man = RequestManager(port, attrs={'delete': True}) assert req_man._is_delete_set() is True + + +CONFIG_MINIMAL = """ +[test] +origin = foo +listen = bar +""" +CONFIG_NUM_SSH_RETRIES = CONFIG_MINIMAL + """ +num_ssh_retries = 5 +""" + + +def test_config_defaults(): + """Test that config defaults are set.""" + from trollmoves.server import read_config + + with NamedTemporaryFile(mode='w') as tmp_file: + tmp_file.write(CONFIG_MINIMAL) + tmp_file.file.flush() + + config = read_config(tmp_file.name) + + test_section = config["test"] + assert "origin" in test_section + assert "listen" in test_section + assert test_section["working_directory"] is None + assert test_section["compression"] is False + assert test_section["req_timeout"] == 1 + assert test_section["transfer_req_timeout"] == 10 + assert test_section["ssh_key_filename"] is None + assert test_section["delete"] is False + assert test_section["num_ssh_retries"] == 3 + assert test_section["nameserver"] is None + assert test_section["addresses"] is None + + +def test_config_num_ssh_retries(): + """Test that config defaults are set.""" + from trollmoves.server import read_config + + with NamedTemporaryFile(mode='w') as tmp_file: + tmp_file.write(CONFIG_NUM_SSH_RETRIES) + tmp_file.file.flush() + + config = read_config(tmp_file.name) + + test_section = config["test"] + assert test_section["num_ssh_retries"] == 5