diff --git a/pysqa/ext/remote.py b/pysqa/ext/remote.py index d03d0dfe..8ffc3c5e 100644 --- a/pysqa/ext/remote.py +++ b/pysqa/ext/remote.py @@ -141,7 +141,7 @@ def get_job_from_remote(self, working_directory): if self._ssh_delete_file_on_remote: self._execute_remote_command(command="rm -r " + remote_working_directory) - def transfer_file(self, file, transfer_back=False): + def transfer_file(self, file, transfer_back=False, delete_file_on_remote=False): working_directory = os.path.abspath(os.path.expanduser(file)) remote_working_directory = self._get_remote_working_dir( working_directory=working_directory @@ -152,7 +152,7 @@ def transfer_file(self, file, transfer_back=False): sftp=None, transfer_back=transfer_back, ) - if self._ssh_delete_file_on_remote and transfer_back: + if self._ssh_delete_file_on_remote and transfer_back and delete_file_on_remote: self._execute_remote_command(command="rm " + remote_working_directory) def __del__(self): diff --git a/pysqa/queueadapter.py b/pysqa/queueadapter.py index 95dd7c4b..df32da94 100644 --- a/pysqa/queueadapter.py +++ b/pysqa/queueadapter.py @@ -201,16 +201,20 @@ def get_job_from_remote(self, working_directory): """ self._adapter.get_job_from_remote(working_directory=working_directory) - def transfer_file_to_remote(self, file, transfer_back=False): + def transfer_file_to_remote(self, file, transfer_back=False, delete_file_on_remote=False): """ + Transfer file from remote host to local host Args: file (str): transfer_back (bool): - Returns: - str: + delete_file_on_remote (bool): """ - self._adapter.transfer_file(file=file, transfer_back=transfer_back) + self._adapter.transfer_file( + file=file, + transfer_back=transfer_back, + delete_file_on_remote=delete_file_on_remote + ) def convert_path_to_remote(self, path): """ diff --git a/pysqa/utils/basic.py b/pysqa/utils/basic.py index 9b702a15..29d8ab5b 100644 --- a/pysqa/utils/basic.py +++ b/pysqa/utils/basic.py @@ -295,7 +295,7 @@ def get_job_from_remote(self, working_directory): def convert_path_to_remote(self, path): raise NotImplementedError - def transfer_file(self, file, transfer_back=False): + def transfer_file(self, file, transfer_back=False, delete_file_on_remote=False): raise NotImplementedError def check_queue_parameters( diff --git a/tests/test_slurm.py b/tests/test_slurm.py index a1af4b1a..db71781c 100644 --- a/tests/test_slurm.py +++ b/tests/test_slurm.py @@ -230,7 +230,7 @@ def execute_command( slurm_tmp._adapter.convert_path_to_remote(path="test") with self.assertRaises(NotImplementedError): - slurm_tmp._adapter.transfer_file(file="test", transfer_back=False) + slurm_tmp._adapter.transfer_file(file="test", transfer_back=False, delete_file_on_remote=False) with self.assertRaises(NotImplementedError): slurm_tmp._adapter.get_job_from_remote(working_directory=".")