From 0ff051a7279fd7da2a2271b6601d13ce9eb0463c Mon Sep 17 00:00:00 2001 From: arxyzan Date: Fri, 10 Nov 2023 19:51:57 +0330 Subject: [PATCH] Refactor import style and disallow direct imports --- hezar/__init__.py | 64 ++++++++++++++++++++++++----- hezar/data/datasets/__init__.py | 3 +- hezar/embeddings/__init__.py | 3 +- hezar/metrics/__init__.py | 3 +- hezar/models/__init__.py | 3 +- hezar/preprocessors/__init__.py | 3 +- hezar/preprocessors/preprocessor.py | 3 +- hezar/trainer/__init__.py | 2 +- hezar/utils/registry_utils.py | 5 +++ 9 files changed, 72 insertions(+), 17 deletions(-) diff --git a/hezar/__init__.py b/hezar/__init__.py index e443ea32..d087c611 100644 --- a/hezar/__init__.py +++ b/hezar/__init__.py @@ -1,13 +1,57 @@ -from .registry import * -from .builders import * -from .configs import * -from .data import * -from .embeddings import * -from .metrics import * -from .models import * -from .preprocessors import * -from .trainer import * -from .utils import * +""" +Direct importing from hezar's root is no longer supported nor recommended since version 0.33.0. The following is just a +workaround for backward compatibility. Any class, functions, etc. must be imported from its main submodule under hezar. +""" +import warnings __version__ = "0.32.1" + + +def _warn_on_import(name: str, submodule: str): + warnings.warn( + f"Importing {name} from hezar root is deprecated and will be removed soon. " + f"Please use `from {submodule} import {name}`" + ) + + +def __getattr__(name: str): + if name == "Model": + from hezar.models import Model + _warn_on_import(name, "hezar.models") + return Model + elif name == "Dataset": + from .data import Dataset + _warn_on_import(name, "hezar.data") + return Dataset + elif name == "Trainer": + from .trainer import Trainer + _warn_on_import(name, "hezar.trainer") + return Trainer + elif name == "Embedding": + from .embeddings import Embedding + _warn_on_import(name, "hezar.embeddings") + return Embedding + elif name == "Preprocessor": + from .preprocessors import Preprocessor + _warn_on_import(name, "hezar.preprocessors") + return Preprocessor + elif name == "Metric": + from .metrics import Metric + _warn_on_import(name, "hezar.metrics") + return Metric + elif "Config" in name: + from .configs import Config + _warn_on_import(name, "hezar.configs") + return Config + + +__all__ = [ + "Config", + "Model", + "Dataset", + "Trainer", + "Preprocessor", + "Embedding", + "Metric", +] diff --git a/hezar/data/datasets/__init__.py b/hezar/data/datasets/__init__.py index 8f1ce618..91acdbab 100644 --- a/hezar/data/datasets/__init__.py +++ b/hezar/data/datasets/__init__.py @@ -1,4 +1,5 @@ -from .dataset import Dataset +from ...registry import register_dataset # noqa +from .dataset import Dataset, DatasetConfig # noqa from .ocr_dataset import OCRDataset, OCRDatasetConfig from .sequence_labeling_dataset import SequenceLabelingDataset, SequenceLabelingDatasetConfig from .text_classification_dataset import TextClassificationDataset, TextClassificationDatasetConfig diff --git a/hezar/embeddings/__init__.py b/hezar/embeddings/__init__.py index 55b1e03d..048a7edf 100644 --- a/hezar/embeddings/__init__.py +++ b/hezar/embeddings/__init__.py @@ -1,3 +1,4 @@ -from .embedding import Embedding +from ..registry import register_embedding # noqa +from .embedding import Embedding, EmbeddingConfig # noqa from .fasttext import FastText, FastTextConfig from .word2vec import Word2Vec, Word2VecConfig diff --git a/hezar/metrics/__init__.py b/hezar/metrics/__init__.py index 3f484a1f..6aca50f9 100644 --- a/hezar/metrics/__init__.py +++ b/hezar/metrics/__init__.py @@ -1,4 +1,5 @@ -from .metric import Metric +from ..registry import register_metric # noqa +from .metric import Metric, MetricConfig # noqa from .accuracy import Accuracy, AccuracyConfig from .bleu import BLEU, BLEUConfig from .cer import CER, CERConfig diff --git a/hezar/models/__init__.py b/hezar/models/__init__.py index 1045c094..542dba38 100644 --- a/hezar/models/__init__.py +++ b/hezar/models/__init__.py @@ -1,4 +1,5 @@ -from .model import Model +from ..registry import register_model # noqa +from .model import Model, ModelConfig # noqa from .audio_classification import * from .backbone import * from .image2text import * diff --git a/hezar/preprocessors/__init__.py b/hezar/preprocessors/__init__.py index dcd3790b..c7e0944d 100644 --- a/hezar/preprocessors/__init__.py +++ b/hezar/preprocessors/__init__.py @@ -1,4 +1,5 @@ -from .preprocessor import Preprocessor, PreprocessorsContainer +from ..registry import register_preprocessor # noqa +from .preprocessor import Preprocessor, PreprocessorConfig, PreprocessorsContainer # noqa from .audio_feature_extractor import AudioFeatureExtractor, AudioFeatureExtractorConfig from .image_processor import ImageProcessor, ImageProcessorConfig from .text_normalizer import TextNormalizer, TextNormalizerConfig diff --git a/hezar/preprocessors/preprocessor.py b/hezar/preprocessors/preprocessor.py index c1bcb7fc..32d3045d 100644 --- a/hezar/preprocessors/preprocessor.py +++ b/hezar/preprocessors/preprocessor.py @@ -5,6 +5,7 @@ from huggingface_hub import hf_hub_download from omegaconf import OmegaConf +from ..configs import PreprocessorConfig from ..constants import DEFAULT_PREPROCESSOR_SUBFOLDER, Backends, RegistryType, RepoType, HEZAR_CACHE_DIR from ..utils import get_module_class, list_repo_files, verify_dependencies @@ -21,7 +22,7 @@ class Preprocessor: preprocessor_subfolder = DEFAULT_PREPROCESSOR_SUBFOLDER - def __init__(self, config, **kwargs): + def __init__(self, config: PreprocessorConfig, **kwargs): verify_dependencies(self, self.required_backends) # Check if all the required dependencies are installed self.config = config.update(kwargs) diff --git a/hezar/trainer/__init__.py b/hezar/trainer/__init__.py index ddfbc5e4..da09b85e 100644 --- a/hezar/trainer/__init__.py +++ b/hezar/trainer/__init__.py @@ -1,3 +1,3 @@ -from .trainer import Trainer +from .trainer import Trainer, TrainerConfig # noqa from .trainer_utils import * from .metrics_handlers import * diff --git a/hezar/utils/registry_utils.py b/hezar/utils/registry_utils.py index 5593208e..d7d73425 100644 --- a/hezar/utils/registry_utils.py +++ b/hezar/utils/registry_utils.py @@ -47,26 +47,31 @@ def list_available_embeddings(): def _get_registry_from_type(registry_type: RegistryType): if registry_type == RegistryType.MODEL: + from ..models import Model # noqa from ..registry import models_registry # noqa registry = models_registry elif registry_type == RegistryType.PREPROCESSOR: + from ..preprocessors import Preprocessor # noqa from ..registry import preprocessors_registry # noqa registry = preprocessors_registry elif registry_type == RegistryType.DATASET: + from ..data import Dataset # noqa from ..registry import datasets_registry # noqa registry = datasets_registry elif registry_type == RegistryType.EMBEDDING: + from ..embeddings import Embedding # noqa from ..registry import embeddings_registry # noqa registry = embeddings_registry elif registry_type == RegistryType.METRIC: + from ..metrics import Metric # noqa from ..registry import metrics_registry # noqa registry = metrics_registry