Skip to content

Commit

Permalink
Add type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
jan-janssen committed Mar 18, 2024
1 parent 629ea23 commit 1f94020
Show file tree
Hide file tree
Showing 18 changed files with 204 additions and 189 deletions.
2 changes: 1 addition & 1 deletion pysqa/cmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from pysqa.utils.execute import execute_command


def command_line(arguments_lst=None, execute_command=execute_command):
def command_line(arguments_lst: list=None, execute_command: callable = execute_command):
"""
Parse the command line arguments.
Expand Down
6 changes: 3 additions & 3 deletions pysqa/executor/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
)


def execute_files_from_list(tasks_in_progress_dict, cache_directory, executor):
def execute_files_from_list(tasks_in_progress_dict: dict, cache_directory: str, executor):
file_lst = os.listdir(cache_directory)
for file_name_in in file_lst:
key = file_name_in.split(".in.pl")[0]
Expand All @@ -37,7 +37,7 @@ def execute_files_from_list(tasks_in_progress_dict, cache_directory, executor):
)


def execute_tasks(cores, cache_directory):
def execute_tasks(cores: int, cache_directory: str):
tasks_in_progress_dict = {}
with PyMPIExecutor(
max_workers=cores,
Expand All @@ -58,7 +58,7 @@ def execute_tasks(cores, cache_directory):
)


def command_line(arguments_lst=None):
def command_line(arguments_lst: list=None):
if arguments_lst is None:
arguments_lst = sys.argv[1:]
cores_arg = arguments_lst[arguments_lst.index("--cores") + 1]
Expand Down
6 changes: 3 additions & 3 deletions pysqa/executor/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@


class Executor(FutureExecutor):
def __init__(self, cwd=None, queue_adapter=None, queue_adapter_kwargs=None):
def __init__(self, cwd: str=None, queue_adapter=None, queue_adapter_kwargs: dict=None):
self._task_queue = queue.Queue()
self._memory_dict = {}
self._cache_directory = os.path.abspath(os.path.expanduser(cwd))
Expand Down Expand Up @@ -42,7 +42,7 @@ def __init__(self, cwd=None, queue_adapter=None, queue_adapter_kwargs=None):
)
self._process.start()

def submit(self, fn, *args, **kwargs):
def submit(self, fn: callable, *args, **kwargs):
funct_dict = serialize_funct(fn, *args, **kwargs)
key = list(funct_dict.keys())[0]
if key not in self._memory_dict.keys():
Expand All @@ -53,7 +53,7 @@ def submit(self, fn, *args, **kwargs):
self._task_queue.put({key: self._memory_dict[key]})
return self._memory_dict[key]

def shutdown(self, wait=True, *, cancel_futures=False):
def shutdown(self, wait: bool = True, *, cancel_futures: bool = False):
if cancel_futures:
cancel_items_in_queue(que=self._task_queue)
self._task_queue.put({"shutdown": True, "wait": wait})
Expand Down
12 changes: 6 additions & 6 deletions pysqa/executor/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def deserialize(funct_dict):
return {}


def find_executed_tasks(future_queue, cache_directory):
def find_executed_tasks(future_queue: queue.Queue, cache_directory: str):
task_memory_dict = {}
while True:
task_dict = {}
Expand All @@ -32,13 +32,13 @@ def find_executed_tasks(future_queue, cache_directory):
)


def read_from_file(file_name):
def read_from_file(file_name: str) -> dict:
name = file_name.split("/")[-1].split(".")[0]
with open(file_name, "rb") as f:
return {name: f.read()}


def reload_previous_futures(future_queue, future_dict, cache_directory):
def reload_previous_futures(future_queue: queue.Queue, future_dict: dict, cache_directory: str):
file_lst = os.listdir(cache_directory)
for f in file_lst:
if f.endswith(".in.pl"):
Expand All @@ -54,16 +54,16 @@ def reload_previous_futures(future_queue, future_dict, cache_directory):
future_queue.put({key: future_dict[key]})


def serialize_result(result_dict):
def serialize_result(result_dict: dict):
return {k: cloudpickle.dumps(v) for k, v in result_dict.items()}


def serialize_funct(fn, *args, **kwargs):
def serialize_funct(fn: callable, *args, **kwargs):
binary = cloudpickle.dumps({"fn": fn, "args": args, "kwargs": kwargs})
return {fn.__name__ + _get_hash(binary=binary): binary}


def write_to_file(funct_dict, state, cache_directory):
def write_to_file(funct_dict: dict, state, cache_directory: str):
file_name_lst = []
for k, v in funct_dict.items():
file_name = _get_file_name(name=k, state=state)
Expand Down
30 changes: 15 additions & 15 deletions pysqa/ext/modular.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@


class ModularQueueAdapter(BasisQueueAdapter):
def __init__(self, config, directory="~/.queues", execute_command=execute_command):
def __init__(self, config: dict, directory: str = "~/.queues", execute_command: callable = execute_command):
super(ModularQueueAdapter, self).__init__(
config=config, directory=directory, execute_command=execute_command
)
Expand All @@ -26,16 +26,16 @@ def __init__(self, config, directory="~/.queues", execute_command=execute_comman

def submit_job(
self,
queue=None,
job_name=None,
working_directory=None,
cores=None,
memory_max=None,
run_time_max=None,
dependency_list=None,
command=None,
queue: str = None,
job_name: str = None,
working_directory: str = None,
cores: int = None,
memory_max: str = None,
run_time_max: int = None,
dependency_list: list[str] = None,
command: str = None,
**kwargs,
):
) -> int:
"""
Args:
Expand Down Expand Up @@ -79,7 +79,7 @@ def submit_job(
else:
return None

def enable_reservation(self, process_id):
def enable_reservation(self, process_id: int):
"""
Args:
Expand All @@ -103,7 +103,7 @@ def enable_reservation(self, process_id):
else:
return None

def delete_job(self, process_id):
def delete_job(self, process_id: int):
"""
Args:
Expand All @@ -127,7 +127,7 @@ def delete_job(self, process_id):
else:
return None

def get_queue_status(self, user=None):
def get_queue_status(self, user: str = None) -> pandas.DataFrame:
"""
Args:
Expand Down Expand Up @@ -155,11 +155,11 @@ def get_queue_status(self, user=None):
return df[df["user"] == user]

@staticmethod
def _resolve_queue_id(process_id, cluster_dict):
def _resolve_queue_id(process_id: int, cluster_dict: dict):
cluster_queue_id = int(process_id / 10)
cluster_module = cluster_dict[process_id - cluster_queue_id * 10]
return cluster_module, cluster_queue_id

@staticmethod
def _switch_cluster_command(cluster_module):
def _switch_cluster_command(cluster_module: str):
return ["module", "--quiet", "swap", "cluster/{};".format(cluster_module)]
79 changes: 47 additions & 32 deletions pysqa/ext/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@


class RemoteQueueAdapter(BasisQueueAdapter):
def __init__(self, config, directory="~/.queues", execute_command=execute_command):
def __init__(self, config: dict, directory: str = "~/.queues", execute_command: callable = execute_command):
super(RemoteQueueAdapter, self).__init__(
config=config, directory=directory, execute_command=execute_command
)
Expand Down Expand Up @@ -78,22 +78,37 @@ def __init__(self, config, directory="~/.queues", execute_command=execute_comman
self._ssh_proxy_connection = None
self._remote_flag = True

def convert_path_to_remote(self, path):
def convert_path_to_remote(self, path: str):
working_directory = os.path.abspath(os.path.expanduser(path))
return self._get_remote_working_dir(working_directory=working_directory)

def submit_job(
self,
queue=None,
job_name=None,
working_directory=None,
cores=None,
memory_max=None,
run_time_max=None,
dependency_list=None,
command=None,
queue: str = None,
job_name: str = None,
working_directory: str = None,
cores: int = None,
memory_max: int = None,
run_time_max: int = None,
dependency_list: list[str] = None,
command: str = None,
**kwargs,
):
) -> int:
"""
Args:
queue (str/None):
job_name (str/None):
working_directory (str/None):
cores (int/None):
memory_max (int/None):
run_time_max (int/None):
dependency_list (list/None):
command (str/None):
Returns:
int:
"""
if dependency_list is not None:
raise NotImplementedError(
"Submitting jobs with dependencies to a remote cluster is not yet supported."
Expand All @@ -102,7 +117,7 @@ def submit_job(
output = self._execute_remote_command(command=command)
return int(output.split()[-1])

def enable_reservation(self, process_id):
def enable_reservation(self, process_id: int) -> str:
"""
Args:
Expand All @@ -115,7 +130,7 @@ def enable_reservation(self, process_id):
command=self._reservation_command(job_id=process_id)
)

