diff --git a/pyproject.toml b/pyproject.toml index 3d9bd4bf..2107d4a0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,7 @@ classifiers = [ requires-python = ">=3.9" dependencies =[ "jobflow[strict]", - "pydantic<2", + "pydantic>=2.0.1", "fireworks", "fabric", "tomlkit", diff --git a/src/jobflow_remote/config/base.py b/src/jobflow_remote/config/base.py index c0818e70..d9326df5 100644 --- a/src/jobflow_remote/config/base.py +++ b/src/jobflow_remote/config/base.py @@ -6,7 +6,7 @@ from typing import Annotated, Literal, Optional, Union from jobflow import JobStore -from pydantic import BaseModel, Extra, Field, validator +from pydantic import BaseModel, ConfigDict, Field, field_validator from qtoolkit.io import BaseSchedulerIO, scheduler_mapping from jobflow_remote.fireworks.launchpad import RemoteLaunchPad @@ -71,8 +71,7 @@ def get_delta_retry(self, step_attempts: int) -> int: ind = min(step_attempts, len(self.delta_retry)) - 1 return self.delta_retry[ind] - class Config: - extra = Extra.forbid + model_config = ConfigDict(extra="forbid") class LogLevel(str, Enum): @@ -131,11 +130,9 @@ class WorkerBase(BaseModel): description="Timeout for the execution of the commands in the worker " "(e.g. submitting a job)", ) + model_config = ConfigDict(extra="forbid") - class Config: - extra = Extra.forbid - - @validator("scheduler_type", always=True) + @field_validator("scheduler_type") def check_scheduler_type(cls, scheduler_type: str, values: dict) -> str: """ Validator to set the default of scheduler_type @@ -144,7 +141,7 @@ def check_scheduler_type(cls, scheduler_type: str, values: dict) -> str: raise ValueError(f"Unknown scheduler type {scheduler_type}") return scheduler_type - @validator("work_dir", always=True) + @field_validator("work_dir") def check_work_dir(cls, v) -> Path: if not v.is_absolute(): raise ValueError("`work_dir` must be an absolute path") @@ -230,8 +227,8 @@ class RemoteWorker(WorkerBase): "remote", description="The discriminator field to determine the worker type" ) host: str = Field(description="The host to which to connect") - user: str = Field(None, description="Login username") - port: int = Field(None, description="Port number") + user: Optional[str] = Field(None, description="Login username") + port: Optional[int] = Field(None, description="Port number") password: Optional[str] = Field(None, description="Login password") key_filename: Optional[Union[str, list[str]]] = Field( None, @@ -241,18 +238,20 @@ class RemoteWorker(WorkerBase): passphrase: Optional[str] = Field( None, description="Passphrase used for decrypting private keys" ) - gateway: str = Field( + gateway: Optional[str] = Field( None, description="A shell command string to use as a proxy or gateway" ) - forward_agent: bool = Field( + forward_agent: Optional[bool] = Field( None, description="Whether to enable SSH agent forwarding" ) - connect_timeout: int = Field(None, description="Connection timeout, in seconds") - connect_kwargs: dict = Field( + connect_timeout: Optional[int] = Field( + None, description="Connection timeout, in seconds" + ) + connect_kwargs: Optional[dict] = Field( None, description="Other keyword arguments passed to paramiko.client.SSHClient.connect", ) - inline_ssh_env: bool = Field( + inline_ssh_env: Optional[bool] = Field( None, description="Whether to send environment variables 'inline' as prefixes in " "front of command strings", @@ -336,9 +335,7 @@ class ExecutionConfig(BaseModel): post_run: Optional[str] = Field( None, description="Commands to be executed after the execution of a job" ) - - class Config: - extra = Extra.forbid + model_config = ConfigDict(extra="forbid") class Project(BaseModel): @@ -351,19 +348,23 @@ class Project(BaseModel): None, description="The base directory containing the project related files. Default " "is a folder with the project name inside the projects folder", + validate_default=True, ) tmp_dir: Optional[str] = Field( None, description="Folder where remote files are copied. Default a 'tmp' folder in base_dir", + validate_default=True, ) log_dir: Optional[str] = Field( None, description="Folder containing all the logs. Default a 'log' folder in base_dir", + validate_default=True, ) daemon_dir: Optional[str] = Field( None, description="Folder containing daemon related files. Default to a 'daemon' " "folder in base_dir", + validate_default=True, ) log_level: LogLevel = Field(LogLevel.INFO, description="The level set for logging") runner: RunnerOptions = Field( @@ -379,6 +380,7 @@ class Project(BaseModel): description="Dictionary describing a maggma Store used for the queue data. " "Can contain the monty serialized dictionary or a dictionary with a 'type' " "specifying the Store subclass", + validate_default=True, ) exec_config: dict[str, ExecutionConfig] = Field( default_factory=dict, @@ -389,6 +391,7 @@ class Project(BaseModel): default_factory=lambda: dict(DEFAULT_JOBSTORE), description="The JobStore used for the input. Can contain the monty " "serialized dictionary or the Store int the Jobflow format", + validate_default=True, ) metadata: Optional[dict] = Field( None, description="A dictionary with metadata associated to the project" @@ -429,7 +432,7 @@ def get_launchpad(self) -> RemoteLaunchPad: """ return RemoteLaunchPad(self.get_queue_store()) - @validator("base_dir", always=True) + @field_validator("base_dir") def check_base_dir(cls, base_dir: str, values: dict) -> str: """ Validator to set the default of base_dir based on the project name @@ -440,7 +443,7 @@ def check_base_dir(cls, base_dir: str, values: dict) -> str: return str(Path(SETTINGS.projects_folder, values["name"])) return base_dir - @validator("tmp_dir", always=True) + @field_validator("tmp_dir") def check_tmp_dir(cls, tmp_dir: str, values: dict) -> str: """ Validator to set the default of tmp_dir based on the base_dir @@ -449,7 +452,7 @@ def check_tmp_dir(cls, tmp_dir: str, values: dict) -> str: return str(Path(values["base_dir"], "tmp")) return tmp_dir - @validator("log_dir", always=True) + @field_validator("log_dir") def check_log_dir(cls, log_dir: str, values: dict) -> str: """ Validator to set the default of log_dir based on the base_dir @@ -458,7 +461,7 @@ def check_log_dir(cls, log_dir: str, values: dict) -> str: return str(Path(values["base_dir"], "log")) return log_dir - @validator("daemon_dir", always=True) + @field_validator("daemon_dir") def check_daemon_dir(cls, daemon_dir: str, values: dict) -> str: """ Validator to set the default of daemon_dir based on the base_dir @@ -467,7 +470,7 @@ def check_daemon_dir(cls, daemon_dir: str, values: dict) -> str: return str(Path(values["base_dir"], "daemon")) return daemon_dir - @validator("jobstore", always=True) + @field_validator("jobstore") def check_jobstore(cls, jobstore: dict, values: dict) -> dict: """ Check that the jobstore configuration could be converted to a JobStore. @@ -484,7 +487,7 @@ def check_jobstore(cls, jobstore: dict, values: dict) -> dict: ) from e return jobstore - @validator("queue", always=True) + @field_validator("queue") def check_queue(cls, queue: dict, values: dict) -> dict: """ Check that the queue configuration could be converted to a Store. @@ -498,8 +501,7 @@ def check_queue(cls, queue: dict, values: dict) -> dict: ) from e return queue - class Config: - extra = Extra.forbid + model_config = ConfigDict(extra="forbid") class ConfigError(Exception): diff --git a/src/jobflow_remote/config/settings.py b/src/jobflow_remote/config/settings.py index 82096b1c..ca73c4a9 100644 --- a/src/jobflow_remote/config/settings.py +++ b/src/jobflow_remote/config/settings.py @@ -1,8 +1,8 @@ -from __future__ import annotations - from pathlib import Path +from typing import Optional -from pydantic import BaseSettings, Field, root_validator +from pydantic import Field, model_validator +from pydantic_settings import BaseSettings, SettingsConfigDict DEFAULT_PROJECTS_FOLDER = Path("~/.jfremote").expanduser().as_posix() @@ -17,7 +17,7 @@ class JobflowRemoteSettings(BaseSettings): projects_folder: str = Field( DEFAULT_PROJECTS_FOLDER, description="Location of the projects files." ) - project: str = Field(None, description="The name of the project used.") + project: Optional[str] = Field(None, description="The name of the project used.") cli_full_exc: bool = Field( False, description="If True prints the full stack trace of the exception when raised in the CLI.", @@ -25,13 +25,10 @@ class JobflowRemoteSettings(BaseSettings): cli_suggestions: bool = Field( True, description="If True prints some suggestions in the CLI commands." ) + model_config = SettingsConfigDict(env_prefix="jfremote_") - class Config: - """Pydantic config settings.""" - - env_prefix = "jfremote_" - - @root_validator(pre=True) + @model_validator(mode="before") + @classmethod def load_default_settings(cls, values): """ Load settings from file or environment variables.