/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.aigenerations;

import com.dataiku.common.rpc.InternalAPIClient;
import com.dataiku.dip.ApplicationConfigurator;
import com.dataiku.dip.aigenerations.AIRecipeGenerationService;
import com.dataiku.dip.connections.AbstractSQLConnection;
import com.dataiku.dip.dao.GeneralSettingsDAO;
import com.dataiku.dip.exceptions.DKUSecurityException;
import com.dataiku.dip.futures.FuturePayload;
import com.dataiku.dip.futures.FutureResponse;
import com.dataiku.dip.futures.FutureService;
import com.dataiku.dip.futures.SimpleFutureThread;
import com.dataiku.dip.license.LicenseStatusService;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.server.services.TransactionService;
import com.dataiku.dip.util.AIFeaturesUtil;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.ExceptionUtils;
import com.google.gson.reflect.TypeToken;
import java.io.IOException;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.log4j.NDC;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;

@Service
public class AISQLQueryGenerationService {
    @Autowired
    private FutureService futureService;
    @Autowired
    private TransactionService transactionService;
    @Autowired
    private LicenseStatusService licenseStatusService;
    private static final Set<String> SUPPORTED_DIALECTS = Set.of("AsterData", "Athena", "BigQuery", "DB2", "Databricks", "Exasol", "GreenPlum", "GreenPlum5", "H2", "H2V2", "Hive", "Impala", "KDB+", "MySQL", "MySQL8", "Netezza", "FabricWarehouse", "Oracle", "PostgreSQL", "Presto", "Redshift", "SAPHANA", "SQLServer", "Snowflake", "SparkSQL", "Sqream", "SybaseIQ", "Synapse", "Teradata", "Trino", "Vertica", "Yellowbrick");
    private static final DKULogger logger = DKULogger.getLogger(AISQLQueryGenerationService.class);

    public FutureResponse<AISqlQueryGenerationFrontendResponse> startGeneration(AbstractAISqlQueryGenerationFutureThread futureThread) throws Exception {
        return this.futureService.runFuture(futureThread, 0L, new TypeToken<FutureResponse<AISqlQueryGenerationFrontendResponse>>(){});
    }

    public void checkUserCanUseAISQLGeneration() {
        LicenseStatusService.LicensingStatus licensingStatus = this.licenseStatusService.getLicensingStatus();
        if (licensingStatus == null || licensingStatus.community) {
            throw new IllegalArgumentException("AI services are not available with Dataiku Free Edition");
        }
        GeneralSettingsDAO.AIDrivenAnalyticsSettings aiDrivenAnalyticsSettings = ApplicationConfigurator.getGeneralSettingsUnsafeAutoTXN().aiDrivenAnalyticsSettings;
        if (!aiDrivenAnalyticsSettings.aiGenerateSQLEnabled) {
            throw new IllegalStateException("AI SQL generation is not enabled");
        }
    }