def delete_job(self, process_id):
def delete_job(self, process_id: int) -> str:
"""
Args:
Expand All @@ -128,7 +143,7 @@ def delete_job(self, process_id):
command=self._delete_command(job_id=process_id)
)

def get_queue_status(self, user=None):
def get_queue_status(self, user: str = None) -> pandas.DataFrame:
"""
Args:
Expand All @@ -147,7 +162,7 @@ def get_queue_status(self, user=None):
else:
return df[df["user"] == user]

def get_job_from_remote(self, working_directory):
def get_job_from_remote(self, working_directory: str):
"""
Get the results of the calculation - this is necessary when the calculation was executed on a remote host.
"""
Expand Down Expand Up @@ -176,7 +191,7 @@ def get_job_from_remote(self, working_directory):
if self._ssh_delete_file_on_remote:
self._execute_remote_command(command="rm -r " + remote_working_directory)

def transfer_file(self, file, transfer_back=False, delete_file_on_remote=False):
def transfer_file(self, file: str, transfer_back: bool = False, delete_file_on_remote: bool = False):
working_directory = os.path.abspath(os.path.expanduser(file))
remote_working_directory = self._get_remote_working_dir(
working_directory=working_directory
Expand All @@ -200,7 +215,7 @@ def _check_ssh_connection(self):
if self._ssh_connection is None:
self._ssh_connection = self._open_ssh_connection()

def _transfer_files(self, file_dict, sftp=None, transfer_back=False):
def _transfer_files(self, file_dict: dict, sftp=None, transfer_back: bool = False):
if sftp is None:
if self._ssh_continous_connection:
self._check_ssh_connection()
Expand Down Expand Up @@ -346,13 +361,13 @@ def _get_queue_status_command(self):

def _submit_command(
self,
queue=None,
job_name=None,
working_directory=None,
cores=None,
memory_max=None,
run_time_max=None,
command_str=None,
queue: str = None,
job_name: str = None,
working_directory: str = None,
cores: int = None,
memory_max: int = None,
run_time_max: int = None,
command_str: str = None,
):
command = self._remote_command() + "--submit "
if queue is not None:
Expand All @@ -371,13 +386,13 @@ def _submit_command(
command += '--command "' + command_str + '" '
return command

def _delete_command(self, job_id):
def _delete_command(self, job_id: int) -> str:
return self._remote_command() + "--delete --id " + str(job_id)

def _reservation_command(self, job_id):
def _reservation_command(self, job_id: int) -> str:
return self._remote_command() + "--reservation --id " + str(job_id)

def _execute_remote_command(self, command):
def _execute_remote_command(self, command: str):
if self._ssh_continous_connection:
self._check_ssh_connection()
ssh = self._ssh_connection
Expand All @@ -390,13 +405,13 @@ def _execute_remote_command(self, command):
ssh.close()
return output

def _get_remote_working_dir(self, working_directory):
def _get_remote_working_dir(self, working_directory: str):
return os.path.join(
self._ssh_remote_path,
os.path.relpath(working_directory, self._ssh_local_path),
)

def _create_remote_dir(self, directory):
def _create_remote_dir(self, directory: str):
if isinstance(directory, str):
self._execute_remote_command(command="mkdir -p " + directory)
elif isinstance(directory, list):
Expand All @@ -407,7 +422,7 @@ def _create_remote_dir(self, directory):
else:
raise TypeError()

def _transfer_data_to_remote(self, working_directory):
def _transfer_data_to_remote(self, working_directory: str):
working_directory = os.path.abspath(os.path.expanduser(working_directory))
remote_working_directory = self._get_remote_working_dir(
working_directory=working_directory
Expand All @@ -432,7 +447,7 @@ def _transfer_data_to_remote(self, working_directory):
self._create_remote_dir(directory=new_dir_list)
self._transfer_files(file_dict=file_dict, sftp=None, transfer_back=False)

def _get_user(self):
def _get_user(self) -> str:
"""
Returns:
Expand All @@ -441,7 +456,7 @@ def _get_user(self):
return self._ssh_username

@staticmethod
def _get_file_transfer(file, local_dir, remote_dir):
def _get_file_transfer(file: str, local_dir: str, remote_dir: str) -> str:
return os.path.abspath(
os.path.join(remote_dir, os.path.relpath(file, local_dir))
)
Loading

0 comments on commit 1f94020

Please sign in to comment.