Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] pydantic2 updates #29

Merged
merged 4 commits into from
Dec 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ classifiers = [
requires-python = ">=3.9"
dependencies =[
"jobflow[strict]",
"pydantic<2",
"pydantic>=2.0.1",
"fireworks",
"fabric",
"tomlkit",
Expand Down
54 changes: 28 additions & 26 deletions src/jobflow_remote/config/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import Annotated, Literal, Optional, Union

from jobflow import JobStore
from pydantic import BaseModel, Extra, Field, validator
from pydantic import BaseModel, ConfigDict, Field, field_validator
from qtoolkit.io import BaseSchedulerIO, scheduler_mapping

from jobflow_remote.fireworks.launchpad import RemoteLaunchPad
Expand Down Expand Up @@ -71,8 +71,7 @@ def get_delta_retry(self, step_attempts: int) -> int:
ind = min(step_attempts, len(self.delta_retry)) - 1
return self.delta_retry[ind]

class Config:
extra = Extra.forbid
model_config = ConfigDict(extra="forbid")


class LogLevel(str, Enum):
Expand Down Expand Up @@ -131,11 +130,9 @@ class WorkerBase(BaseModel):
description="Timeout for the execution of the commands in the worker "
"(e.g. submitting a job)",
)
model_config = ConfigDict(extra="forbid")

class Config:
extra = Extra.forbid

@validator("scheduler_type", always=True)
@field_validator("scheduler_type")
def check_scheduler_type(cls, scheduler_type: str, values: dict) -> str:
"""
Validator to set the default of scheduler_type
Expand All @@ -144,7 +141,7 @@ def check_scheduler_type(cls, scheduler_type: str, values: dict) -> str:
raise ValueError(f"Unknown scheduler type {scheduler_type}")
return scheduler_type

@validator("work_dir", always=True)
@field_validator("work_dir")
def check_work_dir(cls, v) -> Path:
if not v.is_absolute():
raise ValueError("`work_dir` must be an absolute path")
Expand Down Expand Up @@ -230,8 +227,8 @@ class RemoteWorker(WorkerBase):
"remote", description="The discriminator field to determine the worker type"
)
host: str = Field(description="The host to which to connect")
user: str = Field(None, description="Login username")
port: int = Field(None, description="Port number")
user: Optional[str] = Field(None, description="Login username")
port: Optional[int] = Field(None, description="Port number")
password: Optional[str] = Field(None, description="Login password")
key_filename: Optional[Union[str, list[str]]] = Field(
None,
Expand All @@ -241,18 +238,20 @@ class RemoteWorker(WorkerBase):
passphrase: Optional[str] = Field(
None, description="Passphrase used for decrypting private keys"
)
gateway: str = Field(
gateway: Optional[str] = Field(
None, description="A shell command string to use as a proxy or gateway"
)
forward_agent: bool = Field(
forward_agent: Optional[bool] = Field(
None, description="Whether to enable SSH agent forwarding"
)
connect_timeout: int = Field(None, description="Connection timeout, in seconds")
connect_kwargs: dict = Field(
connect_timeout: Optional[int] = Field(
None, description="Connection timeout, in seconds"
)
connect_kwargs: Optional[dict] = Field(
None,
description="Other keyword arguments passed to paramiko.client.SSHClient.connect",
)
inline_ssh_env: bool = Field(
inline_ssh_env: Optional[bool] = Field(
None,
description="Whether to send environment variables 'inline' as prefixes in "
"front of command strings",
Expand Down Expand Up @@ -336,9 +335,7 @@ class ExecutionConfig(BaseModel):
post_run: Optional[str] = Field(
None, description="Commands to be executed after the execution of a job"
)

class Config:
extra = Extra.forbid
model_config = ConfigDict(extra="forbid")


class Project(BaseModel):
Expand All @@ -351,19 +348,23 @@ class Project(BaseModel):
None,
description="The base directory containing the project related files. Default "
"is a folder with the project name inside the projects folder",
validate_default=True,
)
tmp_dir: Optional[str] = Field(
None,
description="Folder where remote files are copied. Default a 'tmp' folder in base_dir",
validate_default=True,
)
log_dir: Optional[str] = Field(
None,
description="Folder containing all the logs. Default a 'log' folder in base_dir",
validate_default=True,
)
daemon_dir: Optional[str] = Field(
None,
description="Folder containing daemon related files. Default to a 'daemon' "
"folder in base_dir",
validate_default=True,
)
log_level: LogLevel = Field(LogLevel.INFO, description="The level set for logging")
runner: RunnerOptions = Field(
Expand All @@ -379,6 +380,7 @@ class Project(BaseModel):
description="Dictionary describing a maggma Store used for the queue data. "
"Can contain the monty serialized dictionary or a dictionary with a 'type' "
"specifying the Store subclass",
validate_default=True,
)
exec_config: dict[str, ExecutionConfig] = Field(
default_factory=dict,
Expand All @@ -389,6 +391,7 @@ class Project(BaseModel):
default_factory=lambda: dict(DEFAULT_JOBSTORE),
description="The JobStore used for the input. Can contain the monty "
"serialized dictionary or the Store int the Jobflow format",
validate_default=True,
)
metadata: Optional[dict] = Field(
None, description="A dictionary with metadata associated to the project"
Expand Down Expand Up @@ -429,7 +432,7 @@ def get_launchpad(self) -> RemoteLaunchPad:
"""
return RemoteLaunchPad(self.get_queue_store())

@validator("base_dir", always=True)
@field_validator("base_dir")
def check_base_dir(cls, base_dir: str, values: dict) -> str:
"""
Validator to set the default of base_dir based on the project name
Expand All @@ -440,7 +443,7 @@ def check_base_dir(cls, base_dir: str, values: dict) -> str:
return str(Path(SETTINGS.projects_folder, values["name"]))
return base_dir

@validator("tmp_dir", always=True)
@field_validator("tmp_dir")
def check_tmp_dir(cls, tmp_dir: str, values: dict) -> str:
"""
Validator to set the default of tmp_dir based on the base_dir
Expand All @@ -449,7 +452,7 @@ def check_tmp_dir(cls, tmp_dir: str, values: dict) -> str:
return str(Path(values["base_dir"], "tmp"))
return tmp_dir

@validator("log_dir", always=True)
@field_validator("log_dir")
def check_log_dir(cls, log_dir: str, values: dict) -> str:
"""
Validator to set the default of log_dir based on the base_dir
Expand All @@ -458,7 +461,7 @@ def check_log_dir(cls, log_dir: str, values: dict) -> str:
return str(Path(values["base_dir"], "log"))
return log_dir

@validator("daemon_dir", always=True)
@field_validator("daemon_dir")
def check_daemon_dir(cls, daemon_dir: str, values: dict) -> str:
"""
Validator to set the default of daemon_dir based on the base_dir
Expand All @@ -467,7 +470,7 @@ def check_daemon_dir(cls, daemon_dir: str, values: dict) -> str:
return str(Path(values["base_dir"], "daemon"))
return daemon_dir

@validator("jobstore", always=True)
@field_validator("jobstore")
def check_jobstore(cls, jobstore: dict, values: dict) -> dict:
"""
Check that the jobstore configuration could be converted to a JobStore.
Expand All @@ -484,7 +487,7 @@ def check_jobstore(cls, jobstore: dict, values: dict) -> dict:
) from e
return jobstore

@validator("queue", always=True)
@field_validator("queue")
def check_queue(cls, queue: dict, values: dict) -> dict:
"""
Check that the queue configuration could be converted to a Store.
Expand All @@ -498,8 +501,7 @@ def check_queue(cls, queue: dict, values: dict) -> dict:
) from e
return queue

