diff --git a/sdv/__init__.py b/sdv/__init__.py index ab9655cdb..16f24fb17 100644 --- a/sdv/__init__.py +++ b/sdv/__init__.py @@ -13,9 +13,10 @@ import warnings from importlib.metadata import entry_points from operator import attrgetter +from types import ModuleType from sdv import ( - constraints, data_processing, datasets, evaluation, lite, metadata, metrics, multi_table, + constraints, data_processing, datasets, evaluation, io, lite, metadata, metrics, multi_table, sampling, sequential, single_table, version) __all__ = [ @@ -23,6 +24,7 @@ 'data_processing', 'datasets', 'evaluation', + 'io', 'lite', 'metadata', 'metrics', @@ -105,6 +107,11 @@ def _find_addons(): warnings.warn(msg) continue + if isinstance(addon, ModuleType): + addon_module_name = f'{addon_target.__name__}.{addon_name}' + if addon_module_name not in sys.modules: + sys.modules[addon_module_name] = addon + setattr(addon_target, addon_name, addon) diff --git a/sdv/io/__init__.py b/sdv/io/__init__.py new file mode 100644 index 000000000..913ecf279 --- /dev/null +++ b/sdv/io/__init__.py @@ -0,0 +1 @@ +"""I/O module.""" diff --git a/tests/unit/test___init__.py b/tests/unit/test___init__.py index 41fc19736..e94f3b214 100644 --- a/tests/unit/test___init__.py +++ b/tests/unit/test___init__.py @@ -1,4 +1,5 @@ import sys +from types import ModuleType from unittest.mock import Mock, patch import pytest @@ -11,6 +12,7 @@ def mock_sdv(): sdv_module = sys.modules['sdv'] sdv_mock = Mock() + sdv_mock.submodule.__name__ = 'sdv.submodule' sys.modules['sdv'] = sdv_mock yield sdv_mock sys.modules['sdv'] = sdv_module @@ -20,9 +22,10 @@ def mock_sdv(): def test__find_addons_module(entry_points_mock, mock_sdv): """Test loading an add-on.""" # Setup + add_on_mock = Mock(spec=ModuleType) entry_point = Mock() entry_point.name = 'sdv.submodule.entry_name' - entry_point.load.return_value = 'entry_point' + entry_point.load.return_value = add_on_mock entry_points_mock.return_value = [entry_point] # Run @@ -30,7 +33,8 @@ def test__find_addons_module(entry_points_mock, mock_sdv): # Assert entry_points_mock.assert_called_once_with(group='sdv_modules') - assert mock_sdv.submodule.entry_name == 'entry_point' + assert mock_sdv.submodule.entry_name == add_on_mock + assert sys.modules['sdv.submodule.entry_name'] == add_on_mock @patch.object(sdv, 'entry_points')