    public static abstract class AbstractAISqlQueryGenerationFutureThread
    extends SimpleFutureThread<AISqlQueryGenerationFrontendResponse> {
        protected final String projectKey;
        protected final List<String> sqlTableSchemas;
        protected final String query;
        protected final QueryOrigin requestOrigin;
        protected final GeneralSettingsDAO.AIDrivenAnalyticsSettings aiDrivenAnalyticsSettings;
        protected final AuthCtx authCtx;
        protected final AbstractSQLConnection connection;
        private final LicenseStatusService.LicensingStatus licensingStatus;
        private final GeneralSettingsDAO.GeneralSettings generalSettings;

        public AbstractAISqlQueryGenerationFutureThread(AuthCtx owner, String projectKey, List<String> tableToSQLSchemas, String query, AbstractSQLConnection connection, LicenseStatusService.LicensingStatus licensingStatus, QueryOrigin requestOrigin) {
            super(owner);
            this.projectKey = projectKey;
            this.sqlTableSchemas = tableToSQLSchemas;
            this.aiDrivenAnalyticsSettings = ApplicationConfigurator.getGeneralSettingsUnsafeAutoTXN().aiDrivenAnalyticsSettings;
            this.query = query;
            this.authCtx = owner;
            this.connection = connection;
            this.licensingStatus = licensingStatus;
            this.requestOrigin = requestOrigin;
            this.generalSettings = ApplicationConfigurator.getGeneralSettingsUnsafeAutoTXN();
        }

        public FuturePayload getPayload() {
            return FuturePayload.newSimple((String)"ai_sql_query_generation", (String)"AI SQL Query Generation");
        }

        @Override
        protected AISqlQueryGenerationFrontendResponse compute() {
            NDC.push((String)("ai-sqlquerygeneration: " + this.query));
            try {
                AISqlQueryGenerationBackendQueryBase sqlGenerationQuery = this.createSqlQueryGenerationQuery();
                AISqlQueryGenerationFrontendResponse aISqlQueryGenerationFrontendResponse = this.requestAIServerBackendAndProcessResponseForSqlQuery(sqlGenerationQuery);
                return aISqlQueryGenerationFrontendResponse;
            }
            catch (Exception e) {
                AISqlQueryGenerationFrontendResponse aISqlQueryGenerationFrontendResponse = this.handleFailedRequestFromAIServer(e, "Unfortunately, AI was not able to suggest a query from your instructions. You may want to update them and try again");
                return aISqlQueryGenerationFrontendResponse;
            }
            finally {
                NDC.pop();
            }
        }

        private AISqlQueryGenerationFrontendResponse requestAIServerBackendAndProcessResponseForSqlQuery(AISqlQueryGenerationBackendQueryBase sqlGenerationQuery) {
            AISqlQueryGenerationFrontendResponse aISqlQueryGenerationFrontendResponse;
            block12: {
                String requestId = null;
                InternalAPIClient apiClient = AIFeaturesUtil.getAiServerAPIClient(this.authCtx, this.generalSettings, GeneralSettingsDAO.LocalAIServerSettings.AiAssistant.aiGenerateSQL, AIFeaturesUtil.CONNECTION_TIMEOUT, AIFeaturesUtil.SOCKET_TIMEOUT);
                try {
                    AISqlQueryGenerationBackendResponse resp = (AISqlQueryGenerationBackendResponse)apiClient.postObject("/text2sql/sql-generation", AISqlQueryGenerationBackendResponse.class, (Object)sqlGenerationQuery);
                    requestId = resp.request_id;
                    AISqlQueryGenerationFrontendResponse finalResp = new AISqlQueryGenerationFrontendResponse();
                    finalResp.ok = true;
                    finalResp.sqlQuery = resp.query;
                    finalResp.queryName = resp.query_name;
                    finalResp.reasoning = resp.reasoning;
                    finalResp.requestId = requestId;
                    if (resp.has_context_overflow) {
                        finalResp.messages = List.of(new AIRecipeGenerationService.CreationMessage(AIRecipeGenerationService.CreationMessage.Level.WARNING, "The query includes a large number of tables and columns, so we've applied restrictions to optimize the SQL generation process."));
                    }
                    this.validateQuery(finalResp);
                    aISqlQueryGenerationFrontendResponse = finalResp;
                    if (apiClient == null) break block12;
                }
                catch (Throwable resp) {
                    try {
                        if (apiClient != null) {
                            try {
                                apiClient.close();
                            }
                            catch (Throwable throwable) {
                                resp.addSuppressed(throwable);
                            }
                        }
                        throw resp;
                    }
                    catch (IOException e) {
                        return this.handleFailedRequestFromAIServer(e, "Error while generating the SQL query");
                    }
                    catch (DKUSecurityException e) {
                        logger.error((Object)"Security exception while generating SQL query", (Throwable)e);
                        AISqlQueryGenerationFrontendResponse finalResp = new AISqlQueryGenerationFrontendResponse();
                        finalResp.ok = false;
                        finalResp.messages = List.of(new AIRecipeGenerationService.CreationMessage(AIRecipeGenerationService.CreationMessage.Level.ERROR, ExceptionUtils.getMessageWithCauses((Throwable)e)));
                        return finalResp;
                    }
                    catch (IllegalArgumentException e) {
                        logger.error((Object)"Configuration error while connecting to AI server", (Throwable)e);
                        AISqlQueryGenerationFrontendResponse finalResp = new AISqlQueryGenerationFrontendResponse();
                        finalResp.ok = false;
                        finalResp.messages = List.of(new AIRecipeGenerationService.CreationMessage(AIRecipeGenerationService.CreationMessage.Level.ERROR, ExceptionUtils.getMessageWithCauses((Throwable)e)));
                        return finalResp;
                    }
                    catch (QueryValidationError e) {
                        AISqlQueryGenerationFrontendResponse resp2 = this.handleFailedQueryValidation(e, e.messageLevel.equals((Object)AIRecipeGenerationService.CreationMessage.Level.ERROR) ? "Failed to validate SQL query" : "SQL query validation warning", e.messageLevel, requestId);
                        this.logError(resp2, requestId);
                        return resp2;
                    }
                }
                apiClient.close();
            }
            return aISqlQueryGenerationFrontendResponse;
        }

        private void logError(AISqlQueryGenerationFrontendResponse resp, String requestId) {
            if (this.generalSettings.localAIServerSettings.isLocalAiAssistantEnabled(GeneralSettingsDAO.LocalAIServerSettings.AiAssistant.aiGenerateSQL)) {
                logger.debug((Object)"Validation error not sent because local AI server is used for feature generate SQL");
                return;
            }
            String licenseId = this.licensingStatus != null && this.licensingStatus.licenseContent != null ? this.licensingStatus.licenseContent.licenseId : null;
            try (InternalAPIClient apiClient = AIFeaturesUtil.getAiServerAPIClient(this.authCtx, this.generalSettings, GeneralSettingsDAO.LocalAIServerSettings.AiAssistant.aiGenerateSQL, AIFeaturesUtil.CONNECTION_TIMEOUT, AIFeaturesUtil.SOCKET_TIMEOUT);){
                apiClient.postObject("/text2sql/log-validation", AISqlQueryGenerationBackendResponse.class, (Object)new AISqlQueryResultBackend(requestId, resp.messages.stream().map(m -> m.content).collect(Collectors.joining(", ")), licenseId, this.aiDrivenAnalyticsSettings.aiGenerateSQLTelemetryEnabled));
            }
            catch (Exception exception) {
                // empty catch block
            }
        }

        protected abstract void validateQuery(AISqlQueryGenerationFrontendResponse var1) throws QueryValidationError;

        private AISqlQueryGenerationBackendQueryBase createSqlQueryGenerationQuery() {
            AISqlQueryGenerationBackendQueryBase backendQueryBase = new AISqlQueryGenerationBackendQueryBase();
            backendQueryBase.licenseId = this.licensingStatus != null && this.licensingStatus.licenseContent != null ? this.licensingStatus.licenseContent.licenseId : null;
            backendQueryBase.telemetryEnabled = this.aiDrivenAnalyticsSettings.aiGenerateSQLTelemetryEnabled;
            backendQueryBase.requestOrigin = this.requestOrigin;
            backendQueryBase.query = this.query;
            backendQueryBase.sqlTableSchemas = this.sqlTableSchemas;
            backendQueryBase.dialect = this.getDialect();
            return backendQueryBase;
        }

        private String getDialect() {
            String dialect;
            try {
                dialect = this.connection.getDialect().getId();
                if (!SUPPORTED_DIALECTS.contains(dialect)) {
                    dialect = "SQL";
                }
            }
            catch (Exception e) {
                dialect = "SQL";
            }
            return dialect;
        }

        private AISqlQueryGenerationFrontendResponse handleFailedRequestFromAIServer(Exception e, String message) {
            logger.error((Object)"Exception while generating SQL query", (Throwable)e);
            AISqlQueryGenerationFrontendResponse finalResp = new AISqlQueryGenerationFrontendResponse();
            finalResp.ok = false;
            finalResp.messages = List.of(new AIRecipeGenerationService.CreationMessage(AIRecipeGenerationService.CreationMessage.Level.ERROR, message));
            return finalResp;
        }

        private AISqlQueryGenerationFrontendResponse handleFailedQueryValidation(QueryValidationError e, String message, AIRecipeGenerationService.CreationMessage.Level level, String requestId) {
            logger.error((Object)"Failed to validate SQL query", (Throwable)e);
            AISqlQueryGenerationFrontendResponse finalResp = new AISqlQueryGenerationFrontendResponse();
            finalResp.ok = false;
            finalResp.sqlQuery = e.query;
            finalResp.queryName = e.queryName;
            finalResp.reasoning = e.reasoning;
            finalResp.requestId = requestId;
            finalResp.messages = List.of(new AIRecipeGenerationService.CreationMessage(level, message), new AIRecipeGenerationService.CreationMessage(level, e.getMessage()));
            return finalResp;
        }
    }

