/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.lambda.auth;

import com.dataiku.common.audit.AuditContextBase;
import com.dataiku.common.audit.AuthError;
import com.dataiku.dip.DKUApp;
import com.dataiku.dip.exceptions.NotAuthenticatedException;
import com.dataiku.dip.exceptions.UnauthorizedException;
import com.dataiku.dip.security.jwt.JwtVerificationService;
import com.dataiku.dip.util.HTTPClientBaseUtils;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dss.shadelib.com.google.common.annotations.VisibleForTesting;
import com.dataiku.dss.shadelib.com.nimbusds.jose.JOSEException;
import com.dataiku.dss.shadelib.com.nimbusds.jose.JWSAlgorithm;
import com.dataiku.dss.shadelib.com.nimbusds.jose.JWSHeader;
import com.dataiku.dss.shadelib.com.nimbusds.jose.KeySourceException;
import com.dataiku.dss.shadelib.com.nimbusds.jose.jwk.ECKey;
import com.dataiku.dss.shadelib.com.nimbusds.jose.jwk.JWKMatcher;
import com.dataiku.dss.shadelib.com.nimbusds.jose.jwk.JWKSelector;
import com.dataiku.dss.shadelib.com.nimbusds.jose.jwk.RSAKey;
import com.dataiku.dss.shadelib.com.nimbusds.jose.jwk.source.ImmutableJWKSet;
import com.dataiku.dss.shadelib.com.nimbusds.jose.jwk.source.JWKSource;
import com.dataiku.dss.shadelib.com.nimbusds.jose.jwk.source.JWKSourceBuilder;
import com.dataiku.dss.shadelib.com.nimbusds.jose.proc.BadJOSEException;
import com.dataiku.dss.shadelib.com.nimbusds.jose.proc.JWSVerificationKeySelector;
import com.dataiku.dss.shadelib.com.nimbusds.jose.proc.SecurityContext;
import com.dataiku.dss.shadelib.com.nimbusds.jose.util.DefaultResourceRetriever;
import com.dataiku.dss.shadelib.com.nimbusds.jose.util.ResourceRetriever;
import com.dataiku.dss.shadelib.com.nimbusds.jwt.JWTClaimsSet;
import com.dataiku.dss.shadelib.com.nimbusds.jwt.SignedJWT;
import com.dataiku.dss.shadelib.com.nimbusds.jwt.proc.DefaultJWTProcessor;
import com.dataiku.dss.shadelib.com.nimbusds.jwt.util.DateUtils;
import com.dataiku.lambda.auth.AuthMethod;
import com.dataiku.lambda.auth.BadAccessTokenException;
import com.dataiku.lambda.auth.DataikuKeySourceException;
import com.dataiku.lambda.auth.TemporaryAPIKeyService;
import com.dataiku.lambda.model.serverconfig.OAuth2Config;
import com.dataiku.lambda.model.serverconfig.QueryAPIKey;
import jakarta.servlet.http.HttpServletRequest;
import java.net.MalformedURLException;
import java.net.URL;
import java.security.Key;
import java.security.KeyManagementException;
import java.security.NoSuchAlgorithmException;
import java.security.SecureRandom;
import java.security.cert.X509Certificate;
import java.text.ParseException;
import java.text.SimpleDateFormat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Date;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLSocketFactory;
import javax.net.ssl.TrustManager;
import javax.net.ssl.X509TrustManager;
import org.apache.commons.lang.StringUtils;

