RevocationAuthenticationProvider.java

/*
 * Copyright 2022 Global Crop Diversity Trust
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.gringlobal.custom.security;

import org.gringlobal.custom.security.service.JwtTokenIdExtractor;
import org.springframework.security.authentication.AuthenticationProvider;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
import org.springframework.security.oauth2.core.OAuth2Token;
import org.springframework.security.oauth2.core.oidc.OidcIdToken;
import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationCode;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientAuthenticationToken;
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2TokenRevocationAuthenticationToken;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
import org.springframework.util.Assert;

public class RevocationAuthenticationProvider implements AuthenticationProvider {
	private final OAuth2AuthorizationService authorizationService;
	private final JwtTokenIdExtractor jwtTokenIdExtractor;
	
	public RevocationAuthenticationProvider(OAuth2AuthorizationService authorizationService, JwtTokenIdExtractor jwtTokenIdExtractor) {
		Assert.notNull(authorizationService, "authorizationService cannot be null");
		this.authorizationService = authorizationService;
		this.jwtTokenIdExtractor = jwtTokenIdExtractor;
	}

	@Override
	public Authentication authenticate(Authentication authentication) throws AuthenticationException {
		OAuth2TokenRevocationAuthenticationToken tokenRevocationAuthentication =
			(OAuth2TokenRevocationAuthenticationToken) authentication;

		OAuth2ClientAuthenticationToken clientPrincipal = null;
		if (OAuth2ClientAuthenticationToken.class.isAssignableFrom(authentication.getPrincipal().getClass())) {
			clientPrincipal = (OAuth2ClientAuthenticationToken) authentication.getPrincipal();
		}
		if (clientPrincipal == null || !clientPrincipal.isAuthenticated()) {
			throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_CLIENT);
		}
		
		RegisteredClient registeredClient = clientPrincipal.getRegisteredClient();

		OAuth2Authorization authorization = this.authorizationService.findByToken(
			tokenRevocationAuthentication.getToken(), null);
		if (authorization == null) {
			// Return the authentication request when token not found
			return tokenRevocationAuthentication;
		}

		if (!registeredClient.getId().equals(authorization.getRegisteredClientId())) {
			throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_CLIENT);
		}

		OAuth2Authorization.Token<OAuth2Token> token = getToken(tokenRevocationAuthentication.getToken(), authorization);
		
		// @formatter:off
		OAuth2Authorization.Builder authorizationBuilder = OAuth2Authorization.from(authorization)
			.token(token.getToken(),
				(metadata) ->
					metadata.put(OAuth2Authorization.Token.INVALIDATED_METADATA_NAME, true));

		if (OAuth2RefreshToken.class.isAssignableFrom(token.getClass())) {
			authorizationBuilder.token(
				authorization.getAccessToken().getToken(),
				(metadata) ->
					metadata.put(OAuth2Authorization.Token.INVALIDATED_METADATA_NAME, true));

			OAuth2Authorization.Token<OAuth2AuthorizationCode> authorizationCode =
				authorization.getToken(OAuth2AuthorizationCode.class);
			if (authorizationCode != null && !authorizationCode.isInvalidated()) {
				authorizationBuilder.token(
					authorizationCode.getToken(),
					(metadata) ->
						metadata.put(OAuth2Authorization.Token.INVALIDATED_METADATA_NAME, true));
			}
		}
		// @formatter:on

		authorization = authorizationBuilder.build();
		this.authorizationService.save(authorization);

		return new OAuth2TokenRevocationAuthenticationToken(token.getToken(), clientPrincipal);
	}

	private  <T extends OAuth2Token> OAuth2Authorization.Token<T> getToken(String tokenValue, OAuth2Authorization authorization) {
		Assert.hasText(tokenValue, "tokenValue cannot be empty");
		String tokenId = jwtTokenIdExtractor.getJwtTokenId(tokenValue);
		if (tokenId == null) {
			return null;
		}
		OAuth2Authorization.Token<T> token = (OAuth2Authorization.Token<T>) authorization.getAccessToken();
		if (token != null && token.getToken().getTokenValue().equals(tokenId)) {
			return token;
		}

		token = (OAuth2Authorization.Token<T>) authorization.getToken(OAuth2RefreshToken.class);
		if (token != null && token.getToken().getTokenValue().equals(tokenId)) {
			return token;
		}

		token = (OAuth2Authorization.Token<T>) authorization.getToken(OidcIdToken.class);
		if (token != null && token.getToken().getTokenValue().equals(tokenId)) {
			return token;
		}

		return null;
	}

	@Override
	public boolean supports(Class<?> authentication) {
		return OAuth2TokenRevocationAuthenticationToken.class.isAssignableFrom(authentication);
	}
}