diff --git a/src/jobflow_remote/remote/data.py b/src/jobflow_remote/remote/data.py index 15b80059..f8be4fa6 100644 --- a/src/jobflow_remote/remote/data.py +++ b/src/jobflow_remote/remote/data.py @@ -4,16 +4,34 @@ import logging import os from pathlib import Path -from typing import Any +from typing import Any, Callable import orjson +from jobflow.core.job import Job, JobConfig from jobflow.core.store import JobStore from maggma.stores.mongolike import JSONStore from monty.json import jsanitize +from pydantic import BaseModel from jobflow_remote.utils.data import uuid_to_path +class JobWrapper(Job): + + function: Callable + function_args: tuple[Any, ...] + function_kwargs: dict[str, Any] + output_schema: type[BaseModel] + uuid: str + index: int + name: str + metadata: dict[str, Any] + config: JobConfig + hosts: list[str] + metadata_updates: list[dict[str, Any]] + config_updates: list[dict[str, Any]] + + def get_job_path( job_id: str, index: int | None, base_path: str | Path | None = None ) -> str: @@ -165,4 +183,15 @@ def resolve_job_dict_args(job_dict: dict, store: JobStore) -> dict: # substitution is in place job_dict["function_args"] = resolved_args job_dict["function_kwargs"] = resolved_kwargs + + expected_fields = set(JobWrapper.__annotations__) + additional_store_names = set(job_dict.keys()) - expected_fields + for additional_store in additional_store_names: + if additional_store.startswith("@"): + continue + if additional_store not in store.additional_stores: + raise ValueError( + f"Additional store {additional_store!r} is not configured." + ) + return job_dict diff --git a/tests/integration/test_slurm.py b/tests/integration/test_slurm.py index 5150fc79..4d373ba2 100644 --- a/tests/integration/test_slurm.py +++ b/tests/integration/test_slurm.py @@ -279,7 +279,7 @@ def test_undefined_additional_stores(worker, job_controller): from jobflow_remote import submit_flow from jobflow_remote.jobs.runner import Runner - from jobflow_remote.jobs.state import FlowState, JobState + from jobflow_remote.jobs.state import JobState from jobflow_remote.testing import add_big_undefined_store job = add_big_undefined_store(100, 100) @@ -293,5 +293,4 @@ def test_undefined_additional_stores(worker, job_controller): runner.run(ticks=10) # The job should fail, as the additional store is not defined - assert job_controller.count_jobs(state=JobState.FAILED) == 1 - assert job_controller.count_flows(state=FlowState.FAILED) == 1 + assert job_controller.count_jobs(state=JobState.REMOTE_ERROR) == 1