Skip to content

Commit

Permalink
Merge pull request #92 from victorskl/sso-session-refresh
Browse files Browse the repository at this point in the history
Re-engineered AWS SSO accessToken expiry mechanism
  • Loading branch information
victorskl authored Feb 18, 2024
2 parents f6b70fc + 5994a25 commit caabaa1
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 82 deletions.
4 changes: 2 additions & 2 deletions pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
log_cli = 1
log_cli_level = INFO
log_cli_format = %(asctime)s [%(levelname)8s] %(message)s (%(filename)s:%(lineno)s)
log_cli_date_format=%Y-%m-%d %H:%M:%S
addopts = --cov-config=.coveragerc --no-cov-on-fail --cov-report term-missing --cov=yawsso tests/
log_cli_date_format = %Y-%m-%d %H:%M:%S
addopts = --code-highlight=no --cov-config=.coveragerc --no-cov-on-fail --cov-report term-missing --cov=yawsso tests/
37 changes: 7 additions & 30 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,10 @@ def setUp(self) -> None:
"startUrl": "https://petshop.awsapps.com/start",
"region": "ap-southeast-2",
"accessToken": "longTextA.AverylOngText",
"expiresAt": f"{str((datetime.utcnow() + timedelta(hours=3)).isoformat())[:-7]}UTC"
"expiresAt": f"{str((datetime.utcnow() + timedelta(hours=3)).isoformat())[:-7]}UTC",
"clientId": "longTextA",
"clientSecret": "longTextA", # pragma: allowlist secret
"refreshToken": "longTextA" # pragma: allowlist secret
}
self.sso_cache_json.write(json.dumps(cache_json).encode('utf-8'))
self.sso_cache_json.seek(0)
Expand Down Expand Up @@ -359,31 +362,6 @@ def test_sso_cache_not_found(self):
cli.main()
self.assertEqual(x.exception.code, 1)

def test_sso_cache_expires(self):
"""
python -m unittest tests.test_cli.CLIUnitTests.test_sso_cache_expires
"""
with ArgvContext(program, '-p', 'dev', '-t'), self.assertRaises(SystemExit) as x:
# clean up as going to mutate this
self.sso_cache_json.close()
os.unlink(self.sso_cache_json.name)
self.sso_cache_dir.cleanup()
# start new test case
self.sso_cache_dir = tempfile.TemporaryDirectory()
self.sso_cache_json = tempfile.NamedTemporaryFile(dir=self.sso_cache_dir.name, suffix='.json', delete=False)
cache_json = {
"startUrl": "https://petshop.awsapps.com/start",
"region": "ap-southeast-2",
"accessToken": "longTextA.AverylOngText",
"expiresAt": f"{str((datetime.utcnow()).isoformat())[:-7]}UTC"
}
self.sso_cache_json.write(json.dumps(cache_json).encode('utf-8'))
self.sso_cache_json.seek(0)
self.sso_cache_json.read()
cli.core.aws_sso_cache_path = self.sso_cache_dir.name
cli.main()
self.assertEqual(x.exception.code, 1)

