Skip to content

Commit

Permalink
allow specify the jobs for wait
Browse files Browse the repository at this point in the history
  • Loading branch information
Nanguage committed Jul 20, 2023
1 parent 1a71474 commit 90ec71b
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 6 deletions.
2 changes: 1 addition & 1 deletion executor/engine/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .core import Engine, EngineSetting
from .job import LocalJob, ThreadJob, ProcessJob

__version__ = '0.2.1'
__version__ = '0.2.2'

__all__ = [
'Engine', 'EngineSetting',
Expand Down
30 changes: 25 additions & 5 deletions executor/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,16 +245,23 @@ def wait_job(
def wait(
self,
timeout: T.Optional[float] = None,
time_delta: float = 0.2):
time_delta: float = 0.2,
select_jobs: T.Optional[T.Callable[[Jobs], T.List[Job]]] = None,
):
"""Block until all jobs are finished or timeout.
Args:
timeout: Timeout in seconds.
time_delta: Time interval to check job status.
select_jobs: Function to select jobs to wait.
"""
if select_jobs is None:
select_jobs = (
lambda jobs: jobs.running.values() + jobs.pending.values()
)
total_time = timeout if timeout is not None else float('inf')
while True:
n_wait_jobs = len(self.jobs.running) + len(self.jobs.pending)
n_wait_jobs = len(select_jobs(self.jobs))
if n_wait_jobs == 0:
break
if total_time <= 0:
Expand All @@ -265,11 +272,24 @@ def wait(
async def wait_async(
self,
timeout: T.Optional[float] = None,
time_delta: float = 0.2):
"""Asynchronous interface for wait."""
time_delta: float = 0.2,
select_jobs: T.Optional[T.Callable[[Jobs], T.List[Job]]] = None,
):
"""Asynchronous interface for wait.
Block until all jobs are finished or timeout.
Args:
timeout: Timeout in seconds.
time_delta: Time interval to check job status.
select_jobs: Function to select jobs to wait.
"""
if select_jobs is None:
select_jobs = (
lambda jobs: jobs.running.values() + jobs.pending.values()
)
total_time = timeout if timeout is not None else float('inf')
while True:
n_wait_jobs = len(self.jobs.running) + len(self.jobs.pending)
n_wait_jobs = len(select_jobs(self.jobs))
if n_wait_jobs == 0:
break
if total_time <= 0:
Expand Down
25 changes: 25 additions & 0 deletions tests/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,31 @@ async def test_join_jobs():
assert job2.result() == 9


@pytest.mark.asyncio
async def test_wait_jobs():
def sleep_square(x):
time.sleep(x)
return x**2
engine = Engine()
job1 = ThreadJob(sleep_square, (1,))
job2 = ThreadJob(sleep_square, (2,))
job3 = ThreadJob(sleep_square, (3,))
await engine.submit_async(job1, job2, job3)

def select_func(jobs):
if (job1.status == "running") or (job1.status == "pending"):
return [job1]
else:
return []
await engine.wait_async(
select_jobs=select_func
)
assert job1.status == "done"
assert job2.status == "running"
assert job3.status == "running"
await engine.wait_async()


def test_engine_start_stop():
engine = Engine()
engine.start()
Expand Down

0 comments on commit 90ec71b

Please sign in to comment.