diff --git a/colcon_core/extension_point.py b/colcon_core/extension_point.py index c724d221..2d33b42a 100644 --- a/colcon_core/extension_point.py +++ b/colcon_core/extension_point.py @@ -9,12 +9,10 @@ try: from importlib.metadata import distributions from importlib.metadata import EntryPoint - from importlib.metadata import entry_points except ImportError: # TODO: Drop this with Python 3.7 support from importlib_metadata import distributions from importlib_metadata import EntryPoint - from importlib_metadata import entry_points from colcon_core.environment_variable import EnvironmentVariable from colcon_core.logging import colcon_logger @@ -26,7 +24,6 @@ logger = colcon_logger.getChild(__name__) - """ The group name for entry points identifying colcon extension points. @@ -36,14 +33,27 @@ """ EXTENSION_POINT_GROUP_NAME = 'colcon_core.extension_point' +_cached_extension_points = [] + -def _get_unique_distributions(): - seen = set() - for dist in distributions(): - dist_name = dist.metadata['Name'] - if dist_name not in seen: +def _get_cached_extension_points(): + if not _cached_extension_points: + seen = set() + for dist in distributions(): + dist_meta = dist.metadata + dist_name = dist_meta['Name'] + if dist_name in seen: + continue seen.add(dist_name) - yield dist + dist_tuple = (dist_name, dist_meta['Version']) + for entry_point in dist.entry_points: + _cached_extension_points.append((entry_point, *dist_tuple)) + return _cached_extension_points + + +def clear_extension_point_cache(): + """Purge the extension point cache.""" + _cached_extension_points.clear() def get_all_extension_points(): @@ -60,21 +70,18 @@ def get_all_extension_points(): colcon_extension_points.setdefault(EXTENSION_POINT_GROUP_NAME, None) entry_points = defaultdict(dict) - for dist in _get_unique_distributions(): - for entry_point in dist.entry_points: - # skip groups which are not registered as extension points - if entry_point.group not in colcon_extension_points: - continue + for entry_point, dist_name, dist_version in _get_cached_extension_points(): + if entry_point.group not in colcon_extension_points: + continue - if entry_point.name in entry_points[entry_point.group]: - previous = entry_points[entry_point.group][entry_point.name] - logger.error( - f"Entry point '{entry_point.group}.{entry_point.name}' is " - f"declared multiple times, '{entry_point.value}' " - f"from '{dist._path}' " - f"overwriting '{previous}'") - entry_points[entry_point.group][entry_point.name] = \ - (entry_point.value, dist.metadata['Name'], dist.version) + ep_tuple = (entry_point.value, dist_name, dist_version) + if entry_point.name in entry_points[entry_point.group]: + previous = entry_points[entry_point.group][entry_point.name] + logger.error( + f"Entry point '{entry_point.group}.{entry_point.name}' is " + f"declared multiple times, '{ep_tuple}' " + f"overwriting '{previous}'") + entry_points[entry_point.group][entry_point.name] = ep_tuple return entry_points @@ -87,16 +94,9 @@ def get_extension_points(group): :rtype: dict """ extension_points = {} - try: - # Python 3.10 and newer - query = entry_points(group=group) - except TypeError: - query = ( - entry_point - for dist in _get_unique_distributions() - for entry_point in dist.entry_points - if entry_point.group == group) - for entry_point in query: + for entry_point, *_ in _get_cached_extension_points(): + if entry_point.group != group: + continue if entry_point.name in extension_points: previous_entry_point = extension_points[entry_point.name] logger.error( diff --git a/test/test_extension_point.py b/test/test_extension_point.py index 96e58a0d..f7adcb59 100644 --- a/test/test_extension_point.py +++ b/test/test_extension_point.py @@ -6,6 +6,7 @@ from unittest.mock import DEFAULT from unittest.mock import patch +from colcon_core.extension_point import clear_extension_point_cache from colcon_core.extension_point import EntryPoint from colcon_core.extension_point import EXTENSION_POINT_GROUP_NAME from colcon_core.extension_point import get_all_extension_points @@ -25,10 +26,11 @@ class Dist(): - version = '0.0.0' - def __init__(self, entry_points): - self.metadata = {'Name': f'dist-{id(self)}'} + self.metadata = { + 'Name': f'dist-{id(self)}', + 'Version': '0.0.0', + } self._entry_points = entry_points @property @@ -39,17 +41,9 @@ def entry_points(self): def name(self): return self.metadata['Name'] - -def iter_entry_points(*, group=None): - if group == EXTENSION_POINT_GROUP_NAME: - return [Group1, Group2] - elif group == Group1.name: - return [ExtA, ExtB] - assert not group - return { - EXTENSION_POINT_GROUP_NAME: [Group1, Group2], - Group1.name: [ExtA, ExtB], - } + @property + def version(self): + return self.metadata['Version'] def distributions(): @@ -62,40 +56,36 @@ def distributions(): def test_all_extension_points(): with patch( - 'colcon_core.extension_point.entry_points', - side_effect=iter_entry_points + 'colcon_core.extension_point.distributions', + side_effect=distributions ): - with patch( - 'colcon_core.extension_point.distributions', - side_effect=distributions - ): - # successfully load a known entry point - extension_points = get_all_extension_points() - assert set(extension_points.keys()) == { - EXTENSION_POINT_GROUP_NAME, - 'group1', - 'group2', - } - assert set(extension_points['group1'].keys()) == {'extA', 'extB'} - assert extension_points['group1']['extA'][0] == 'eA' + clear_extension_point_cache() + + # successfully load a known entry point + extension_points = get_all_extension_points() + assert set(extension_points.keys()) == { + EXTENSION_POINT_GROUP_NAME, + 'group1', + 'group2', + } + assert set(extension_points['group1'].keys()) == {'extA', 'extB'} + assert extension_points['group1']['extA'][0] == 'eA' def test_extension_point_blocklist(): # successful loading of extension point without a blocklist with patch( - 'colcon_core.extension_point.entry_points', - side_effect=iter_entry_points + 'colcon_core.extension_point.distributions', + side_effect=distributions ): - with patch( - 'colcon_core.extension_point.distributions', - side_effect=distributions - ): - extension_points = get_extension_points('group1') + clear_extension_point_cache() + extension_points = get_extension_points('group1') assert 'extA' in extension_points.keys() extension_point = extension_points['extA'] assert extension_point == 'eA' with patch.object(EntryPoint, 'load', return_value=None) as load: + clear_extension_point_cache() load_extension_point('extA', 'eA', 'group1') assert load.call_count == 1 @@ -104,12 +94,14 @@ def test_extension_point_blocklist(): with EnvironmentContext(COLCON_EXTENSION_BLOCKLIST=os.pathsep.join([ 'group1.extB', 'group2.extC']) ): + clear_extension_point_cache() load_extension_point('extA', 'eA', 'group1') assert load.call_count == 1 # entry point in a blocked group can't be loaded load.reset_mock() with EnvironmentContext(COLCON_EXTENSION_BLOCKLIST='group1'): + clear_extension_point_cache() with pytest.raises(RuntimeError) as e: load_extension_point('extA', 'eA', 'group1') assert 'The entry point group name is listed in the environment ' \ @@ -120,6 +112,7 @@ def test_extension_point_blocklist(): with EnvironmentContext(COLCON_EXTENSION_BLOCKLIST=os.pathsep.join([ 'group1.extA', 'group1.extB']) ): + clear_extension_point_cache() with pytest.raises(RuntimeError) as e: load_extension_point('extA', 'eA', 'group1') assert 'The entry point name is listed in the environment ' \