diff --git a/modelkit/core/settings.py b/modelkit/core/settings.py index c996cc36..2e8d53be 100644 --- a/modelkit/core/settings.py +++ b/modelkit/core/settings.py @@ -95,16 +95,15 @@ class NativeCacheSettings(CacheSettings): def cache_settings(): s = CacheSettings() - if s.cache_provider is None: + + if s.cache_provider == "none": return None - try: + elif s.cache_provider == "redis": return RedisSettings() - except pydantic.ValidationError: - pass - try: + elif s.cache_provider == "native": return NativeCacheSettings() - except pydantic.ValidationError: - pass + else: + return None def _get_library_settings_cache_provider(v: Optional[str]) -> str: @@ -134,8 +133,5 @@ class LibrarySettings(ModelkitSettings): Annotated[None, pydantic.Tag("none")], ], pydantic.Discriminator(_get_library_settings_cache_provider), - ] = pydantic.Field( - default_factory=cache_settings, - union_mode="left_to_right", - ) + ] = pydantic.Field(default_factory=cache_settings) model_config = pydantic.ConfigDict(extra="allow") diff --git a/pyproject.toml b/pyproject.toml index dacbd0fc..b51b9679 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -126,7 +126,7 @@ addopts = """ --failed-first --durations 10 --color=yes -tests""" +""" [tool.black] target-version = ['py38'] diff --git a/tests/test_settings.py b/tests/test_settings.py index b6944364..29e17b01 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -6,7 +6,12 @@ from modelkit.assets.drivers.gcs import GCSStorageDriverSettings from modelkit.assets.drivers.local import LocalStorageDriverSettings from modelkit.assets.drivers.s3 import S3StorageDriverSettings -from modelkit.core.settings import ModelkitSettings +from modelkit.core.settings import ( + LibrarySettings, + ModelkitSettings, + NativeCacheSettings, + RedisSettings, +) def test_modelkit_settings_working(monkeypatch): @@ -48,3 +53,24 @@ def test_storage_driver_settings(Settings, monkeypatch): assert Settings(bucket="bar").bucket == "bar" with pytest.raises(pydantic.ValidationError): _ = Settings() + + +def test_cache_provider_settings(monkeypatch): + monkeypatch.setenv("MODELKIT_CACHE_PROVIDER", "redis") + lib_settings = LibrarySettings() + assert isinstance(lib_settings.cache, RedisSettings) + assert lib_settings.cache.cache_provider == "redis" + + monkeypatch.setenv("MODELKIT_CACHE_PROVIDER", "native") + lib_settings = LibrarySettings() + assert isinstance(lib_settings.cache, NativeCacheSettings) + assert lib_settings.cache.cache_provider == "native" + + monkeypatch.setenv("MODELKIT_CACHE_PROVIDER", "none") + assert LibrarySettings().cache is None + + monkeypatch.setenv("MODELKIT_CACHE_PROVIDER", "not supported") + assert LibrarySettings().cache is None + + monkeypatch.delenv("MODELKIT_CACHE_PROVIDER") + assert LibrarySettings().cache is None