diff --git a/docs/modules/ROOT/pages/protocol-endpoints.adoc b/docs/modules/ROOT/pages/protocol-endpoints.adoc index 94a55f9a72..ecb4b3f914 100644 --- a/docs/modules/ROOT/pages/protocol-endpoints.adoc +++ b/docs/modules/ROOT/pages/protocol-endpoints.adoc @@ -167,7 +167,7 @@ public SecurityFilterChain authorizationServerSecurityFilterChain(HttpSecurity h * `*AuthenticationConverter*` -- An `OAuth2DeviceAuthorizationRequestAuthenticationConverter`. * `*AuthenticationManager*` -- An `AuthenticationManager` composed of `OAuth2DeviceAuthorizationRequestAuthenticationProvider`. * `*AuthenticationSuccessHandler*` -- An internal implementation that handles an "`authenticated`" `OAuth2DeviceAuthorizationRequestAuthenticationToken` and returns the `OAuth2DeviceAuthorizationResponse`. -* `*AuthenticationFailureHandler*` -- An internal implementation that uses the `OAuth2Error` associated with the `OAuth2AuthenticationException` and returns the `OAuth2Error` response. +* `*AuthenticationFailureHandler*` -- An `OAuth2ErrorAuthenticationFailureHandler` instance that handles the `OAuth2Error` associated with the `OAuth2AuthenticationException` and returns the `OAuth2Error` response. [[oauth2-device-verification-endpoint]] == OAuth2 Device Verification Endpoint @@ -264,7 +264,7 @@ The supported https://datatracker.ietf.org/doc/html/rfc6749#section-1.3[authoriz * `*AuthenticationConverter*` -- A `DelegatingAuthenticationConverter` composed of `OAuth2AuthorizationCodeAuthenticationConverter`, `OAuth2RefreshTokenAuthenticationConverter`, `OAuth2ClientCredentialsAuthenticationConverter`, and `OAuth2DeviceCodeAuthenticationConverter`. * `*AuthenticationManager*` -- An `AuthenticationManager` composed of `OAuth2AuthorizationCodeAuthenticationProvider`, `OAuth2RefreshTokenAuthenticationProvider`, `OAuth2ClientCredentialsAuthenticationProvider`, and `OAuth2DeviceCodeAuthenticationProvider`. * `*AuthenticationSuccessHandler*` -- An internal implementation that handles an `OAuth2AccessTokenAuthenticationToken` and returns the `OAuth2AccessTokenResponse`. -* `*AuthenticationFailureHandler*` -- An internal implementation that uses the `OAuth2Error` associated with the `OAuth2AuthenticationException` and returns the `OAuth2Error` response. +* `*AuthenticationFailureHandler*` -- An `OAuth2ErrorAuthenticationFailureHandler` instance that handles the `OAuth2Error` associated with the `OAuth2AuthenticationException` and returns the `OAuth2Error` response. [[oauth2-token-introspection-endpoint]] == OAuth2 Token Introspection Endpoint @@ -311,7 +311,7 @@ public SecurityFilterChain authorizationServerSecurityFilterChain(HttpSecurity h * `*AuthenticationConverter*` -- An `OAuth2TokenIntrospectionAuthenticationConverter`. * `*AuthenticationManager*` -- An `AuthenticationManager` composed of `OAuth2TokenIntrospectionAuthenticationProvider`. * `*AuthenticationSuccessHandler*` -- An internal implementation that handles an "`authenticated`" `OAuth2TokenIntrospectionAuthenticationToken` and returns the `OAuth2TokenIntrospection` response. -* `*AuthenticationFailureHandler*` -- An internal implementation that uses the `OAuth2Error` associated with the `OAuth2AuthenticationException` and returns the `OAuth2Error` response. +* `*AuthenticationFailureHandler*` -- An `OAuth2ErrorAuthenticationFailureHandler` instance that handles the `OAuth2Error` associated with the `OAuth2AuthenticationException` and returns the `OAuth2Error` response. [[oauth2-token-revocation-endpoint]] == OAuth2 Token Revocation Endpoint @@ -358,7 +358,7 @@ public SecurityFilterChain authorizationServerSecurityFilterChain(HttpSecurity h * `*AuthenticationConverter*` -- An `OAuth2TokenRevocationAuthenticationConverter`. * `*AuthenticationManager*` -- An `AuthenticationManager` composed of `OAuth2TokenRevocationAuthenticationProvider`. * `*AuthenticationSuccessHandler*` -- An internal implementation that handles an "`authenticated`" `OAuth2TokenRevocationAuthenticationToken` and returns the OAuth2 revocation response. -* `*AuthenticationFailureHandler*` -- An internal implementation that uses the `OAuth2Error` associated with the `OAuth2AuthenticationException` and returns the `OAuth2Error` response. +* `*AuthenticationFailureHandler*` -- An `OAuth2ErrorAuthenticationFailureHandler` instance that handles the `OAuth2Error` associated with the `OAuth2AuthenticationException` and returns the `OAuth2Error` response. [[oauth2-authorization-server-metadata-endpoint]] == OAuth2 Authorization Server Metadata Endpoint diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2DeviceAuthorizationEndpointFilter.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2DeviceAuthorizationEndpointFilter.java index f5473f9232..9dd8f44e42 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2DeviceAuthorizationEndpointFilter.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2DeviceAuthorizationEndpointFilter.java @@ -24,14 +24,12 @@ import org.springframework.core.log.LogMessage; import org.springframework.http.HttpMethod; -import org.springframework.http.HttpStatus; import org.springframework.http.converter.HttpMessageConverter; import org.springframework.http.server.ServletServerHttpResponse; import org.springframework.security.authentication.AbstractAuthenticationToken; import org.springframework.security.authentication.AuthenticationDetailsSource; import org.springframework.security.authentication.AuthenticationManager; import org.springframework.security.core.Authentication; -import org.springframework.security.core.AuthenticationException; import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.OAuth2DeviceCode; @@ -44,6 +42,7 @@ import org.springframework.security.oauth2.server.authorization.authentication.OAuth2DeviceAuthorizationRequestAuthenticationProvider; import org.springframework.security.oauth2.server.authorization.authentication.OAuth2DeviceAuthorizationRequestAuthenticationToken; import org.springframework.security.oauth2.server.authorization.context.AuthorizationServerContextHolder; +import org.springframework.security.oauth2.server.authorization.web.authentication.OAuth2ErrorAuthenticationFailureHandler; import org.springframework.security.oauth2.server.authorization.web.authentication.OAuth2DeviceAuthorizationRequestAuthenticationConverter; import org.springframework.security.web.authentication.AuthenticationConverter; import org.springframework.security.web.authentication.AuthenticationFailureHandler; @@ -76,13 +75,11 @@ public final class OAuth2DeviceAuthorizationEndpointFilter extends OncePerReques private final RequestMatcher deviceAuthorizationEndpointMatcher; private final HttpMessageConverter deviceAuthorizationHttpResponseConverter = new OAuth2DeviceAuthorizationResponseHttpMessageConverter(); - private final HttpMessageConverter errorHttpResponseConverter = - new OAuth2ErrorHttpMessageConverter(); private AuthenticationDetailsSource authenticationDetailsSource = new WebAuthenticationDetailsSource(); private AuthenticationConverter authenticationConverter; private AuthenticationSuccessHandler authenticationSuccessHandler = this::sendDeviceAuthorizationResponse; - private AuthenticationFailureHandler authenticationFailureHandler = this::sendErrorResponse; + private AuthenticationFailureHandler authenticationFailureHandler = new OAuth2ErrorAuthenticationFailureHandler(); private String verificationUri = OAuth2DeviceVerificationEndpointFilter.DEFAULT_DEVICE_VERIFICATION_ENDPOINT_URI; /** @@ -225,13 +222,4 @@ private void sendDeviceAuthorizationResponse(HttpServletRequest request, HttpSer this.deviceAuthorizationHttpResponseConverter.write(deviceAuthorizationResponse, null, httpResponse); } - private void sendErrorResponse(HttpServletRequest request, HttpServletResponse response, - AuthenticationException authenticationException) throws IOException { - - OAuth2Error error = ((OAuth2AuthenticationException) authenticationException).getError(); - ServletServerHttpResponse httpResponse = new ServletServerHttpResponse(response); - httpResponse.setStatusCode(HttpStatus.BAD_REQUEST); - this.errorHttpResponseConverter.write(error, null, httpResponse); - } - } diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenEndpointFilter.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenEndpointFilter.java index e44106dd3e..4cea0e9ad1 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenEndpointFilter.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenEndpointFilter.java @@ -27,14 +27,12 @@ import org.springframework.core.log.LogMessage; import org.springframework.http.HttpMethod; -import org.springframework.http.HttpStatus; import org.springframework.http.converter.HttpMessageConverter; import org.springframework.http.server.ServletServerHttpResponse; import org.springframework.security.authentication.AbstractAuthenticationToken; import org.springframework.security.authentication.AuthenticationDetailsSource; import org.springframework.security.authentication.AuthenticationManager; import org.springframework.security.core.Authentication; -import org.springframework.security.core.AuthenticationException; import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; @@ -51,6 +49,7 @@ import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientCredentialsAuthenticationProvider; import org.springframework.security.oauth2.server.authorization.authentication.OAuth2DeviceCodeAuthenticationProvider; import org.springframework.security.oauth2.server.authorization.authentication.OAuth2RefreshTokenAuthenticationProvider; +import org.springframework.security.oauth2.server.authorization.web.authentication.OAuth2ErrorAuthenticationFailureHandler; import org.springframework.security.oauth2.server.authorization.web.authentication.DelegatingAuthenticationConverter; import org.springframework.security.oauth2.server.authorization.web.authentication.OAuth2AuthorizationCodeAuthenticationConverter; import org.springframework.security.oauth2.server.authorization.web.authentication.OAuth2ClientCredentialsAuthenticationConverter; @@ -107,13 +106,11 @@ public final class OAuth2TokenEndpointFilter extends OncePerRequestFilter { private final RequestMatcher tokenEndpointMatcher; private final HttpMessageConverter accessTokenHttpResponseConverter = new OAuth2AccessTokenResponseHttpMessageConverter(); - private final HttpMessageConverter errorHttpResponseConverter = - new OAuth2ErrorHttpMessageConverter(); private AuthenticationDetailsSource authenticationDetailsSource = new WebAuthenticationDetailsSource(); private AuthenticationConverter authenticationConverter; private AuthenticationSuccessHandler authenticationSuccessHandler = this::sendAccessTokenResponse; - private AuthenticationFailureHandler authenticationFailureHandler = this::sendErrorResponse; + private AuthenticationFailureHandler authenticationFailureHandler = new OAuth2ErrorAuthenticationFailureHandler(); /** * Constructs an {@code OAuth2TokenEndpointFilter} using the provided parameters. @@ -250,15 +247,6 @@ private void sendAccessTokenResponse(HttpServletRequest request, HttpServletResp this.accessTokenHttpResponseConverter.write(accessTokenResponse, null, httpResponse); } - private void sendErrorResponse(HttpServletRequest request, HttpServletResponse response, - AuthenticationException exception) throws IOException { - - OAuth2Error error = ((OAuth2AuthenticationException) exception).getError(); - ServletServerHttpResponse httpResponse = new ServletServerHttpResponse(response); - httpResponse.setStatusCode(HttpStatus.BAD_REQUEST); - this.errorHttpResponseConverter.write(error, null, httpResponse); - } - private static void throwError(String errorCode, String parameterName) { OAuth2Error error = new OAuth2Error(errorCode, "OAuth 2.0 Parameter: " + parameterName, DEFAULT_ERROR_URI); throw new OAuth2AuthenticationException(error); diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenIntrospectionEndpointFilter.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenIntrospectionEndpointFilter.java index edf623cab1..b471febe86 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenIntrospectionEndpointFilter.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenIntrospectionEndpointFilter.java @@ -1,5 +1,5 @@ /* - * Copyright 2020-2022 the original author or authors. + * Copyright 2020-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,12 +24,10 @@ import org.springframework.core.log.LogMessage; import org.springframework.http.HttpMethod; -import org.springframework.http.HttpStatus; import org.springframework.http.converter.HttpMessageConverter; import org.springframework.http.server.ServletServerHttpResponse; import org.springframework.security.authentication.AuthenticationManager; import org.springframework.security.core.Authentication; -import org.springframework.security.core.AuthenticationException; import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.OAuth2Error; @@ -38,6 +36,7 @@ import org.springframework.security.oauth2.server.authorization.authentication.OAuth2TokenIntrospectionAuthenticationProvider; import org.springframework.security.oauth2.server.authorization.authentication.OAuth2TokenIntrospectionAuthenticationToken; import org.springframework.security.oauth2.server.authorization.http.converter.OAuth2TokenIntrospectionHttpMessageConverter; +import org.springframework.security.oauth2.server.authorization.web.authentication.OAuth2ErrorAuthenticationFailureHandler; import org.springframework.security.oauth2.server.authorization.web.authentication.OAuth2TokenIntrospectionAuthenticationConverter; import org.springframework.security.web.authentication.AuthenticationConverter; import org.springframework.security.web.authentication.AuthenticationFailureHandler; @@ -69,9 +68,8 @@ public final class OAuth2TokenIntrospectionEndpointFilter extends OncePerRequest private AuthenticationConverter authenticationConverter; private final HttpMessageConverter tokenIntrospectionHttpResponseConverter = new OAuth2TokenIntrospectionHttpMessageConverter(); - private final HttpMessageConverter errorHttpResponseConverter = new OAuth2ErrorHttpMessageConverter(); private AuthenticationSuccessHandler authenticationSuccessHandler = this::sendIntrospectionResponse; - private AuthenticationFailureHandler authenticationFailureHandler = this::sendErrorResponse; + private AuthenticationFailureHandler authenticationFailureHandler = new OAuth2ErrorAuthenticationFailureHandler(); /** * Constructs an {@code OAuth2TokenIntrospectionEndpointFilter} using the provided parameters. @@ -166,12 +164,4 @@ private void sendIntrospectionResponse(HttpServletRequest request, HttpServletRe this.tokenIntrospectionHttpResponseConverter.write(tokenClaims, null, httpResponse); } - private void sendErrorResponse(HttpServletRequest request, HttpServletResponse response, - AuthenticationException exception) throws IOException { - OAuth2Error error = ((OAuth2AuthenticationException) exception).getError(); - ServletServerHttpResponse httpResponse = new ServletServerHttpResponse(response); - httpResponse.setStatusCode(HttpStatus.BAD_REQUEST); - this.errorHttpResponseConverter.write(error, null, httpResponse); - } - } diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenRevocationEndpointFilter.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenRevocationEndpointFilter.java index 5f9fda4591..991772713d 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenRevocationEndpointFilter.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenRevocationEndpointFilter.java @@ -1,5 +1,5 @@ /* - * Copyright 2020-2022 the original author or authors. + * Copyright 2020-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -25,17 +25,15 @@ import org.springframework.core.log.LogMessage; import org.springframework.http.HttpMethod; import org.springframework.http.HttpStatus; -import org.springframework.http.converter.HttpMessageConverter; -import org.springframework.http.server.ServletServerHttpResponse; import org.springframework.security.authentication.AuthenticationManager; import org.springframework.security.core.Authentication; -import org.springframework.security.core.AuthenticationException; import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.http.converter.OAuth2ErrorHttpMessageConverter; import org.springframework.security.oauth2.server.authorization.authentication.OAuth2TokenRevocationAuthenticationProvider; import org.springframework.security.oauth2.server.authorization.authentication.OAuth2TokenRevocationAuthenticationToken; +import org.springframework.security.oauth2.server.authorization.web.authentication.OAuth2ErrorAuthenticationFailureHandler; import org.springframework.security.oauth2.server.authorization.web.authentication.OAuth2TokenRevocationAuthenticationConverter; import org.springframework.security.web.authentication.AuthenticationConverter; import org.springframework.security.web.authentication.AuthenticationFailureHandler; @@ -65,10 +63,8 @@ public final class OAuth2TokenRevocationEndpointFilter extends OncePerRequestFil private final AuthenticationManager authenticationManager; private final RequestMatcher tokenRevocationEndpointMatcher; private AuthenticationConverter authenticationConverter; - private final HttpMessageConverter errorHttpResponseConverter = - new OAuth2ErrorHttpMessageConverter(); private AuthenticationSuccessHandler authenticationSuccessHandler = this::sendRevocationSuccessResponse; - private AuthenticationFailureHandler authenticationFailureHandler = this::sendErrorResponse; + private AuthenticationFailureHandler authenticationFailureHandler = new OAuth2ErrorAuthenticationFailureHandler(); /** * Constructs an {@code OAuth2TokenRevocationEndpointFilter} using the provided parameters. @@ -157,12 +153,4 @@ private void sendRevocationSuccessResponse(HttpServletRequest request, HttpServl response.setStatus(HttpStatus.OK.value()); } - private void sendErrorResponse(HttpServletRequest request, HttpServletResponse response, - AuthenticationException exception) throws IOException { - OAuth2Error error = ((OAuth2AuthenticationException) exception).getError(); - ServletServerHttpResponse httpResponse = new ServletServerHttpResponse(response); - httpResponse.setStatusCode(HttpStatus.BAD_REQUEST); - this.errorHttpResponseConverter.write(error, null, httpResponse); - } - } diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2ErrorAuthenticationFailureHandler.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2ErrorAuthenticationFailureHandler.java new file mode 100644 index 0000000000..3923cbee22 --- /dev/null +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2ErrorAuthenticationFailureHandler.java @@ -0,0 +1,64 @@ +/* + * Copyright 2020-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.server.authorization.web.authentication; + +import java.io.IOException; + +import jakarta.servlet.ServletException; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; +import org.springframework.http.HttpStatus; +import org.springframework.http.converter.HttpMessageConverter; +import org.springframework.http.server.ServletServerHttpResponse; +import org.springframework.security.core.AuthenticationException; +import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.core.OAuth2Error; +import org.springframework.security.oauth2.core.http.converter.OAuth2ErrorHttpMessageConverter; +import org.springframework.security.web.authentication.AuthenticationFailureHandler; + +/** + * A default implementation of an {@link AuthenticationFailureHandler} used for handling an {@link OAuth2AuthenticationException} + * and returning the {@link OAuth2Error Error Response}. + * + * @author Dmitriy Dubson + * @see AuthenticationFailureHandler + * @since 1.2.0 + */ +public final class OAuth2ErrorAuthenticationFailureHandler implements AuthenticationFailureHandler { + private HttpMessageConverter errorHttpResponseConverter = new OAuth2ErrorHttpMessageConverter(); + + @Override + public void onAuthenticationFailure(HttpServletRequest request, HttpServletResponse response, AuthenticationException authenticationException) throws IOException, ServletException { + ServletServerHttpResponse httpResponse = new ServletServerHttpResponse(response); + httpResponse.setStatusCode(HttpStatus.BAD_REQUEST); + + if(authenticationException instanceof OAuth2AuthenticationException) { + OAuth2Error error = ((OAuth2AuthenticationException) authenticationException).getError(); + this.errorHttpResponseConverter.write(error, null, httpResponse); + } else { + this.errorHttpResponseConverter.write(new OAuth2Error("a warning?"), null, httpResponse); + } + } + + /** + * Sets OAuth error HTTP message converter to write to upon authentication failure + * + * @param errorHttpResponseConverter the error HTTP message converter to set + */ + public void setErrorHttpResponseConverter(HttpMessageConverter errorHttpResponseConverter) { + this.errorHttpResponseConverter = errorHttpResponseConverter; + } +} diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2ErrorAuthenticationFailureHandlerTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2ErrorAuthenticationFailureHandlerTests.java new file mode 100644 index 0000000000..4a8531513f --- /dev/null +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2ErrorAuthenticationFailureHandlerTests.java @@ -0,0 +1,82 @@ +/* + * Copyright 2020-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.server.authorization.web.authentication; + +import java.io.IOException; + +import jakarta.servlet.ServletException; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; +import org.junit.jupiter.api.Test; +import org.springframework.http.HttpStatus; +import org.springframework.http.converter.HttpMessageConverter; +import org.springframework.http.server.ServletServerHttpResponse; +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.security.authentication.BadCredentialsException; +import org.springframework.security.core.AuthenticationException; +import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.core.OAuth2Error; +import org.springframework.security.oauth2.core.OAuth2ErrorCodes; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.ArgumentMatchers.isNull; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; + +/** + * Tests for {@link OAuth2ErrorAuthenticationFailureHandler} + * + * @author Dmitriy Dubson + */ +public class OAuth2ErrorAuthenticationFailureHandlerTests { + + @Test + @SuppressWarnings("unchecked") + public void onAuthenticationFailure() throws IOException, ServletException { + HttpMessageConverter errorHttpMessageConverter = (HttpMessageConverter) mock(HttpMessageConverter.class); + HttpServletRequest request = new MockHttpServletRequest(); + HttpServletResponse response = new MockHttpServletResponse(); + + OAuth2Error invalidRequestError = new OAuth2Error(OAuth2ErrorCodes.INVALID_REQUEST); + AuthenticationException authenticationException = new OAuth2AuthenticationException(invalidRequestError); + OAuth2ErrorAuthenticationFailureHandler handler = new OAuth2ErrorAuthenticationFailureHandler(); + handler.setErrorHttpResponseConverter(errorHttpMessageConverter); + + handler.onAuthenticationFailure(request, response, authenticationException); + + verify(errorHttpMessageConverter).write(eq(invalidRequestError), isNull(), any(ServletServerHttpResponse.class)); + assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value()); + } + + @Test + public void onAuthenticationFailure_ifExceptionProvidedIsNotOAuth2AuthenticationException() throws ServletException, IOException { + HttpMessageConverter errorHttpMessageConverter = (HttpMessageConverter) mock(HttpMessageConverter.class); + HttpServletRequest request = new MockHttpServletRequest(); + HttpServletResponse response = new MockHttpServletResponse(); + + OAuth2ErrorAuthenticationFailureHandler handler = new OAuth2ErrorAuthenticationFailureHandler(); + handler.setErrorHttpResponseConverter(errorHttpMessageConverter); + + handler.onAuthenticationFailure(request, response, new BadCredentialsException("Not a valid exception.")); + + verify(errorHttpMessageConverter).write(any(OAuth2Error.class), isNull(), any(ServletServerHttpResponse.class)); + assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value()); + } + +}