def test_aws_cli_v1(self):
"""
python -m unittest tests.test_cli.CLIUnitTests.test_aws_cli_v1
Expand Down Expand Up @@ -442,7 +420,7 @@ def test_sso_cache_not_json(self):
"""
python -m unittest tests.test_cli.CLIUnitTests.test_sso_cache_not_json
"""
with ArgvContext(program, '-p', 'dev', '-t'), self.assertRaises(SystemExit) as x:
with ArgvContext(program, '-p', 'dev', '-t'):
# clean up as going to mutate this
self.sso_cache_json.close()
os.unlink(self.sso_cache_json.name)
Expand All @@ -454,13 +432,12 @@ def test_sso_cache_not_json(self):
self.sso_cache_json.read()
cli.core.aws_sso_cache_path = self.sso_cache_dir.name
cli.main()
self.assertEqual(x.exception.code, 1)

def test_not_equal_sso_start_url(self):
"""
python -m unittest tests.test_cli.CLIUnitTests.test_not_equal_sso_start_url
"""
with ArgvContext(program, '-p', 'dev', '-t'), self.assertRaises(SystemExit) as x:
with ArgvContext(program, '-p', 'dev', '-t'):
# clean up as going to mutate this
self.config.close()
os.unlink(self.config.name)
Expand All @@ -480,7 +457,6 @@ def test_not_equal_sso_start_url(self):
self.config.read()
cli.core.aws_config_file = self.config.name
cli.main()
self.assertEqual(x.exception.code, 1)

def test_not_equal_sso_region(self):
"""
Expand Down Expand Up @@ -537,6 +513,7 @@ def test_sso_get_role_credentials_fail(self):
python -m unittest tests.test_cli.CLIUnitTests.test_sso_get_role_credentials_fail
"""
when(cli.utils).invoke(contains('aws sso get-role-credentials')).thenReturn((False, 'does-not-matter'))
when(cli.utils).invoke(contains('aws sso-oidc create-token')).thenReturn((False, 'does-not-matter'))
cred = cli.core.update_profile("dev", cli.utils.read_config(self.config.name))
self.assertIsNone(cred)

Expand Down
26 changes: 22 additions & 4 deletions tests/test_cmd.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
import os
import tempfile
from datetime import datetime
from io import StringIO
from unittest.mock import patch

Expand Down Expand Up @@ -64,9 +64,8 @@ def test_auto_command_login_expires(self):
when(mock_poll).start(...).thenReturn(mock_poll)
when(mock_poll).resolve(...).thenReturn(True)

when(cli.cmd.AutoCommand).get_sso_cached_login(...).thenReturn({
'expiresAt': f"{str((datetime.utcnow()).isoformat())[:-7]}UTC"
})
when(cli.cmd.AutoCommand).session_cached(...).thenReturn((False, 'does-not-matter'))
when(cli.cmd.AutoCommand).session_refresh(...).thenReturn((False, 'does-not-matter'))

with ArgvContext(program, '-t', 'auto', '--profile', 'dev'):
cli.main()
Expand All @@ -76,6 +75,25 @@ def test_auto_command_login_expires(self):
self.assertEqual(new_tok, 'VeryLongBase664String==')
verify(cli.utils, times=1).Poll(...)

def test_auto_command_login_expires2(self):
"""
python -m unittest tests.test_cmd.AutoCommandUnitTests.test_auto_command_login_expires2
"""
when(cli.cmd.AutoCommand).session_cached(...).thenReturn((False, 'does-not-matter'))
when(cli.utils).invoke(contains('aws sso-oidc create-token')).thenReturn((True, json.dumps({
"accessToken": "does-not-matter",
"tokenType": "Bearer",
"expiresIn": 3600,
"refreshToken": "does-not-matter"
})))

with ArgvContext(program, '-t', 'auto', '--profile', 'dev'):
cli.main()
cred = cli.utils.read_config(self.credentials.name)
new_tok = cred['dev']['aws_session_token']
self.assertNotEqual(new_tok, 'tok')
self.assertEqual(new_tok, 'VeryLongBase664String==')

def test_auto_login_not_sso_profile(self):
"""
python -m unittest tests.test_cmd.AutoCommandUnitTests.test_auto_login_not_sso_profile
Expand Down
38 changes: 18 additions & 20 deletions yawsso/cmd.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import os
import sys
from abc import ABC, abstractmethod
from datetime import datetime

from yawsso import TRACE, Constant, logger, core, utils

Expand Down Expand Up @@ -133,34 +132,33 @@ class AutoCommand(LoginCommand):
def __init__(self, co):
super(AutoCommand, self).__init__(co)

def get_sso_cached_login(self, profile):
cached_login = core.get_aws_cli_v2_sso_cached_login(profile)

if cached_login is None:
utils.halt(f"Can not find valid AWS CLI v2 SSO login cache in {core.aws_sso_cache_path} "
f"for profile {self.login_profile}.")

return cached_login

def is_sso_cached_login_expired(self, cached_login):
expires_utc = core.parse_sso_cached_login_expiry(cached_login)
def session_cached(self, profile, cached_login):
return core.session_cached(self.login_profile, profile, cached_login)

if datetime.utcnow() > expires_utc:
logger.log(TRACE, f"Current cached SSO login is expired since {expires_utc.astimezone().isoformat()}. "
f"Performing auto login for profile {self.login_profile}.")
return True

return False
def session_refresh(self, profile, cached_login):
return core.session_refresh(self.login_profile, profile, cached_login)

def perform(self):
profile = core.load_profile_from_config(self.login_profile, self.co.config)

if not core.is_sso_profile(profile):
utils.halt(f"Login profile is not an AWS SSO profile. Abort auto syncing profile `{self.login_profile}`")

cached_login = self.get_sso_cached_login(profile)
cached_login = core.get_aws_cli_v2_sso_cached_login(profile)
if cached_login is None:
utils.halt(f"Can not find SSO login session cache in {core.aws_sso_cache_path} "
f"for ({profile['sso_start_url']}) profile `{self.login_profile})`.")

# try 1: attempt using cached accessToken
role_cred_success, _ = self.session_cached(profile, cached_login)

# try 2: attempt using refreshToken to generate accessToken
if not role_cred_success:
role_cred_success, _ = self.session_refresh(profile, cached_login)

if self.is_sso_cached_login_expired(cached_login=cached_login):
# try 3: attempt aws sso login
if not role_cred_success:
logger.log(TRACE, f"Your SSO login session ({profile['sso_start_url']}) has expired. Attempt auto login...")
super(AutoCommand, self).perform()

return self
Expand Down
74 changes: 48 additions & 26 deletions yawsso/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@ def get_aws_cli_v2_sso_cached_login(profile):

def update_aws_cli_v1_credentials(profile_name, profile, credentials):
if credentials is None:
logger.warning(f"No appropriate credentials found for profile '{profile_name}'. "
f"Skip syncing it. Use --trace flag to see possible error causes.")
logger.warning(f"No appropriate credentials found. Skip syncing profile `{profile_name}`")
return

