diff --git a/config.py b/config.py index 3bf4bae..e26c67a 100644 --- a/config.py +++ b/config.py @@ -1,26 +1,32 @@ import os -import json import time -import inspect import shutil from pathlib import Path + from filelock import FileLock -from types import SimpleNamespace -from dataclasses import dataclass, field from abc import ABC, abstractmethod -from typing import Callable, Any, MutableSequence - - -@dataclass -class AlltalkConfigTheme: +from typing import Callable, MutableSequence + +from pydantic import BaseModel, ConfigDict, AliasGenerator, AliasChoices, Field + +class AlltalkConfigTheme(BaseModel): + # Map 'class' to 'clazz' and vice versa + model_config = ConfigDict( + alias_generator=AliasGenerator( + validation_alias=lambda field_name: + { + "clazz": AliasChoices("clazz", "class"), + }.get(field_name, None), + serialization_alias=lambda field_name: "class" if field_name == "clazz" else field_name, + ) + ) file: str | None = None - clazz = "" + clazz: str = "gradio/base" -@dataclass -class AlltalkConfigRvcSettings: +class AlltalkConfigRvcSettings(BaseModel): rvc_enabled: bool = False rvc_char_model_file: str = "Disabled" - rvc_narr_model_file: str = "Disabled" + rvc_narr_model_file: str = "Disabled" split_audio: bool = True autotune: bool = False pitch: int = 0 @@ -33,8 +39,7 @@ class AlltalkConfigRvcSettings: embedder_model: str = "hubert" training_data_size: int = 45000 -@dataclass -class AlltalkConfigTgwUi: +class AlltalkConfigTgwUi(BaseModel): tgwui_activate_tts: bool = True tgwui_autoplay_tts: bool = True tgwui_narrator_enabled: str = "false" @@ -54,8 +59,7 @@ class AlltalkConfigTgwUi: tgwui_rvc_narr_voice: str = "Disabled" tgwui_rvc_narr_pitch: int = 0 -@dataclass -class AlltalkConfigApiDef: +class AlltalkConfigApiDef(BaseModel): api_port_number: int = 7851 api_allowed_filter: str = "[^a-zA-Z0-9\\s.,;:!?\\-\\'\"$\\u0400-\\u04FF\\u00C0-\\u017F\\u0150\\u0151\\u0170\\u0171\\u011E\\u011F\\u0130\\u0131\\u0900-\\u097F\\u2018\\u2019\\u201C\\u201D\\u3001\\u3002\\u3040-\\u309F\\u30A0-\\u30FF\\u4E00-\\u9FFF\\u3400-\\u4DBF\\uF900-\\uFAFF\\u0600-\\u06FF\\u0750-\\u077F\\uFB50-\\uFDFF\\uFE70-\\uFEFF\\uAC00-\\uD7A3\\u1100-\\u11FF\\u3130-\\u318F\\uFF01\\uFF0c\\uFF1A\\uFF1B\\uFF1F]" api_length_stripping: int = 3 @@ -71,8 +75,7 @@ class AlltalkConfigApiDef: api_autoplay: bool = False api_autoplay_volume: float = 0.5 -@dataclass -class AlltalkConfigDebug: +class AlltalkConfigDebug(BaseModel): debug_transcode: bool = False debug_tts: bool = False debug_openai: bool = False @@ -87,8 +90,7 @@ class AlltalkConfigDebug: debug_transcribe: bool = False debug_proxy: bool = False -@dataclass -class AlltalkConfigGradioPages: +class AlltalkConfigGradioPages(BaseModel): Generate_Help_page: bool = True Voice2RVC_page: bool = True TTS_Generator_page: bool = True @@ -96,58 +98,29 @@ class AlltalkConfigGradioPages: alltalk_documentation_page: bool = True api_documentation_page: bool = True +class AlltalkAvailableEngine(BaseModel): + name: str = "" + selected_model: str = "" -@dataclass -class AlltalkAvailableEngine: - name = "" - selected_model = "" - -@dataclass -class AlltalkConfigProxyEndpoint: +class AlltalkConfigProxyEndpoint(BaseModel): enabled: bool = False external_port: int = 0 external_ip: str = "0.0.0.0" cert_name: str = "" - def to_dict(self): - return { - "enabled": self.enabled, - "external_port": self.external_port, - "external_ip": self.external_ip, - "cert_name": self.cert_name - } - -@dataclass -class AlltalkConfigProxySettings: +class AlltalkConfigProxySettings(BaseModel): proxy_enabled: bool = False start_on_startup: bool = False - gradio_endpoint: AlltalkConfigProxyEndpoint = field(default_factory=lambda: AlltalkConfigProxyEndpoint(external_port=444)) - api_endpoint: AlltalkConfigProxyEndpoint = field(default_factory=lambda: AlltalkConfigProxyEndpoint(external_port=443)) + gradio_endpoint: AlltalkConfigProxyEndpoint = AlltalkConfigProxyEndpoint(external_port=444) + api_endpoint: AlltalkConfigProxyEndpoint = AlltalkConfigProxyEndpoint(external_port=443) cert_validation: bool = True logging_enabled: bool = True log_level: str = "INFO" - def __post_init__(self): - # Handle conversion from dict to proper objects - if isinstance(self.gradio_endpoint, dict): - self.gradio_endpoint = AlltalkConfigProxyEndpoint(**self.gradio_endpoint) - if isinstance(self.api_endpoint, dict): - self.api_endpoint = AlltalkConfigProxyEndpoint(**self.api_endpoint) - - def to_dict(self): - return { - "proxy_enabled": self.proxy_enabled, - "start_on_startup": self.start_on_startup, - "gradio_endpoint": self.gradio_endpoint.to_dict(), - "api_endpoint": self.api_endpoint.to_dict(), - "cert_validation": self.cert_validation, - "logging_enabled": self.logging_enabled, - "log_level": self.log_level - } - class AbstractJsonConfig(ABC): - def __init__(self, config_path: Path | str, file_check_interval: int): + super().__init__() + self.__delegate = None self.__config_path = Path(config_path) if type(config_path) is str else config_path self.__last_read_time = 0 # Track when we last read the file self.__file_check_interval = file_check_interval @@ -159,14 +132,6 @@ def reload(self): self._load_config() return self - def to_dict(self): - # Remove private fields: - without_private_fields = {} - for attr, value in self.__dict__.items(): - if not attr.startswith("_"): - without_private_fields[attr] = value - return without_private_fields - def _reload_on_change(self): # Check if config file has been modified and reload if needed if time.time() - self.__last_read_time >= self.__file_check_interval: @@ -181,32 +146,19 @@ def _load_config(self): self.__last_read_time = self.get_config_path().stat().st_mtime def __load(): with open(self.get_config_path(), "r") as configfile: - data = json.load(configfile, object_hook=self._object_hook()) - self._handle_loaded_config(data) + json_string = configfile.read() + # The delegate is the actual config loaded: + self.__delegate = self._handle_loaded_config(json_string) self.__with_lock_and_backup(self.get_config_path(), False, __load) - def _object_hook(self) -> Callable[[dict[Any, Any]], Any] | None: - return lambda d: SimpleNamespace(**d) - def _save_file(self, path: Path | None | str, default=None, indent=4): file_path = (Path(path) if type(path) is str else path) if path is not None else self.get_config_path() - - def custom_default(o): - if isinstance(o, Path): - return str(o) # Convert Path objects to strings - elif hasattr(o, '__dict__'): - return o.__dict__ # Use the object's __dict__ if it exists - else: - raise TypeError(f"Object of type {type(o).__name__} is not JSON serializable") - - default = default or custom_default - + def __save(): with open(file_path, 'w') as file: - json.dump(self.to_dict(), file, indent=indent, default=default) - - self.__with_lock_and_backup(file_path, True, __save) + file.write(self.__delegate.model_dump_json(indent=indent, by_alias=True)) + self.__with_lock_and_backup(file_path, True, __save) def __with_lock_and_backup(self, path: Path, backup: bool, callable: Callable[[], None]): lock_path = path.with_suffix('.lock') @@ -226,34 +178,71 @@ def __with_lock_and_backup(self, path: Path, backup: bool, callable: Callable[[] raise Exception(f"Failed to save config: {e}") finally: # Cleanup lock and backup files: - if lock_path.exists(): # Only try to delete if it exists - try: - lock_path.unlink() - except FileNotFoundError: - pass # Ignore if file doesn't exist - + if lock_path.exists(): + lock_path.unlink() + if backup and backup_path and backup_path.exists(): - try: - backup_path.unlink() - except FileNotFoundError: - pass # Ignore if file doesn't exist + backup_path.unlink() @abstractmethod - def _handle_loaded_config(self, data): + def _handle_loaded_config(self, json_string: str): pass -class AlltalkNewEnginesConfig(AbstractJsonConfig): + def to_dict(self): + return self.__delegate.model_dump(by_alias=True) + + # Delegation to the loaded config data + @property + def _xtra(self): return [o for o in dir(self.__delegate) if not o.startswith('_')] + + def __getattr__(self, key): + if key in self._xtra: return getattr(self.__delegate, key) + raise AttributeError(key) + + def __getattribute__(self, key): + if key.startswith("_"): + # Private and protected attributes always are on self: + return super().__getattribute__(key) + if key in self._xtra: + # Delegate attribute + return getattr(self.__delegate, key) + # Get own attribute: + return super().__getattribute__(key) + + def __setattr__(self, key, value): + if key.startswith("_"): + # Private and protected attributes always are on self: + super().__setattr__(key, value) + elif key in dir(self.__delegate): + # Delegate attribute + setattr(self.__delegate, key, value) + else: + # Set own attribute: + super().__setattr__(key, value) + + def __dir__(self): + def custom_dir(c, add): return dir(type(c)) + list(c.__dict__.keys()) + add + return custom_dir(self, self._xtra) + +class AlltalkNewEnginesConfigFields: + engines_available: MutableSequence[AlltalkAvailableEngine] = Field(default_factory=list) + +class AlltalkNewEnginesConfigModel(BaseModel, AlltalkNewEnginesConfigFields): + + def get_engine_names_available(self): + return [engine.name for engine in self.engines_available] + + def get_engines_matching(self, condition: Callable[[AlltalkAvailableEngine], bool]): + return [x for x in self.engines_available if condition(x)] + +class AlltalkNewEnginesConfig(AbstractJsonConfig, AlltalkNewEnginesConfigFields): __instance = None __this_dir = Path(__file__).parent.resolve() def __init__(self, config_path: Path | str = os.path.join(__this_dir, "system", "tts_engines", "new_engines.json")): super().__init__(config_path, 5) - self.engines_available: MutableSequence[AlltalkAvailableEngine] = [] self._load_config() - def get_engine_names_available(self): - return [engine.name for engine in self.engines_available] - @staticmethod def get_instance(): if AlltalkNewEnginesConfig.__instance is None: @@ -261,27 +250,40 @@ def get_instance(): AlltalkNewEnginesConfig.__instance._reload_on_change() return AlltalkNewEnginesConfig.__instance - def _handle_loaded_config(self, data): - self.engines_available = data.engines_available + def _handle_loaded_config(self, json_string: str): + return AlltalkNewEnginesConfigModel.model_validate_json(json_string) - def get_engines_matching(self, condition: Callable[[AlltalkAvailableEngine], bool]): - return [x for x in self.engines_available if condition(x)] +class AlltalkTTSEnginesConfigFields: + engines_available: MutableSequence[AlltalkAvailableEngine] = Field(default_factory=list) + engine_loaded: str = "" + selected_model: str = "" + +class AlltalkTTSEnginesConfigModel(BaseModel, AlltalkTTSEnginesConfigFields): + def get_engine_names_available(self): + return [engine.name for engine in self.engines_available] -class AlltalkTTSEnginesConfig(AbstractJsonConfig): + def is_valid_engine(self, engine_name): + return engine_name in self.get_engine_names_available() + + def change_engine(self, requested_engine): + if requested_engine == self.engine_loaded: + return self + for engine in self.engines_available: + if engine.name == requested_engine: + self.engine_loaded = requested_engine + self.selected_model = engine.selected_model + return self + return self + +class AlltalkTTSEnginesConfig(AbstractJsonConfig, AlltalkTTSEnginesConfigFields): __instance = None __this_dir = Path(__file__).parent.resolve() def __init__(self, config_path: Path | str = os.path.join(__this_dir, "system", "tts_engines", "tts_engines.json")): super().__init__(config_path, 5) - self.engines_available: MutableSequence[AlltalkAvailableEngine] = [] - self.engine_loaded = "" - self.selected_model = "" self._load_config() - def get_engine_names_available(self): - return [engine.name for engine in self.engines_available] - @staticmethod def get_instance(force_reload = False): if AlltalkTTSEnginesConfig.__instance is None: @@ -294,16 +296,12 @@ def get_instance(force_reload = False): AlltalkTTSEnginesConfig.__instance._reload_on_change() return AlltalkTTSEnginesConfig.__instance - def _handle_loaded_config(self, data): - # List of the available TTS engines: - self.engines_available = self.__handle_loaded_config_engines(data) - - # The currently set TTS engine from tts_engines.json - self.engine_loaded = data.engine_loaded - self.selected_model = data.selected_model + def _handle_loaded_config(self, json_string: str): + cfg = AlltalkTTSEnginesConfigModel.model_validate_json(json_string) + cfg.engines_available = self.__handle_loaded_config_engines(cfg.engines_available) + return cfg - def __handle_loaded_config_engines(self, data): - available_engines = data.engines_available + def __handle_loaded_config_engines(self, available_engines): available_engine_names = [engine.name for engine in available_engines] # Getting the engines that are not already part of the available engines: @@ -316,42 +314,36 @@ def __handle_loaded_config_engines(self, data): def save(self, path: Path | str | None = None): self._save_file(path) - def is_valid_engine(self, engine_name): - return engine_name in self.get_engine_names_available() +class AlltalkConfigFields: + branding: str = "AllTalk " + delete_output_wavs: str = "Disabled" + gradio_interface: bool = True + output_folder: str = "outputs" + gradio_port_number: int = 7852 + firstrun_model: bool = True + firstrun_splash: bool = True + launch_gradio: bool = True + transcode_audio_format: str = "Disabled" + theme: AlltalkConfigTheme = AlltalkConfigTheme() + rvc_settings: AlltalkConfigRvcSettings = AlltalkConfigRvcSettings() + tgwui: AlltalkConfigTgwUi = AlltalkConfigTgwUi() + api_def: AlltalkConfigApiDef = AlltalkConfigApiDef() + debugging: AlltalkConfigDebug = AlltalkConfigDebug() + gradio_pages: AlltalkConfigGradioPages = AlltalkConfigGradioPages() + proxy_settings: AlltalkConfigProxySettings = AlltalkConfigProxySettings() + +class AlltalkConfigModel(BaseModel, AlltalkConfigFields): + __this_dir = Path(__file__).parent.resolve() - def change_engine(self, requested_engine): - if requested_engine == self.engine_loaded: - return self - for engine in self.engines_available: - if engine.name == requested_engine: - self.engine_loaded = requested_engine - self.selected_model = engine.selected_model - return self - return self + def get_output_directory(self): + return self.__this_dir / self.output_folder -@dataclass -class AlltalkConfig(AbstractJsonConfig): +class AlltalkConfig(AbstractJsonConfig, AlltalkConfigFields): __instance = None __this_dir = Path(__file__).parent.resolve() def __init__(self, config_path: Path | str = __this_dir / "confignew.json"): super().__init__(config_path, 5) - self.branding = "" - self.delete_output_wavs = "" - self.gradio_interface = False - self.output_folder = "" - self.gradio_port_number = 0 - self.firstrun_model = False - self.firstrun_splash = False - self.launch_gradio = False - self.transcode_audio_format = "" - self.theme = AlltalkConfigTheme() - self.rvc_settings = AlltalkConfigRvcSettings() - self.tgwui = AlltalkConfigTgwUi() - self.api_def = AlltalkConfigApiDef() - self.debugging = AlltalkConfigDebug() - self.gradio_pages = AlltalkConfigGradioPages() - self.proxy_settings = AlltalkConfigProxySettings() self._load_config() @staticmethod @@ -370,114 +362,10 @@ def get_instance(force_reload = False): AlltalkConfig.__instance._reload_on_change() return AlltalkConfig.__instance - def get_output_directory(self): - return self.__this_dir / self.output_folder - def save(self, path: Path | str | None = None): self._save_file(path) - def _handle_loaded_config(self, data): - from dataclasses import fields, is_dataclass, asdict - debug_me = False - if debug_me: - print("=== Loading Config ===") - print(f"Initial data state: {vars(data)}") - - # Create new instances with defaults - default_instances = { - 'debugging': AlltalkConfigDebug(), - 'rvc_settings': AlltalkConfigRvcSettings(), - 'tgwui': AlltalkConfigTgwUi(), - 'api_def': AlltalkConfigApiDef(), - 'theme': AlltalkConfigTheme(), - 'gradio_pages': AlltalkConfigGradioPages(), - 'proxy_settings': AlltalkConfigProxySettings() - } - - if debug_me: - print("\nDefault values for each class:") - - for name, instance in default_instances.items(): - if debug_me: - print(f"{name}: {asdict(instance)}") - print(f"Default values: {[(f.name, getattr(instance, f.name)) for f in fields(instance)]}") - - if hasattr(data, name): - source = getattr(data, name) - if debug_me: - print(f"Source data for {name}: {vars(source) if hasattr(source, '__dict__') else source}") - - # Special handling for proxy_settings due to nested structure - if name == 'proxy_settings' and hasattr(source, '__dict__'): - new_instance = AlltalkConfigProxySettings() - for field in fields(new_instance): - if hasattr(source, field.name): - field_value = getattr(source, field.name) - if field.name in ['gradio_endpoint', 'api_endpoint']: - # Handle nested endpoint objects - if isinstance(field_value, dict): - setattr(new_instance, field.name, - AlltalkConfigProxyEndpoint(**field_value)) - elif hasattr(field_value, '__dict__'): - setattr(new_instance, field.name, - AlltalkConfigProxyEndpoint(**field_value.__dict__)) - else: - setattr(new_instance, field.name, field_value) - instance = new_instance - else: - # Standard handling for other dataclasses - for field in fields(instance): - if hasattr(source, field.name): - setattr(instance, field.name, getattr(source, field.name)) - - if debug_me: - print(f"Processed instance {name}: {asdict(instance)}") - - setattr(self, name, instance) - - # Handle non-dataclass fields - for n, v in inspect.getmembers(data): - if hasattr(self, n) and not n.startswith("__") and not is_dataclass(type(getattr(self, n))): - setattr(self, n, v) - - # Special handling for theme class/clazz as before - self.theme.clazz = data.theme.__dict__.get("class", data.theme.__dict__.get("clazz", "")) - self.get_output_directory().mkdir(parents=True, exist_ok=True) - - def to_dict(self): - from dataclasses import is_dataclass, asdict - debug_me = False - if debug_me: - print("=== Converting to dict ===") - result = {} - - for key, value in vars(self).items(): - if not key.startswith('_'): - # print(f"\nProcessing {key}:") - if key == 'proxy_settings' and is_dataclass(value): - # Special handling for proxy settings - result[key] = value.to_dict() - elif is_dataclass(value): - # print(f"Dataclass value before conversion: {vars(value)}") - result[key] = asdict(value) - # print(f"Converted to dict: {result[key]}") - elif isinstance(value, SimpleNamespace): - # print(f"SimpleNamespace value: {value.__dict__}") - result[key] = value.__dict__ - else: - # print(f"Regular value: {value}") - result[key] = value - - # Maintain existing theme handling exactly as before - if 'theme' in result: - if debug_me: - print("\nProcessing theme:") - print(f"Before class handling: {result['theme']}") - result['theme']['class'] = self.theme.clazz - result['theme'].pop('clazz', None) - if debug_me: - print(f"After class handling: {result['theme']}") - if debug_me: - print(f"\nFinal dict: {result}") - return result - + def _handle_loaded_config(self, json_string: str): + model = AlltalkConfigModel.model_validate_json(json_string) + model.get_output_directory().mkdir(parents=True, exist_ok=True) + return model diff --git a/confignew.json b/confignew.json index 865bf49..33547f4 100644 --- a/confignew.json +++ b/confignew.json @@ -44,7 +44,9 @@ "tgwui_show_text": true, "tgwui_character_voice": "female_01.wav", "tgwui_rvc_char_voice": "Disabled", - "tgwui_rvc_narr_voice": "Disabled" + "tgwui_rvc_char_pitch": 0, + "tgwui_rvc_narr_voice": "Disabled", + "tgwui_rvc_narr_pitch": 0 }, "api_def": { "api_port_number": 7851, @@ -68,7 +70,14 @@ "debug_openai": false, "debug_concat": false, "debug_tts_variables": false, - "debug_rvc": false + "debug_rvc": false, + "debug_func": false, + "debug_api": false, + "debug_fullttstext": false, + "debug_narrator": false, + "debug_gradio_IP": false, + "debug_transcribe": false, + "debug_proxy": false }, "gradio_pages": { "Generate_Help_page": true, @@ -77,5 +86,24 @@ "TTS_Engines_Settings_page": true, "alltalk_documentation_page": true, "api_documentation_page": true + }, + "proxy_settings": { + "proxy_enabled": false, + "start_on_startup": false, + "gradio_endpoint": { + "enabled": false, + "external_port": 444, + "external_ip": "0.0.0.0", + "cert_name": "" + }, + "api_endpoint": { + "enabled": false, + "external_port": 443, + "external_ip": "0.0.0.0", + "cert_name": "" + }, + "cert_validation": true, + "logging_enabled": true, + "log_level": "INFO" } -} +} \ No newline at end of file diff --git a/test/confignew_partial.json b/test/confignew_partial.json new file mode 100644 index 0000000..fc00d18 --- /dev/null +++ b/test/confignew_partial.json @@ -0,0 +1,15 @@ +{ + "branding": "Another AllTalk ", + "theme": { + "file": null + }, + "rvc_settings": { + "rvc_narr_model_file": "another/file" + }, + "tgwui": { + "tgwui_narrator_voice": "another_female_01.wav" + }, + "api_def": { + "api_output_file_name": "another_myoutputfile" + } +} diff --git a/test/empty.json b/test/empty.json new file mode 100644 index 0000000..9e26dfe --- /dev/null +++ b/test/empty.json @@ -0,0 +1 @@ +{} \ No newline at end of file diff --git a/test/test_config.py b/test/test_config.py index 2845414..34648b3 100644 --- a/test/test_config.py +++ b/test/test_config.py @@ -1,17 +1,57 @@ import os.path import tempfile -import unittest +from unittest import TestCase from pathlib import Path - from config import AlltalkConfig, AlltalkTTSEnginesConfig, AlltalkNewEnginesConfig -class TestAlltalkConfig(unittest.TestCase): +class TestAlltalkConfig(TestCase): def setUp(self): self.config = AlltalkConfig.get_instance() self.config.reload() + def test_default_values_loaded(self): + cfg = AlltalkConfig(Path(__file__).parent.resolve() / 'empty.json') + + # Since the default config file is expected to also contain the + # default values, we can simply check that both dictionaries are identical: + self.assertEqual(self.config.to_dict(), cfg.to_dict()) + + def test_no_default_values_missing(self): + # Loading the empty JSON will populate the config with defaults from the code: + cfg = AlltalkConfig(Path(__file__).parent.resolve() / 'empty.json') + with tempfile.NamedTemporaryFile(suffix=".json") as tmp: + cfg.save(tmp.name) + # Compare the defaults from code (when written to a file) to the actual default config file: + with open(tmp.name, "r") as file1: + with open(self.config.get_config_path(), "r") as file2: + self.assertEqual(file1.read(), file2.read()) + + def test_values_merged_with_defaults(self): + cfg = AlltalkConfig(Path(__file__).parent.resolve() / 'confignew_partial.json') + + # Check some values that are missing in the JSON: + self.assertEqual(cfg.gradio_port_number, 7852) + self.assertEqual(cfg.theme.clazz, "gradio/base") + self.assertEqual(cfg.rvc_settings.index_rate, 0.75) + self.assertEqual(cfg.rvc_settings.embedder_model, "hubert") + self.assertEqual(cfg.tgwui.tgwui_language, "English") + self.assertEqual(cfg.tgwui.tgwui_repetitionpenalty_set, 10) + self.assertEqual(cfg.api_def.api_port_number, 7851) + self.assertEqual(cfg.api_def.api_text_filtering, "standard") + self.assertFalse(cfg.debugging.debug_rvc) + self.assertFalse(cfg.debugging.debug_openai) + + def test_loading_values(self): + cfg = AlltalkConfig(Path(__file__).parent.resolve() / 'confignew_partial.json') + + # Check that some values that are in the JSON: + self.assertEqual(cfg.branding, "Another AllTalk ") + self.assertEqual(cfg.rvc_settings.rvc_narr_model_file, "another/file") + self.assertEqual(cfg.tgwui.tgwui_narrator_voice, "another_female_01.wav") + self.assertEqual(cfg.api_def.api_output_file_name, "another_myoutputfile") + def test_default_config_path(self): expected_config_path = Path(__file__).parent.parent.resolve() / "confignew.json" self.assertEqual(self.config.get_config_path(), expected_config_path) @@ -119,15 +159,23 @@ def test_gradio_pages(self): def test_save_config(self): with tempfile.NamedTemporaryFile(suffix=".json") as tmp: self.config.branding = "foo" + self.config.theme.clazz = "bar" self.config.save(tmp.name) new_config = AlltalkConfig(tmp.name) self.assertEqual(new_config.branding, "foo") + self.assertEqual(new_config.theme.clazz, "bar") + + # Test serialization of field 'clazz' to field "class" + with open(tmp.name, "r") as file: + json = file.read() + self.assertTrue("class" in json) + self.assertFalse("clazz" in json) def test_no_private_fields(self): for attr in self.config.to_dict().keys(): self.assertTrue(not attr.startswith("_")) -class TestAlltalkTTSEnginesConfig(unittest.TestCase): +class TestAlltalkTTSEnginesConfig(TestCase): def setUp(self): self.tts_engines_config = AlltalkTTSEnginesConfig.get_instance() @@ -178,7 +226,7 @@ def test_no_private_fields(self): self.assertTrue(not attr.startswith("_")) -class TestAlltalkNewEnginesConfig(unittest.TestCase): +class TestAlltalkNewEnginesConfig(TestCase): def setUp(self): self.new_engines_config = AlltalkNewEnginesConfig.get_instance()