Skip to content

Commit

Permalink
Merge pull request #275 from pyiron/type_hints
Browse files Browse the repository at this point in the history
Add type hints
  • Loading branch information
jan-janssen authored Mar 19, 2024
2 parents 629ea23 + 48d2426 commit 793dd56
Show file tree
Hide file tree
Showing 18 changed files with 274 additions and 194 deletions.
5 changes: 4 additions & 1 deletion pysqa/cmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@
import json
import os
import sys
from typing import Optional

from pysqa.queueadapter import QueueAdapter
from pysqa.utils.execute import execute_command


def command_line(arguments_lst=None, execute_command=execute_command):
def command_line(
arguments_lst: Optional[list] = None, execute_command: callable = execute_command
):
"""
Parse the command line arguments.
Expand Down
9 changes: 6 additions & 3 deletions pysqa/executor/backend.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from typing import Optional
import sys

from pympipool.mpi import PyMPIExecutor
Expand All @@ -10,7 +11,9 @@
)


def execute_files_from_list(tasks_in_progress_dict, cache_directory, executor):
def execute_files_from_list(
tasks_in_progress_dict: dict, cache_directory: str, executor
):
file_lst = os.listdir(cache_directory)
for file_name_in in file_lst:
key = file_name_in.split(".in.pl")[0]
Expand All @@ -37,7 +40,7 @@ def execute_files_from_list(tasks_in_progress_dict, cache_directory, executor):
)


def execute_tasks(cores, cache_directory):
def execute_tasks(cores: int, cache_directory: str):
tasks_in_progress_dict = {}
with PyMPIExecutor(
max_workers=cores,
Expand All @@ -58,7 +61,7 @@ def execute_tasks(cores, cache_directory):
)


def command_line(arguments_lst=None):
def command_line(arguments_lst: Optional[list] = None):
if arguments_lst is None:
arguments_lst = sys.argv[1:]
cores_arg = arguments_lst[arguments_lst.index("--cores") + 1]
Expand Down
12 changes: 9 additions & 3 deletions pysqa/executor/executor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import queue
from typing import Optional
from concurrent.futures import Future, Executor as FutureExecutor

from pympipool.shared import cancel_items_in_queue, RaisingThread
Expand All @@ -12,7 +13,12 @@


class Executor(FutureExecutor):
def __init__(self, cwd=None, queue_adapter=None, queue_adapter_kwargs=None):
def __init__(
self,
cwd: Optional[str] = None,
queue_adapter=None,
queue_adapter_kwargs: Optional[dict] = None,
):
self._task_queue = queue.Queue()
self._memory_dict = {}
self._cache_directory = os.path.abspath(os.path.expanduser(cwd))
Expand Down Expand Up @@ -42,7 +48,7 @@ def __init__(self, cwd=None, queue_adapter=None, queue_adapter_kwargs=None):
)
self._process.start()

def submit(self, fn, *args, **kwargs):
def submit(self, fn: callable, *args, **kwargs):
funct_dict = serialize_funct(fn, *args, **kwargs)
key = list(funct_dict.keys())[0]
if key not in self._memory_dict.keys():
Expand All @@ -53,7 +59,7 @@ def submit(self, fn, *args, **kwargs):
self._task_queue.put({key: self._memory_dict[key]})
return self._memory_dict[key]

def shutdown(self, wait=True, *, cancel_futures=False):
def shutdown(self, wait: bool = True, *, cancel_futures: bool = False):
if cancel_futures:
cancel_items_in_queue(que=self._task_queue)
self._task_queue.put({"shutdown": True, "wait": wait})
Expand Down
16 changes: 9 additions & 7 deletions pysqa/executor/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@
import cloudpickle


def deserialize(funct_dict):
def deserialize(funct_dict: dict) -> dict:
try:
return {k: cloudpickle.loads(v) for k, v in funct_dict.items()}
except EOFError:
return {}


def find_executed_tasks(future_queue, cache_directory):
def find_executed_tasks(future_queue: queue.Queue, cache_directory: str):
task_memory_dict = {}
while True:
task_dict = {}
Expand All @@ -32,13 +32,15 @@ def find_executed_tasks(future_queue, cache_directory):
)


def read_from_file(file_name):
def read_from_file(file_name: str) -> dict:
name = file_name.split("/")[-1].split(".")[0]
with open(file_name, "rb") as f:
return {name: f.read()}


def reload_previous_futures(future_queue, future_dict, cache_directory):
def reload_previous_futures(
future_queue: queue.Queue, future_dict: dict, cache_directory: str
):
file_lst = os.listdir(cache_directory)
for f in file_lst:
if f.endswith(".in.pl"):
Expand All @@ -54,16 +56,16 @@ def reload_previous_futures(future_queue, future_dict, cache_directory):
future_queue.put({key: future_dict[key]})


def serialize_result(result_dict):
def serialize_result(result_dict: dict):
return {k: cloudpickle.dumps(v) for k, v in result_dict.items()}


def serialize_funct(fn, *args, **kwargs):
def serialize_funct(fn: callable, *args, **kwargs):
binary = cloudpickle.dumps({"fn": fn, "args": args, "kwargs": kwargs})
return {fn.__name__ + _get_hash(binary=binary): binary}


def write_to_file(funct_dict, state, cache_directory):
def write_to_file(funct_dict: dict, state, cache_directory: str):
file_name_lst = []
for k, v in funct_dict.items():
file_name = _get_file_name(name=k, state=state)
Expand Down
37 changes: 21 additions & 16 deletions pysqa/ext/modular.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
# coding: utf-8
# Copyright (c) Jan Janssen

from typing import Optional
import pandas

from pysqa.utils.basic import BasisQueueAdapter
from pysqa.utils.execute import execute_command


class ModularQueueAdapter(BasisQueueAdapter):
def __init__(self, config, directory="~/.queues", execute_command=execute_command):
def __init__(
self,
config: dict,
directory: str = "~/.queues",
execute_command: callable = execute_command,
):
super(ModularQueueAdapter, self).__init__(
config=config, directory=directory, execute_command=execute_command
)
Expand All @@ -26,16 +31,16 @@ def __init__(self, config, directory="~/.queues", execute_command=execute_comman

def submit_job(
self,
queue=None,
job_name=None,
working_directory=None,
cores=None,
memory_max=None,
run_time_max=None,
dependency_list=None,
command=None,
queue: Optional[str] = None,
job_name: Optional[str] = None,
working_directory: Optional[str] = None,
cores: Optional[int] = None,
memory_max: Optional[str] = None,
run_time_max: Optional[int] = None,
dependency_list: Optional[list[str]] = None,
command: Optional[str] = None,
**kwargs,
):
) -> int:
"""
Args:
Expand Down Expand Up @@ -79,7 +84,7 @@ def submit_job(
else:
return None

def enable_reservation(self, process_id):
def enable_reservation(self, process_id: int):
"""
Args:
Expand All @@ -103,7 +108,7 @@ def enable_reservation(self, process_id):
else:
return None

def delete_job(self, process_id):
def delete_job(self, process_id: int):
"""
Args:
Expand All @@ -127,7 +132,7 @@ def delete_job(self, process_id):
else:
return None

def get_queue_status(self, user=None):
def get_queue_status(self, user: Optional[str] = None) -> pandas.DataFrame:
"""
Args:
Expand Down Expand Up @@ -155,11 +160,11 @@ def get_queue_status(self, user=None):
return df[df["user"] == user]

@staticmethod
def _resolve_queue_id(process_id, cluster_dict):
def _resolve_queue_id(process_id: int, cluster_dict: dict):
cluster_queue_id = int(process_id / 10)
cluster_module = cluster_dict[process_id - cluster_queue_id * 10]
return cluster_module, cluster_queue_id

@staticmethod
def _switch_cluster_command(cluster_module):
def _switch_cluster_command(cluster_module: str):
return ["module", "--quiet", "swap", "cluster/{};".format(cluster_module)]
Loading

0 comments on commit 793dd56

Please sign in to comment.