From a077df44daed57c4e88bd5296e24286eaba45989 Mon Sep 17 00:00:00 2001 From: Guido Petretto Date: Fri, 29 Sep 2023 00:48:04 +0200 Subject: [PATCH 1/2] pydantic2 updates --- src/jobflow_remote/config/base.py | 54 ++++++++++++++------------- src/jobflow_remote/config/settings.py | 14 +++---- 2 files changed, 34 insertions(+), 34 deletions(-) diff --git a/src/jobflow_remote/config/base.py b/src/jobflow_remote/config/base.py index 0317a61c..2477c361 100644 --- a/src/jobflow_remote/config/base.py +++ b/src/jobflow_remote/config/base.py @@ -8,7 +8,7 @@ from typing import Annotated, Literal from jobflow import JobStore -from pydantic import BaseModel, Extra, Field, validator +from pydantic import BaseModel, ConfigDict, Field, field_validator from qtoolkit.io import BaseSchedulerIO, scheduler_mapping from jobflow_remote.fireworks.launchpad import RemoteLaunchPad @@ -73,8 +73,7 @@ def get_delta_retry(self, step_attempts: int) -> int: ind = min(step_attempts, len(self.delta_retry)) - 1 return self.delta_retry[ind] - class Config: - extra = Extra.forbid + model_config = ConfigDict(extra="forbid") class LogLevel(str, Enum): @@ -133,11 +132,9 @@ class WorkerBase(BaseModel): description="Timeout for the execution of the commands in the worker " "(e.g. submitting a job)", ) + model_config = ConfigDict(extra="forbid") - class Config: - extra = Extra.forbid - - @validator("scheduler_type", always=True) + @field_validator("scheduler_type") def check_scheduler_type(cls, scheduler_type: str, values: dict) -> str: """ Validator to set the default of scheduler_type @@ -146,7 +143,7 @@ def check_scheduler_type(cls, scheduler_type: str, values: dict) -> str: raise ValueError(f"Unknown scheduler type {scheduler_type}") return scheduler_type - @validator("work_dir", always=True) + @field_validator("work_dir") def check_work_dir(cls, v) -> Path: if not v.is_absolute(): raise ValueError("`work_dir` must be an absolute path") @@ -232,8 +229,8 @@ class RemoteWorker(WorkerBase): "remote", description="The discriminator field to determine the worker type" ) host: str = Field(description="The host to which to connect") - user: str = Field(None, description="Login username") - port: int = Field(None, description="Port number") + user: str | None = Field(None, description="Login username") + port: int | None = Field(None, description="Port number") password: str | None = Field(None, description="Login password") key_filename: str | list[str] | None = Field( None, @@ -243,18 +240,20 @@ class RemoteWorker(WorkerBase): passphrase: str | None = Field( None, description="Passphrase used for decrypting private keys" ) - gateway: str = Field( + gateway: str | None = Field( None, description="A shell command string to use as a proxy or gateway" ) - forward_agent: bool = Field( + forward_agent: bool | None = Field( None, description="Whether to enable SSH agent forwarding" ) - connect_timeout: int = Field(None, description="Connection timeout, in seconds") - connect_kwargs: dict = Field( + connect_timeout: int | None = Field( + None, description="Connection timeout, in seconds" + ) + connect_kwargs: dict | None = Field( None, description="Other keyword arguments passed to paramiko.client.SSHClient.connect", ) - inline_ssh_env: bool = Field( + inline_ssh_env: bool | None = Field( None, description="Whether to send environment variables 'inline' as prefixes in " "front of command strings", @@ -336,9 +335,7 @@ class ExecutionConfig(BaseModel): post_run: str | None = Field( None, description="Commands to be executed after the execution of a job" ) - - class Config: - extra = Extra.forbid + model_config = ConfigDict(extra="forbid") class Project(BaseModel): @@ -351,19 +348,23 @@ class Project(BaseModel): None, description="The base directory containing the project related files. Default " "is a folder with the project name inside the projects folder", + validate_default=True, ) tmp_dir: str | None = Field( None, description="Folder where remote files are copied. Default a 'tmp' folder in base_dir", + validate_default=True, ) log_dir: str | None = Field( None, description="Folder containing all the logs. Default a 'log' folder in base_dir", + validate_default=True, ) daemon_dir: str | None = Field( None, description="Folder containing daemon related files. Default to a 'daemon' " "folder in base_dir", + validate_default=True, ) log_level: LogLevel = Field(LogLevel.INFO, description="The level set for logging") runner: RunnerOptions = Field( @@ -379,6 +380,7 @@ class Project(BaseModel): description="Dictionary describing a maggma Store used for the queue data. " "Can contain the monty serialized dictionary or a dictionary with a 'type' " "specifying the Store subclass", + validate_default=True, ) exec_config: dict[str, ExecutionConfig] = Field( default_factory=dict, @@ -389,6 +391,7 @@ class Project(BaseModel): default_factory=lambda: dict(DEFAULT_JOBSTORE), description="The JobStore used for the input. Can contain the monty " "serialized dictionary or the Store int the Jobflow format", + validate_default=True, ) metadata: dict | None = Field( None, description="A dictionary with metadata associated to the project" @@ -429,7 +432,7 @@ def get_launchpad(self) -> RemoteLaunchPad: """ return RemoteLaunchPad(self.get_queue_store()) - @validator("base_dir", always=True) + @field_validator("base_dir") def check_base_dir(cls, base_dir: str, values: dict) -> str: """ Validator to set the default of base_dir based on the project name @@ -440,7 +443,7 @@ def check_base_dir(cls, base_dir: str, values: dict) -> str: return str(Path(SETTINGS.projects_folder, values["name"])) return base_dir - @validator("tmp_dir", always=True) + @field_validator("tmp_dir") def check_tmp_dir(cls, tmp_dir: str, values: dict) -> str: """ Validator to set the default of tmp_dir based on the base_dir @@ -449,7 +452,7 @@ def check_tmp_dir(cls, tmp_dir: str, values: dict) -> str: return str(Path(values["base_dir"], "tmp")) return tmp_dir - @validator("log_dir", always=True) + @field_validator("log_dir") def check_log_dir(cls, log_dir: str, values: dict) -> str: """ Validator to set the default of log_dir based on the base_dir @@ -458,7 +461,7 @@ def check_log_dir(cls, log_dir: str, values: dict) -> str: return str(Path(values["base_dir"], "log")) return log_dir - @validator("daemon_dir", always=True) + @field_validator("daemon_dir") def check_daemon_dir(cls, daemon_dir: str, values: dict) -> str: """ Validator to set the default of daemon_dir based on the base_dir @@ -467,7 +470,7 @@ def check_daemon_dir(cls, daemon_dir: str, values: dict) -> str: return str(Path(values["base_dir"], "daemon")) return daemon_dir - @validator("jobstore", always=True) + @field_validator("jobstore") def check_jobstore(cls, jobstore: dict, values: dict) -> dict: """ Check that the jobstore configuration could be converted to a JobStore. @@ -484,7 +487,7 @@ def check_jobstore(cls, jobstore: dict, values: dict) -> dict: ) from e return jobstore - @validator("queue", always=True) + @field_validator("queue") def check_queue(cls, queue: dict, values: dict) -> dict: """ Check that the queue configuration could be converted to a Store. @@ -498,8 +501,7 @@ def check_queue(cls, queue: dict, values: dict) -> dict: ) from e return queue - class Config: - extra = Extra.forbid + model_config = ConfigDict(extra="forbid") class ConfigError(Exception): diff --git a/src/jobflow_remote/config/settings.py b/src/jobflow_remote/config/settings.py index 4b0610b1..1ed64db3 100644 --- a/src/jobflow_remote/config/settings.py +++ b/src/jobflow_remote/config/settings.py @@ -2,7 +2,8 @@ from pathlib import Path -from pydantic import BaseSettings, Field, root_validator +from pydantic import Field, model_validator +from pydantic_settings import BaseSettings, SettingsConfigDict DEFAULT_PROJECTS_FOLDER = Path("~/.jfremote").expanduser().as_posix() @@ -18,7 +19,7 @@ class JobflowRemoteSettings(BaseSettings): projects_folder: str = Field( DEFAULT_PROJECTS_FOLDER, description="Location of the projects files." ) - project: str = Field(None, description="The name of the project used.") + project: str | None = Field(None, description="The name of the project used.") cli_full_exc: bool = Field( False, description="If True prints the full stack trace of the exception when raised in the CLI.", @@ -26,13 +27,10 @@ class JobflowRemoteSettings(BaseSettings): cli_suggestions: bool = Field( True, description="If True prints some suggestions in the CLI commands." ) + model_config = SettingsConfigDict(env_prefix="jfremote_") - class Config: - """Pydantic config settings.""" - - env_prefix = "jfremote_" - - @root_validator(pre=True) + @model_validator(mode="before") + @classmethod def load_default_settings(cls, values): """ Load settings from file or environment variables. From 5fe779f90fa490686ee81da69c1d0dfd8daab408 Mon Sep 17 00:00:00 2001 From: Guido Petretto Date: Thu, 12 Oct 2023 01:00:13 +0200 Subject: [PATCH 2/2] fix settings for python 3.9 --- src/jobflow_remote/config/settings.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/jobflow_remote/config/settings.py b/src/jobflow_remote/config/settings.py index f64d5d43..ca73c4a9 100644 --- a/src/jobflow_remote/config/settings.py +++ b/src/jobflow_remote/config/settings.py @@ -1,6 +1,5 @@ -from __future__ import annotations - from pathlib import Path +from typing import Optional from pydantic import Field, model_validator from pydantic_settings import BaseSettings, SettingsConfigDict @@ -18,7 +17,7 @@ class JobflowRemoteSettings(BaseSettings): projects_folder: str = Field( DEFAULT_PROJECTS_FOLDER, description="Location of the projects files." ) - project: str | None = Field(None, description="The name of the project used.") + project: Optional[str] = Field(None, description="The name of the project used.") cli_full_exc: bool = Field( False, description="If True prints the full stack trace of the exception when raised in the CLI.",