Skip to content

Commit

Permalink
Refactor import style and disallow direct imports
Browse files Browse the repository at this point in the history
  • Loading branch information
arxyzan committed Nov 10, 2023
1 parent 080680b commit 0ff051a
Show file tree
Hide file tree
Showing 9 changed files with 72 additions and 17 deletions.
64 changes: 54 additions & 10 deletions hezar/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,57 @@
from .registry import *
from .builders import *
from .configs import *
from .data import *
from .embeddings import *
from .metrics import *
from .models import *
from .preprocessors import *
from .trainer import *
from .utils import *
"""
Direct importing from hezar's root is no longer supported nor recommended since version 0.33.0. The following is just a
workaround for backward compatibility. Any class, functions, etc. must be imported from its main submodule under hezar.
"""
import warnings


__version__ = "0.32.1"


def _warn_on_import(name: str, submodule: str):
warnings.warn(
f"Importing {name} from hezar root is deprecated and will be removed soon. "
f"Please use `from {submodule} import {name}`"
)


def __getattr__(name: str):
if name == "Model":
from hezar.models import Model
_warn_on_import(name, "hezar.models")
return Model
elif name == "Dataset":
from .data import Dataset
_warn_on_import(name, "hezar.data")
return Dataset
elif name == "Trainer":
from .trainer import Trainer
_warn_on_import(name, "hezar.trainer")
return Trainer
elif name == "Embedding":
from .embeddings import Embedding
_warn_on_import(name, "hezar.embeddings")
return Embedding
elif name == "Preprocessor":
from .preprocessors import Preprocessor
_warn_on_import(name, "hezar.preprocessors")
return Preprocessor
elif name == "Metric":
from .metrics import Metric
_warn_on_import(name, "hezar.metrics")
return Metric
elif "Config" in name:
from .configs import Config
_warn_on_import(name, "hezar.configs")
return Config


__all__ = [
"Config",
"Model",
"Dataset",
"Trainer",
"Preprocessor",
"Embedding",
"Metric",
]
3 changes: 2 additions & 1 deletion hezar/data/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .dataset import Dataset
from ...registry import register_dataset # noqa
from .dataset import Dataset, DatasetConfig # noqa
from .ocr_dataset import OCRDataset, OCRDatasetConfig
from .sequence_labeling_dataset import SequenceLabelingDataset, SequenceLabelingDatasetConfig
from .text_classification_dataset import TextClassificationDataset, TextClassificationDatasetConfig
Expand Down
3 changes: 2 additions & 1 deletion hezar/embeddings/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .embedding import Embedding
from ..registry import register_embedding # noqa
from .embedding import Embedding, EmbeddingConfig # noqa
from .fasttext import FastText, FastTextConfig
from .word2vec import Word2Vec, Word2VecConfig
3 changes: 2 additions & 1 deletion hezar/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .metric import Metric
from ..registry import register_metric # noqa
from .metric import Metric, MetricConfig # noqa
from .accuracy import Accuracy, AccuracyConfig
from .bleu import BLEU, BLEUConfig
from .cer import CER, CERConfig
Expand Down
3 changes: 2 additions & 1 deletion hezar/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .model import Model
from ..registry import register_model # noqa
from .model import Model, ModelConfig # noqa
from .audio_classification import *
from .backbone import *
from .image2text import *
Expand Down
3 changes: 2 additions & 1 deletion hezar/preprocessors/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .preprocessor import Preprocessor, PreprocessorsContainer
from ..registry import register_preprocessor # noqa
from .preprocessor import Preprocessor, PreprocessorConfig, PreprocessorsContainer # noqa
from .audio_feature_extractor import AudioFeatureExtractor, AudioFeatureExtractorConfig
from .image_processor import ImageProcessor, ImageProcessorConfig
from .text_normalizer import TextNormalizer, TextNormalizerConfig
Expand Down
3 changes: 2 additions & 1 deletion hezar/preprocessors/preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from huggingface_hub import hf_hub_download
from omegaconf import OmegaConf

from ..configs import PreprocessorConfig
from ..constants import DEFAULT_PREPROCESSOR_SUBFOLDER, Backends, RegistryType, RepoType, HEZAR_CACHE_DIR
from ..utils import get_module_class, list_repo_files, verify_dependencies

Expand All @@ -21,7 +22,7 @@ class Preprocessor:

preprocessor_subfolder = DEFAULT_PREPROCESSOR_SUBFOLDER

def __init__(self, config, **kwargs):
def __init__(self, config: PreprocessorConfig, **kwargs):
verify_dependencies(self, self.required_backends) # Check if all the required dependencies are installed

self.config = config.update(kwargs)
Expand Down
2 changes: 1 addition & 1 deletion hezar/trainer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .trainer import Trainer
from .trainer import Trainer, TrainerConfig # noqa
from .trainer_utils import *
from .metrics_handlers import *
5 changes: 5 additions & 0 deletions hezar/utils/registry_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,26 +47,31 @@ def list_available_embeddings():

def _get_registry_from_type(registry_type: RegistryType):
if registry_type == RegistryType.MODEL:
from ..models import Model # noqa
from ..registry import models_registry # noqa

registry = models_registry

elif registry_type == RegistryType.PREPROCESSOR:
from ..preprocessors import Preprocessor # noqa
from ..registry import preprocessors_registry # noqa

registry = preprocessors_registry

elif registry_type == RegistryType.DATASET:
from ..data import Dataset # noqa
from ..registry import datasets_registry # noqa

registry = datasets_registry

elif registry_type == RegistryType.EMBEDDING:
from ..embeddings import Embedding # noqa
from ..registry import embeddings_registry # noqa

registry = embeddings_registry

elif registry_type == RegistryType.METRIC:
from ..metrics import Metric # noqa
from ..registry import metrics_registry # noqa

registry = metrics_registry
Expand Down

0 comments on commit 0ff051a

Please sign in to comment.