-
Notifications
You must be signed in to change notification settings - Fork 26
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #424 from hrushikesh-s/pydantic
Formalizing the JobStore document format as a pydantic model
- Loading branch information
Showing
4 changed files
with
163 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
"""A Pydantic model for Jobstore document.""" | ||
|
||
from typing import Generic, List, TypeVar | ||
|
||
from monty.json import MontyDecoder | ||
from pydantic import BaseModel, Field, field_validator | ||
|
||
T = TypeVar("T") | ||
|
||
|
||
class JobStoreDocument(BaseModel, Generic[T]): | ||
"""A Pydantic model for Jobstore document.""" | ||
|
||
uuid: str = Field( | ||
None, description="An unique identifier for the job. Generated automatically." | ||
) | ||
index: int = Field( | ||
None, | ||
description="The index of the job (number of times the job has been replaced).", | ||
) | ||
output: T = Field( | ||
None, | ||
description="This is a reference to the future job output.", | ||
) | ||
completed_at: str = Field(None, description="The time the job was completed.") | ||
metadata: dict = Field( | ||
None, | ||
description="Metadeta information supplied by the user.", | ||
) | ||
hosts: List[str] = Field( | ||
None, | ||
description="The list of UUIDs of the hosts containing the job.", | ||
) | ||
name: str = Field( | ||
None, | ||
description="The name of the job.", | ||
) | ||
|
||
@field_validator("output", mode="before") | ||
@classmethod | ||
def reserialize_output(cls, v): | ||
""" | ||
Pre-validator for the 'output' field. | ||
This method checks if the input 'v' is a dictionary with specific keys | ||
('@module' and '@class'). If these keys are present, it reprocesses | ||
the input dictionary using MontyDecoder to deserialize it. | ||
Parameters | ||
---------- | ||
cls : Type[JobStoreDocument] | ||
The class this validator is applied to. | ||
v : Any | ||
The input value to validate. | ||
Returns | ||
------- | ||
Any | ||
The validated and potentially deserialized value. | ||
""" | ||
if isinstance(v, dict) and "@module" in v and "@class" in v: | ||
v = MontyDecoder().process_decoded(v) | ||
return v |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
from datetime import datetime | ||
|
||
import pytest | ||
|
||
|
||
@pytest.fixture | ||
def sample_data(): | ||
from jobflow.schemas.job_output_schema import JobStoreDocument | ||
|
||
return JobStoreDocument( | ||
uuid="abc123", | ||
index=1, | ||
output=None, | ||
completed_at=datetime.now().isoformat(), | ||
metadata={"key": "value"}, | ||
hosts=["host1", "host2"], | ||
name="my_job", | ||
) | ||
|
||
|
||
def test_job_store_document_model(sample_data): | ||
# Test creating model | ||
data = sample_data | ||
|
||
assert data.uuid == "abc123" | ||
assert data.index == 1 | ||
assert data.output is None | ||
assert datetime.fromisoformat(data.completed_at).hour == datetime.now().hour | ||
assert data.metadata == {"key": "value"} | ||
assert data.hosts == ["host1", "host2"] | ||
assert data.name == "my_job" | ||
|
||
|
||
def test_job_store_update(memory_jobstore, sample_data): | ||
# Storing document as a JobStoreDocument | ||
from jobflow.schemas.job_output_schema import JobStoreDocument | ||
|
||
d = { | ||
"index": 1, | ||
"uuid": "abc123", | ||
"metadata": {"key": "value"}, | ||
"hosts": ["host1", "host2"], | ||
"name": "my_job", | ||
"e": 6, | ||
"d": 4, | ||
} | ||
sample_data = JobStoreDocument(**d) | ||
memory_jobstore.update(sample_data) | ||
|
||
# Check document was inserted | ||
results = memory_jobstore.query_one(criteria={"hosts": {"$exists": 1}}) | ||
assert results["index"] == 1 | ||
assert results["uuid"] == "abc123" | ||
assert results["metadata"] == {"key": "value"} | ||
assert results["hosts"] == ["host1", "host2"] | ||
assert results["name"] == "my_job" | ||
assert "e" not in results | ||
assert "d" not in results | ||
|
||
# Further checks to see if two documents get inserted | ||
e = d.copy() | ||
e["uuid"] = "def456" | ||
new_data_e = JobStoreDocument(**e) | ||
f = d.copy() | ||
f["uuid"] = "ghi789" | ||
new_data_f = JobStoreDocument(**f) | ||
memory_jobstore.update([new_data_e, new_data_f]) | ||
|
||
# Check if document new_data_e is present in the store | ||
results = memory_jobstore.query_one(criteria={"uuid": "def456"}) | ||
assert results["index"] == 1 | ||
assert results["uuid"] == "def456" | ||
assert results["metadata"] == {"key": "value"} | ||
assert results["hosts"] == ["host1", "host2"] | ||
assert results["name"] == "my_job" | ||
assert "e" not in results | ||
assert "d" not in results | ||
|
||
# Check if document new_data_f is present in the store | ||
results = memory_jobstore.query_one(criteria={"uuid": "ghi789"}) | ||
assert results["index"] == 1 | ||
assert results["uuid"] == "ghi789" | ||
assert results["metadata"] == {"key": "value"} | ||
assert results["hosts"] == ["host1", "host2"] | ||
assert results["name"] == "my_job" | ||
assert "e" not in results | ||
assert "d" not in results |