diff --git a/src/jobflow_remote/testing/__init__.py b/src/jobflow_remote/testing/__init__.py index 4aafe91e..53d40963 100644 --- a/src/jobflow_remote/testing/__init__.py +++ b/src/jobflow_remote/testing/__init__.py @@ -2,7 +2,7 @@ from typing import Callable, Optional, Union -from jobflow import job +from jobflow import Response, job @job @@ -41,3 +41,14 @@ def check_env_var() -> str: import os return os.environ.get("TESTING_ENV_VAR", "unset") + + +@job(big_files="data") +def add_big(a: float, b: float): + """Adds two numbers together and writes the answer to an artificially large file.""" + import pathlib + + result = a + b + with open("file.txt", "w") as f: + f.writelines([f"{result}"] * int(1e5)) + return Response({"data": pathlib.Path("file.txt"), "result": a + b}) diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 903d0415..2dacb5cf 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -196,7 +196,14 @@ def write_tmp_settings( "host": "localhost", "port": db_port, "collection_name": "docs", - } + }, + "big_files_store": { + "type": "GridFSStore", + "database": store_database_name, + "host": "localhost", + "port": db_port, + "collection_name": "data", + }, }, queue={ "store": { diff --git a/tests/integration/test_slurm.py b/tests/integration/test_slurm.py index 7dc849e9..89f4922e 100644 --- a/tests/integration/test_slurm.py +++ b/tests/integration/test_slurm.py @@ -234,3 +234,32 @@ def test_exec_config(worker, job_controller, random_project_name): job = job_controller.get_jobs({})[0] output = job_controller.jobstore.get_output(uuid=job["uuid"]) assert output == random_project_name + + +@pytest.mark.parametrize( + "worker", + ["test_local_worker", "test_remote_worker"], +) +def test_additional_stores(worker, job_controller): + from jobflow import Flow + + from jobflow_remote import submit_flow + from jobflow_remote.jobs.runner import Runner + from jobflow_remote.jobs.state import FlowState, JobState + from jobflow_remote.testing import add_big + + job = add_big(100, 100) + flow = Flow(job) + submit_flow(flow, worker=worker) + + assert job_controller.count_jobs({}) == 1 + assert job_controller.count_flows({}) == 1 + + runner = Runner() + runner.run(ticks=10) + + job_controller.get_jobs({})[0] + breakpoint() + + assert job_controller.count_jobs(state=JobState.FAILED) == 1 + assert job_controller.count_flows(state=FlowState.FAILED) == 1