/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.server.controllers;

import com.dataiku.dip.aigenerations.AIRecipeGenerationService;
import com.dataiku.dip.aigenerations.AISQLQueryGenerationService;
import com.dataiku.dip.connections.AbstractSQLConnection;
import com.dataiku.dip.connections.ConnectionsDAO;
import com.dataiku.dip.connections.DSSConnection;
import com.dataiku.dip.connections.SQLConnectionProvider;
import com.dataiku.dip.coremodel.Dataset;
import com.dataiku.dip.coremodel.InfoMessage;
import com.dataiku.dip.coremodel.SerializedDataset;
import com.dataiku.dip.coremodel.SerializedRecipe;
import com.dataiku.dip.dao.DatasetsDAO;
import com.dataiku.dip.datasets.sql.AbstractSQLDatasetHandler;
import com.dataiku.dip.futures.FutureResponse;
import com.dataiku.dip.license.LicenseStatusService;
import com.dataiku.dip.queries.QueryRunResult;
import com.dataiku.dip.recipes.RecipeRegistry;
import com.dataiku.dip.recipes.code.CodeBasedRecipeStatus;
import com.dataiku.dip.recipes.code.sql.SQLQueryRecipeStatusComputer;
import com.dataiku.dip.recipes.common.RecipeStatusComputer;
import com.dataiku.dip.recipes.consistency.RecipeCodes;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.security.Privileges;
import com.dataiku.dip.security.auth.UIAuthService;
import com.dataiku.dip.server.controllers.AuditedCall;
import com.dataiku.dip.server.controllers.DIPInternalControllerBase;
import com.dataiku.dip.server.services.ConnectionsService;
import com.dataiku.dip.server.services.ProjectsService;
import com.dataiku.dip.server.services.TransactionService;
import com.dataiku.dip.sql.SQLDialect;
import com.dataiku.dip.transactions.ifaces.Transaction;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.ExceptionUtils;
import com.dataiku.dip.utils.JSON;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.commons.lang.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Controller;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.ResponseBody;