# region = profile.get("region", aws_default_region)
Expand All @@ -53,14 +52,6 @@ def update_aws_cli_v1_credentials(profile_name, profile, credentials):
logger.debug(f"Done syncing AWS CLI v1 credentials using AWS CLI v2 SSO login session for profile `{profile_name}`")


def parse_sso_cached_login_expiry(cached_login):
# older versions of aws-cli might use non-standard format with `UTC` instead of `Z`
expires_at = cached_login["expiresAt"].replace('UTC', 'Z')
datetime_format_in_sso_cached_login = "%Y-%m-%dT%H:%M:%SZ"
expires_utc = datetime.strptime(expires_at, datetime_format_in_sso_cached_login)
return expires_utc


def parse_assume_role_credentials_expiry(dt_str):
datetime_format_in_assume_role_expiration = "%Y-%m-%dT%H:%M:%S+00:00"
expires_utc = datetime.strptime(dt_str, datetime_format_in_assume_role_expiration)
Expand All @@ -82,41 +73,73 @@ def append_cli_global_options(cmd: str, profile: dict):
ca_bundle = profile.get('ca_bundle', None)
if ca_bundle:
cmd = f"{cmd} --ca-bundle '{ca_bundle}'"

logger.log(TRACE, f"COMMAND: {cmd}")
return cmd


def check_sso_cached_login_expires(profile_name, profile):
cached_login = get_aws_cli_v2_sso_cached_login(profile)
if cached_login is None:
u.halt(f"Can not find valid AWS CLI v2 SSO login cache in {aws_sso_cache_path} for profile {profile_name}.")

expires_utc = parse_sso_cached_login_expiry(cached_login)
def create_access_token(cached_login):
cmd_create_token = f"{aws_bin} sso-oidc create-token " \
f"--output json " \
f"--client-id {cached_login['clientId']} " \
f"--client-secret {cached_login['clientSecret']} " \
f"--grant-type refresh_token " \
f"--refresh-token {cached_login['refreshToken']}"

if datetime.utcnow() > expires_utc:
u.halt(f"Current cached SSO login is expired since {expires_utc.astimezone().isoformat()}. Try login again.")
create_token_success, create_token_output = u.invoke(cmd_create_token)

return cached_login
if not create_token_success:
logger.log(TRACE, f"EXCEPTION: '{create_token_output}'")

return create_token_success, create_token_output

def fetch_credentials(profile_name, profile):
cached_login = check_sso_cached_login_expires(profile_name, profile)

def get_role_credentials(profile_name, profile, access_token):
cmd_get_role_cred = f"{aws_bin} sso get-role-credentials " \
f"--output json " \
f"--profile {profile_name} " \
f"--region {profile['sso_region']} " \
f"--role-name {profile['sso_role_name']} " \
f"--account-id {profile['sso_account_id']} " \
f"--access-token {cached_login['accessToken']}"
f"--access-token {access_token}"

cmd_get_role_cred = append_cli_global_options(cmd_get_role_cred, profile)

role_cred_success, role_cred_output = u.invoke(cmd_get_role_cred)

if not role_cred_success:
logger.log(TRACE, f"ERROR EXECUTING COMMAND: '{cmd_get_role_cred}'. EXCEPTION: '{role_cred_output}'")
logger.log(TRACE, f"EXCEPTION: '{role_cred_output}'")

return role_cred_success, role_cred_output


def session_cached(profile_name, profile, cached_login):
return get_role_credentials(profile_name, profile, cached_login['accessToken'])


def session_refresh(profile_name, profile, cached_login):
logger.log(TRACE, f"Attempt using SSO refreshToken to generate accessToken")
create_token_success, create_token_output = create_access_token(cached_login)
if create_token_success:
return get_role_credentials(profile_name, profile, json.loads(create_token_output)['accessToken'])
return False, create_token_output


def fetch_credentials(profile_name, profile):
cached_login = get_aws_cli_v2_sso_cached_login(profile)
if cached_login is None:
logger.warning(f"Can not find SSO login session cache in {aws_sso_cache_path} "
f"for ({profile['sso_start_url']}) profile `{profile_name}`.")
return

# try 1: attempt using cached accessToken
role_cred_success, role_cred_output = session_cached(profile_name, profile, cached_login)

# try 2: attempt using refreshToken to generate accessToken
if not role_cred_success:
role_cred_success, role_cred_output = session_refresh(profile_name, profile, cached_login)

# try 3: attempt aws sso login
if not role_cred_success:
logger.warning(f"Your SSO login session ({profile['sso_start_url']}) has expired. Try aws sso login again.")
return

return json.loads(role_cred_output)['roleCredentials']
Expand Down Expand Up @@ -250,7 +273,6 @@ def update_profile(profile_name, config, new_profile_name=""):
if not is_sso_profile(source_profile):
logger.warning(f"Your source_profile is not an AWS SSO profile. Skip syncing profile `{profile_name}`")
return
check_sso_cached_login_expires(source_profile_name, source_profile)
eager_sync_source_profile(source_profile_name, source_profile)
logger.log(TRACE, f"Fetching credentials using assume role for `{profile_name}`")
credentials = fetch_credentials_with_assume_role(profile_name, profile)
Expand Down

0 comments on commit caabaa1

Please sign in to comment.