Skip to content

Commit

Permalink
draft: initial implementation of customization functions in auth erro…
Browse files Browse the repository at this point in the history
…r response handler

fixes spring-projectsgh-1369
  • Loading branch information
ddubson committed Jan 19, 2024
1 parent 01441f9 commit f2e0434
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 62 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package org.springframework.security.oauth2.server.authorization.oidc.web;

import java.io.IOException;
import java.util.function.BiConsumer;

import jakarta.servlet.FilterChain;
import jakarta.servlet.ServletException;
Expand All @@ -29,7 +30,6 @@
import org.springframework.http.server.ServletServerHttpResponse;
import org.springframework.security.authentication.AuthenticationManager;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2Error;
Expand All @@ -42,6 +42,7 @@
import org.springframework.security.oauth2.server.authorization.oidc.authentication.OidcClientRegistrationAuthenticationToken;
import org.springframework.security.oauth2.server.authorization.oidc.http.converter.OidcClientRegistrationHttpMessageConverter;
import org.springframework.security.oauth2.server.authorization.oidc.web.authentication.OidcClientRegistrationAuthenticationConverter;
import org.springframework.security.oauth2.server.authorization.web.authentication.OAuth2ErrorAuthenticationFailureHandler;
import org.springframework.security.web.authentication.AuthenticationConverter;
import org.springframework.security.web.authentication.AuthenticationFailureHandler;
import org.springframework.security.web.authentication.AuthenticationSuccessHandler;
Expand Down Expand Up @@ -81,7 +82,19 @@ public final class OidcClientRegistrationEndpointFilter extends OncePerRequestFi
new OAuth2ErrorHttpMessageConverter();
private AuthenticationConverter authenticationConverter = new OidcClientRegistrationAuthenticationConverter();
private AuthenticationSuccessHandler authenticationSuccessHandler = this::sendClientRegistrationResponse;
private AuthenticationFailureHandler authenticationFailureHandler = this::sendErrorResponse;
BiConsumer<OAuth2AuthenticationException, ServletServerHttpResponse> authenticationFailureHttpResponseCustomizer = (e, httpResponse) -> {
HttpStatus httpStatus = HttpStatus.BAD_REQUEST;
if (OAuth2ErrorCodes.INVALID_TOKEN.equals(e.getError().getErrorCode())) {
httpStatus = HttpStatus.UNAUTHORIZED;
} else if (OAuth2ErrorCodes.INSUFFICIENT_SCOPE.equals(e.getError().getErrorCode())) {
httpStatus = HttpStatus.FORBIDDEN;
} else if (OAuth2ErrorCodes.INVALID_CLIENT.equals(e.getError().getErrorCode())) {
httpStatus = HttpStatus.UNAUTHORIZED;
}
httpResponse.setStatusCode(httpStatus);
};

private AuthenticationFailureHandler authenticationFailureHandler = new OAuth2ErrorAuthenticationFailureHandler();

/**
* Constructs an {@code OidcClientRegistrationEndpointFilter} using the provided parameters.
Expand All @@ -90,6 +103,7 @@ public final class OidcClientRegistrationEndpointFilter extends OncePerRequestFi
*/
public OidcClientRegistrationEndpointFilter(AuthenticationManager authenticationManager) {
this(authenticationManager, DEFAULT_OIDC_CLIENT_REGISTRATION_ENDPOINT_URI);
((OAuth2ErrorAuthenticationFailureHandler) authenticationFailureHandler).setHttpResponseCustomizer(authenticationFailureHttpResponseCustomizer);
}

