/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.connections;

import com.dataiku.dip.CodedRuntimeException;
import com.dataiku.dip.DKUApp;
import com.dataiku.dip.ProxySettings;
import com.dataiku.dip.connections.AbstractCloudStorageConnection;
import com.dataiku.dip.connections.AbstractSQLConnection;
import com.dataiku.dip.connections.AccessTokenInjectingConnection;
import com.dataiku.dip.connections.AutoFastPathConnection;
import com.dataiku.dip.connections.AzureConnection;
import com.dataiku.dip.connections.ConnectionCredentialUtils;
import com.dataiku.dip.connections.ConnectionUtils;
import com.dataiku.dip.connections.ConnectionWithBasicCredential;
import com.dataiku.dip.connections.ConnectionWithEncryptedFields;
import com.dataiku.dip.connections.ConnectionWithPerUserOAuth2Credentials;
import com.dataiku.dip.connections.DSSConnection;
import com.dataiku.dip.connections.EC2Connection;
import com.dataiku.dip.connections.GCSConnection;
import com.dataiku.dip.connections.SQLConnectionProvider;
import com.dataiku.dip.connections.SimpleSQLDSSConnectionWithBasicCredential;
import com.dataiku.dip.coremodel.InfoMessage;
import com.dataiku.dip.dao.GeneralSettingsDAO;
import com.dataiku.dip.exceptions.CodedException;
import com.dataiku.dip.exceptions.ConfValidators;
import com.dataiku.dip.exceptions.DKUSecurityException;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.security.PasswordEncryptionService;
import com.dataiku.dip.security.model.ICredentialsService;
import com.dataiku.dip.security.model.OAuth2Client;
import com.dataiku.dip.server.SpringUtils;
import com.dataiku.dip.server.connections.ConnectionCodes;
import com.dataiku.dip.sql.TrinoSQLDialect;
import com.dataiku.dip.utils.ErrorContext;
import com.dataiku.dip.utils.JSON;
import com.dataiku.dip.variables.VariablesContext;
import com.dataiku.dip.variables.VariablesService;
import com.dataiku.dss.shadelib.com.nimbusds.oauth2.sdk.ParseException;
import com.google.gson.JsonElement;
import com.google.gson.JsonObject;
import java.io.IOException;
import java.lang.reflect.Field;
import java.net.URISyntaxException;
import java.nio.charset.StandardCharsets;
import java.sql.Connection;
import java.sql.SQLException;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import java.util.stream.Collectors;
import org.apache.commons.codec.binary.Base64;
import org.apache.commons.lang.StringUtils;
import org.apache.commons.lang.text.StrSubstitutor;
import org.apache.log4j.Logger;

