Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

migrate entrypoints handling to standard library importlib.metadata #589

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 29 additions & 9 deletions fs/opener/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import collections
import contextlib
import pkg_resources
import sys

from ..errors import ResourceReadOnly
from .base import Opener
Expand All @@ -21,6 +21,30 @@
from ..base import FS


if sys.version_info >= (3, 8):
import importlib.metadata

if sys.version_info >= (3, 10):

def entrypoints(group, name=None):
ep = importlib.metadata.entry_points(group=group, name=name)
return tuple(n for n in ep)

else:

def entrypoints(group, name=None):
ep = importlib.metadata.entry_points()
if name:
return tuple(n for n in ep.get(group, ()) if n.name == name)
return ep.get(group, ())

else:
import pkg_resources

def entrypoints(group, name=None):
return tuple(pkg_resources.iter_entry_points(group, name))


class Registry(object):
"""A registry for `Opener` instances."""

Expand Down Expand Up @@ -74,10 +98,7 @@ def protocols(self):
"""`list`: the list of supported protocols."""
_protocols = list(self._protocols)
if self.load_extern:
_protocols.extend(
entry_point.name
for entry_point in pkg_resources.iter_entry_points("fs.opener")
)
_protocols.extend(n.name for n in entrypoints("fs.opener"))
_protocols = list(collections.OrderedDict.fromkeys(_protocols))
return _protocols

Expand All @@ -101,10 +122,9 @@ def get_opener(self, protocol):
"""
protocol = protocol or self.default_opener

if self.load_extern:
entry_point = next(
pkg_resources.iter_entry_points("fs.opener", protocol), None
)
ep = entrypoints("fs.opener", protocol)
if self.load_extern and ep:
entry_point = ep[0]
else:
entry_point = None

Expand Down
74 changes: 58 additions & 16 deletions tests/test_opener.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import sys

import os
import pkg_resources
import shutil
import tempfile
import unittest
Expand All @@ -21,6 +20,11 @@
except ImportError:
import mock

if sys.version_info >= (3, 8):
import importlib.metadata
else:
import pkg_resources


class TestParse(unittest.TestCase):
def test_registry_repr(self):
Expand Down Expand Up @@ -111,14 +115,25 @@ def test_protocols(self):

def test_registry_protocols(self):
# Check registry.protocols list the names of all available extension
extensions = [
pkg_resources.EntryPoint("proto1", "mod1"),
pkg_resources.EntryPoint("proto2", "mod2"),
]
m = mock.MagicMock(return_value=extensions)
with mock.patch.object(
sys.modules["pkg_resources"], "iter_entry_points", new=m
):
if sys.version_info >= (3, 8):
extensions = (
importlib.metadata.EntryPoint("proto1", "mod1", "fs.opener"),
importlib.metadata.EntryPoint("proto2", "mod2", "fs.opener"),
)
if sys.version_info >= (3, 10):
m = mock.MagicMock(return_value=extensions)
else:
m = mock.MagicMock(return_value={"fs.opener": extensions})
patch = mock.patch("importlib.metadata.entry_points", m)
else:
extensions = [
pkg_resources.EntryPoint("proto1", "mod1"),
pkg_resources.EntryPoint("proto2", "mod2"),
]
m = mock.MagicMock(return_value=extensions)
patch = mock.patch("pkg_resources.iter_entry_points", m)

with patch:
self.assertIn("proto1", opener.registry.protocols)
self.assertIn("proto2", opener.registry.protocols)

Expand All @@ -129,11 +144,19 @@ def test_unknown_protocol(self):
def test_entry_point_load_error(self):

entry_point = mock.MagicMock()
entry_point.name = "test"
entry_point.load.side_effect = ValueError("some error")

iter_entry_points = mock.MagicMock(return_value=iter([entry_point]))

with mock.patch("pkg_resources.iter_entry_points", iter_entry_points):
if sys.version_info >= (3, 8):
if sys.version_info >= (3, 10):
entry_points = mock.MagicMock(return_value=tuple([entry_point]))
else:
entry_points = mock.MagicMock(return_value={"fs.opener": [entry_point]})
patch = mock.patch("importlib.metadata.entry_points", entry_points)
else:
iter_entry_points = mock.MagicMock(return_value=iter([entry_point]))
patch = mock.patch("pkg_resources.iter_entry_points", iter_entry_points)
with patch:
with self.assertRaises(errors.EntryPointError) as ctx:
opener.open_fs("test://")
self.assertEqual(
Expand All @@ -145,10 +168,19 @@ class NotAnOpener(object):
pass

entry_point = mock.MagicMock()
entry_point.name = "test"
entry_point.load = mock.MagicMock(return_value=NotAnOpener)
iter_entry_points = mock.MagicMock(return_value=iter([entry_point]))

with mock.patch("pkg_resources.iter_entry_points", iter_entry_points):
if sys.version_info >= (3, 8):
if sys.version_info >= (3, 10):
entry_points = mock.MagicMock(return_value=tuple([entry_point]))
else:
entry_points = mock.MagicMock(return_value={"fs.opener": [entry_point]})
patch = mock.patch("importlib.metadata.entry_points", entry_points)
else:
iter_entry_points = mock.MagicMock(return_value=iter([entry_point]))
patch = mock.patch("pkg_resources.iter_entry_points", iter_entry_points)
with patch:
with self.assertRaises(errors.EntryPointError) as ctx:
opener.open_fs("test://")
self.assertEqual("entry point did not return an opener", str(ctx.exception))
Expand All @@ -162,10 +194,20 @@ def open_fs(self, *args, **kwargs):
pass

entry_point = mock.MagicMock()
entry_point.name = "test"
entry_point.load = mock.MagicMock(return_value=BadOpener)
iter_entry_points = mock.MagicMock(return_value=iter([entry_point]))

with mock.patch("pkg_resources.iter_entry_points", iter_entry_points):
if sys.version_info >= (3, 8):
if sys.version_info >= (3, 10):
entry_points = mock.MagicMock(return_value=tuple([entry_point]))
else:
entry_points = mock.MagicMock(return_value={"fs.opener": [entry_point]})
patch = mock.patch("importlib.metadata.entry_points", entry_points)
else:
iter_entry_points = mock.MagicMock(return_value=iter([entry_point]))
patch = mock.patch("pkg_resources.iter_entry_points", iter_entry_points)

with patch:
with self.assertRaises(errors.EntryPointError) as ctx:
opener.open_fs("test://")
self.assertEqual(
Expand Down