diff --git a/sisyphus/job.py b/sisyphus/job.py index a1db331..5901482 100644 --- a/sisyphus/job.py +++ b/sisyphus/job.py @@ -18,7 +18,7 @@ import sys import time import traceback -from typing import List, Iterator +from typing import List, Iterator, Type, TypeVar from sisyphus import block, tools from sisyphus.task import Task @@ -61,12 +61,15 @@ def job_finished(path): ) +T = TypeVar("T", bound="Job") + + class JobSingleton(type): """Meta class to ensure that every Job with the same hash value is only created once""" - def __call__(cls, *args, **kwargs): + def __call__(cls: Type[T], *args, **kwargs) -> T: """Implemented to ensure that each job is created only once""" try: if "sis_tags" in kwargs: @@ -109,6 +112,7 @@ def __call__(cls, *args, **kwargs): else: # create new object job = super(Job, cls).__new__(cls) + assert isinstance(job, Job) job._sis_tags = tags # store _sis_id @@ -175,7 +179,7 @@ def get_lock(cls): Job._lock_storage.append(multiprocessing.Lock()) return Job._lock_storage[Job._lock_index] - def __new__(cls, *args, **kwargs): + def __new__(cls: Type[T], *args, **kwargs) -> T: # Make sure unpickled jobs stay singletons assert len(args) == 1 and len(kwargs) == 0 and isinstance(args[0], str) sis_cache_key = args[0]