From 893cc79757d8f8db893ad9bfa8b2adb8d2cb5e56 Mon Sep 17 00:00:00 2001 From: Scott K Logan Date: Fri, 16 Feb 2024 14:14:58 -0600 Subject: [PATCH] Add an explicit cache on Python entry points Whenever we enumerate Python entry points to load colcon extension points, we're re-parsing metadata for every Python package found on the system. Worse yet, accessing attributes on importlib.metadata.Distribution typically results in re-reading the metadata each time, so we're hitting the disk pretty hard. We don't generally expect the entry points available to change, so we should cache that information once and parse each package's metadata a single time. --- colcon_core/extension_point.py | 66 +++++++++++++++++----------------- test/test_extension_point.py | 65 +++++++++++++++------------------ 2 files changed, 62 insertions(+), 69 deletions(-) diff --git a/colcon_core/extension_point.py b/colcon_core/extension_point.py index c724d221..2d33b42a 100644 --- a/colcon_core/extension_point.py +++ b/colcon_core/extension_point.py @@ -9,12 +9,10 @@ try: from importlib.metadata import distributions from importlib.metadata import EntryPoint - from importlib.metadata import entry_points except ImportError: # TODO: Drop this with Python 3.7 support from importlib_metadata import distributions from importlib_metadata import EntryPoint - from importlib_metadata import entry_points from colcon_core.environment_variable import EnvironmentVariable from colcon_core.logging import colcon_logger @@ -26,7 +24,6 @@ logger = colcon_logger.getChild(__name__) - """ The group name for entry points identifying colcon extension points. @@ -36,14 +33,27 @@ """ EXTENSION_POINT_GROUP_NAME = 'colcon_core.extension_point' +_cached_extension_points = [] + -def _get_unique_distributions(): - seen = set() - for dist in distributions(): - dist_name = dist.metadata['Name'] - if dist_name not in seen: +def _get_cached_extension_points(): + if not _cached_extension_points: + seen = set() + for dist in distributions(): + dist_meta = dist.metadata + dist_name = dist_meta['Name'] + if dist_name in seen: + continue seen.add(dist_name) - yield dist + dist_tuple = (dist_name, dist_meta['Version']) + for entry_point in dist.entry_points: + _cached_extension_points.append((entry_point, *dist_tuple)) + return _cached_extension_points + + +def clear_extension_point_cache(): + """Purge the extension point cache.""" + _cached_extension_points.clear() def get_all_extension_points(): @@ -60,21 +70,18 @@ def get_all_extension_points(): colcon_extension_points.setdefault(EXTENSION_POINT_GROUP_NAME, None) entry_points = defaultdict(dict) - for dist in _get_unique_distributions(): - for entry_point in dist.entry_points: - # skip groups which are not registered as extension points - if entry_point.group not in colcon_extension_points: - continue + for entry_point, dist_name, dist_version in _get_cached_extension_points(): + if entry_point.group not in colcon_extension_points: + continue - if entry_point.name in entry_points[entry_point.group]: - previous = entry_points[entry_point.group][entry_point.name] - logger.error( - f"Entry point '{entry_point.group}.{entry_point.name}' is " - f"declared multiple times, '{entry_point.value}' " - f"from '{dist._path}' " - f"overwriting '{previous}'") - entry_points[entry_point.group][entry_point.name] = \ - (entry_point.value, dist.metadata['Name'], dist.version) + ep_tuple = (entry_point.value, dist_name, dist_version) + if entry_point.name in entry_points[entry_point.group]: + previous = entry_points[entry_point.group][entry_point.name] + logger.error( + f"Entry point '{entry_point.group}.{entry_point.name}' is " + f"declared multiple times, '{ep_tuple}' " + f"overwriting '{previous}'") + entry_points[entry_point.group][entry_point.name] = ep_tuple return entry_points @@ -87,16 +94,9 @@ def get_extension_points(group): :rtype: dict """ extension_points = {} - try: - # Python 3.10 and newer - query = entry_points(group=group) - except TypeError: - query = ( - entry_point - for dist in _get_unique_distributions() - for entry_point in dist.entry_points - if entry_point.group == group) - for entry_point in query: + for entry_point, *_ in _get_cached_extension_points(): + if entry_point.group != group: + continue if entry_point.name in extension_points: previous_entry_point = extension_points[entry_point.name] logger.error( diff --git a/test/test_extension_point.py b/test/test_extension_point.py index 96e58a0d..f7adcb59 100644 --- a/test/test_extension_point.py +++ b/test/test_extension_point.py @@ -6,6 +6,7 @@ from unittest.mock import DEFAULT from unittest.mock import patch +from colcon_core.extension_point import clear_extension_point_cache from colcon_core.extension_point import EntryPoint from colcon_core.extension_point import EXTENSION_POINT_GROUP_NAME from colcon_core.extension_point import get_all_extension_points @@ -25,10 +26,11 @@ class Dist(): - version = '0.0.0' - def __init__(self, entry_points): - self.metadata = {'Name': f'dist-{id(self)}'} + self.metadata = { + 'Name': f'dist-{id(self)}', + 'Version': '0.0.0', + } self._entry_points = entry_points @property @@ -39,17 +41,9 @@ def entry_points(self): def name(self): return self.metadata['Name'] - -def iter_entry_points(*, group=None): - if group == EXTENSION_POINT_GROUP_NAME: - return [Group1, Group2] - elif group == Group1.name: - return [ExtA, ExtB] - assert not group - return { - EXTENSION_POINT_GROUP_NAME: [Group1, Group2], - Group1.name: [ExtA, ExtB], - } + @property + def version(self): + return self.metadata['Version'] def distributions(): @@ -62,40 +56,36 @@ def distributions(): def test_all_extension_points(): with patch( - 'colcon_core.extension_point.entry_points', - side_effect=iter_entry_points + 'colcon_core.extension_point.distributions', + side_effect=distributions ): - with patch( - 'colcon_core.extension_point.distributions', - side_effect=distributions - ): - # successfully load a known entry point - extension_points = get_all_extension_points() - assert set(extension_points.keys()) == { - EXTENSION_POINT_GROUP_NAME, - 'group1', - 'group2', - } - assert set(extension_points['group1'].keys()) == {'extA', 'extB'} - assert extension_points['group1']['extA'][0] == 'eA' + clear_extension_point_cache() + + # successfully load a known entry point + extension_points = get_all_extension_points() + assert set(extension_points.keys()) == { + EXTENSION_POINT_GROUP_NAME, + 'group1', + 'group2', + } + assert set(extension_points['group1'].keys()) == {'extA', 'extB'} + assert extension_points['group1']['extA'][0] == 'eA' def test_extension_point_blocklist(): # successful loading of extension point without a blocklist with patch( - 'colcon_core.extension_point.entry_points', - side_effect=iter_entry_points + 'colcon_core.extension_point.distributions', + side_effect=distributions ): - with patch( - 'colcon_core.extension_point.distributions', - side_effect=distributions - ): - extension_points = get_extension_points('group1') + clear_extension_point_cache() + extension_points = get_extension_points('group1') assert 'extA' in extension_points.keys() extension_point = extension_points['extA'] assert extension_point == 'eA' with patch.object(EntryPoint, 'load', return_value=None) as load: + clear_extension_point_cache() load_extension_point('extA', 'eA', 'group1') assert load.call_count == 1 @@ -104,12 +94,14 @@ def test_extension_point_blocklist(): with EnvironmentContext(COLCON_EXTENSION_BLOCKLIST=os.pathsep.join([ 'group1.extB', 'group2.extC']) ): + clear_extension_point_cache() load_extension_point('extA', 'eA', 'group1') assert load.call_count == 1 # entry point in a blocked group can't be loaded load.reset_mock() with EnvironmentContext(COLCON_EXTENSION_BLOCKLIST='group1'): + clear_extension_point_cache() with pytest.raises(RuntimeError) as e: load_extension_point('extA', 'eA', 'group1') assert 'The entry point group name is listed in the environment ' \ @@ -120,6 +112,7 @@ def test_extension_point_blocklist(): with EnvironmentContext(COLCON_EXTENSION_BLOCKLIST=os.pathsep.join([ 'group1.extA', 'group1.extB']) ): + clear_extension_point_cache() with pytest.raises(RuntimeError) as e: load_extension_point('extA', 'eA', 'group1') assert 'The entry point name is listed in the environment ' \