Skip to content

Commit

Permalink
Improve addon detection to include modules (#1929)
Browse files Browse the repository at this point in the history
  • Loading branch information
amontanez24 authored Apr 18, 2024
1 parent 86e0036 commit e484969
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 3 deletions.
9 changes: 8 additions & 1 deletion sdv/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,18 @@
import warnings
from importlib.metadata import entry_points
from operator import attrgetter
from types import ModuleType

from sdv import (
constraints, data_processing, datasets, evaluation, lite, metadata, metrics, multi_table,
constraints, data_processing, datasets, evaluation, io, lite, metadata, metrics, multi_table,
sampling, sequential, single_table, version)

__all__ = [
'constraints',
'data_processing',
'datasets',
'evaluation',
'io',
'lite',
'metadata',
'metrics',
Expand Down Expand Up @@ -105,6 +107,11 @@ def _find_addons():
warnings.warn(msg)
continue

if isinstance(addon, ModuleType):
addon_module_name = f'{addon_target.__name__}.{addon_name}'
if addon_module_name not in sys.modules:
sys.modules[addon_module_name] = addon

setattr(addon_target, addon_name, addon)


Expand Down
1 change: 1 addition & 0 deletions sdv/io/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""I/O module."""
8 changes: 6 additions & 2 deletions tests/unit/test___init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import sys
from types import ModuleType
from unittest.mock import Mock, patch

import pytest
Expand All @@ -11,6 +12,7 @@
def mock_sdv():
sdv_module = sys.modules['sdv']
sdv_mock = Mock()
sdv_mock.submodule.__name__ = 'sdv.submodule'
sys.modules['sdv'] = sdv_mock
yield sdv_mock
sys.modules['sdv'] = sdv_module
Expand All @@ -20,17 +22,19 @@ def mock_sdv():
def test__find_addons_module(entry_points_mock, mock_sdv):
"""Test loading an add-on."""
# Setup
add_on_mock = Mock(spec=ModuleType)
entry_point = Mock()
entry_point.name = 'sdv.submodule.entry_name'
entry_point.load.return_value = 'entry_point'
entry_point.load.return_value = add_on_mock
entry_points_mock.return_value = [entry_point]

# Run
_find_addons()

# Assert
entry_points_mock.assert_called_once_with(group='sdv_modules')
assert mock_sdv.submodule.entry_name == 'entry_point'
assert mock_sdv.submodule.entry_name == add_on_mock
assert sys.modules['sdv.submodule.entry_name'] == add_on_mock


@patch.object(sdv, 'entry_points')
Expand Down

0 comments on commit e484969

Please sign in to comment.