Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
davidlm committed Nov 3, 2023
1 parent 086fb71 commit 6c289ff
Show file tree
Hide file tree
Showing 6 changed files with 33 additions and 98 deletions.
2 changes: 2 additions & 0 deletions botocore/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -780,6 +780,8 @@ def compute_endpoint_resolver_builtin_defaults(
# account ID is calculated later if account based routing is
# enabled and configured for the service
EPRBuiltins.AWS_ACCOUNT_ID: None,
# credential scope is calculated later if configured on the
# credentials
EPRBuiltins.AWS_CREDENTIAL_SCOPE: None,
}

Expand Down
67 changes: 21 additions & 46 deletions botocore/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,10 +699,6 @@ def refresh_needed(self, refresh_in=None):

class CachedCredentialFetcher:
DEFAULT_EXPIRY_WINDOW_SECONDS = 60 * 15
OPTIONAL_FIELD_MAPPING = {
'AccountId': 'account_id',
'CredentialScope': 'scope',
}

def __init__(self, cache=None, expiry_window_seconds=None):
if cache is None:
Expand Down Expand Up @@ -748,7 +744,12 @@ def _get_cached_credentials(self):
'token': creds['SessionToken'],
'expiry_time': expiration,
}
self._handle_optional_fields(creds, creds_dict)
account_id = creds.get('AccountId')
if account_id is not None:
creds_dict['account_id'] = account_id
scope = creds.get('CredentialScope')
if scope is not None:
creds_dict['scope'] = scope
return creds_dict

def _load_from_cache(self):
Expand All @@ -771,12 +772,6 @@ def _is_expired(self, credentials):
seconds = total_seconds(end_time - _local_now())
return seconds < self._expiry_window_seconds

def _handle_optional_fields(self, response, creds_dict):
"""Handle optional fields in the response."""
for response_key, creds_key in self.OPTIONAL_FIELD_MAPPING.items():
if response_key in response:
creds_dict[creds_key] = response[response_key]


class BaseAssumeRoleCredentialFetcher(CachedCredentialFetcher):
def __init__(
Expand Down Expand Up @@ -1052,16 +1047,6 @@ def _extract_creds_from_mapping(self, mapping, *key_names):

class ProcessProvider(CredentialProvider):
METHOD = 'custom-process'
OPTIONAL_FIELD_MAPPING = {
'AccountId': {
'profile_field': 'aws_account_id',
'creds_dict_field': 'account_id',
},
'CredentialScope': {
'profile_field': 'aws_credential_scope',
'creds_dict_field': 'scope',
},
}

def __init__(self, profile_name, load_config, popen=subprocess.Popen):
self._profile_name = profile_name
Expand Down Expand Up @@ -1125,18 +1110,24 @@ def _retrieve_credentials_using(self, credential_process):
provider=self.METHOD,
error_msg=f"Missing required key in response: {e}",
)
self._handle_optional_fields(parsed, creds_dict)
return creds_dict

account_id = self._resolve_account_id(parsed)
if account_id is not None:
creds_dict['account_id'] = account_id
scope = self._resolve_scope(parsed)
if scope is not None:
creds_dict['scope'] = scope

return creds_dict

def _resolve_account_id(self, parsed_response):
account_id = parsed_response.get('AccountId')
return account_id or self.profile_config.get('aws_account_id')

def _resolve_scope(self, parsed_response):
scope = parsed_response.get('CredentialScope')
return scope or self.profile_config.get('aws_credential_scope')

@property
def _credential_process(self):
return self.profile_config.get('credential_process')
Expand All @@ -1148,16 +1139,6 @@ def profile_config(self):
profiles = self._loaded_config.get('profiles', {})
return profiles.get(self._profile_name, {})

def _handle_optional_fields(self, response, creds_dict):
"""Handle optional fields in the response."""
for response_key, creds_config in self.OPTIONAL_FIELD_MAPPING.items():
prof_field = creds_config['profile_field']
resp_val = response.get(response_key)
optional_value = resp_val or self.profile_config.get(prof_field)
if optional_value is not None:
dict_field = creds_config['creds_dict_field']
creds_dict[dict_field] = optional_value


class InstanceMetadataProvider(CredentialProvider):
METHOD = 'iam-role'
Expand Down Expand Up @@ -1200,11 +1181,6 @@ class EnvProvider(CredentialProvider):
ACCOUNT_ID = 'AWS_ACCOUNT_ID'
SCOPE = 'AWS_CREDENTIAL_SCOPE'

OPTIONAL_FIELDS = (
'account_id',
'scope',
)