public class TrinoConnection
extends SimpleSQLDSSConnectionWithBasicCredential
implements ConnectionWithEncryptedFields,
ConnectionWithPerUserOAuth2Credentials,
AutoFastPathConnection,
AccessTokenInjectingConnection {
    public static final String connectionType = "Trino";
    private static TrinoSQLDialect dialect;
    protected Params params = new Params();
    private static final Logger logger;

    @Override
    public AbstractSQLConnection.AbstractSQLParams getParams() {
        return this.params;
    }

    @Override
    public String getType() {
        return connectionType;
    }

    @Override
    public void checkConfiguration() throws CodedException, DKUSecurityException {
        if (this.params.useURL) {
            ConfValidators.checkNotBlank(this.params.url, ConnectionCodes.ERR_CONNECTION_INVALID_CONFIG, "Url");
        } else {
            ConfValidators.checkNotBlank(this.params.host, ConnectionCodes.ERR_CONNECTION_INVALID_CONFIG, "Host");
        }
    }

    @Override
    public TrinoSQLDialect getDialect() {
        if (dialect == null) {
            dialect = new TrinoSQLDialect();
        }
        return dialect;
    }

    @Override
    String getDriver() {
        if (StringUtils.isNotBlank((String)this.params.driver)) {
            return this.params.driver;
        }
        return "io.trino.jdbc.TrinoDriver";
    }

    @Override
    String getJdbcUrl() {
        if (this.params.useURL) {
            if (StringUtils.isBlank((String)this.params.url)) {
                throw ErrorContext.iae((String)"Trino connection JDBC URL is not set");
            }
            return this.params.url;
        }
        return "jdbc:trino://" + this.params.host + ":" + this.params.port + "/" + this.params.db;
    }

    @Override
    String getJarsDirectory() {
        if (StringUtils.isNotBlank((String)this.params.driver)) {
            return this.params.jarsDirectory;
        }
        switch (this.params.driverMode) {
            case MANAGED: {
                return DKUApp.getInstallFile((String[])new String[]{"lib", "ivy", "jdbc-trino"}).getAbsolutePath();
            }
            case CUSTOM: {
                return this.params.jarsDirectory;
            }
        }
        throw new Error("unreachable");
    }

    @Override
    String getDisplayableJdbcUrl() {
        return this.params.useURL && StringUtils.isNotBlank((String)this.params.displayedUrl) ? this.params.displayedUrl : this.getJdbcUrl();
    }

    @Override
    public ICredentialsService.BasicCredential getGlobalCredential() {
        return new ICredentialsService.BasicCredential(this.params.user, this.params.password);
    }

    @Override
    public void encryptFields(PasswordEncryptionService cryptoService, GeneralSettingsDAO.SecuritySettings unused) {
        this.params.password = cryptoService.encryptIfNotEncryptedOrEmpty(this.params.password);
        this.params.appSecret = cryptoService.encryptIfNotEncryptedOrEmpty(this.params.appSecret);
    }

    @Override
    public void decryptFields(PasswordEncryptionService cryptoService) {
        this.params.password = cryptoService.decryptIfEncrypted(this.params.password);
        this.params.appSecret = cryptoService.decryptIfEncrypted(this.params.appSecret);
    }

    @Override
    public List<String> getKnownDriverJars() {
        return super.getKnownDriverJarsFromJarsDirectory();
    }

    @Override
    public InfoMessage.InfoMessages canHaveSparkIntegration() {
        return new InfoMessage.InfoMessages();
    }

    @Override
    public boolean useAutoFastConnection() {
        return this.params.useAutoFastPath;
    }

    @Override
    public String getAutoFastPathConnection() {
        return this.params.autoFastPathConnection;
    }

    @Override
    public String getAutoFastPathConnectionPath() {
        return this.params.autoFastPathConnectionPath;
    }

    @Override
    public boolean isSupportedCloudStorage(AbstractCloudStorageConnection connection) {
        return connection instanceof EC2Connection || connection instanceof AzureConnection || connection instanceof GCSConnection;
    }

    @Override
    public boolean actuallyHasBasicCredential() {
        return this.params.authType == AuthType.PASSWORD;
    }

    @Override
    public boolean actuallyHasPerUserOAuth2Credential() {
        return this.credentialsMode == DSSConnection.CredentialsMode.PER_USER && this.params.authType == AuthType.OAUTH2;
    }

    @Override
    public boolean hasRefreshTokenRotation() {
        return this.params.refreshTokenRotation;
    }

    @Override
    public boolean mustResolveOnBackend() {
        return this.hasRefreshTokenRotation() || super.mustResolveOnBackend();
    }

    @Override
    protected <T> T getFullyResolvedCredentials_internal(ConnectionWithBasicCredential.CredentialResolutionContext ctx, Class<T> clazz) throws DKUSecurityException, IOException, SQLException {
        assert (clazz.isAssignableFrom(SerializableTrinoCredentials.class));
        SerializableTrinoCredentials creds = new SerializableTrinoCredentials();
        creds.authType = this.params.authType;
        if (creds.authType == AuthType.PASSWORD) {
            ICredentialsService.BasicCredential basicCreds = ConnectionCredentialUtils.getDecryptedBasicCredential_autoTXN(this, ctx.authCtx);
            creds.user = basicCreds.user;
            creds.password = basicCreds.password;
        } else if (creds.authType == AuthType.OAUTH2) {
            OAuth2Client.AccessTokenResult accessTokenResult = this.getAccessToken(ctx.authCtx, this.getProxySettings());
            creds.accessToken = accessTokenResult.getAccessToken();
            long l = creds.accessTokenExpiresOn = accessTokenResult.getExpiresOn() == null ? -1L : accessTokenResult.getExpiresOn().getTime();
            if (StringUtils.isNotBlank((String)this.params.user)) {
                String user = this.params.user;
                if (user.contains("${")) {
                    VariablesContext vc = ((VariablesService)SpringUtils.getBean(VariablesService.class)).getForConnectionAndProject(this, ctx.authCtx, ctx.projectKey);
                    Map map = vc.getAllVariables();
                    String[] parts = creds.accessToken.split("\\.");
                    if (parts.length > 1) {
                        String payload = parts[1];
                        try {
                            String payloadStr = new String(Base64.decodeBase64((String)payload), StandardCharsets.UTF_8);
                            JsonObject payloadObj = (JsonObject)JSON.parse((String)payloadStr, JsonObject.class);
                            for (String k : payloadObj.keySet()) {
                                JsonElement v = payloadObj.get(k);
                                if (!v.isJsonPrimitive()) continue;
                                map.put("jwt." + k, v.getAsString());
                            }
                            map.putAll(map);
                        }
                        catch (Exception e) {
                            logger.warn((Object)"Unable to get payload info from token", (Throwable)e);
                        }
                    }
                    StrSubstitutor subs = new StrSubstitutor(map, "${", "}");
                    user = subs.replace(user);
                }
                creds.user = user;
            }
        }
        return clazz.cast(creds);
    }

    @Override
    public OAuth2Client buildOAuth2Client(ProxySettings proxySettings, AuthCtx authCtx) throws DKUSecurityException {
        boolean useCache = this.getDkuPropertiesAsParams().getBoolParam("dku.connection.oauth.enableCache", true);
        return new OAuth2Client.Builder().authorizationEndpoint(this.params.authorizationEndpoint).tokenEndpoint(this.params.tokenEndpoint).clientId(this.params.appId).clientSecret(this.params.appSecret).scope(this.params.scope).usePkce(true).proxy(proxySettings).useAccessTokenCache(useCache).build();
    }

    @Override
    public SQLConnectionProvider.SQLConnectionData getConnectionData_NT(AuthCtx authCtx, String projectKey) throws DKUSecurityException, SQLException {
        SerializableTrinoCredentials creds = this.getFullyResolvedCredentials_sqlLike(new ConnectionWithBasicCredential.CredentialResolutionContext(authCtx, projectKey), SerializableTrinoCredentials.class);
        SQLConnectionProvider.GenericSQLConnectionData cd = this.makeInjectingSQLConnectionData(ConnectionUtils.SQLConnectionType.valueOf(this.type.toUpperCase()), this.getDialect(), this, this.getDriver(), this.getJdbcUrl(), this.getJarsDirectory());
        this.fillConnectionData(cd);
        if (creds.authType == AuthType.PASSWORD) {
            cd.withProperty(new AbstractSQLConnection.CustomDatabaseProperty("user", creds.user, false));
            cd.withProperty(new AbstractSQLConnection.CustomDatabaseProperty("password", creds.password, true));
        } else if (creds.authType == AuthType.OAUTH2) {
            if (StringUtils.isNotBlank((String)creds.user)) {
                cd.withProperty(new AbstractSQLConnection.CustomDatabaseProperty("user", creds.user, false));
            }
            cd.withProperty(new AbstractSQLConnection.CustomDatabaseProperty("accessToken", creds.accessToken, true));
        }
        return cd;
    }

    @Override
    public ICredentialsService.OAuth2Credential getResolvedOAuth2Credential(AuthCtx authCtx) {
        return new ICredentialsService.OAuth2Credential(this.getAccessToken(authCtx, this.getProxySettings()).getAccessToken());
    }

    public OAuth2Client.AccessTokenResult getAccessToken(AuthCtx authCtx, ProxySettings proxySettings) {
        PasswordEncryptionService cryptoService = (PasswordEncryptionService)SpringUtils.getBean(PasswordEncryptionService.class);
        this.decryptFields(cryptoService);
        boolean useCache = this.getDkuPropertiesAsParams().getBoolParam("dku.connection.oauth.enableCache", true);
        logger.info((Object)"Exchanging user's refresh token for an access token");
        try {
            OAuth2Client oAuth2Client = this.buildOAuth2Client(proxySettings, authCtx);
            if (this.credentialsMode == DSSConnection.CredentialsMode.PER_USER) {
                return this.getAccessTokenFromRefreshTokenAndUpdateIfNeeded(authCtx, oAuth2Client, false);
            }
            return oAuth2Client.acquireAccessTokenResultWithClientCredentialsGrant(useCache);
        }
        catch (DKUSecurityException e) {
            throw new CodedRuntimeException(e.getCode(), "Failed to get OAuth2 access token", (Throwable)e);
        }
        catch (ParseException | IOException | URISyntaxException e) {
            throw new CodedRuntimeException((InfoMessage.MessageCode)ConnectionCodes.ERR_CONNECTION_INVALID_CONFIG, "Failed to get OAuth2 access token", e);
        }
    }

    @Override
    public AccessTokenInjectingConnection.AccessTokenInjector buildAccessTokenInjector(Connection connection, Properties properties) throws SQLException {
        String current = properties.getProperty("accessToken");
        Object tokenAuthInterceptor = null;
        Field accessTokenField = null;
        try {
            Object httpClient;
            Class<?> connectionClass = connection.getClass();
            List fieldNames = Arrays.asList(connectionClass.getDeclaredFields()).stream().map(f -> f.getName()).collect(Collectors.toList());
            if (fieldNames.contains("httpCallFactory")) {
                Field httpCallFactoryField = connectionClass.getDeclaredField("httpCallFactory");
                httpCallFactoryField.setAccessible(true);
                httpClient = httpCallFactoryField.get(connection);
            } else {
                Field httpClientField = connectionClass.getDeclaredField("httpClient");
                httpClientField.setAccessible(true);
                httpClient = httpClientField.get(connection);
            }
            Class<?> httpClientClass = httpClient.getClass();
            if (Arrays.asList(httpClientClass.getDeclaredFields()).stream().anyMatch(f -> "okHttpClient".equals(f.getName()))) {
                Field okHttpClientField = httpClientClass.getDeclaredField("okHttpClient");
                okHttpClientField.setAccessible(true);
                httpClient = okHttpClientField.get(httpClient);
                httpClientClass = httpClient.getClass();
            }
            Field interceptorsField = httpClientClass.getDeclaredField("interceptors");
            interceptorsField.setAccessible(true);
            List interceptors = (List)interceptorsField.get(httpClient);
            for (Object interceptor : interceptors) {
                Class<?> interceptorClass = interceptor.getClass();
                List interceptorFieldNames = Arrays.asList(interceptorClass.getDeclaredFields()).stream().map(f -> f.getName()).collect(Collectors.toList());
                for (String name : interceptorFieldNames) {
                    Field f2 = interceptorClass.getDeclaredField(name);
                    f2.setAccessible(true);
                    Object value = f2.get(interceptor);
                    if (value == null || !value.equals(current)) continue;
                    tokenAuthInterceptor = interceptor;
                    accessTokenField = f2;
                }
            }
        }
        catch (Exception e) {
            throw new SQLException("Cannot dig token injection point", e);
        }
        if (tokenAuthInterceptor == null) {
            throw new SQLException("Cannot find token injection point");
        }
        final Object tokenAuthInterceptorFinal = tokenAuthInterceptor;
        final Field accessTokenFieldFinal = accessTokenField;
        return new AccessTokenInjectingConnection.AccessTokenInjector(){

            @Override
            public void inject(String accessToken) throws SQLException, InterruptedException, IOException, DKUSecurityException {
                try {
                    accessTokenFieldFinal.set(tokenAuthInterceptorFinal, accessToken);
                }
                catch (IllegalAccessException e) {
                    throw new SQLException("Unable to swap access token", e);
                }
            }
        };
    }

    @Override
    public boolean needsAccessTokenInjection() {
        return this.params.authType == AuthType.OAUTH2;
    }

    static {
        logger = Logger.getLogger((String)"dku.trino");
    }

    public static class Params
    extends AbstractSQLConnection.AbstractSQLParamsWithStdFields {
        public int port = 8080;
        public TrinoDriverMode driverMode = TrinoDriverMode.MANAGED;
        public String driver;
        public AuthType authType = AuthType.PASSWORD;
        public String appId;
        public String appSecret;
        public String authorizationEndpoint;
        public String tokenEndpoint;
        public String scope;
        public boolean refreshTokenRotation;
        public boolean useAutoFastPath;
        public String autoFastPathConnection;
        public String autoFastPathConnectionPath;
        public String autoFastPathCatalog;
        public String autoFastPathSchema;
    }

    public static enum TrinoDriverMode {
        MANAGED,
        CUSTOM;

    }

    public static enum AuthType {
        PASSWORD,
        OAUTH2;

    }

    class SerializableTrinoCredentials
    implements ICredentialsService.BasicCredentialConvertible,
    OAuth2Client.AccessTokenCredentialConvertible {
        public AuthType authType;
        public String user;
        public String password;
        public String accessToken;
        public long accessTokenExpiresOn;

        SerializableTrinoCredentials() {
        }

        @Override
        public ICredentialsService.BasicCredential toBasicCredential() {
            return new ICredentialsService.BasicCredential(this.user, this.password);
        }

        @Override
        public OAuth2Client.SerializableAccessTokenResult toSerializableAccessTokenResult() {
            return new OAuth2Client.SerializableAccessTokenResult(this.accessToken, this.accessTokenExpiresOn);
        }
    }
}

