Skip to content

Commit

Permalink
Fix entry point discovery on Python < 3.10 (#604)
Browse files Browse the repository at this point in the history
It seems that the behavior of the importlib.metadata.entry_points
function changed in Python 3.10 to automatically de-duplicate
distributions, but prior to that the "shadowed" distributions were also
enumerated. This change specifically ignores "shadowed" distributions so
that they aren't identified as extension point overwrites.
  • Loading branch information
cottsay authored Jan 17, 2024
1 parent d29f38d commit 500dec0
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 11 deletions.
24 changes: 16 additions & 8 deletions colcon_core/extension_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,15 @@
EXTENSION_POINT_GROUP_NAME = 'colcon_core.extension_point'


def _get_unique_distributions():
seen = set()
for dist in distributions():
dist_name = dist.metadata['Name']
if dist_name not in seen:
seen.add(dist_name)
yield dist


def get_all_extension_points():
"""
Get all extension points related to `colcon` and any of its extensions.
Expand All @@ -51,12 +60,7 @@ def get_all_extension_points():
colcon_extension_points.setdefault(EXTENSION_POINT_GROUP_NAME, None)

entry_points = defaultdict(dict)
seen = set()
for dist in distributions():
dist_name = dist.metadata['Name']
if dist_name in seen:
continue
seen.add(dist_name)
for dist in _get_unique_distributions():
for entry_point in dist.entry_points:
# skip groups which are not registered as extension points
if entry_point.group not in colcon_extension_points:
Expand All @@ -70,7 +74,7 @@ def get_all_extension_points():
f"from '{dist._path}' "
f"overwriting '{previous}'")
entry_points[entry_point.group][entry_point.name] = \
(entry_point.value, dist_name, dist.version)
(entry_point.value, dist.metadata['Name'], dist.version)
return entry_points


Expand All @@ -87,7 +91,11 @@ def get_extension_points(group):
# Python 3.10 and newer
query = entry_points(group=group)
except TypeError:
query = entry_points().get(group, ())
query = (
entry_point
for dist in _get_unique_distributions()
for entry_point in dist.entry_points
if entry_point.group == group)
for entry_point in query:
if entry_point.name in extension_points:
previous_entry_point = extension_points[entry_point.name]
Expand Down
10 changes: 7 additions & 3 deletions test/test_extension_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ def iter_entry_points(*, group=None):

def distributions():
return [
Dist(iter_entry_points(group='group1')),
Dist([EntryPoint('extC', 'eC', Group2.name)]),
Dist([Group1, ExtA, ExtB]),
Dist([Group2, EntryPoint('extC', 'eC', Group2.name)]),
Dist([EntryPoint('extD', 'eD', 'groupX')]),
]

Expand All @@ -71,7 +71,11 @@ def test_all_extension_points():
):
# successfully load a known entry point
extension_points = get_all_extension_points()
assert set(extension_points.keys()) == {'group1', 'group2'}
assert set(extension_points.keys()) == {
EXTENSION_POINT_GROUP_NAME,
'group1',
'group2',
}
assert set(extension_points['group1'].keys()) == {'extA', 'extB'}
assert extension_points['group1']['extA'][0] == 'eA'

Expand Down

0 comments on commit 500dec0

Please sign in to comment.