public class OAuth2AuthMethod
implements AuthMethod<QueryAPIKey> {
    public static final long CLOCK_SKEW = 5L;
    private static final DKULogger logger = DKULogger.getLogger((String)"dku.apinode.auth.service");
    public static final String AUDIT_LOG_REF_JWT = "jwtAccessToken";
    public static final String AUDIT_LOG_AUTH_ERROR = "authErrorType";
    public static final String AUDIT_LOG_REF_JWT_CLIENT_ID = "jwtClientId";
    public static final String AUDIT_LOG_REF_JWT_SUB = "jwtSub";
    public static final String AUDIT_LOG_IS_INTERNAL_CALL = "isInternalCall";
    public static final String JWK_URL_CACHE_PROPERTY_TIME_TO_LIVE_FOR_TEST = "dku.oauth2.jwkurl.cache.timeToLive.fortest";
    private static final Map<String, JwkSourceCacheEntry> globalJwkSetPerJwkUri = new HashMap<String, JwkSourceCacheEntry>();
    private static final TemporaryAPIKeyService globalTemporaryAPIKeyService = new TemporaryAPIKeyService();
    private final JwtVerificationService jwtVerificationService;
    private final String serviceId;
    private final OAuth2Config oauth2Config;
    private final SimpleDateFormat dateFormat = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSSZ");

    protected OAuth2AuthMethod(JwtVerificationService jwtVerificationService, String serviceId, OAuth2Config oAuth2Config) {
        this.jwtVerificationService = jwtVerificationService;
        this.serviceId = serviceId;
        this.oauth2Config = oAuth2Config;
    }

    @Override
    public void validate(HttpServletRequest request) throws NotAuthenticatedException, UnauthorizedException {
        String accessToken;
        block6: {
            logger.debugV("OAuth2 configured, lets check first if it is an internal call using the internal API key", new Object[0]);
            accessToken = HTTPClientBaseUtils.decodeBearerAuth((HttpServletRequest)request);
            try {
                if (this.isValidInternalCall(request)) {
                    logger.debugV("The request is an internal call using a valid key", new Object[0]);
                    return;
                }
            }
            catch (UnauthorizedException e) {
                if (accessToken != null) break block6;
                throw e;
            }
        }
        logger.debugV("The request doesn't contain an API key, it's not an internal call.", new Object[0]);
        logger.debugV("Let's now verify the access token", new Object[0]);
        if (accessToken == null) {
            throw new NotAuthenticatedException("Access token not provided", AuthError.TOKEN_NOT_FOUND.name());
        }
        try {
            this.verifyAccessToken(this.serviceId, accessToken);
        }
        catch (UnauthorizedException e) {
            AuditContextBase.addCustom((String)AUDIT_LOG_AUTH_ERROR, (String)e.getType());
            AuditContextBase.addCustom((String)AUDIT_LOG_REF_JWT, (String)OAuth2AuthMethod.makeAccessTokenSafeToLog(accessToken));
            throw e;
        }
    }

    private boolean isValidInternalCall(HttpServletRequest request) throws UnauthorizedException {
        AuditContextBase.addCustom((String)AUDIT_LOG_IS_INTERNAL_CALL, (String)"false");
        try {
            globalTemporaryAPIKeyService.validate(request);
            AuditContextBase.addCustom((String)AUDIT_LOG_IS_INTERNAL_CALL, (String)"true");
            return true;
        }
        catch (NotAuthenticatedException e) {
            return false;
        }
    }

    @Override
    public QueryAPIKey getApiKeyForInternalCalls(HttpServletRequest req) {
        QueryAPIKey currentAPIKey = globalTemporaryAPIKeyService.getCurrentAPIKey(req);
        if (currentAPIKey != null) {
            return currentAPIKey;
        }
        return globalTemporaryAPIKeyService.getApiKeyForInternalCalls(req);
    }

    private void verifyAccessToken(String serviceId, String accessToken) throws UnauthorizedException {
        try {
            DefaultJWTProcessor jwtProcessor = new DefaultJWTProcessor();
            jwtProcessor.setJWSKeySelector((jwsHeader, securityContext) -> this.selectJWSKeys(this.getKeySource(), jwsHeader, securityContext));
            jwtProcessor.setJWTClaimsSetVerifier((jwtClaimsSet, securityContext) -> this.verifyClaims(serviceId, jwtClaimsSet));
            jwtProcessor.process(accessToken, null);
        }
        catch (BadAccessTokenException e) {
            throw new UnauthorizedException(e.getMessage(), e.getAuthError().name(), (Throwable)((Object)e));
        }
        catch (DataikuKeySourceException e) {
            throw new UnauthorizedException(e.getMessage(), e.getAuthError().name(), (Throwable)((Object)e));
        }
        catch (ParseException e) {
            throw new UnauthorizedException("Invalid Access token format", AuthError.TOKEN_INVALID_FORMAT.name(), (Throwable)e);
        }
        catch (BadJOSEException e) {
            throw new UnauthorizedException(e.getMessage(), AuthError.JWT_INVALID_SIGNATURE.name(), (Throwable)e);
        }
        catch (JOSEException e) {
            throw new UnauthorizedException(e.getMessage(), "UNKNOWN", (Throwable)e);
        }
    }

    private List<Key> selectJWSKeys(JWKSource<SecurityContext> keySource, JWSHeader jwsHeader, SecurityContext securityContext) throws KeySourceException {
        List keysFound;
        if (jwsHeader.getAlgorithm() == null || !JWSAlgorithm.Family.EC.contains((Object)jwsHeader.getAlgorithm()) && !JWSAlgorithm.Family.RSA.contains((Object)jwsHeader.getAlgorithm())) {
            logger.warnV("The JWT algorithm %s is not supported. Only EC or RSA is supported currently", new Object[]{jwsHeader.getAlgorithm()});
            throw new DataikuKeySourceException("The JWT algorithm '" + String.valueOf(jwsHeader.getAlgorithm()) + "' is not supported. Only EC or RSA is supported currently", AuthError.JWT_ALGO_NOT_SUPPORTED);
        }
        if (!StringUtils.isBlank((String)jwsHeader.getKeyID())) {
            keysFound = keySource.get(new JWKSelector(new JWKMatcher.Builder().keyID(jwsHeader.getKeyID()).build()), null).stream().map(jwk -> {
                try {
                    if (jwk instanceof ECKey) {
                        return ((ECKey)jwk).toECPublicKey();
                    }
                    if (jwk instanceof RSAKey) {
                        return ((RSAKey)jwk).toRSAPublicKey();
                    }
                    logger.warnV("The JWK found '%s' is neither a RSA or EC key, JWK instance of '%s'. We skip this key", new Object[]{jwk.toJSONString(), jwk.getClass()});
                    return null;
                }
                catch (JOSEException e) {
                    logger.warnV("Could not get public key from key '%s' of type '%s'. We skip this key", new Object[]{jwk.toJSONString(), jwk.getClass()});
                    return null;
                }
            }).filter(Objects::nonNull).collect(Collectors.toList());
            if (keysFound.size() == 0) {
                logger.warnV("Couldn't find the JWK corresponding to kid %s in remote jwk set", new Object[]{jwsHeader.getKeyID()});
                throw new DataikuKeySourceException("Failed to find JWK for the following kid='" + jwsHeader.getKeyID() + "'", AuthError.JWK_NOT_FOUND);
            }
        } else {
            keysFound = new JWSVerificationKeySelector(jwsHeader.getAlgorithm(), keySource).selectJWSKeys(jwsHeader, securityContext);
            if (keysFound.size() == 0) {
                logger.warnV("JWT has no KID and we couldn't find the JWK corresponding to the algorithm %s in remote jwk set", new Object[]{jwsHeader.getAlgorithm()});
                throw new DataikuKeySourceException("Failed to find JWK for the following alg='" + String.valueOf(jwsHeader.getAlgorithm()) + "'", AuthError.JWK_NOT_FOUND);
            }
        }
        return keysFound;
    }

    private JWKSource<SecurityContext> getKeySource() throws DataikuKeySourceException {
        switch (this.oauth2Config.keysFormat) {
            case STATIC_JWKS: {
                try {
                    return new ImmutableJWKSet(this.oauth2Config.getJwksSet());
                }
                catch (ParseException e) {
                    logger.errorV("Invalid JWKs Set format '%s'", new Object[]{this.oauth2Config.jwksSet});
                    throw new DataikuKeySourceException("Invalid JWKs Set format '" + this.oauth2Config.jwksSet + "'", AuthError.JWKS_SET_INVALID_FORMAT, e);
                }
            }
        }
        return this.getJWKSource(this.oauth2Config);
    }

    private SSLSocketFactory getSSLSocketFactoryWithNoSSLCheck() {
        try {
            TrustManager[] trustAllManagers = new TrustManager[]{new X509TrustManager(){

                @Override
                public void checkClientTrusted(X509Certificate[] chain, String authType) {
                }

                @Override
                public void checkServerTrusted(X509Certificate[] chain, String authType) {
                }

                @Override
                public X509Certificate[] getAcceptedIssuers() {
                    return new X509Certificate[0];
                }
            }};
            SSLContext sslContext = SSLContext.getInstance("TLS");
            sslContext.init(null, trustAllManagers, new SecureRandom());
            return sslContext.getSocketFactory();
        }
        catch (KeyManagementException | NoSuchAlgorithmException e) {
            logger.warn((Object)"Could not create a SSL socket factory, continuing with SSL validation on", (Throwable)e);
            return null;
        }
    }

    private synchronized JWKSource<SecurityContext> getJWKSource(OAuth2Config oAuth2Config) throws DataikuKeySourceException {
        if (!globalJwkSetPerJwkUri.containsKey(oAuth2Config.jwksUri) || !OAuth2AuthMethod.globalJwkSetPerJwkUri.get((Object)oAuth2Config.jwksUri).oAuth2Config.equals((Object)oAuth2Config)) {
            try {
                DefaultResourceRetriever defaultResourceRetriever;
                if (!oAuth2Config.disableSSLCertificatesCheck) {
                    defaultResourceRetriever = new DefaultResourceRetriever(oAuth2Config.jwkUriConnectTimeout, oAuth2Config.jwkUriReadTimeout, oAuth2Config.jwkUriSizeLimit);
                } else {
                    logger.info((Object)"Disable SSL check for the JWK_URI");
                    defaultResourceRetriever = new DefaultResourceRetriever(oAuth2Config.jwkUriConnectTimeout, oAuth2Config.jwkUriReadTimeout, oAuth2Config.jwkUriSizeLimit, true, this.getSSLSocketFactoryWithNoSSLCheck());
                }
                JWKSourceBuilder jwkSourceBuilder = JWKSourceBuilder.create((URL)new URL(oAuth2Config.jwksUri), (ResourceRetriever)defaultResourceRetriever);
                this.jwtVerificationService.dkuDefaultJwkSrcConfiguration(jwkSourceBuilder, oAuth2Config.jwksUri);
                if (DKUApp.getProperty((String)JWK_URL_CACHE_PROPERTY_TIME_TO_LIVE_FOR_TEST, null) != null) {
                    int timeToLive = DKUApp.getProperty((String)JWK_URL_CACHE_PROPERTY_TIME_TO_LIVE_FOR_TEST, (int)300000);
                    jwkSourceBuilder.rateLimited(false);
                    jwkSourceBuilder.refreshAheadCache(false).cache((long)timeToLive, 15000L);
                }
                globalJwkSetPerJwkUri.put(oAuth2Config.jwksUri, new JwkSourceCacheEntry(oAuth2Config, (JWKSource<SecurityContext>)jwkSourceBuilder.build()));
            }
            catch (MalformedURLException e) {
                logger.errorV("Malformed JWKs URI '%s'", new Object[]{oAuth2Config.jwksUri});
                throw new DataikuKeySourceException("Malformed JWKs URI '" + oAuth2Config.jwksUri + "'", AuthError.JWKS_SET_INVALID_FORMAT, e);
            }
        }
        return OAuth2AuthMethod.globalJwkSetPerJwkUri.get((Object)oAuth2Config.jwksUri).jwkSource;
    }

    private void verifyClaims(String serviceId, JWTClaimsSet jwtClaimsSet) throws BadAccessTokenException {
        try {
            this.auditLogJWT(serviceId, jwtClaimsSet);
            if (jwtClaimsSet.getExpirationTime() == null) {
                logger.debugV("The token has no expiration time", new Object[0]);
                throw new BadAccessTokenException("No expired time in access token", AuthError.JWT_NO_EXPIRED_TIME);
            }
            Date now = new Date();
            if (!DateUtils.isAfter((Date)jwtClaimsSet.getExpirationTime(), (Date)now, (long)5L)) {
                logger.debugV("The token has an expiration time '%s' + %s sec < now(%s)", new Object[]{this.dateFormat.format(jwtClaimsSet.getExpirationTime()), this.dateFormat.format(now), 5L});
                throw new BadAccessTokenException("Expired access token", AuthError.JWT_EXPIRED);
            }
            if (jwtClaimsSet.getIssueTime() != null && !DateUtils.isBefore((Date)jwtClaimsSet.getIssueTime(), (Date)now, (long)5L)) {
                logger.debugV("The token has an issuer time '%s' - %s sec > now(%s)", new Object[]{this.dateFormat.format(jwtClaimsSet.getIssueTime()), this.dateFormat.format(now), 5L});
                throw new BadAccessTokenException("Invalid IAT for access token", AuthError.JWT_IAT_INVALID);
            }
            if (jwtClaimsSet.getNotBeforeTime() != null && !DateUtils.isBefore((Date)jwtClaimsSet.getNotBeforeTime(), (Date)now, (long)5L)) {
                logger.debugV("The token is marked to be not consumed before time '%s' - %s sec < now(%s)", new Object[]{this.dateFormat.format(jwtClaimsSet.getIssueTime()), this.dateFormat.format(now), 5L});
                throw new BadAccessTokenException("Invalid NBT for access token", AuthError.JWT_NBT_INVALID);
            }
            if (!this.oauth2Config.issuer.equals(jwtClaimsSet.getIssuer())) {
                logger.debugV("Invalid issuer '%s', expecting '%s'", new Object[]{jwtClaimsSet.getIssuer(), this.oauth2Config.issuer});
                throw new BadAccessTokenException("Invalid issuer '" + jwtClaimsSet.getIssuer() + "', expecting '" + this.oauth2Config.issuer + "'", AuthError.JWT_INVALID_ISSUER);
            }
            if (!StringUtils.isBlank((String)this.oauth2Config.audience) && !jwtClaimsSet.getAudience().contains(this.oauth2Config.audience)) {
                logger.debugV("Invalid audiences '%s', expecting '%s'", new Object[]{String.join((CharSequence)",", jwtClaimsSet.getAudience()), this.oauth2Config.audience});
                throw new BadAccessTokenException("Invalid audiences '" + String.join((CharSequence)",", jwtClaimsSet.getAudience()) + "', expecting '" + this.oauth2Config.audience + "'", AuthError.JWT_INVALID_AUDIENCE);
            }
            if (!StringUtils.isBlank((String)this.oauth2Config.scope)) {
                String scopeClaim = StringUtils.defaultIfBlank((String)this.oauth2Config.scopeClaimKey, (String)"scope");
                List currentScopes = switch (this.oauth2Config.scopeClaimFormat) {
                    case OAuth2Config.ScopeClaimFormat.ARRAY -> Stream.of(jwtClaimsSet.getStringArrayClaim(scopeClaim)).collect(Collectors.toList());
                    default -> Optional.ofNullable(jwtClaimsSet.getStringClaim(scopeClaim)).map(scopes -> Arrays.asList(scopes.split(" "))).orElse(new ArrayList());
                };
                if (currentScopes.size() == 0) {
                    logger.debugV("No scope present in the access token, expecting scope '%s'", new Object[]{this.oauth2Config.scope});
                    throw new BadAccessTokenException("No scope found", AuthError.JWT_NO_SCOPE);
                }
                if (!currentScopes.contains(this.oauth2Config.scope)) {
                    logger.debugV("Invalid scopes '%s', doesn't contain '%s'", new Object[]{currentScopes, this.oauth2Config.scope});
                    throw new BadAccessTokenException("Invalid scopes '" + String.valueOf(currentScopes) + "', doesn't contain '" + this.oauth2Config.scope + "'", AuthError.JWT_OAUTH2_INVALID_SCOPE);
                }
            } else {
                logger.infoV("No scope requirement defined for this API %s", new Object[]{serviceId});
            }
        }
        catch (ParseException e) {
            throw new BadAccessTokenException("Invalid Access token format", AuthError.TOKEN_INVALID_FORMAT, e);
        }
    }

    private void auditLogJWT(String serviceId, JWTClaimsSet jwtClaimsSet) {
        block6: {
            if (this.oauth2Config.clientIdClaimKey != null) {
                try {
                    String clientId = jwtClaimsSet.getStringClaim(this.oauth2Config.clientIdClaimKey);
                    if (clientId == null) {
                        logger.warnV("Client ID not defined in the access token claims", new Object[]{this.oauth2Config.clientIdClaimKey});
                        break block6;
                    }
                    AuditContextBase.addCustom((String)AUDIT_LOG_REF_JWT_CLIENT_ID, (String)clientId);
                }
                catch (ParseException e) {
                    logger.warnV((Throwable)e, "Could not read the client id claim '%s' from the access token", new Object[]{this.oauth2Config.clientIdClaimKey});
                }
            } else {
                logger.warnV("The client ID claim is not setup for service '%s'. We recommend setting up the client ID claim key to help you troubleshoot any authorization issue.", new Object[]{serviceId});
            }
        }
        String sub = jwtClaimsSet.getSubject();
        if (sub != null) {
            AuditContextBase.addCustom((String)AUDIT_LOG_REF_JWT_SUB, (String)sub);
        }
    }

    @VisibleForTesting
    static String makeAccessTokenSafeToLog(String accessToken) {
        try {
            SignedJWT jwt = SignedJWT.parse((String)accessToken);
            return String.valueOf(jwt.getHeader().toBase64URL()) + "." + String.valueOf(jwt.getPayload().toBase64URL()) + ".STRIP_SIGNATURE_FOR_SAFETY";
        }
        catch (ParseException e) {
            logger.debug((Object)"Access token not a JWT, not printing it for safety reason.");
            return "";
        }
    }

    private static class JwkSourceCacheEntry {
        public OAuth2Config oAuth2Config;
        public JWKSource<SecurityContext> jwkSource;

        public JwkSourceCacheEntry(OAuth2Config oAuth2Config, JWKSource<SecurityContext> jwkSource) {
            this.oAuth2Config = oAuth2Config;
            this.jwkSource = jwkSource;
        }
    }
}

