Skip to content

Commit

Permalink
Add test for jobs with callable kwargs
Browse files Browse the repository at this point in the history
  • Loading branch information
ml-evs committed Jan 29, 2024
1 parent 7ffb688 commit 9b59171
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 0 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ rst-roles = "class, func, ref, obj"
[tool.mypy]
ignore_missing_imports = true
no_strict_optional = true
follow_imports = "skip"

[tool.pytest.ini_options]
filterwarnings = [
Expand Down
13 changes: 13 additions & 0 deletions src/jobflow_remote/testing/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""A series of toy workflows that can be used for testing."""
from typing import Callable, Optional, Union

from jobflow import job

Expand All @@ -20,3 +21,15 @@ def write_file(n):
with open("results.txt", "w") as f:
f.write(str(n))
return


@job
def arithmetic(
a: Union[float, list[float]],
b: Union[float, list[float]],
op: Optional[Callable] = None,
) -> Optional[float]:
if op:
return op(a, b)

return None
44 changes: 44 additions & 0 deletions tests/integration/test_slurm.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,50 @@ def test_submit_flow_with_dependencies(worker, job_controller):
), f"Flows not marked as completed, full flow info:\n{job_controller.get_flows({})}"


@pytest.mark.parametrize(
"worker",
["test_local_worker", "test_remote_worker"],
)
def test_job_with_callable_kwarg(worker, job_controller):
"""Test whether a callable can be successfully provided as a keyword
argument to a job.
"""
import math

from jobflow import Flow

from jobflow_remote import submit_flow
from jobflow_remote.jobs.runner import Runner
from jobflow_remote.jobs.state import FlowState, JobState
from jobflow_remote.testing import arithmetic

job_1 = arithmetic(1, -2, op=math.copysign)
job_2 = arithmetic([job_1.output], [1], op=math.dist)
job_3 = arithmetic(job_2.output, 2, op=math.pow)

flow = Flow([job_1, job_2, job_3])
submit_flow(flow, worker=worker)

runner = Runner()
runner.run(ticks=10)

assert job_controller.count_jobs({}) == 3
assert len(job_controller.get_jobs({})) == 3
assert job_controller.count_flows({}) == 1

jobs = job_controller.get_jobs({})
outputs = [job_controller.jobstore.get_output(uuid=job["uuid"]) for job in jobs]
assert outputs == [-1, 2, 4]

assert (
job_controller.count_jobs(state=JobState.COMPLETED) == 3
), f"Jobs not marked as completed, full job info:\n{job_controller.get_jobs({})}"
assert (
job_controller.count_flows(state=FlowState.COMPLETED) == 1
), f"Flows not marked as completed, full flow info:\n{job_controller.get_flows({})}"


@pytest.mark.parametrize(
"worker",
["test_local_worker", "test_remote_worker"],
Expand Down

0 comments on commit 9b59171

Please sign in to comment.