/**
Expand Down Expand Up @@ -206,20 +220,4 @@ private void sendClientRegistrationResponse(HttpServletRequest request, HttpServ
this.clientRegistrationHttpMessageConverter.write(clientRegistration, null, httpResponse);
}

private void sendErrorResponse(HttpServletRequest request, HttpServletResponse response,
AuthenticationException authenticationException) throws IOException {
OAuth2Error error = ((OAuth2AuthenticationException) authenticationException).getError();
HttpStatus httpStatus = HttpStatus.BAD_REQUEST;
if (OAuth2ErrorCodes.INVALID_TOKEN.equals(error.getErrorCode())) {
httpStatus = HttpStatus.UNAUTHORIZED;
} else if (OAuth2ErrorCodes.INSUFFICIENT_SCOPE.equals(error.getErrorCode())) {
httpStatus = HttpStatus.FORBIDDEN;
} else if (OAuth2ErrorCodes.INVALID_CLIENT.equals(error.getErrorCode())) {
httpStatus = HttpStatus.UNAUTHORIZED;
}
ServletServerHttpResponse httpResponse = new ServletServerHttpResponse(response);
httpResponse.setStatusCode(httpStatus);
this.errorHttpResponseConverter.write(error, null, httpResponse);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package org.springframework.security.oauth2.server.authorization.oidc.web;

import java.io.IOException;
import java.util.function.BiConsumer;

import jakarta.servlet.FilterChain;
import jakarta.servlet.ServletException;
Expand All @@ -29,7 +30,6 @@
import org.springframework.http.server.ServletServerHttpResponse;
import org.springframework.security.authentication.AuthenticationManager;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2Error;
Expand All @@ -39,6 +39,7 @@
import org.springframework.security.oauth2.server.authorization.oidc.authentication.OidcUserInfoAuthenticationProvider;
import org.springframework.security.oauth2.server.authorization.oidc.authentication.OidcUserInfoAuthenticationToken;
import org.springframework.security.oauth2.server.authorization.oidc.http.converter.OidcUserInfoHttpMessageConverter;
import org.springframework.security.oauth2.server.authorization.web.authentication.OAuth2ErrorAuthenticationFailureHandler;
import org.springframework.security.web.authentication.AuthenticationConverter;
import org.springframework.security.web.authentication.AuthenticationFailureHandler;
import org.springframework.security.web.authentication.AuthenticationSuccessHandler;
Expand Down Expand Up @@ -74,7 +75,18 @@ public final class OidcUserInfoEndpointFilter extends OncePerRequestFilter {
new OAuth2ErrorHttpMessageConverter();
private AuthenticationConverter authenticationConverter = this::createAuthentication;
private AuthenticationSuccessHandler authenticationSuccessHandler = this::sendUserInfoResponse;
private AuthenticationFailureHandler authenticationFailureHandler = this::sendErrorResponse;

BiConsumer<OAuth2AuthenticationException, ServletServerHttpResponse> authenticationFailureHttpResponseCustomizer = (e, httpResponse) -> {
HttpStatus httpStatus = HttpStatus.BAD_REQUEST;
if (e.getError().getErrorCode().equals(OAuth2ErrorCodes.INVALID_TOKEN)) {
httpStatus = HttpStatus.UNAUTHORIZED;
} else if (e.getError().getErrorCode().equals(OAuth2ErrorCodes.INSUFFICIENT_SCOPE)) {
httpStatus = HttpStatus.FORBIDDEN;
}
httpResponse.setStatusCode(httpStatus);
};

private AuthenticationFailureHandler authenticationFailureHandler = new OAuth2ErrorAuthenticationFailureHandler();

/**
* Constructs an {@code OidcUserInfoEndpointFilter} using the provided parameters.
Expand All @@ -83,6 +95,7 @@ public final class OidcUserInfoEndpointFilter extends OncePerRequestFilter {
*/
public OidcUserInfoEndpointFilter(AuthenticationManager authenticationManager) {
this(authenticationManager, DEFAULT_OIDC_USER_INFO_ENDPOINT_URI);
((OAuth2ErrorAuthenticationFailureHandler) authenticationFailureHandler).setHttpResponseCustomizer(authenticationFailureHttpResponseCustomizer);
}

