/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.llm.langchain;

import com.dataiku.dip.connections.AbstractLLMConnection;
import com.dataiku.dip.connections.AbstractSQLConnection;
import com.dataiku.dip.dao.SavedModel;
import com.dataiku.dip.llm.EnrichedLLMStructuredRef;
import com.dataiku.dip.llm.LLMRefEnricherService;
import com.dataiku.dip.llm.LLMStructuredRef;
import com.dataiku.dip.llm.langchain.PythonLLMServerAPI;
import com.dataiku.dip.llm.online.LLMClient;
import com.dataiku.dip.resourceusage.ComputeResourceUsage;
import com.dataiku.dip.security.DSSAuthCtx;
import com.dataiku.dip.server.SpringUtils;
import com.dataiku.dip.utils.DKUCompletableFuture;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.JSON;
import com.dataiku.dip.utils.Params;
import com.dataiku.dip.utils.SmartLogTail;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.stream.Collectors;

public abstract class AbstractAgentLLMClient
implements LLMClient {
    protected PythonLLMServerAPI serverAPI;
    protected final DSSAuthCtx authCtx;
    protected final SavedModel savedModel;
    protected final SavedModel.SavedModelInlineVersion smiv;
    protected boolean devKernel;
    private static DKULogger logger = DKULogger.getLogger((String)"dku.llm.agent");

    public AbstractAgentLLMClient(DSSAuthCtx authCtx, SavedModel sm, SavedModel.SavedModelInlineVersion smiv, boolean devKernel) {
        this.authCtx = authCtx;
        this.savedModel = sm;
        this.smiv = smiv;
        this.devKernel = devKernel;
    }

    protected abstract void initOnce() throws IOException;

    public boolean isDevMode() {
        Params dkuParams = AbstractSQLConnection.CustomDatabaseProperty.toParams(this.savedModel.getAgentSettings((SavedModel.SavedModelInlineVersion)this.smiv).dkuProperties);
        return dkuParams.getBoolParam("dku.agents.python.devMode", false);
    }

    @Override
    public SmartLogTail getKernelLog() throws Exception {
        this.initOnce();
        return this.serverAPI.getKernelLog();
    }

    @Override
    public void close() {
        if (this.serverAPI != null) {
            logger.info((Object)"Closing PythonLLMClient");
            try {
                this.serverAPI.close();
            }
            catch (IOException e) {
                throw new RuntimeException("Failed to close", e);
            }
        }
    }

    @Override
    public boolean supportNativeBatch() {
        return false;
    }

    @Override
    public boolean requiresCostLimiting() {
        return true;
    }

    @Override
    public AbstractLLMConnection<?, ?, ?> getConnection() {
        return null;
    }

    @Override
    public int getMaxParallelism() {
        return this.savedModel.getAgentSettings((SavedModel.SavedModelInlineVersion)this.smiv).maxParallelRequestsPerProcess;
    }

    @Override
    public List<LLMClient.ChatMessage> getFormattedPrompt(List<LLMClient.ChatMessage> chatMessages) {
        ArrayList<LLMClient.ChatMessage> formattedPromptMessages = new ArrayList<LLMClient.ChatMessage>();
        LLMClient.ChatMessage formattedPrompt = new LLMClient.ChatMessage("prompt", chatMessages.stream().map(m -> m.getText()).collect(Collectors.joining("\n")));
        formattedPromptMessages.add(formattedPrompt);
        return formattedPromptMessages;
    }

    @Override
    public List<LLMClient.SimpleCompletionResponse> completeBatch(List<LLMClient.SingleCompletionQuery> queries, LLMClient.CompletionSettings settings) throws IOException {
        this.initOnce();
        try {
            logger.infoV("Sending %s completion queries, settings: %s", new Object[]{queries.size(), JSON.json((Object)settings)});
            List futures = queries.stream().map(query -> this.serverAPI.processAsync((LLMClient.SingleCompletionQuery)query, settings)).collect(Collectors.toList());
            List ret = DKUCompletableFuture.collectResponses(futures);
            logger.infoV("Received %s completion responses", new Object[]{ret.size()});
            return ret;
        }
        catch (Exception e) {
            throw new IOException("Processing failed", e);
        }
    }

    @Override
    public boolean supportsStream() {
        return true;
    }

    @Override
    public void streamComplete(LLMClient.SingleCompletionQuery query, LLMClient.CompletionSettings settings, LLMClient.StreamedCompletionResponseConsumer consumer) throws Exception {
        this.initOnce();
        logger.infoV("Sending stream completion query, settings: %s", new Object[]{JSON.json((Object)settings)});
        CompletableFuture<Integer> streamFuture = this.serverAPI.streamProcessAsync(query, settings, consumer);
        Integer nChunks = (Integer)DKUCompletableFuture.collectResponse(streamFuture);
        logger.infoV("Fully received streamed completion response with %s chunks", new Object[]{nChunks});
    }

    @Override
    public ComputeResourceUsage getTotalCRU(ComputeResourceUsage.LLMUsageType usageType, LLMStructuredRef llmRef) {
        return null;
    }

    @Override
    public EnrichedLLMStructuredRef getEnrichedRef() throws Exception {
        return ((LLMRefEnricherService)SpringUtils.getBean(LLMRefEnricherService.class)).getEnrichedLLMRefFromAgentSMVersion(this.authCtx, this.savedModel, this.smiv.versionId, this.savedModel.projectKey);
    }

    @Override
    public List<LLMClient.SimpleEmbeddingResponse> embedBatch(List<LLMClient.EmbeddingQuery> queries, LLMClient.EmbeddingSettings settings) throws IOException {
        throw new IllegalArgumentException("Embeddings not supported on this LLM");
    }
}