def __init__(self, environ=None, mapping=None):
"""
Expand Down Expand Up @@ -1326,19 +1302,18 @@ def fetch_credentials(require_expiry=True):
provider=method, cred_var=mapping['expiry_time']
)

self._handle_optional_fields(environ, mapping, credentials)
account_id = environ.get(mapping['account_id'])
if account_id is not None:
credentials['account_id'] = account_id

scope = environ.get(mapping['scope'])
if scope is not None:
credentials['scope'] = scope

return credentials

return fetch_credentials

def _handle_optional_fields(self, environ, mapping, creds_dict):
"""Handle optional fields in environment variables."""
for field_key in self.OPTIONAL_FIELDS:
field_value = environ.get(mapping[field_key])
if field_value is not None:
creds_dict[field_key] = field_value


class OriginalEC2Provider(CredentialProvider):
METHOD = 'ec2-credentials-file'
Expand Down
4 changes: 0 additions & 4 deletions botocore/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -820,9 +820,5 @@ class AccountIdNotFound(EndpointResolutionError):
fmt = '{msg}'


class CredentialScopeNotFound(EndpointResolutionError):
fmt = '{msg}'


class InvalidEndpointRegion(EndpointResolutionError):
fmt = '{msg}'
10 changes: 1 addition & 9 deletions botocore/regions.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,6 @@
DEFAULT_URI_TEMPLATE = '{service}.{region}.{dnsSuffix}' # noqa
DEFAULT_SERVICE_DATA = {'endpoints': {}}
# Allowed values for the ``account_id_endpoint_mode`` config field.
VALID_ACCOUNT_ID_ENDPOINT_MODES = [
'preferred',
'disabled',
'required',
]


class BaseEndpointResolver:
Expand Down Expand Up @@ -477,15 +472,12 @@ class CredentialBuiltinResolver:
'required',
)

def __init__(
self, credentials, account_id_endpoint_mode, uses_builtin_data_path
):
def __init__(self, credentials, account_id_endpoint_mode):
self._credentials = credentials
if account_id_endpoint_mode is None:
account_id_endpoint_mode = self.DEFAULT_ACCOUNT_ID_ENDPOINT_MODE
self._account_id_endpoint_mode = account_id_endpoint_mode
self._validate_account_id_endpoint_mode()
self._uses_builtin_data_path = uses_builtin_data_path

def resolve_account_id_builtin(self, builtin_configured, builtin_value):
"""Resolve the ``AWS::Auth::AccountId`` builtin."""
Expand Down
7 changes: 6 additions & 1 deletion tests/unit/test_credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,12 @@ def test_account_id_set(self):
self.assertTrue(self.creds.refresh_needed())
self.assertEqual(self.creds.account_id, '123456789012')

def test_credential_scope_refresh(self):
def test_credentail_scope_unset(self):
self.mock_time.return_value = datetime.now(tzlocal())
self.assertTrue(self.creds.refresh_needed())
self.assertIsNone(self.creds.scope)

def test_credential_scope_set(self):
metadata = self.metadata.copy()
metadata['scope'] = 'us-west-2'
self.refresher.return_value = metadata
Expand Down
41 changes: 3 additions & 38 deletions tests/unit/test_endpoint_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
)
from botocore.exceptions import (
AccountIdNotFound,
CredentialScopeNotFound,
EndpointResolutionError,
InvalidConfigError,
MissingDependencyException,
Expand Down Expand Up @@ -608,7 +607,6 @@ def create_ruleset_resolver(
bulitins,
credentials,
account_id_endpoint_mode,
uses_builtin_data_path=True,
):
service_model = Mock()
service_model.client_context_parameters = []
Expand All @@ -626,7 +624,6 @@ def create_ruleset_resolver(
client_context=None,
event_emitter=Mock(),
builtin_resolver=builtin_resolver,
uses_builtin_data_path=uses_builtin_data_path,
)


Expand Down Expand Up @@ -754,85 +751,53 @@ def test_required_mode_no_account_id(


@pytest.mark.parametrize(
"builtins, credentials, auth_scheme, expected_url",
"builtins, credentials, expected_url",
[
# scope matches region
(
BUILTINS_WITH_UNRESOLVED_CREDENTIAL_SCOPE,
CREDENTIALS_WITH_SCOPE,
None,
URL_WITH_CREDENTIAL_SCOPE,
),
# pre-resolved scope
(
BUILTINS_WITH_RESOLVED_CREDENTIAL_SCOPE,
CREDENTIALS_WITH_SCOPE,
None,
URL_WITH_OTHER_CREDENTIAL_SCOPE,
),
# custom endpoint with scope
(
BUILTINS_WITH_UNRESOLVED_CREDENTIAL_SCOPE_CUSTOM_ENDPOINT,
CREDENTIALS_WITH_SCOPE,
None,
URL_WITH_OTHER_CREDENTIAL_SCOPE,
),
# no scope in credentials
(
BUILTINS_WITH_UNRESOLVED_CREDENTIAL_SCOPE,
CREDENTIALS_NO_SCOPE,
None,
URL_NO_SCOPE,
),
# unsigned request
(
BUILTINS_WITH_UNRESOLVED_CREDENTIAL_SCOPE,
CREDENTIALS_WITH_SCOPE,
UNSIGNED,
URL_NO_SCOPE,
),
# no credentials
(
BUILTINS_WITH_UNRESOLVED_CREDENTIAL_SCOPE,
None,
None,
URL_NO_SCOPE,
),
],
)
def test_credential_scope_builtin(
operation_model_empty_context_params,
builtins,
credential_scope_ruleset,
builtins,
credentials,
auth_scheme,
expected_url,
):
resolver = create_ruleset_resolver(
credential_scope_ruleset, builtins, credentials, auth_scheme, PREFERRED
credential_scope_ruleset, builtins, credentials, PREFERRED
)
endpoint = resolver.construct_endpoint(
operation_model=operation_model_empty_context_params,
request_context={},
call_args={},
)
assert endpoint.url == expected_url


def test_non_builtin_path_raises(
operation_model_empty_context_params, credential_scope_ruleset
):
resolver = create_ruleset_resolver(
credential_scope_ruleset,
BUILTINS_WITH_UNRESOLVED_CREDENTIAL_SCOPE_CUSTOM_ENDPOINT,
CREDENTIALS_WITH_SCOPE,
None,
PREFERRED,
uses_builtin_data_path=False,
)
with pytest.raises(CredentialScopeNotFound):
resolver.construct_endpoint(
operation_model=operation_model_empty_context_params,
request_context={},
call_args={},
)

0 comments on commit 6c289ff

Please sign in to comment.