diff --git a/pyproject.toml b/pyproject.toml index 8c9b70fe..f27ffdca 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -92,6 +92,7 @@ rst-roles = "class, func, ref, obj" [tool.mypy] ignore_missing_imports = true no_strict_optional = true +follow_imports = "skip" [tool.pytest.ini_options] filterwarnings = [ diff --git a/src/jobflow_remote/testing/__init__.py b/src/jobflow_remote/testing/__init__.py index 3eecee92..8ff03cb7 100644 --- a/src/jobflow_remote/testing/__init__.py +++ b/src/jobflow_remote/testing/__init__.py @@ -1,4 +1,5 @@ """A series of toy workflows that can be used for testing.""" +from typing import Callable, Optional, Union from jobflow import job @@ -20,3 +21,15 @@ def write_file(n): with open("results.txt", "w") as f: f.write(str(n)) return + + +@job +def arithmetic( + a: Union[float, list[float]], + b: Union[float, list[float]], + op: Optional[Callable] = None, +) -> Optional[float]: + if op: + return op(a, b) + + return None diff --git a/tests/integration/test_slurm.py b/tests/integration/test_slurm.py index eb85f71c..9c935ff3 100644 --- a/tests/integration/test_slurm.py +++ b/tests/integration/test_slurm.py @@ -134,6 +134,50 @@ def test_submit_flow_with_dependencies(worker, job_controller): ), f"Flows not marked as completed, full flow info:\n{job_controller.get_flows({})}" +@pytest.mark.parametrize( + "worker", + ["test_local_worker", "test_remote_worker"], +) +def test_job_with_callable_kwarg(worker, job_controller): + """Test whether a callable can be successfully provided as a keyword + argument to a job. + + """ + import math + + from jobflow import Flow + + from jobflow_remote import submit_flow + from jobflow_remote.jobs.runner import Runner + from jobflow_remote.jobs.state import FlowState, JobState + from jobflow_remote.testing import arithmetic + + job_1 = arithmetic(1, -2, op=math.copysign) + job_2 = arithmetic([job_1.output], [1], op=math.dist) + job_3 = arithmetic(job_2.output, 2, op=math.pow) + + flow = Flow([job_1, job_2, job_3]) + submit_flow(flow, worker=worker) + + runner = Runner() + runner.run(ticks=10) + + assert job_controller.count_jobs({}) == 3 + assert len(job_controller.get_jobs({})) == 3 + assert job_controller.count_flows({}) == 1 + + jobs = job_controller.get_jobs({}) + outputs = [job_controller.jobstore.get_output(uuid=job["uuid"]) for job in jobs] + assert outputs == [-1, 2, 4] + + assert ( + job_controller.count_jobs(state=JobState.COMPLETED) == 3 + ), f"Jobs not marked as completed, full job info:\n{job_controller.get_jobs({})}" + assert ( + job_controller.count_flows(state=FlowState.COMPLETED) == 1 + ), f"Flows not marked as completed, full flow info:\n{job_controller.get_flows({})}" + + @pytest.mark.parametrize( "worker", ["test_local_worker", "test_remote_worker"],