diff --git a/src/cryo_challenge/data/_validation/config_validators.py b/src/cryo_challenge/data/_validation/config_validators.py index 3e1d06f..eedf16f 100644 --- a/src/cryo_challenge/data/_validation/config_validators.py +++ b/src/cryo_challenge/data/_validation/config_validators.py @@ -2,7 +2,7 @@ import numpy as np import pandas as pd import os -from pydantic import BaseModel, validator, root_validator +from pydantic import BaseModel, field_validator, model_validator from typing import Optional, List @@ -259,21 +259,21 @@ class SVDNormalizeParams(BaseModel): bfactor: float = None box_size_ds: Optional[int] = None - @validator("mask_path") + @field_validator("mask_path") def check_mask_path_exists(cls, value): if value is not None: if not os.path.exists(value): raise ValueError(f"Mask file {value} does not exist.") return value - @validator("bfactor") + @field_validator("bfactor") def check_bfactor(cls, value): if value is not None: if value < 0: raise ValueError("B-factor must be non-negative.") return value - @validator("box_size_ds") + @field_validator("box_size_ds") def check_box_size_ds(cls, value): if value is not None: if value < 0: @@ -285,7 +285,7 @@ class SVDGtParams(BaseModel): gt_vols_file: str skip_vols: int = 1 - @validator("gt_vols_file") + @field_validator("gt_vols_file") def check_mask_path_exists(cls, value): if not os.path.exists(value): raise ValueError(f"Could not find file {value}.") @@ -300,7 +300,7 @@ def check_mask_path_exists(cls, value): ) return value - @validator("skip_vols") + @field_validator("skip_vols") def check_skip_vols(cls, value): if value is not None: if value < 0: @@ -327,10 +327,10 @@ class SVDConfig(BaseModel): gt_params: Optional[SVDGtParams] = None output_params: SVDOutputParams - @root_validator - def check_path_to_submissions(cls, values): - path_to_submissions = values.get("path_to_submissions") - excluded_submissions = values.get("excluded_submissions") + @model_validator(mode="after") + def check_path_to_submissions(self): + path_to_submissions = self.path_to_submissions + excluded_submissions = self.excluded_submissions if not os.path.exists(path_to_submissions): raise ValueError(f"Could not find path {path_to_submissions}.") @@ -354,21 +354,21 @@ def check_path_to_submissions(cls, values): f"No submission files found after excluding {excluded_submissions}." ) - return values + return self - @validator("dtype") + @field_validator("dtype") def check_dtype(cls, value): if value not in ["float32", "float64"]: raise ValueError(f"Invalid dtype {value}.") return value - @validator("svd_max_rank") + @field_validator("svd_max_rank") def check_svd_max_rank(cls, value): if value < 1 and value is not None: raise ValueError("Max rank must be at least 1.") return value - @validator("voxel_size") + @field_validator("voxel_size") def check_voxel_size(cls, value): if value <= 0: raise ValueError("Voxel size must be positive.")