Skip to content

Commit

Permalink
Fix lazy imports of objects on Python 3.6.
Browse files Browse the repository at this point in the history
  • Loading branch information
aaugustin committed Nov 30, 2020
1 parent 42f0e2c commit 965f8ec
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 21 deletions.
34 changes: 25 additions & 9 deletions src/websockets/imports.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import importlib
import sys
import warnings
from typing import Any, Dict, Iterable, Optional
Expand All @@ -7,6 +6,27 @@
__all__ = ["lazy_import"]


def import_name(name: str, source: str, namespace: Dict[str, Any]) -> Any:
"""
Import <name> from <source> in <namespace>.
There are two cases:
- <name> is an object defined in <source>
- <name> is a submodule of source
Neither __import__ nor importlib.import_module does exactly this.
__import__ is closer to the intended behavior.
"""
level = 0
while source[level] == ".":
level += 1
assert level < len(source), "importing from parent isn't supported"
module = __import__(source[level:], namespace, None, [name], level)
return getattr(module, name)


def lazy_import(
namespace: Dict[str, Any],
aliases: Optional[Dict[str, str]] = None,
Expand Down Expand Up @@ -58,8 +78,7 @@ def __getattr__(name: str) -> Any:
except KeyError:
pass
else:
module = importlib.import_module(source, package)
return getattr(module, name)
return import_name(name, source, namespace)

assert deprecated_aliases is not None # mypy cannot figure this out
try:
Expand All @@ -72,8 +91,7 @@ def __getattr__(name: str) -> Any:
DeprecationWarning,
stacklevel=2,
)
module = importlib.import_module(source, package)
return getattr(module, name)
return import_name(name, source, namespace)

raise AttributeError(f"module {package!r} has no attribute {name!r}")

Expand All @@ -87,9 +105,7 @@ def __dir__() -> Iterable[str]:
else: # pragma: no cover

for name, source in aliases.items():
module = importlib.import_module(source, package)
namespace[name] = getattr(module, name)
namespace[name] = import_name(name, source, namespace)

for name, source in deprecated_aliases.items():
module = importlib.import_module(source, package)
namespace[name] = getattr(module, name)
namespace[name] = import_name(name, source, namespace)
39 changes: 27 additions & 12 deletions tests/test_imports.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import sys
import types
import unittest
import warnings
Expand All @@ -11,18 +12,30 @@


class ImportsTests(unittest.TestCase):
def setUp(self):
self.mod = types.ModuleType("tests.test_imports.test_alias")
self.mod.__package__ = self.mod.__name__

def test_get_alias(self):
mod = types.ModuleType("tests.test_imports.test_alias")
lazy_import(vars(mod), aliases={"foo": ".."})
lazy_import(
vars(self.mod),
aliases={"foo": "...test_imports"},
)

self.assertEqual(mod.foo, foo)
self.assertEqual(self.mod.foo, foo)

def test_get_deprecated_alias(self):
mod = types.ModuleType("tests.test_imports.test_alias")
lazy_import(vars(mod), deprecated_aliases={"bar": ".."})
lazy_import(
vars(self.mod),
deprecated_aliases={"bar": "...test_imports"},
)

with warnings.catch_warnings(record=True) as recorded_warnings:
self.assertEqual(mod.bar, bar)
self.assertEqual(self.mod.bar, bar)

# No warnings raised on pre-PEP 526 Python.
if sys.version_info[:2] < (3, 7): # pragma: no cover
return

self.assertEqual(len(recorded_warnings), 1)
warning = recorded_warnings[0].message
Expand All @@ -32,20 +45,22 @@ def test_get_deprecated_alias(self):
self.assertEqual(type(warning), DeprecationWarning)

def test_dir(self):
mod = types.ModuleType("tests.test_imports.test_alias")
lazy_import(vars(mod), aliases={"foo": ".."}, deprecated_aliases={"bar": ".."})
lazy_import(
vars(self.mod),
aliases={"foo": "...test_imports"},
deprecated_aliases={"bar": "...test_imports"},
)

self.assertEqual(
[item for item in dir(mod) if not item[:2] == item[-2:] == "__"],
[item for item in dir(self.mod) if not item[:2] == item[-2:] == "__"],
["bar", "foo"],
)

def test_attribute_error(self):
mod = types.ModuleType("tests.test_imports.test_alias")
lazy_import(vars(mod))
lazy_import(vars(self.mod))

with self.assertRaises(AttributeError) as raised:
mod.foo
self.mod.foo

self.assertEqual(
str(raised.exception),
Expand Down

0 comments on commit 965f8ec

Please sign in to comment.