diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/web/OidcClientRegistrationEndpointFilter.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/web/OidcClientRegistrationEndpointFilter.java index 5de3352ed8..d368086e71 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/web/OidcClientRegistrationEndpointFilter.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/web/OidcClientRegistrationEndpointFilter.java @@ -16,6 +16,7 @@ package org.springframework.security.oauth2.server.authorization.oidc.web; import java.io.IOException; +import java.util.function.BiConsumer; import jakarta.servlet.FilterChain; import jakarta.servlet.ServletException; @@ -29,7 +30,6 @@ import org.springframework.http.server.ServletServerHttpResponse; import org.springframework.security.authentication.AuthenticationManager; import org.springframework.security.core.Authentication; -import org.springframework.security.core.AuthenticationException; import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.OAuth2Error; @@ -42,6 +42,7 @@ import org.springframework.security.oauth2.server.authorization.oidc.authentication.OidcClientRegistrationAuthenticationToken; import org.springframework.security.oauth2.server.authorization.oidc.http.converter.OidcClientRegistrationHttpMessageConverter; import org.springframework.security.oauth2.server.authorization.oidc.web.authentication.OidcClientRegistrationAuthenticationConverter; +import org.springframework.security.oauth2.server.authorization.web.authentication.OAuth2ErrorAuthenticationFailureHandler; import org.springframework.security.web.authentication.AuthenticationConverter; import org.springframework.security.web.authentication.AuthenticationFailureHandler; import org.springframework.security.web.authentication.AuthenticationSuccessHandler; @@ -81,7 +82,19 @@ public final class OidcClientRegistrationEndpointFilter extends OncePerRequestFi new OAuth2ErrorHttpMessageConverter(); private AuthenticationConverter authenticationConverter = new OidcClientRegistrationAuthenticationConverter(); private AuthenticationSuccessHandler authenticationSuccessHandler = this::sendClientRegistrationResponse; - private AuthenticationFailureHandler authenticationFailureHandler = this::sendErrorResponse; + BiConsumer authenticationFailureHttpResponseCustomizer = (e, httpResponse) -> { + HttpStatus httpStatus = HttpStatus.BAD_REQUEST; + if (OAuth2ErrorCodes.INVALID_TOKEN.equals(e.getError().getErrorCode())) { + httpStatus = HttpStatus.UNAUTHORIZED; + } else if (OAuth2ErrorCodes.INSUFFICIENT_SCOPE.equals(e.getError().getErrorCode())) { + httpStatus = HttpStatus.FORBIDDEN; + } else if (OAuth2ErrorCodes.INVALID_CLIENT.equals(e.getError().getErrorCode())) { + httpStatus = HttpStatus.UNAUTHORIZED; + } + httpResponse.setStatusCode(httpStatus); + }; + + private AuthenticationFailureHandler authenticationFailureHandler = new OAuth2ErrorAuthenticationFailureHandler(); /** * Constructs an {@code OidcClientRegistrationEndpointFilter} using the provided parameters. @@ -90,6 +103,7 @@ public final class OidcClientRegistrationEndpointFilter extends OncePerRequestFi */ public OidcClientRegistrationEndpointFilter(AuthenticationManager authenticationManager) { this(authenticationManager, DEFAULT_OIDC_CLIENT_REGISTRATION_ENDPOINT_URI); + ((OAuth2ErrorAuthenticationFailureHandler) authenticationFailureHandler).setHttpResponseCustomizer(authenticationFailureHttpResponseCustomizer); } /** @@ -206,20 +220,4 @@ private void sendClientRegistrationResponse(HttpServletRequest request, HttpServ this.clientRegistrationHttpMessageConverter.write(clientRegistration, null, httpResponse); } - private void sendErrorResponse(HttpServletRequest request, HttpServletResponse response, - AuthenticationException authenticationException) throws IOException { - OAuth2Error error = ((OAuth2AuthenticationException) authenticationException).getError(); - HttpStatus httpStatus = HttpStatus.BAD_REQUEST; - if (OAuth2ErrorCodes.INVALID_TOKEN.equals(error.getErrorCode())) { - httpStatus = HttpStatus.UNAUTHORIZED; - } else if (OAuth2ErrorCodes.INSUFFICIENT_SCOPE.equals(error.getErrorCode())) { - httpStatus = HttpStatus.FORBIDDEN; - } else if (OAuth2ErrorCodes.INVALID_CLIENT.equals(error.getErrorCode())) { - httpStatus = HttpStatus.UNAUTHORIZED; - } - ServletServerHttpResponse httpResponse = new ServletServerHttpResponse(response); - httpResponse.setStatusCode(httpStatus); - this.errorHttpResponseConverter.write(error, null, httpResponse); - } - } diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/web/OidcUserInfoEndpointFilter.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/web/OidcUserInfoEndpointFilter.java index 610b387104..106f471809 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/web/OidcUserInfoEndpointFilter.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/web/OidcUserInfoEndpointFilter.java @@ -16,6 +16,7 @@ package org.springframework.security.oauth2.server.authorization.oidc.web; import java.io.IOException; +import java.util.function.BiConsumer; import jakarta.servlet.FilterChain; import jakarta.servlet.ServletException; @@ -29,7 +30,6 @@ import org.springframework.http.server.ServletServerHttpResponse; import org.springframework.security.authentication.AuthenticationManager; import org.springframework.security.core.Authentication; -import org.springframework.security.core.AuthenticationException; import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.OAuth2Error; @@ -39,6 +39,7 @@ import org.springframework.security.oauth2.server.authorization.oidc.authentication.OidcUserInfoAuthenticationProvider; import org.springframework.security.oauth2.server.authorization.oidc.authentication.OidcUserInfoAuthenticationToken; import org.springframework.security.oauth2.server.authorization.oidc.http.converter.OidcUserInfoHttpMessageConverter; +import org.springframework.security.oauth2.server.authorization.web.authentication.OAuth2ErrorAuthenticationFailureHandler; import org.springframework.security.web.authentication.AuthenticationConverter; import org.springframework.security.web.authentication.AuthenticationFailureHandler; import org.springframework.security.web.authentication.AuthenticationSuccessHandler; @@ -74,7 +75,18 @@ public final class OidcUserInfoEndpointFilter extends OncePerRequestFilter { new OAuth2ErrorHttpMessageConverter(); private AuthenticationConverter authenticationConverter = this::createAuthentication; private AuthenticationSuccessHandler authenticationSuccessHandler = this::sendUserInfoResponse; - private AuthenticationFailureHandler authenticationFailureHandler = this::sendErrorResponse; + + BiConsumer authenticationFailureHttpResponseCustomizer = (e, httpResponse) -> { + HttpStatus httpStatus = HttpStatus.BAD_REQUEST; + if (e.getError().getErrorCode().equals(OAuth2ErrorCodes.INVALID_TOKEN)) { + httpStatus = HttpStatus.UNAUTHORIZED; + } else if (e.getError().getErrorCode().equals(OAuth2ErrorCodes.INSUFFICIENT_SCOPE)) { + httpStatus = HttpStatus.FORBIDDEN; + } + httpResponse.setStatusCode(httpStatus); + }; + + private AuthenticationFailureHandler authenticationFailureHandler = new OAuth2ErrorAuthenticationFailureHandler(); /** * Constructs an {@code OidcUserInfoEndpointFilter} using the provided parameters. @@ -83,6 +95,7 @@ public final class OidcUserInfoEndpointFilter extends OncePerRequestFilter { */ public OidcUserInfoEndpointFilter(AuthenticationManager authenticationManager) { this(authenticationManager, DEFAULT_OIDC_USER_INFO_ENDPOINT_URI); + ((OAuth2ErrorAuthenticationFailureHandler) authenticationFailureHandler).setHttpResponseCustomizer(authenticationFailureHttpResponseCustomizer); } /** @@ -184,18 +197,4 @@ private void sendUserInfoResponse(HttpServletRequest request, HttpServletRespons this.userInfoHttpMessageConverter.write(userInfoAuthenticationToken.getUserInfo(), null, httpResponse); } - private void sendErrorResponse(HttpServletRequest request, HttpServletResponse response, - AuthenticationException authenticationException) throws IOException { - OAuth2Error error = ((OAuth2AuthenticationException) authenticationException).getError(); - HttpStatus httpStatus = HttpStatus.BAD_REQUEST; - if (error.getErrorCode().equals(OAuth2ErrorCodes.INVALID_TOKEN)) { - httpStatus = HttpStatus.UNAUTHORIZED; - } else if (error.getErrorCode().equals(OAuth2ErrorCodes.INSUFFICIENT_SCOPE)) { - httpStatus = HttpStatus.FORBIDDEN; - } - ServletServerHttpResponse httpResponse = new ServletServerHttpResponse(response); - httpResponse.setStatusCode(httpStatus); - this.errorHttpResponseConverter.write(error, null, httpResponse); - } - } diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2ClientAuthenticationFilter.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2ClientAuthenticationFilter.java index 6926314d82..b527a81392 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2ClientAuthenticationFilter.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2ClientAuthenticationFilter.java @@ -17,6 +17,8 @@ import java.io.IOException; import java.util.Arrays; +import java.util.function.BiConsumer; +import java.util.function.Function; import jakarta.servlet.FilterChain; import jakarta.servlet.ServletException; @@ -25,7 +27,6 @@ import org.springframework.core.log.LogMessage; import org.springframework.http.HttpStatus; -import org.springframework.http.converter.HttpMessageConverter; import org.springframework.http.server.ServletServerHttpResponse; import org.springframework.security.authentication.AbstractAuthenticationToken; import org.springframework.security.authentication.AuthenticationDetailsSource; @@ -37,7 +38,6 @@ import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.OAuth2ErrorCodes; -import org.springframework.security.oauth2.core.http.converter.OAuth2ErrorHttpMessageConverter; import org.springframework.security.oauth2.server.authorization.authentication.ClientSecretAuthenticationProvider; import org.springframework.security.oauth2.server.authorization.authentication.JwtClientAssertionAuthenticationProvider; import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientAuthenticationToken; @@ -46,6 +46,7 @@ import org.springframework.security.oauth2.server.authorization.web.authentication.ClientSecretPostAuthenticationConverter; import org.springframework.security.oauth2.server.authorization.web.authentication.DelegatingAuthenticationConverter; import org.springframework.security.oauth2.server.authorization.web.authentication.JwtClientAssertionAuthenticationConverter; +import org.springframework.security.oauth2.server.authorization.web.authentication.OAuth2ErrorAuthenticationFailureHandler; import org.springframework.security.oauth2.server.authorization.web.authentication.PublicClientAuthenticationConverter; import org.springframework.security.web.authentication.AuthenticationConverter; import org.springframework.security.web.authentication.AuthenticationFailureHandler; @@ -75,13 +76,32 @@ public final class OAuth2ClientAuthenticationFilter extends OncePerRequestFilter { private final AuthenticationManager authenticationManager; private final RequestMatcher requestMatcher; - private final HttpMessageConverter errorHttpResponseConverter = new OAuth2ErrorHttpMessageConverter(); private final AuthenticationDetailsSource authenticationDetailsSource = new WebAuthenticationDetailsSource(); private AuthenticationConverter authenticationConverter; private AuthenticationSuccessHandler authenticationSuccessHandler = this::onAuthenticationSuccess; private AuthenticationFailureHandler authenticationFailureHandler = this::onAuthenticationFailure; + private final OAuth2ErrorAuthenticationFailureHandler failureHandler = new OAuth2ErrorAuthenticationFailureHandler(); + + // We don't want to reveal too much information to the caller so just return the error code + Function errorCustomizer = (e) -> new OAuth2Error(e.getError().getErrorCode()); + + BiConsumer authenticationFailureHttpResponseCustomizer = (e, httpResponse) -> { + // TODO + // The authorization server MAY return an HTTP 401 (Unauthorized) status code + // to indicate which HTTP authentication schemes are supported. + // If the client attempted to authenticate via the "Authorization" request header field, + // the authorization server MUST respond with an HTTP 401 (Unauthorized) status code and + // include the "WWW-Authenticate" response header field + // matching the authentication scheme used by the client. + if (OAuth2ErrorCodes.INVALID_CLIENT.equals(e.getError().getErrorCode())) { + httpResponse.setStatusCode(HttpStatus.UNAUTHORIZED); + } else { + httpResponse.setStatusCode(HttpStatus.BAD_REQUEST); + } + }; + /** * Constructs an {@code OAuth2ClientAuthenticationFilter} using the provided parameters. * @@ -100,6 +120,8 @@ public OAuth2ClientAuthenticationFilter(AuthenticationManager authenticationMana new ClientSecretBasicAuthenticationConverter(), new ClientSecretPostAuthenticationConverter(), new PublicClientAuthenticationConverter())); + failureHandler.setErrorCustomizer(errorCustomizer); + failureHandler.setHttpResponseCustomizer(authenticationFailureHttpResponseCustomizer); } @Override @@ -178,28 +200,9 @@ private void onAuthenticationSuccess(HttpServletRequest request, HttpServletResp } private void onAuthenticationFailure(HttpServletRequest request, HttpServletResponse response, - AuthenticationException exception) throws IOException { - + AuthenticationException exception) throws IOException, ServletException { SecurityContextHolder.clearContext(); - - // TODO - // The authorization server MAY return an HTTP 401 (Unauthorized) status code - // to indicate which HTTP authentication schemes are supported. - // If the client attempted to authenticate via the "Authorization" request header field, - // the authorization server MUST respond with an HTTP 401 (Unauthorized) status code and - // include the "WWW-Authenticate" response header field - // matching the authentication scheme used by the client. - - OAuth2Error error = ((OAuth2AuthenticationException) exception).getError(); - ServletServerHttpResponse httpResponse = new ServletServerHttpResponse(response); - if (OAuth2ErrorCodes.INVALID_CLIENT.equals(error.getErrorCode())) { - httpResponse.setStatusCode(HttpStatus.UNAUTHORIZED); - } else { - httpResponse.setStatusCode(HttpStatus.BAD_REQUEST); - } - // We don't want to reveal too much information to the caller so just return the error code - OAuth2Error errorResponse = new OAuth2Error(error.getErrorCode()); - this.errorHttpResponseConverter.write(errorResponse, null, httpResponse); + failureHandler.onAuthenticationFailure(request, response, exception); } private static void validateClientIdentifier(Authentication authentication) { diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2ErrorAuthenticationFailureHandler.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2ErrorAuthenticationFailureHandler.java index a5d350a65f..5ff7ebf8d5 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2ErrorAuthenticationFailureHandler.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2ErrorAuthenticationFailureHandler.java @@ -16,6 +16,8 @@ package org.springframework.security.oauth2.server.authorization.web.authentication; import java.io.IOException; +import java.util.function.BiConsumer; +import java.util.function.Function; import jakarta.servlet.ServletException; import jakarta.servlet.http.HttpServletRequest; @@ -47,14 +49,18 @@ public final class OAuth2ErrorAuthenticationFailureHandler implements Authentica private final Log logger = LogFactory.getLog(getClass()); private HttpMessageConverter errorResponseConverter = new OAuth2ErrorHttpMessageConverter(); + private BiConsumer httpResponseCustomizer = (exception, httpResponse) -> httpResponse.setStatusCode(HttpStatus.BAD_REQUEST); + + private Function errorCustomizer = OAuth2AuthenticationException::getError; + @Override public void onAuthenticationFailure(HttpServletRequest request, HttpServletResponse response, AuthenticationException authenticationException) throws IOException, ServletException { - ServletServerHttpResponse httpResponse = new ServletServerHttpResponse(response); - httpResponse.setStatusCode(HttpStatus.BAD_REQUEST); + if (authenticationException instanceof OAuth2AuthenticationException exception) { + ServletServerHttpResponse httpResponse = new ServletServerHttpResponse(response); + httpResponseCustomizer.accept(exception, httpResponse); - if (authenticationException instanceof OAuth2AuthenticationException) { - OAuth2Error error = ((OAuth2AuthenticationException) authenticationException).getError(); + OAuth2Error error = errorCustomizer.apply(exception); this.errorResponseConverter.write(error, null, httpResponse); } else { if (this.logger.isWarnEnabled()) { @@ -75,4 +81,11 @@ public void setErrorResponseConverter(HttpMessageConverter errorRes this.errorResponseConverter = errorResponseConverter; } + public void setHttpResponseCustomizer(BiConsumer httpResponseCustomizer) { + this.httpResponseCustomizer = httpResponseCustomizer; + } + + public void setErrorCustomizer(Function errorCustomizer) { + this.errorCustomizer = errorCustomizer; + } }