class Config:
extra = Extra.forbid
model_config = ConfigDict(extra="forbid")


class ConfigError(Exception):
Expand Down
17 changes: 7 additions & 10 deletions src/jobflow_remote/config/settings.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from __future__ import annotations

from pathlib import Path
from typing import Optional

from pydantic import BaseSettings, Field, root_validator
from pydantic import Field, model_validator
from pydantic_settings import BaseSettings, SettingsConfigDict

DEFAULT_PROJECTS_FOLDER = Path("~/.jfremote").expanduser().as_posix()

Expand All @@ -17,21 +17,18 @@ class JobflowRemoteSettings(BaseSettings):
projects_folder: str = Field(
DEFAULT_PROJECTS_FOLDER, description="Location of the projects files."
)
project: str = Field(None, description="The name of the project used.")
project: Optional[str] = Field(None, description="The name of the project used.")
cli_full_exc: bool = Field(
False,
description="If True prints the full stack trace of the exception when raised in the CLI.",
)
cli_suggestions: bool = Field(
True, description="If True prints some suggestions in the CLI commands."
)
model_config = SettingsConfigDict(env_prefix="jfremote_")

class Config:
"""Pydantic config settings."""

env_prefix = "jfremote_"

@root_validator(pre=True)
@model_validator(mode="before")
@classmethod
def load_default_settings(cls, values):
"""
Load settings from file or environment variables.
Expand Down