/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.llm.online.databricks;

import com.dataiku.dip.connections.AbstractLLMConnection;
import com.dataiku.dip.connections.ConnectionUtils;
import com.dataiku.dip.connections.ConnectionsDAO;
import com.dataiku.dip.connections.DatabricksLLMConnection;
import com.dataiku.dip.connections.DatabricksModelDeploymentConnection;
import com.dataiku.dip.exceptions.DKUSecurityException;
import com.dataiku.dip.exceptions.UnauthorizedException;
import com.dataiku.dip.llm.LLMModelHandle;
import com.dataiku.dip.llm.LLMStructuredRef;
import com.dataiku.dip.llm.online.AbstractLLMClient;
import com.dataiku.dip.llm.online.LLMChatMessageUtils;
import com.dataiku.dip.llm.online.LLMClient;
import com.dataiku.dip.llm.online.databricks.RawDatabricksLLMClient;
import com.dataiku.dip.llm.online.marshall.CoreCompletionSettings;
import com.dataiku.dip.resourceusage.ComputeResourceUsage;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.server.SpringUtils;
import com.dataiku.dip.server.services.TransactionService;
import com.dataiku.dip.transactions.ifaces.IsolationLevel;
import com.dataiku.dip.transactions.ifaces.Transaction;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.JSON;
import com.google.common.base.Stopwatch;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.TimeUnit;

public class DatabricksLLMClient
extends AbstractLLMClient
implements LLMClient {
    private final DatabricksLLMConnection connection;
    private final RawDatabricksLLMClient raw;
    private final DatabricksLLMConnection.DatabricksLLMModel model;
    private final ComputeResourceUsage.InternalLLMUsageData usageData = new ComputeResourceUsage.LLMUsageData();
    private static final DKULogger logger = DKULogger.getLogger((String)"dku.llm.databricks-llm");

    public DatabricksLLMClient(AuthCtx authCtx, DatabricksLLMConnection connection, LLMModelHandle<DatabricksLLMConnection.DatabricksLLMModel> modelHandle) throws IOException, DKUSecurityException {
        super(modelHandle.getEnrichedRef());
        DatabricksModelDeploymentConnection deplConnection;
        this.connection = connection;
        TransactionService ts = (TransactionService)SpringUtils.getBean(TransactionService.class);
        try (Transaction t = ts.retrieveOrBeginRead(IsolationLevel.YOLO);){
            deplConnection = ConnectionsDAO.get().getMandatoryConnectionAs(authCtx, connection.params.modelDeploymentConnection, DatabricksModelDeploymentConnection.class);
            if (!deplConnection.isFreelyUsableBy(authCtx)) {
                throw new UnauthorizedException("You may not use the databricks model deployment connection " + deplConnection.name, "denied");
            }
        }
        boolean forceContentLength = ConnectionUtils.getParamsFromProperties(connection.getDkuProperties()).getBoolParam("dku.connection.llm.forceContentLength", false);
        this.raw = new RawDatabricksLLMClient(deplConnection.params.getHostBaseURL(), deplConnection.getAuthToken(authCtx), connection.params.networkSettings, connection.getProxySettings(), forceContentLength);
        this.model = modelHandle.getModel();
    }

    @Override
    public void close() {
        this.raw.close();
    }

    @Override
    public boolean supportNativeBatch() {
        return false;
    }

    @Override
    public boolean requiresCostLimiting() {
        return true;
    }

    @Override
    public String getProviderId() {
        return "DatabricksLLM";
    }

    @Override
    public int getMaxParallelism() {
        return this.connection.params.maxParallelism;
    }

    @Override
    public AbstractLLMConnection getConnection() {
        return this.connection;
    }

    @Override
    public List<LLMClient.ChatMessage> getFormattedPrompt(List<LLMClient.ChatMessage> chatMessages) {
        LLMChatMessageUtils.throwIfUnsupportedToolOutputParts(chatMessages);
        chatMessages = LLMChatMessageUtils.convertExtraSystemMessageToUser(chatMessages);
        return LLMChatMessageUtils.collapseAdjacentSameRoleMessages(chatMessages);
    }

    @Override
    public List<LLMClient.SimpleCompletionResponse> completeBatch(List<LLMClient.SingleCompletionQuery> queries, LLMClient.CompletionSettings settings) throws IOException {
        CoreCompletionSettings ccs = this.getCoreCompletionSettings(settings);
        ArrayList<LLMClient.SimpleCompletionResponse> ret = new ArrayList<LLMClient.SimpleCompletionResponse>();
        for (LLMClient.SingleCompletionQuery query : queries) {
            Stopwatch stopwatch = Stopwatch.createStarted();
            logger.info((Object)("Databricks LLM single complete query: " + JSON.log((Object)query.getSafeForLoggingCopy())));
            List<LLMClient.ChatMessage> chatMessages = this.getFormattedPrompt(query.messages);
            LLMClient.SimpleCompletionResponse scr = this.raw.chatComplete(this.model.getId(), chatMessages, ccs);
            this.usageData.incrementTotalComputationTimeMS(Long.valueOf(stopwatch.elapsed(TimeUnit.MILLISECONDS)));
            this.usageData.incrementTotalPromptTokens(scr.promptTokens);
            this.usageData.incrementTotalCompletionTokens(scr.completionTokens);
            ret.add(scr);
        }
        return ret;
    }

    @Override
    public List<LLMClient.SimpleEmbeddingResponse> embedBatch(List<LLMClient.EmbeddingQuery> queries, LLMClient.EmbeddingSettings settings) throws IOException {
        if (LLMClient.TextOverflowMode.TRUNCATE.equals((Object)settings.textOverflowMode)) {
            logger.warn((Object)"Truncation for long texts overflow is not supported yet for Databricks LLM, defaulting to Failure mode");
        }
        ArrayList<LLMClient.SimpleEmbeddingResponse> ret = new ArrayList<LLMClient.SimpleEmbeddingResponse>();
        for (LLMClient.EmbeddingQuery query : queries) {
            Stopwatch stopwatch = Stopwatch.createStarted();
            logger.info((Object)("Databricks LLM Embed: " + JSON.log((Object)query.getSafeForLoggingCopy())));
            LLMClient.SimpleEmbeddingResponse ser = this.raw.embed(this.model.getId(), query.text);
            this.usageData.incrementTotalComputationTimeMS(Long.valueOf(stopwatch.elapsed(TimeUnit.MILLISECONDS)));
            this.usageData.incrementTotalPromptTokens(ser.promptTokens);
            ret.add(ser);
        }
        return ret;
    }

    @Override
    public List<LLMClient.SingleRerankingResponse> rerankBatch(List<LLMClient.RerankingQuery> queries, LLMClient.RerankingSettings settings) throws Exception {
        throw new IllegalArgumentException("Rerankings not supported on this LLM");
    }

    @Override
    public ComputeResourceUsage getTotalCRU(ComputeResourceUsage.LLMUsageType usageType, LLMStructuredRef llmRef) {
        ComputeResourceUsage cru = new ComputeResourceUsage();
        cru.setupLLMUsage(usageType, llmRef.connection, llmRef.type.toString(), llmRef.id);
        cru.llmUsage.setFromInternal(this.usageData);
        return cru;
    }
}

