/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.llm.online.snowflakecortex;

import com.dataiku.dip.connections.SQLConnectionProvider;
import com.dataiku.dip.connections.SnowflakeConnection;
import com.dataiku.dip.exceptions.DKUSecurityException;
import com.dataiku.dip.llm.LLMModelHandle;
import com.dataiku.dip.llm.online.LLMClient;
import com.dataiku.dip.llm.online.marshall.CoreCompletionSettings;
import com.dataiku.dip.llm.online.marshall.CoreCompletionSettingsValidator;
import com.dataiku.dip.llm.online.snowflakecortex.AbstractRawSnowflakeCortexLLMClient;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.sql.SQLDialect;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.JSON;
import com.dataiku.dss.shadelib.com.google.common.base.Joiner;
import com.google.gson.annotations.SerializedName;
import com.google.gson.reflect.TypeToken;
import java.io.IOException;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
import java.util.stream.Collectors;
import org.apache.commons.collections4.CollectionUtils;

public class RawSnowflakeCortexLLMSQLClient
extends AbstractRawSnowflakeCortexLLMClient {
    private final AuthCtx authCtx;
    private final SnowflakeConnection sfConnection;
    private final String projectKey;
    public final String SNOWFLAKE_CORTEX_STANDARD_COMPLETION_QUERY = "SELECT SNOWFLAKE.CORTEX.COMPLETE(\n   %s,\n   %s,\n   %s\n);";
    public final String SNOWFLAKE_CORTEX_STANDARD_EMBED_TEXT_EMBED_SIZE_QUERY = "SELECT SNOWFLAKE.CORTEX.EMBED_TEXT_%d(\n   %s,\n   text\n)\nFROM VALUES\n    %s\n   AS t(text);";
    private static final CoreCompletionSettingsValidator completionValidator = new CoreCompletionSettingsValidator("Snowflake Cortex (SQL)").allowMaxTokens().allowTemperature().allowTopP();
    private static final DKULogger logger = DKULogger.getLogger((String)"dku.llm.snowflakecortex-llm.client");

    public RawSnowflakeCortexLLMSQLClient(AuthCtx authCtx, String projectKey, SnowflakeConnection sfConnection) {
        this.authCtx = authCtx;
        this.sfConnection = sfConnection;
        this.projectKey = projectKey;
    }

    private String convertChatMessagesToString(SQLDialect sqlDialect, List<LLMClient.ChatMessage> messages) {
        ArrayList<String> objsForMessage = new ArrayList<String>();
        for (LLMClient.ChatMessage message : messages) {
            objsForMessage.add(String.format("{\n\t'role': %s,\t'content': %s\n}", sqlDialect.quoteString(message.role), sqlDialect.quoteString(message.getText())));
        }
        return "[\n" + Joiner.on((String)",\n").join(objsForMessage) + "]";
    }

    @Override
    public LLMClient.SimpleCompletionResponse chatComplete(LLMModelHandle.Model model, List<LLMClient.ChatMessage> messages, CoreCompletionSettings ccs) throws IOException, SQLException, DKUSecurityException, InterruptedException {
        completionValidator.validate(ccs);
        if (CollectionUtils.isNotEmpty(ccs.tools)) {
            throw new IllegalArgumentException("Tools specified, but not supported when querying a Snowflake Cortex model in SQL");
        }
        try (SQLConnectionProvider.SQLConnectionWrapper sqlWrapper = this.createSQLWrapper();){
            LLMClient.SimpleCompletionResponse simpleCompletionResponse;
            block17: {
                SQLDialect sqlDialect = sqlWrapper.getDialect();
                String quotedModel = sqlDialect.quoteString(model.getId());
                String promptOrHistory = this.convertChatMessagesToString(sqlDialect, messages);
                ArrayList<String> options = new ArrayList<String>();
                if (null != ccs.maxTokens) {
                    options.add(String.format("'max_tokens': %d", ccs.maxTokens));
                }
                if (null != ccs.temperature) {
                    options.add(String.format(Locale.ENGLISH, "'temperature': %f", ccs.temperature));
                }
                if (null != ccs.topP) {
                    options.add(String.format(Locale.ENGLISH, "'top_p': %f", ccs.topP));
                }
                String optionsAsString = "{\n" + Joiner.on((String)",\n").join(options) + "\n}";
                String query = String.format("SELECT SNOWFLAKE.CORTEX.COMPLETE(\n   %s,\n   %s,\n   %s\n);", quotedModel, promptOrHistory, optionsAsString);
                logger.debugV("Querying Snowflake Cortex for completion using query: %s", new Object[]{query});
                PreparedStatement ps2 = sqlWrapper.prepareStatement(query);
                try {
                    ResultSet rs2 = ps2.executeQuery();
                    rs2.next();
                    String jsonString = rs2.getString(1);
                    logger.debugV("Received %s", new Object[]{jsonString});
                    SnowflakeCortexCompletionSQLResponse resp = (SnowflakeCortexCompletionSQLResponse)JSON.parse((String)jsonString, SnowflakeCortexCompletionSQLResponse.class);
                    LLMClient.SimpleCompletionResponse ret = new LLMClient.SimpleCompletionResponse();
                    if (CollectionUtils.isNotEmpty(resp.choices)) {
                        ret.text = resp.choices.get((int)0).messages;
                    }
                    ret.completionTokens = resp.usage.completionTokens;
                    ret.promptTokens = resp.usage.promptTokens;
                    ret.totalTokens = resp.usage.totalTokens;
                    simpleCompletionResponse = ret;
                    if (ps2 == null) break block17;
                }
                catch (Throwable throwable) {
                    if (ps2 != null) {
                        try {
                            ps2.close();
                        }
                        catch (Throwable throwable2) {
                            throwable.addSuppressed(throwable2);
                        }
                    }
                    throw throwable;
                }
                ps2.close();
            }
            return simpleCompletionResponse;
        }
    }

    @Override
    public List<LLMClient.SimpleEmbeddingResponse> embed(String model, Integer embeddingSize, List<String> batchTexts) throws IOException, SQLException, DKUSecurityException, InterruptedException {
        ArrayList<LLMClient.SimpleEmbeddingResponse> ret = new ArrayList<LLMClient.SimpleEmbeddingResponse>(batchTexts.size());
        try (SQLConnectionProvider.SQLConnectionWrapper sqlWrapper = this.createSQLWrapper();){
            SQLDialect sqlDialect = sqlWrapper.getDialect();
            String quotedModel = sqlDialect.quoteString(model);
            String quotedTexts = batchTexts.stream().map(t -> "(" + sqlDialect.quoteString((String)t) + ")").collect(Collectors.joining(",\n"));
            String sqlQuery = String.format("SELECT SNOWFLAKE.CORTEX.EMBED_TEXT_%d(\n   %s,\n   text\n)\nFROM VALUES\n    %s\n   AS t(text);", embeddingSize, quotedModel, quotedTexts);
            logger.debugV("Querying Snowflake Cortex for embeddings using sqlQuery: %s", new Object[]{sqlQuery});
            try (PreparedStatement ps2 = sqlWrapper.prepareStatement(sqlQuery);){
                ResultSet rs2 = ps2.executeQuery();
                while (rs2.next()) {
                    Object embeddingObj = rs2.getObject(1);
                    List embeddingsList = (List)JSON.parse((String)embeddingObj.toString(), (TypeToken)new TypeToken<List<Float>>(){});
                    LLMClient.SimpleEmbeddingResponse curRes = new LLMClient.SimpleEmbeddingResponse();
                    curRes.embedding = new double[embeddingsList.size()];
                    for (int i = 0; i < embeddingsList.size(); ++i) {
                        curRes.embedding[i] = ((Float)embeddingsList.get(i)).floatValue();
                    }
                    ret.add(curRes);
                }
            }
        }
        return ret;
    }

    @Override
    public void streamChatComplete(LLMClient.StreamedCompletionResponseConsumer consumer, LLMModelHandle.Model model, List<LLMClient.ChatMessage> messages, CoreCompletionSettings ccs) throws Exception {
        throw new IllegalArgumentException("Streaming not supported when querying Snowflake Cortex model in SQL");
    }

    SQLConnectionProvider.SQLConnectionWrapper createSQLWrapper() throws SQLException, DKUSecurityException, InterruptedException {
        SQLConnectionProvider.SQLConnectionData connData = this.sfConnection.getConnectionData_NT(this.authCtx, this.projectKey);
        return SQLConnectionProvider.newConnection(connData, this.authCtx, this.projectKey);
    }

    @Override
    public void close() {
    }

    static class SnowflakeCortexCompletionSQLResponse {
        List<Choice> choices;
        String created;
        String model;
        Usage usage;

        SnowflakeCortexCompletionSQLResponse() {
        }

        public static class Usage {
            @SerializedName(value="completion_tokens")
            int completionTokens;
            @SerializedName(value="prompt_tokens")
            int promptTokens;
            @SerializedName(value="total_tokens")
            int totalTokens;
        }

        public static class Choice {
            String messages;
        }
    }
}

