diff --git a/src/jobflow_remote/remote/data.py b/src/jobflow_remote/remote/data.py index 15b80059..aa67e5f6 100644 --- a/src/jobflow_remote/remote/data.py +++ b/src/jobflow_remote/remote/data.py @@ -1,5 +1,6 @@ from __future__ import annotations +import inspect import io import logging import os @@ -7,12 +8,19 @@ from typing import Any import orjson +from jobflow.core.job import Job from jobflow.core.store import JobStore from maggma.stores.mongolike import JSONStore from monty.json import jsanitize +from jobflow_remote.jobs.data import RemoteError from jobflow_remote.utils.data import uuid_to_path +JOB_INIT_ARGS = {k for k in inspect.signature(Job).parameters.keys() if k != "kwargs"} +"""A set of the arguments of the Job constructor which +can be used to detect additional custom arguments +""" + def get_job_path( job_id: str, index: int | None, base_path: str | Path | None = None @@ -165,4 +173,16 @@ def resolve_job_dict_args(job_dict: dict, store: JobStore) -> dict: # substitution is in place job_dict["function_args"] = resolved_args job_dict["function_kwargs"] = resolved_kwargs + + additional_store_names = set(job_dict.keys()) - JOB_INIT_ARGS + for store_name in additional_store_names: + # Exclude MSON fields + if store_name.startswith("@"): + continue + if store_name not in store.additional_stores: + raise RemoteError( + f"Additional store {store_name!r} is not configured for this project.", + no_retry=True, + ) + return job_dict diff --git a/src/jobflow_remote/testing/__init__.py b/src/jobflow_remote/testing/__init__.py index 4aafe91e..9981ec95 100644 --- a/src/jobflow_remote/testing/__init__.py +++ b/src/jobflow_remote/testing/__init__.py @@ -2,7 +2,7 @@ from typing import Callable, Optional, Union -from jobflow import job +from jobflow import Response, job @job @@ -41,3 +41,27 @@ def check_env_var() -> str: import os return os.environ.get("TESTING_ENV_VAR", "unset") + + +@job(big_data="data") +def add_big(a: float, b: float): + """Adds two numbers together and inflates the answer + to a large list list and tries to store that within + the defined store. + + """ + result = a + b + big_array = [result] * 5_000 + return Response({"data": big_array, "result": a + b}) + + +@job(undefined_store="data") +def add_big_undefined_store(a: float, b: float): + """Adds two numbers together and writes the answer to an artificially large file + which is attempted to be stored in a undefined store.""" + import pathlib + + result = a + b + with open("file.txt", "w") as f: + f.writelines([f"{result}"] * int(1e5)) + return Response({"data": pathlib.Path("file.txt"), "result": a + b}) diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 903d0415..112d7589 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -196,7 +196,16 @@ def write_tmp_settings( "host": "localhost", "port": db_port, "collection_name": "docs", - } + }, + "additional_stores": { + "big_data": { + "type": "GridFSStore", + "database": store_database_name, + "host": "localhost", + "port": db_port, + "collection_name": "data", + }, + }, }, queue={ "store": { diff --git a/tests/integration/test_slurm.py b/tests/integration/test_slurm.py index 7dc849e9..ed9adbe0 100644 --- a/tests/integration/test_slurm.py +++ b/tests/integration/test_slurm.py @@ -234,3 +234,63 @@ def test_exec_config(worker, job_controller, random_project_name): job = job_controller.get_jobs({})[0] output = job_controller.jobstore.get_output(uuid=job["uuid"]) assert output == random_project_name + + +@pytest.mark.parametrize( + "worker", + ["test_local_worker", "test_remote_worker"], +) +def test_additional_stores(worker, job_controller): + from jobflow import Flow + + from jobflow_remote import submit_flow + from jobflow_remote.jobs.runner import Runner + from jobflow_remote.jobs.state import FlowState, JobState + from jobflow_remote.testing import add_big + + job = add_big(100, 100) + flow = Flow(job) + submit_flow(flow, worker=worker) + + assert job_controller.count_jobs({}) == 1 + assert job_controller.count_flows({}) == 1 + + runner = Runner() + runner.run(ticks=10) + + doc = job_controller.get_jobs({})[0] + fs = job_controller.jobstore.additional_stores["big_data"] + assert fs.count({"job_uuid": doc["job"]["uuid"]}) == 1 + assert job_controller.count_jobs(state=JobState.COMPLETED) == 1 + assert job_controller.count_flows(state=FlowState.COMPLETED) == 1 + assert job_controller.jobstore.get_output(uuid=doc["job"]["uuid"])["result"] == 200 + blob_uuid = job_controller.jobstore.get_output(uuid=doc["job"]["uuid"])["data"][ + "blob_uuid" + ] + assert list(fs.query({"blob_uuid": blob_uuid}))[0]["job_uuid"] == doc["job"]["uuid"] + + +@pytest.mark.parametrize( + "worker", + ["test_local_worker", "test_remote_worker"], +) +def test_undefined_additional_stores(worker, job_controller): + from jobflow import Flow + + from jobflow_remote import submit_flow + from jobflow_remote.jobs.runner import Runner + from jobflow_remote.jobs.state import JobState + from jobflow_remote.testing import add_big_undefined_store + + job = add_big_undefined_store(100, 100) + flow = Flow(job) + submit_flow(flow, worker=worker) + + assert job_controller.count_jobs({}) == 1 + assert job_controller.count_flows({}) == 1 + + runner = Runner() + runner.run(ticks=10) + + # The job should fail, as the additional store is not defined + assert job_controller.count_jobs(state=JobState.REMOTE_ERROR) == 1