@Controller
public class SQLRecipesController
extends DIPInternalControllerBase {
    @Autowired
    private UIAuthService authService;
    @Autowired
    private TransactionService transactionService;
    @Autowired
    private ConnectionsService connectionsService;
    @Autowired
    private ConnectionsDAO connectionsDAO;
    @Autowired
    private AISQLQueryGenerationService aiSqlQueryGenerationService;
    @Autowired
    private DatasetsDAO datasetsDAO;
    @Autowired
    private LicenseStatusService licenseStatusService;
    @Autowired
    private ProjectsService projectsService;
    private static final DKULogger logger = DKULogger.getLogger(SQLRecipesController.class);

    @AuditedCall(value={"msgType", "generate-sql-query", "projectKey", "${projectKey}", "query", "${id}"})
    @RequestMapping(value={"/api/flow/recipes/sql/generate-query"})
    @ResponseBody
    public FutureResponse<AISQLQueryGenerationService.AISqlQueryGenerationFrontendResponse> generateSQLQuery(HttpServletRequest req, HttpServletResponse resp, @RequestParam String projectKey, @RequestParam String query, @RequestParam String recipe, @RequestParam(required=false) String targetPartition) throws Exception {
        AuthCtx liu;
        SerializedRecipe sr = (SerializedRecipe)JSON.parse((String)recipe, SerializedRecipe.class);
        DSSConnection conn = null;
        AbstractSQLConnection sqlConnection = null;
        SQLDialect dialect = null;
        ArrayList<SerializedDataset> datasets = new ArrayList<SerializedDataset>();
        try (Transaction t = this.transactionService.beginRead();){
            liu = this.authService.getMandatoryUser(req);
            this.projectsService.checkPerm(liu, projectKey, Privileges.ProjectLevelPrivilegeType.WRITE_CONF);
            this.aiSqlQueryGenerationService.checkUserCanUseAISQLGeneration();
            for (SerializedRecipe.RecipeInput recipeInput : sr.getFlatInputs()) {
                SerializedDataset sd = (SerializedDataset)this.datasetsDAO.getMandatoryUnsafe(recipeInput.getLoc(projectKey));
                if (!(sd.getParams() instanceof AbstractSQLDatasetHandler.AbstractSQLConfig)) continue;
                AbstractSQLDatasetHandler.AbstractSQLConfig config = ((AbstractSQLDatasetHandler.AbstractSQLConfig)sd.getParams()).getResolved(projectKey);
                if (StringUtils.isBlank((String)config.table)) continue;
                if (conn == null) {
                    if (StringUtils.isNotBlank((String)config.connection)) {
                        conn = this.connectionsDAO.getMandatoryConnection(liu, config.connection);
                        if (!conn.isFreelyUsableBy(liu)) {
                            throw new SecurityException("You may not list tables on connection " + config.connection);
                        }
                        sqlConnection = SQLConnectionProvider.getDSSConnection(liu, conn.name);
                        dialect = sqlConnection.getDialect();
                    }
                } else if (!config.connection.equals(conn.name)) {
                    throw new IllegalArgumentException("You may not use tables from multiple connections.");
                }
                if (sqlConnection == null || dialect == null) continue;
                datasets.add(sd);
            }
        }
        ArrayList<String> sqlTableSchemas = new ArrayList<String>();
        if (sqlConnection != null && dialect != null) {
            sqlConnection.getConnectionData_NT(liu, projectKey);
            try {
                for (SerializedDataset serializedDataset : datasets) {
                    sqlTableSchemas.add(dialect.getCreateTableStatementSQL(sqlConnection, Dataset.fromSerialized(serializedDataset), new InfoMessage.InfoMessages(), false));
                }
            }
            catch (IllegalArgumentException e) {
                logger.warn((Object)e);
            }
        }
        if (sqlConnection != null) {
            RecipeAISqlQueryGenerationFutureThread recipeAISqlQueryGenerationFutureThread = new RecipeAISqlQueryGenerationFutureThread(liu, projectKey, sqlTableSchemas, query, sqlConnection, sr, targetPartition, switch (sr.getType()) {
                case "sql_script" -> AISQLQueryGenerationService.QueryOrigin.SQL_RECIPE_SCRIPT;
                case "sql_query" -> AISQLQueryGenerationService.QueryOrigin.SQL_RECIPE_QUERY;
                default -> throw new IllegalArgumentException("Invalid recipe type to generate SQL query");
            });
            return this.aiSqlQueryGenerationService.startGeneration(recipeAISqlQueryGenerationFutureThread);
        }
        return new FutureResponse();
    }

    private class RecipeAISqlQueryGenerationFutureThread
    extends AISQLQueryGenerationService.AbstractAISqlQueryGenerationFutureThread {
        private final SerializedRecipe serializedRecipe;
        private final String targetPartition;

        public RecipeAISqlQueryGenerationFutureThread(AuthCtx owner, String projectKey, List<String> sqlTableSchemas, String query, AbstractSQLConnection connection, SerializedRecipe serializedRecipe, String targetPartition, AISQLQueryGenerationService.QueryOrigin requestOrigin) {
            super(owner, projectKey, sqlTableSchemas, query, connection, SQLRecipesController.this.licenseStatusService.getLicensingStatus(), requestOrigin);
            this.serializedRecipe = serializedRecipe;
            this.targetPartition = targetPartition;
        }

        @Override
        protected void validateQuery(AISQLQueryGenerationService.AISqlQueryGenerationFrontendResponse response) throws AISQLQueryGenerationService.QueryValidationError {
            CodeBasedRecipeStatus status;
            RecipeStatusComputer computer = null;
            try {
                computer = RecipeRegistry.getMeta(this.serializedRecipe).buildStatusComputer(this.serializedRecipe, response.sqlQuery);
                if (computer == null) {
                    throw new IllegalArgumentException("Recipes of type " + this.serializedRecipe.type + " do not have a computable status");
                }
            }
            catch (Exception e) {
                throw new AISQLQueryGenerationService.QueryValidationError("Failed to build recipe status", response.sqlQuery, response.queryName, response.reasoning, AIRecipeGenerationService.CreationMessage.Level.ERROR);
            }
            logger.info((Object)("Computing status of recipe " + this.serializedRecipe.name + " on computer " + String.valueOf(computer)));
            try {
                CodeBasedRecipeStatus.CodeBasedRecipeStatusRequest request = new CodeBasedRecipeStatus.CodeBasedRecipeStatusRequest();
                request.targetPartitionSpec = (String)com.dataiku.dss.shadelib.org.apache.commons.lang3.StringUtils.defaultIfBlank((CharSequence)this.targetPartition, null);
                status = (CodeBasedRecipeStatus)computer.getFullStatus_NT(this.authCtx, JSON.json((Object)request));
            }
            catch (Exception e) {
                logger.error((Object)"Failed to compute recipe status", (Throwable)e);
                status = new CodeBasedRecipeStatus();
                status.topLevelMessages.withFatalV((InfoMessage.MessageCode)RecipeCodes.ERR_RECIPE_VALIDATION_FAILED, "Failed to validate recipe: %s", new Object[]{ExceptionUtils.getMessageWithCauses((Throwable)e)});
            }
            if (status.gatherAllMessages().anyFatal()) {
                String validationMessages = status.gatherAllMessages().messages.stream().map(m -> m.message).collect(Collectors.joining("\n"));
                if (status instanceof SQLQueryRecipeStatusComputer.SQLQueryRecipeStatus) {
                    boolean isRunSuccessful;
                    QueryRunResult runResult = ((SQLQueryRecipeStatusComputer.SQLQueryRecipeStatus)status).runResult;
                    boolean bl = isRunSuccessful = runResult != null && runResult.success;
                    if (!isRunSuccessful) {
                        throw new AISQLQueryGenerationService.QueryValidationError(validationMessages, response.sqlQuery, response.queryName, response.reasoning, AIRecipeGenerationService.CreationMessage.Level.ERROR);
                    }
                } else {
                    throw new AISQLQueryGenerationService.QueryValidationError(validationMessages, response.sqlQuery, response.queryName, response.reasoning, AIRecipeGenerationService.CreationMessage.Level.ERROR);
                }
            }
        }
    }
}