    public static class AISqlQueryGenerationFrontendResponse {
        public Boolean ok;
        public List<AIRecipeGenerationService.CreationMessage> messages;
        public String reasoning;
        public String sqlQuery;
        public String queryName;
        public String requestId;
    }

    public static class AISqlQueryGenerationBackendResponse {
        String request_id;
        String query;
        String reasoning;
        String query_name;
        boolean has_context_overflow;
    }

    public static class AISqlQueryGenerationBackendQueryBase {
        String licenseId;
        String query;
        List<String> sqlTableSchemas;
        String dialect;
        boolean telemetryEnabled;
        QueryOrigin requestOrigin;
    }

    public static class AISqlQueryResultBackend {
        String requestId;
        String validationText;
        String licenseId;
        boolean telemetryEnabled;

        public AISqlQueryResultBackend(String requestId, String validationText, String licenseId, boolean telemetryEnabled) {
            this.requestId = requestId;
            this.validationText = validationText;
            this.licenseId = licenseId;
            this.telemetryEnabled = telemetryEnabled;
        }
    }

    public static enum QueryOrigin {
        SQL_NOTEBOOK,
        SQL_RECIPE_QUERY,
        SQL_RECIPE_SCRIPT;

    }

    public static class QueryValidationError
    extends Exception {
        private static final long serialVersionUID = 1L;
        public final String query;
        public final String queryName;
        public final String reasoning;
        public final AIRecipeGenerationService.CreationMessage.Level messageLevel;

        public QueryValidationError(String message, String query, String queryName, String reasoning, AIRecipeGenerationService.CreationMessage.Level messageLevel) {
            super(message);
            this.query = query;
            this.queryName = queryName;
            this.reasoning = reasoning;
            this.messageLevel = messageLevel;
        }
    }
}

