Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds support for OIDC 'prompt=none' parameter value #1351

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,10 @@ public final class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilte
*/
private static final String DEFAULT_AUTHORIZATION_ENDPOINT_URI = "/oauth2/authorize";

private static final String OIDC_PROMPT_PARAMETER_NAME = "prompt";
private static final String OIDC_NO_PROMPT_PARAMETER_VALUE = "none";
private static final String OIDC_AUTH_ERROR_URL = "https://openid.net/specs/openid-connect-core-1_0.html#AuthError";

private final AuthenticationManager authenticationManager;
private final RequestMatcher authorizationEndpointMatcher;
private final RedirectStrategy redirectStrategy = new DefaultRedirectStrategy();
Expand Down Expand Up @@ -158,6 +162,9 @@ protected void doFilterInternal(HttpServletRequest request, HttpServletResponse
return;
}

// Experimental support for OIDC 'prompt=none' parameter
boolean noPrompt = OIDC_NO_PROMPT_PARAMETER_VALUE.equals(request.getParameter(OIDC_PROMPT_PARAMETER_NAME));

try {
Authentication authentication = this.authenticationConverter.convert(request);
if (authentication instanceof AbstractAuthenticationToken) {
Expand All @@ -167,6 +174,13 @@ protected void doFilterInternal(HttpServletRequest request, HttpServletResponse
Authentication authenticationResult = this.authenticationManager.authenticate(authentication);

if (!authenticationResult.isAuthenticated()) {
if (noPrompt) {
throw new OAuth2AuthorizationCodeRequestAuthenticationException(
new OAuth2Error("login_required", "User login is required", OIDC_AUTH_ERROR_URL),
(OAuth2AuthorizationCodeRequestAuthenticationToken) authentication
);
}

// If the Principal (Resource Owner) is not authenticated then
// pass through the chain with the expectation that the authentication process
// will commence via AuthenticationEntryPoint
Expand All @@ -178,6 +192,14 @@ protected void doFilterInternal(HttpServletRequest request, HttpServletResponse
if (this.logger.isTraceEnabled()) {
this.logger.trace("Authorization consent is required");
}

if (noPrompt) {
throw new OAuth2AuthorizationCodeRequestAuthenticationException(
new OAuth2Error("consent_required", "Authorization consent is required", OIDC_AUTH_ERROR_URL),
(OAuth2AuthorizationCodeRequestAuthenticationToken) authentication
);
}

sendAuthorizationConsent(request, response,
(OAuth2AuthorizationCodeRequestAuthenticationToken) authentication,
(OAuth2AuthorizationConsentAuthenticationToken) authenticationResult);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,10 @@
import org.springframework.security.web.authentication.AuthenticationSuccessHandler;
import org.springframework.security.web.authentication.WebAuthenticationDetails;
import org.springframework.security.web.authentication.session.SessionAuthenticationStrategy;
import org.springframework.util.MultiValueMap;
import org.springframework.util.StringUtils;
import org.springframework.web.util.UriComponents;
import org.springframework.web.util.UriComponentsBuilder;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
Expand All @@ -86,6 +89,7 @@ public class OAuth2AuthorizationEndpointFilterTests {
private static final String AUTHORIZATION_URI = "https://provider.com/oauth2/authorize";
private static final String STATE = "state";
private static final String REMOTE_ADDRESS = "remote-address";
private static final String OIDC_PROMPT_PARAMETER_NAME = "prompt";
private AuthenticationManager authenticationManager;
private OAuth2AuthorizationEndpointFilter filter;
private TestingAuthenticationToken principal;
Expand Down Expand Up @@ -470,6 +474,41 @@ public void doFilterWhenAuthorizationRequestPrincipalNotAuthenticatedThenCommenc
verify(filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
}

@Test
public void doFilterWhenAuthorizationRequestPromptNonePrincipalNotAuthenticatedThenErrorResponse() throws Exception {
this.principal.setAuthenticated(false);
RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthenticationResult =
new OAuth2AuthorizationCodeRequestAuthenticationToken(
AUTHORIZATION_URI, registeredClient.getClientId(), principal,
registeredClient.getRedirectUris().iterator().next(), STATE, registeredClient.getScopes(), null);
authorizationCodeRequestAuthenticationResult.setAuthenticated(false);
when(this.authenticationManager.authenticate(any()))
.thenReturn(authorizationCodeRequestAuthenticationResult);

MockHttpServletRequest request = createAuthorizationRequest(registeredClient);
request.addParameter(OIDC_PROMPT_PARAMETER_NAME, "none");
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);

this.filter.doFilter(request, response, filterChain);

verifyNoInteractions(filterChain);

assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value());
assertThat(response.getRedirectedUrl()).isNotNull();
UriComponents uriComponents = UriComponentsBuilder.fromUriString(response.getRedirectedUrl()).build();

// check URI host and path
String withoutParams = UriComponentsBuilder.newInstance().uriComponents(uriComponents).replaceQuery("").toUriString();
assertThat(withoutParams).isEqualTo(request.getParameter(OAuth2ParameterNames.REDIRECT_URI));

// check URI query parameters in any order
MultiValueMap<String, String> parameters = uriComponents.getQueryParams();
assertMapHasSingleValue(parameters, OAuth2ParameterNames.ERROR, "login_required");
assertMapHasSingleValue(parameters, OAuth2ParameterNames.STATE, request.getParameter(OAuth2ParameterNames.STATE));
}

@Test
public void doFilterWhenAuthorizationRequestConsentRequiredWithCustomConsentUriThenRedirectConsentResponse() throws Exception {
Set<String> requestedScopes = new HashSet<>(Arrays.asList("scope1", "scope2"));
Expand Down Expand Up @@ -572,6 +611,41 @@ public void doFilterWhenAuthorizationRequestConsentRequiredWithPreviouslyApprove
}
}

@Test
public void doFilterWhenAuthorizationConsentRequiredWithPromptNoneThenErrorResponse() throws Exception {
RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
OAuth2AuthorizationConsentAuthenticationToken authorizationConsentAuthenticationResult =
new OAuth2AuthorizationConsentAuthenticationToken(
AUTHORIZATION_URI, registeredClient.getClientId(), principal,
STATE, new HashSet<>(), null);
authorizationConsentAuthenticationResult.setAuthenticated(true);
when(this.authenticationManager.authenticate(any()))
.thenReturn(authorizationConsentAuthenticationResult);

MockHttpServletRequest request = createAuthorizationRequest(registeredClient);
request.addParameter(OIDC_PROMPT_PARAMETER_NAME, "none");
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);

this.filter.doFilter(request, response, filterChain);

verify(this.authenticationManager).authenticate(any());
verifyNoInteractions(filterChain);

assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value());
assertThat(response.getRedirectedUrl()).isNotNull();
UriComponents uriComponents = UriComponentsBuilder.fromUriString(response.getRedirectedUrl()).build();

// check URI host and path
String withoutParams = UriComponentsBuilder.newInstance().uriComponents(uriComponents).replaceQuery("").toUriString();
assertThat(withoutParams).isEqualTo(request.getParameter(OAuth2ParameterNames.REDIRECT_URI));

// check URI query parameters in any order
MultiValueMap<String, String> parameters = uriComponents.getQueryParams();
assertMapHasSingleValue(parameters, OAuth2ParameterNames.ERROR, "consent_required");
assertMapHasSingleValue(parameters, OAuth2ParameterNames.STATE, request.getParameter(OAuth2ParameterNames.STATE));
}

@Test
public void doFilterWhenAuthorizationRequestAuthenticatedThenAuthorizationResponse() throws Exception {
RegisteredClient registeredClient = TestRegisteredClients.registeredClient()
Expand Down Expand Up @@ -720,4 +794,10 @@ private static String disabledScopeCheckbox(String scope) {
);
}

private static void assertMapHasSingleValue(MultiValueMap<String, String> map, String key, String value) {
assertThat(map.containsKey(key)).isTrue();
assertThat(map.get(key).size()).isEqualTo(1);
assertThat(map.getFirst(key)).isEqualTo(value);
}

}