/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.llm.custom;

import com.dataiku.dip.code.CodeEnvSelector;
import com.dataiku.dip.connections.CustomLLMConnection;
import com.dataiku.dip.custom.PluginSettingsResolver;
import com.dataiku.dip.dao.GeneralSettingsDAO;
import com.dataiku.dip.exceptions.CodedException;
import com.dataiku.dip.exceptions.DKUSecurityException;
import com.dataiku.dip.io.CustomPythonKernelException;
import com.dataiku.dip.io.SimplePythonKernel;
import com.dataiku.dip.io.SimplePythonKernelFactory;
import com.dataiku.dip.llm.custom.CustomPythonLLMsService;
import com.dataiku.dip.llm.custom.LoadedPythonLLM;
import com.dataiku.dip.llm.io.PythonRequestUtils;
import com.dataiku.dip.llm.io.commands.ProcessSingleEmbeddingCommand;
import com.dataiku.dip.llm.io.commands.ProcessSingleImageGenerationCommand;
import com.dataiku.dip.llm.io.commands.ProcessSinglePromptCommand;
import com.dataiku.dip.llm.io.commands.StartResponse;
import com.dataiku.dip.llm.online.LLMClient;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.security.DSSAuthCtx;
import com.dataiku.dip.server.SpringUtils;
import com.dataiku.dip.util.SecretKeyGenerator;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.DKUtils;
import com.dataiku.dip.utils.JSON;
import com.dataiku.j2py.annotations.PyModel;
import com.google.gson.JsonObject;
import java.io.IOException;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import org.springframework.beans.factory.annotation.Autowired;

public class CustomPythonLLMClient {
    @Autowired
    private CustomPythonLLMsService service;
    private final String kernelId;
    private final LoadedPythonLLM loaded;
    private final CustomLLMConnection.Capability capability;
    private final PluginSettingsResolver.ResolvedSettings settings;
    private final String projectKey;
    private final AuthCtx authCtx;
    private SimplePythonKernel kernel;
    private static final DKULogger logger = DKULogger.getLogger((String)"dku.llm.plugin.python.client");

    public CustomPythonLLMClient(AuthCtx authCtx, String projectKey, CustomLLMConnection.Capability capability, PluginSettingsResolver.ResolvedSettings settings, LoadedPythonLLM loaded) {
        this.authCtx = authCtx;
        this.projectKey = projectKey;
        this.capability = capability;
        this.settings = settings;
        this.loaded = loaded;
        this.kernelId = "llm-custompython-" + SecretKeyGenerator.generateSmall();
        SpringUtils.getInstance().autowire((Object)this);
    }

    private SimplePythonKernel createPythonKernel() throws IOException, CodedException, DKUSecurityException {
        String envName = new CodeEnvSelector().selectForCustomPythonRecipe(this.loaded.getOwnerPluginId());
        Map<String, String> pluginLib = Map.of("plugin-libs", this.service.getLibFolder(this.loaded.getType()));
        return SimplePythonKernelFactory.prepareKernel((DSSAuthCtx)this.authCtx, this.projectKey, GeneralSettingsDAO.CGrouppableProcessType.CUSTOM_PYTHON_DATA_ACCESS_COMPONENT, envName, "dataiku.llm.python.llm_plugin_server", false, false, pluginLib, null, this.kernelId, false, Map.of("DKU_CUSTOM_RESOURCE_FOLDER", this.service.getResourceFolder(this.loaded.getType())));
    }

    void init() throws Exception {
        if (this.kernel != null) {
            return;
        }
        SimplePythonKernel kernel = this.createPythonKernel();
        kernel.start();
        try {
            StartCustomLLMServerCommand startCommand = new StartCustomLLMServerCommand(this.service.getCode(this.loaded.getType()), this.capability, this.settings.config, this.settings.pluginConfig);
            StartResponse response = (StartResponse)kernel.getAsyncLink().request((Object)startCommand, StartResponse.class);
            if (!response.ok) {
                throw new Exception("Kernel failed to start");
            }
        }
        catch (Exception e) {
            DKUtils.SmartLogTailBuilder sltb = kernel.getSmartLogTailBuilder();
            logger.info((Object)("Start failed, with a SLTB: " + String.valueOf(sltb)));
            if (sltb != null) {
                logger.info((Object)("SLTB:" + System.lineSeparator() + JSON.json((Object)sltb.get())));
                throw new CustomPythonKernelException("Failed to start custom LLM server", e, sltb.get());
            }
            throw e;
        }
        this.kernel = kernel;
    }

    public SimplePythonKernel getKernel() {
        return this.kernel;
    }

    public String getKernelId() {
        return this.kernelId;
    }

    public CompletableFuture<LLMClient.SimpleCompletionResponse> asyncComplete(LLMClient.SingleCompletionQuery query, LLMClient.CompletionSettings settings) {
        ProcessSinglePromptCommand command = new ProcessSinglePromptCommand(query, settings, false);
        return this.kernel.getAsyncLink().asyncStreamRequest((Object)command, LLMClient.SimpleCompletionResponse.class).last().toFuture();
    }

    public CompletableFuture<Integer> streamCompleteAsync(LLMClient.SingleCompletionQuery query, LLMClient.CompletionSettings settings, LLMClient.StreamedCompletionResponseConsumer consumer) {
        return PythonRequestUtils.asyncStreamRequest(this.kernel.getAsyncLink(), query, settings, consumer);
    }

    public CompletableFuture<LLMClient.SimpleEmbeddingResponse> asyncEmbed(LLMClient.EmbeddingQuery query, LLMClient.EmbeddingSettings settings) {
        ProcessSingleEmbeddingCommand command = new ProcessSingleEmbeddingCommand(query, settings);
        return this.kernel.getAsyncLink().asyncStreamRequest((Object)command, LLMClient.SimpleEmbeddingResponse.class).last().toFuture();
    }

    public CompletableFuture<LLMClient.ImageGenerationResponse> asyncGenerateImages(LLMClient.ImageGenerationQuery query) {
        ProcessSingleImageGenerationCommand command = new ProcessSingleImageGenerationCommand(query);
        return this.kernel.getAsyncLink().asyncStreamRequest((Object)command, LLMClient.ImageGenerationResponse.class).last().toFuture();
    }

    public void close() throws IOException {
        this.kernel.close();
    }

    @PyModel
    private static class StartCustomLLMServerCommand {
        public final String type = "start-custom-llm-server";
        public String code;
        public CustomLLMConnection.Capability capability;
        public JsonObject config;
        public JsonObject pluginConfig;

        public StartCustomLLMServerCommand(String code, CustomLLMConnection.Capability capability, JsonObject config, JsonObject pluginConfig) {
            this.code = code;
            this.capability = capability;
            this.config = config;
            this.pluginConfig = pluginConfig;
        }

        StartCustomLLMServerCommand() {
        }
    }
}

