Skip to content

Commit

Permalink
Add an explicit cache on Python entry points
Browse files Browse the repository at this point in the history
Whenever we enumerate Python entry points to load colcon extension
points, we're re-parsing metadata for every Python package found on the
system. Worse yet, accessing attributes on
importlib.metadata.Distribution typically results in re-reading the
metadata each time, so we're hitting the disk pretty hard.

We don't generally expect the entry points available to change, so we
should cache that information once and parse each package's metadata a
single time.
  • Loading branch information
cottsay committed Feb 22, 2024
1 parent df3deec commit fbe2360
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 28 deletions.
91 changes: 63 additions & 28 deletions colcon_core/extension_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
# Licensed under the Apache License, Version 2.0

from collections import defaultdict
from itertools import chain
import os
import sys
import traceback

try:
Expand All @@ -26,7 +28,6 @@

logger = colcon_logger.getChild(__name__)


"""
The group name for entry points identifying colcon extension points.
Expand All @@ -36,6 +37,8 @@
"""
EXTENSION_POINT_GROUP_NAME = 'colcon_core.extension_point'

_ENTRY_POINTS_CACHE = []


def _get_unique_distributions():
seen = set()
Expand All @@ -46,6 +49,44 @@ def _get_unique_distributions():
yield dist


def _get_entry_points():
for dist in _get_unique_distributions():
for entry_point in dist.entry_points:
# Modern EntryPoint instances should already have this set
if not hasattr(entry_point, 'dist'):
entry_point.dist = dist
yield entry_point


def _get_cached_entry_points():
if not _ENTRY_POINTS_CACHE:
if sys.version_info >= (3, 10):
# We prefer using importlib.metadata.entry_points because it
# has an internal optimization which allows us to load the entry
# points without reading the individual PKG-INFO files, while
# still visiting each unique distribution only once.
all_entry_points = entry_points()
if isinstance(all_entry_points, dict):
# Prior to Python 3.12, entry_points returned a (deprecated)
# dict. Unfortunately, the "future-proof" recommended
# pattern is to add filter parameters, but we actually
# want to cache everything so that doesn't work here.
all_entry_points = chain.from_iterable(
all_entry_points.values())
_ENTRY_POINTS_CACHE.extend(all_entry_points)
else:
# If we don't have Python 3.10, we must read each PKG-INFO to
# get the name of the distribution so that we can skip the
# "shadowed" distributions properly.
_ENTRY_POINTS_CACHE.extend(_get_entry_points())
return _ENTRY_POINTS_CACHE


def clear_entry_point_cache():
"""Purge the entry point cache."""
_ENTRY_POINTS_CACHE.clear()


def get_all_extension_points():
"""
Get all extension points related to `colcon` and any of its extensions.
Expand All @@ -59,23 +100,24 @@ def get_all_extension_points():
colcon_extension_points = get_extension_points(EXTENSION_POINT_GROUP_NAME)
colcon_extension_points.setdefault(EXTENSION_POINT_GROUP_NAME, None)

entry_points = defaultdict(dict)
for dist in _get_unique_distributions():
for entry_point in dist.entry_points:
# skip groups which are not registered as extension points
if entry_point.group not in colcon_extension_points:
continue

if entry_point.name in entry_points[entry_point.group]:
previous = entry_points[entry_point.group][entry_point.name]
logger.error(
f"Entry point '{entry_point.group}.{entry_point.name}' is "
f"declared multiple times, '{entry_point.value}' "
f"from '{dist._path}' "
f"overwriting '{previous}'")
entry_points[entry_point.group][entry_point.name] = \
(entry_point.value, dist.metadata['Name'], dist.version)
return entry_points
extension_points = defaultdict(dict)
for entry_point in _get_cached_entry_points():
if entry_point.group not in colcon_extension_points:
continue

dist_metadata = entry_point.dist.metadata
ep_tuple = (
entry_point.value,
dist_metadata['Name'], dist_metadata['Version'],
)
if entry_point.name in extension_points[entry_point.group]:
previous = extension_points[entry_point.group][entry_point.name]
logger.error(
f"Entry point '{entry_point.group}.{entry_point.name}' is "
f"declared multiple times, '{ep_tuple}' "
f"overwriting '{previous}'")
extension_points[entry_point.group][entry_point.name] = ep_tuple
return extension_points


def get_extension_points(group):
Expand All @@ -87,16 +129,9 @@ def get_extension_points(group):
:rtype: dict
"""
extension_points = {}
try:
# Python 3.10 and newer
query = entry_points(group=group)
except TypeError:
query = (
entry_point
for dist in _get_unique_distributions()
for entry_point in dist.entry_points
if entry_point.group == group)
for entry_point in query:
for entry_point in _get_cached_entry_points():
if entry_point.group != group:
continue
if entry_point.name in extension_points:
previous_entry_point = extension_points[entry_point.name]
logger.error(
Expand Down
1 change: 1 addition & 0 deletions test/spell_check.words
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ importlib
importorskip
isatty
iterdir
itertools
junit
levelname
libexec
Expand Down
8 changes: 8 additions & 0 deletions test/test_extension_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# TODO: Drop this with Python 3.7 support
from importlib_metadata import Distribution

from colcon_core.extension_point import clear_entry_point_cache
from colcon_core.extension_point import EntryPoint
from colcon_core.extension_point import EXTENSION_POINT_GROUP_NAME
from colcon_core.extension_point import get_all_extension_points
Expand Down Expand Up @@ -73,6 +74,8 @@ def test_all_extension_points():
'colcon_core.extension_point.distributions',
side_effect=_distributions
):
clear_entry_point_cache()

# successfully load a known entry point
extension_points = get_all_extension_points()
assert set(extension_points.keys()) == {
Expand All @@ -94,12 +97,14 @@ def test_extension_point_blocklist():
'colcon_core.extension_point.distributions',
side_effect=_distributions
):
clear_entry_point_cache()
extension_points = get_extension_points('group1')
assert 'extA' in extension_points.keys()
extension_point = extension_points['extA']
assert extension_point == 'eA'

with patch.object(EntryPoint, 'load', return_value=None) as load:
clear_entry_point_cache()
load_extension_point('extA', 'eA', 'group1')
assert load.call_count == 1

Expand All @@ -108,12 +113,14 @@ def test_extension_point_blocklist():
with EnvironmentContext(COLCON_EXTENSION_BLOCKLIST=os.pathsep.join([
'group1.extB', 'group2.extC'])
):
clear_entry_point_cache()
load_extension_point('extA', 'eA', 'group1')
assert load.call_count == 1

# entry point in a blocked group can't be loaded
load.reset_mock()
with EnvironmentContext(COLCON_EXTENSION_BLOCKLIST='group1'):
clear_entry_point_cache()
with pytest.raises(RuntimeError) as e:
load_extension_point('extA', 'eA', 'group1')
assert 'The entry point group name is listed in the environment ' \
Expand All @@ -124,6 +131,7 @@ def test_extension_point_blocklist():
with EnvironmentContext(COLCON_EXTENSION_BLOCKLIST=os.pathsep.join([
'group1.extA', 'group1.extB'])
):
clear_entry_point_cache()
with pytest.raises(RuntimeError) as e:
load_extension_point('extA', 'eA', 'group1')
assert 'The entry point name is listed in the environment ' \
Expand Down

0 comments on commit fbe2360

Please sign in to comment.