diff --git a/hta/configs/config.py b/hta/configs/config.py index 9f53489..dc0f4d3 100644 --- a/hta/configs/config.py +++ b/hta/configs/config.py @@ -1,7 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import importlib.util import json import logging import logging.config @@ -9,7 +8,6 @@ from pathlib import Path from typing import Any, Dict, List, Optional, Set, Union -import hta from hta.configs.default_values import DEFAULT_CONFIG_FILENAME ConfigValue = Union[None, bool, int, float, str, Dict[str, Any], List[Any], Set[Any]] @@ -25,6 +23,8 @@ def setup_logger(config_file: str = "logging.config") -> None: logger = logging.getLogger("hta") +package_path: Path = Path(__file__).parent.parent + logger: logging.Logger = logging.getLogger("hta") @@ -55,11 +55,7 @@ def get_default_paths() -> List[str]: As a class method, this function does not return the user defined config file path. """ return [ - str( - Path(hta.__file__) - .parent.joinpath("configs") - .joinpath(DEFAULT_CONFIG_FILENAME) - ), + str(package_path.joinpath(DEFAULT_CONFIG_FILENAME)), str(Path.home().joinpath(".hta").joinpath(DEFAULT_CONFIG_FILENAME)), str(Path.cwd().joinpath(DEFAULT_CONFIG_FILENAME)), ] @@ -144,14 +140,7 @@ def get_config( def show(self): print(json.dumps(self.config, indent=4, sort_keys=True)) - @classmethod - def get_package_path(cls) -> str: - package_spec = importlib.util.find_spec(hta.__name__) - package_path = Path(package_spec.origin).parent - return str(package_path) - @classmethod def get_test_data_path(cls, dataset: str) -> str: - base_path = Path(cls.get_package_path()).parent - test_data_path = Path.joinpath(base_path, "tests/data/", dataset) + test_data_path = Path.joinpath(package_path, "tests/data/", dataset) return str(test_data_path)