/**
Expand Down Expand Up @@ -184,18 +197,4 @@ private void sendUserInfoResponse(HttpServletRequest request, HttpServletRespons
this.userInfoHttpMessageConverter.write(userInfoAuthenticationToken.getUserInfo(), null, httpResponse);
}

private void sendErrorResponse(HttpServletRequest request, HttpServletResponse response,
AuthenticationException authenticationException) throws IOException {
OAuth2Error error = ((OAuth2AuthenticationException) authenticationException).getError();
HttpStatus httpStatus = HttpStatus.BAD_REQUEST;
if (error.getErrorCode().equals(OAuth2ErrorCodes.INVALID_TOKEN)) {
httpStatus = HttpStatus.UNAUTHORIZED;
} else if (error.getErrorCode().equals(OAuth2ErrorCodes.INSUFFICIENT_SCOPE)) {
httpStatus = HttpStatus.FORBIDDEN;
}
ServletServerHttpResponse httpResponse = new ServletServerHttpResponse(response);
httpResponse.setStatusCode(httpStatus);
this.errorHttpResponseConverter.write(error, null, httpResponse);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

import java.io.IOException;
import java.util.Arrays;
import java.util.function.BiConsumer;
import java.util.function.Function;

import jakarta.servlet.FilterChain;
import jakarta.servlet.ServletException;
Expand All @@ -25,7 +27,6 @@

import org.springframework.core.log.LogMessage;
import org.springframework.http.HttpStatus;
import org.springframework.http.converter.HttpMessageConverter;
import org.springframework.http.server.ServletServerHttpResponse;
import org.springframework.security.authentication.AbstractAuthenticationToken;
import org.springframework.security.authentication.AuthenticationDetailsSource;
Expand All @@ -37,7 +38,6 @@
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
import org.springframework.security.oauth2.core.http.converter.OAuth2ErrorHttpMessageConverter;
import org.springframework.security.oauth2.server.authorization.authentication.ClientSecretAuthenticationProvider;
import org.springframework.security.oauth2.server.authorization.authentication.JwtClientAssertionAuthenticationProvider;
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientAuthenticationToken;
Expand All @@ -46,6 +46,7 @@
import org.springframework.security.oauth2.server.authorization.web.authentication.ClientSecretPostAuthenticationConverter;
import org.springframework.security.oauth2.server.authorization.web.authentication.DelegatingAuthenticationConverter;
import org.springframework.security.oauth2.server.authorization.web.authentication.JwtClientAssertionAuthenticationConverter;
import org.springframework.security.oauth2.server.authorization.web.authentication.OAuth2ErrorAuthenticationFailureHandler;
import org.springframework.security.oauth2.server.authorization.web.authentication.PublicClientAuthenticationConverter;
import org.springframework.security.web.authentication.AuthenticationConverter;
import org.springframework.security.web.authentication.AuthenticationFailureHandler;
Expand Down Expand Up @@ -75,13 +76,32 @@
public final class OAuth2ClientAuthenticationFilter extends OncePerRequestFilter {
private final AuthenticationManager authenticationManager;
private final RequestMatcher requestMatcher;
private final HttpMessageConverter<OAuth2Error> errorHttpResponseConverter = new OAuth2ErrorHttpMessageConverter();
private final AuthenticationDetailsSource<HttpServletRequest, ?> authenticationDetailsSource =
new WebAuthenticationDetailsSource();
private AuthenticationConverter authenticationConverter;
private AuthenticationSuccessHandler authenticationSuccessHandler = this::onAuthenticationSuccess;
private AuthenticationFailureHandler authenticationFailureHandler = this::onAuthenticationFailure;

private final OAuth2ErrorAuthenticationFailureHandler failureHandler = new OAuth2ErrorAuthenticationFailureHandler();

// We don't want to reveal too much information to the caller so just return the error code
Function<OAuth2AuthenticationException, OAuth2Error> errorCustomizer = (e) -> new OAuth2Error(e.getError().getErrorCode());

BiConsumer<OAuth2AuthenticationException, ServletServerHttpResponse> authenticationFailureHttpResponseCustomizer = (e, httpResponse) -> {
// TODO
// The authorization server MAY return an HTTP 401 (Unauthorized) status code
// to indicate which HTTP authentication schemes are supported.
// If the client attempted to authenticate via the "Authorization" request header field,
// the authorization server MUST respond with an HTTP 401 (Unauthorized) status code and
// include the "WWW-Authenticate" response header field
// matching the authentication scheme used by the client.
if (OAuth2ErrorCodes.INVALID_CLIENT.equals(e.getError().getErrorCode())) {
httpResponse.setStatusCode(HttpStatus.UNAUTHORIZED);
} else {
httpResponse.setStatusCode(HttpStatus.BAD_REQUEST);
}
};

/**
* Constructs an {@code OAuth2ClientAuthenticationFilter} using the provided parameters.
*
Expand All @@ -100,6 +120,8 @@ public OAuth2ClientAuthenticationFilter(AuthenticationManager authenticationMana
new ClientSecretBasicAuthenticationConverter(),
new ClientSecretPostAuthenticationConverter(),
new PublicClientAuthenticationConverter()));
failureHandler.setErrorCustomizer(errorCustomizer);
failureHandler.setHttpResponseCustomizer(authenticationFailureHttpResponseCustomizer);
}

@Override
Expand Down Expand Up @@ -178,28 +200,9 @@ private void onAuthenticationSuccess(HttpServletRequest request, HttpServletResp
}

private void onAuthenticationFailure(HttpServletRequest request, HttpServletResponse response,
AuthenticationException exception) throws IOException {

AuthenticationException exception) throws IOException, ServletException {
SecurityContextHolder.clearContext();

// TODO
// The authorization server MAY return an HTTP 401 (Unauthorized) status code
// to indicate which HTTP authentication schemes are supported.
// If the client attempted to authenticate via the "Authorization" request header field,
// the authorization server MUST respond with an HTTP 401 (Unauthorized) status code and
// include the "WWW-Authenticate" response header field
// matching the authentication scheme used by the client.

OAuth2Error error = ((OAuth2AuthenticationException) exception).getError();
ServletServerHttpResponse httpResponse = new ServletServerHttpResponse(response);
if (OAuth2ErrorCodes.INVALID_CLIENT.equals(error.getErrorCode())) {
httpResponse.setStatusCode(HttpStatus.UNAUTHORIZED);
} else {
httpResponse.setStatusCode(HttpStatus.BAD_REQUEST);
}
// We don't want to reveal too much information to the caller so just return the error code
OAuth2Error errorResponse = new OAuth2Error(error.getErrorCode());
this.errorHttpResponseConverter.write(errorResponse, null, httpResponse);
failureHandler.onAuthenticationFailure(request, response, exception);
}

private static void validateClientIdentifier(Authentication authentication) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
package org.springframework.security.oauth2.server.authorization.web.authentication;

import java.io.IOException;
import java.util.function.BiConsumer;
import java.util.function.Function;

import jakarta.servlet.ServletException;
import jakarta.servlet.http.HttpServletRequest;
Expand Down Expand Up @@ -47,14 +49,18 @@ public final class OAuth2ErrorAuthenticationFailureHandler implements Authentica
private final Log logger = LogFactory.getLog(getClass());
private HttpMessageConverter<OAuth2Error> errorResponseConverter = new OAuth2ErrorHttpMessageConverter();

private BiConsumer<OAuth2AuthenticationException, ServletServerHttpResponse> httpResponseCustomizer = (exception, httpResponse) -> httpResponse.setStatusCode(HttpStatus.BAD_REQUEST);

private Function<OAuth2AuthenticationException, OAuth2Error> errorCustomizer = OAuth2AuthenticationException::getError;

@Override
public void onAuthenticationFailure(HttpServletRequest request, HttpServletResponse response,
AuthenticationException authenticationException) throws IOException, ServletException {
ServletServerHttpResponse httpResponse = new ServletServerHttpResponse(response);
httpResponse.setStatusCode(HttpStatus.BAD_REQUEST);
if (authenticationException instanceof OAuth2AuthenticationException exception) {
ServletServerHttpResponse httpResponse = new ServletServerHttpResponse(response);
httpResponseCustomizer.accept(exception, httpResponse);

if (authenticationException instanceof OAuth2AuthenticationException) {
OAuth2Error error = ((OAuth2AuthenticationException) authenticationException).getError();
OAuth2Error error = errorCustomizer.apply(exception);
this.errorResponseConverter.write(error, null, httpResponse);
} else {
if (this.logger.isWarnEnabled()) {
Expand All @@ -75,4 +81,11 @@ public void setErrorResponseConverter(HttpMessageConverter<OAuth2Error> errorRes
this.errorResponseConverter = errorResponseConverter;
}

public void setHttpResponseCustomizer(BiConsumer<OAuth2AuthenticationException, ServletServerHttpResponse> httpResponseCustomizer) {
this.httpResponseCustomizer = httpResponseCustomizer;
}

public void setErrorCustomizer(Function<OAuth2AuthenticationException, OAuth2Error> errorCustomizer) {
this.errorCustomizer = errorCustomizer;
}
}

0 comments on commit f2e0434

Please sign in to comment.