/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.llm.local;

import com.dataiku.dip.DKUApp;
import com.dataiku.dip.connections.AbstractLLMConnection;
import com.dataiku.dip.connections.HuggingFaceLocalConnection;
import com.dataiku.dip.io.KubernetesSimplePythonKernel;
import com.dataiku.dip.io.SimplePythonKernel;
import com.dataiku.dip.llm.LLMModelHandle;
import com.dataiku.dip.llm.LLMStructuredRef;
import com.dataiku.dip.llm.local.HuggingFaceKernelBuilder;
import com.dataiku.dip.llm.local.HuggingFaceKernelClient;
import com.dataiku.dip.llm.online.AbstractLLMClient;
import com.dataiku.dip.llm.online.LLMChatMessageUtils;
import com.dataiku.dip.llm.online.LLMClient;
import com.dataiku.dip.reports.IReflectedEventsService;
import com.dataiku.dip.resourceusage.ComputeResourceUsage;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.server.SpringUtils;
import com.dataiku.dip.util.SecretKeyGenerator;
import com.dataiku.dip.utils.DKUCompletableFuture;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.DKUtils;
import com.dataiku.dip.utils.ExceptionUtils;
import com.dataiku.dip.utils.JSON;
import com.dataiku.dip.utils.SmartLogTail;
import com.dataiku.dss.shadelib.org.apache.commons.lang3.exception.ExceptionUtils;
import com.google.gson.JsonObject;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;

public class HuggingFaceLocalClient
extends AbstractLLMClient
implements LLMClient {
    private final HuggingFaceLocalConnection connection;
    private final HuggingFaceLocalConnection.HFLocalModel model;
    private final HuggingFaceKernelBuilder kernelBuilder;
    private final AuthCtx authCtx;
    private final String projectKey;
    private final HuggingFaceKernelClient.KernelConfig kernelConfig;
    private final String kernelId;
    private final ExecutorService startThread = Executors.newSingleThreadExecutor();
    private final boolean forReservedCapacity;
    private HuggingFaceKernelClient kernelClient;
    private DKUtils.SmartLogTailBuilder smartLogTailBuilder = new DKUtils.SmartLogTailBuilder();
    private String podName;
    private String usedEngine;
    private static final DKULogger logger = DKULogger.getLogger((String)"dku.llm.huggingface");
    private static final int MAX_JSON_SAMPLE_LOG_LENGTH = 120;

    public HuggingFaceLocalClient(AuthCtx authCtx, String projectKey, HuggingFaceLocalConnection connection, LLMModelHandle<HuggingFaceLocalConnection.HFLocalModel> modelHandle, HuggingFaceKernelBuilder kernelBuilder, HuggingFaceKernelClient.KernelConfig kernelConfig, boolean forReservedCapacity) {
        super(modelHandle.getEnrichedRef());
        this.kernelId = "llm-hf-" + SecretKeyGenerator.generateSmall();
        this.authCtx = authCtx;
        this.connection = connection;
        this.model = modelHandle.getModel();
        this.kernelBuilder = kernelBuilder;
        this.projectKey = projectKey;
        this.kernelConfig = kernelConfig;
        this.forReservedCapacity = forReservedCapacity;
    }

    public void startKernel() throws Exception {
        try {
            this.startThread.submit(() -> {
                SimplePythonKernel kernel = this.kernelBuilder.createKernel(this.authCtx, this.projectKey, this.kernelConfig, this.kernelId);
                this.smartLogTailBuilder = kernel.getSmartLogTailBuilder();
                this.smartLogTailBuilder.setMaxLines(500);
                if (kernel instanceof KubernetesSimplePythonKernel) {
                    KubernetesSimplePythonKernel kubeKernel = (KubernetesSimplePythonKernel)kernel;
                    this.podName = kubeKernel.getPodName();
                }
                this.kernelBuilder.startKernel(kernel, this.kernelConfig);
                this.kernelClient = new HuggingFaceKernelClient(kernel);
                JsonObject trackingData = this.kernelClient.collectTrackingData();
                this.usedEngine = this.kernelClient.getUsedEngine();
                trackingData.addProperty("exec", kernel.getType());
                trackingData.addProperty("reserved_capacity", Boolean.valueOf(this.forReservedCapacity));
                IReflectedEventsService.ReflectedEvent event = new IReflectedEventsService.ReflectedEvent("hf-kernel-started", trackingData);
                ((IReflectedEventsService)SpringUtils.getBean(IReflectedEventsService.class)).publish(event);
                return null;
            }).get();
        }
        catch (InterruptedException | ExecutionException e) {
            for (String s : ExceptionUtils.getRootCauseStackTrace((Throwable)e)) {
                this.smartLogTailBuilder.appendLine(s);
            }
            this.close();
            throw new IOException("Failed to start HuggingFace LLM", e instanceof ExecutionException ? e.getCause() : e);
        }
    }

    @Override
    public void close() throws Exception {
        this.startThread.shutdownNow();
        if (!this.startThread.awaitTermination(DKUApp.getParams().getIntParam("dku.hflocalclient.startThreadTimeout", Integer.valueOf(120)), TimeUnit.SECONDS)) {
            logger.warn((Object)"Timeout while waiting for kernel starting thread to terminate, closing HF local client anyway but resources may have not yet been freed");
        }
        if (this.kernelClient != null) {
            this.kernelClient.close();
        }
    }

    public boolean isAlive() {
        return this.kernelClient != null && this.kernelClient.isAlive();
    }

    @Override
    public boolean supportNativeBatch() {
        return false;
    }

    @Override
    public boolean requiresCostLimiting() {
        return false;
    }

    @Override
    public String getProviderId() {
        return "HuggingFaceLocal";
    }

    @Override
    public HuggingFaceLocalConnection getConnection() {
        return this.connection;
    }

    @Override
    public int getMaxParallelism() {
        return 1;
    }

    @Override
    public int getBatchSize(AbstractLLMConnection.QueryType queryType, LLMStructuredRef llmRef) {
        return 1;
    }

    public static int getBatchSize(HuggingFaceLocalConnection.HFLocalModel model, AbstractLLMConnection.QueryType queryType) {
        if (model.getBatchSize().isPresent()) {
            logger.info((Object)"Using user defined batch size");
            return model.getBatchSize().getAsInt();
        }
        switch (queryType) {
            case textEmbedding: 
            case imageEmbedding: {
                return 8;
            }
        }
        return 2;
    }

    public String getUsedEngine() {
        return this.usedEngine;
    }

    public String getKernelId() {
        return this.kernelId;
    }

    public String getPodName() {
        return this.podName;
    }

    private static List<LLMClient.ChatMessage> getFormattedChatMessages(List<LLMClient.ChatMessage> chatMessages, HuggingFaceLocalConnection.HuggingFaceHandlingMode handlingMode) {
        switch (handlingMode) {
            case TEXT_GENERATION_LLAMA_2: 
            case TEXT_GENERATION_LLAMA_GUARD: {
                chatMessages = LLMChatMessageUtils.convertExtraSystemMessageToUser(chatMessages);
                return LLMChatMessageUtils.collapseAdjacentSameRoleMessages(chatMessages);
            }
            case TEXT_GENERATION_MISTRAL: {
                chatMessages = LLMChatMessageUtils.convertMessageRole(chatMessages, "system", "user");
                return LLMChatMessageUtils.collapseAdjacentSameRoleMessages(chatMessages);
            }
        }
        return chatMessages;
    }

    @Override
    public List<LLMClient.ChatMessage> getFormattedPrompt(List<LLMClient.ChatMessage> chatMessages) {
        return HuggingFaceLocalClient.getFormattedPrompt(chatMessages, this.model.handlingMode);
    }

    @Override
    public SmartLogTail getKernelLog() {
        return this.smartLogTailBuilder.get();
    }

    public String getModelId() {
        return this.model.getId();
    }

    public static List<LLMClient.ChatMessage> getFormattedPrompt(List<LLMClient.ChatMessage> chatMessages, HuggingFaceLocalConnection.HuggingFaceHandlingMode handlingMode) {
        chatMessages = HuggingFaceLocalClient.getFormattedChatMessages(chatMessages, handlingMode);
        ArrayList<LLMClient.ChatMessage> formattedPromptMessages = new ArrayList<LLMClient.ChatMessage>();
        LLMClient.ChatMessage formattedPrompt = new LLMClient.ChatMessage("prompt", HuggingFaceLocalClient.getFormattedPromptContent(chatMessages, handlingMode));
        formattedPromptMessages.add(formattedPrompt);
        return formattedPromptMessages;
    }

    public static String getFormattedPromptContent(List<LLMClient.ChatMessage> chatMessages, HuggingFaceLocalConnection.HuggingFaceHandlingMode handlingMode) {
        StringBuilder sb = new StringBuilder();
        switch (handlingMode) {
            case TEXT_GENERATION_LLAMA_2: 
            case TEXT_GENERATION_LLAMA_GUARD: {
                sb.append("<s>[INST] ");
                for (int idx = 0; idx < chatMessages.size(); ++idx) {
                    LLMClient.ChatMessage msg = chatMessages.get(idx);
                    if ("system".equals(msg.role)) {
                        sb.append("<<SYS>>\n");
                        sb.append(msg.getText());
                        sb.append("\n<</SYS>>\n\n");
                    }
                    if ("user".equals(msg.role)) {
                        sb.append(msg.getText());
                        sb.append(" [/INST]");
                    }
                    if (!"assistant".equals(msg.role)) continue;
                    sb.append(" ");
                    sb.append(msg.getText());
                    sb.append("</s>");
                    if (idx >= chatMessages.size() - 1) continue;
                    sb.append("<s>[INST]");
                }
                break;
            }
            case TEXT_GENERATION_LLAMA_3: {
                sb.append("<|begin_of_text|>");
                for (LLMClient.ChatMessage msg : chatMessages) {
                    sb.append("<|start_header_id|>").append(msg.role).append("<|end_header_id|>");
                    sb.append("\n\n");
                    sb.append(msg.getText());
                    sb.append("<|eot_id|>");
                }
                sb.append("<|start_header_id|>assistant<|end_header_id|>\n");
                break;
            }
            case TEXT_GENERATION_PHI_3: {
                sb.append("<s>");
                for (LLMClient.ChatMessage msg : chatMessages) {
                    sb.append("<|").append(msg.role).append("|>");
                    sb.append(msg.getText()).append("<|end|>\n");
                }
                sb.append("<|assistant|>");
                break;
            }
            case TEXT_GENERATION_MPT: {
                for (LLMClient.ChatMessage msg : chatMessages) {
                    if ("system".equals(msg.role)) {
                        sb.append(msg.getText());
                    }
                    if ("user".equals(msg.role)) {
                        sb.append("### Instruction:\n");
                        sb.append(msg.getText());
                    }
                    if ("assistant".equals(msg.role)) {
                        sb.append("### Response:\n");
                        sb.append(msg.getText());
                    }
                    sb.append("\n");
                }
                sb.append("### Response:\n");
                break;
            }
            case TEXT_GENERATION_FALCON: {
                sb.append(">>TITLE<<\nFlawless answer\n");
                for (LLMClient.ChatMessage msg : chatMessages) {
                    if ("system".equals(msg.role)) {
                        sb.append(">>CONTEXT<<");
                        sb.append(msg.getText());
                    }
                    if ("user".equals(msg.role)) {
                        sb.append(">>QUESTION<<");
                        sb.append(msg.getText());
                    }
                    if ("assistant".equals(msg.role)) {
                        sb.append(">>ANSWER<<");
                        sb.append(msg.getText());
                    }
                    sb.append("\n");
                }
                sb.append(">>ANSWER<<\n");
                break;
            }
            case TEXT_GENERATION_ZEPHYR: {
                for (LLMClient.ChatMessage msg : chatMessages) {
                    sb.append("<|").append(msg.role).append("|>\n");
                    sb.append(msg.getText());
                    sb.append("</s>\n");
                }
                sb.append("<|assistant|>\n");
                break;
            }
            case TEXT_GENERATION_MISTRAL: {
                sb.append("<s>[INST] ");
                for (int idx = 0; idx < chatMessages.size(); ++idx) {
                    LLMClient.ChatMessage msg = chatMessages.get(idx);
                    if ("system".equals(msg.role)) {
                        sb.append("\n");
                        sb.append(msg.getText());
                        sb.append("\n");
                    }
                    if ("user".equals(msg.role)) {
                        sb.append(msg.getText());
                        sb.append(" [/INST]");
                    }
                    if (!"assistant".equals(msg.role)) continue;
                    sb.append(msg.getText());
                    sb.append("</s> ");
                    if (idx >= chatMessages.size() - 1) continue;
                    sb.append("[INST] ");
                }
                break;
            }
            case TEXT_GENERATION_GEMMA: {
                sb.append("<bos>");
                for (int idx = 0; idx < chatMessages.size(); ++idx) {
                    LLMClient.ChatMessage msg = chatMessages.get(idx);
                    if ("system".equals(msg.role) || "user".equals(msg.role)) {
                        sb.append("<start_of_turn>user\n");
                        sb.append(msg.getText());
                        sb.append("<end_of_turn>\n");
                    }
                    if (!"assistant".equals(msg.role)) continue;
                    sb.append("<start_of_turn>model\n");
                    sb.append(msg.getText());
                    sb.append("<end_of_turn>\n");
                }
                sb.append("<start_of_turn>model");
                break;
            }
            case TEXT_GENERATION_DEEPSEEK: 
            case TEXT_GENERATION_GPT: 
            case TEXT_GENERATION_QWEN: 
            case TEXT_GENERATION_OPENBMB: 
            case TEXT_GENERATION_GENERIC: {
                sb.append(chatMessages.stream().map(m -> m.getText()).collect(Collectors.joining("\n")));
                break;
            }
            case TEXT_GENERATION_AUTO: {
                for (LLMClient.ChatMessage msg : chatMessages) {
                    if ("system".equals(msg.role)) {
                        sb.append("System message: ");
                        sb.append(msg.getText());
                    }
                    if ("user".equals(msg.role)) {
                        sb.append("User message: ");
                        sb.append(msg.getText());
                    }
                    if ("assistant".equals(msg.role)) {
                        sb.append("Assistant message: ");
                        sb.append(msg.getText());
                    }
                    sb.append("\n");
                }
                break;
            }
            case T5: {
                sb.append(chatMessages.stream().map(m -> m.getText()).collect(Collectors.joining("\n")));
                break;
            }
            default: {
                throw new IllegalArgumentException("Not a handling mode that needs to format a prompt: " + String.valueOf((Object)handlingMode));
            }
        }
        return sb.toString();
    }

    public CompletableFuture<LLMClient.SimpleCompletionResponse> asyncComplete(LLMClient.SingleCompletionQuery query, LLMClient.CompletionSettings settings) {
        return this.kernelClient.asyncComplete(query, settings);
    }

    @Override
    public List<LLMClient.SimpleCompletionResponse> completeBatch(List<LLMClient.SingleCompletionQuery> queries, LLMClient.CompletionSettings settings) throws IOException {
        try {
            logger.infoV("Sending %s completion queries, settings: %s", new Object[]{queries.size(), JSON.json((Object)settings)});
            List<LLMClient.SimpleCompletionResponse> ret = this.kernelClient.complete(queries, settings);
            ret.forEach(response -> {
                response.estimatedCost = this.model.getEstimatedCompletionCost(response.promptTokens, response.completionTokens);
            });
            logger.infoV("Received %s completion responses", new Object[]{ret.size()});
            return ret;
        }
        catch (Exception e) {
            for (String s : ExceptionUtils.getRootCauseStackTrace((Throwable)e)) {
                this.smartLogTailBuilder.appendLine(s);
            }
            throw new IOException("Failed to query HuggingFace LLM", e);
        }
    }

    @Override
    public boolean supportsStream() {
        return true;
    }

    @Override
    public void streamComplete(LLMClient.SingleCompletionQuery query, LLMClient.CompletionSettings settings, LLMClient.StreamedCompletionResponseConsumer consumer) throws Exception {
        logger.infoV("Sending stream completion query, settings: %s", new Object[]{JSON.json((Object)settings)});
        LLMClient.StreamedCompletionResponseConsumerProxy wrappedConsumer = new LLMClient.StreamedCompletionResponseConsumerProxy(consumer, (ExceptionUtils.ThrowingConsumer<LLMClient.StreamedCompletionResponseFooter, Exception>)((ExceptionUtils.ThrowingConsumer)footer -> {
            footer.estimatedCost = this.model.getEstimatedCompletionCost(footer.promptTokens, footer.completionTokens);
        }));
        CompletableFuture<Integer> streamFuture = this.kernelClient.asyncStreamComplete(query, settings, wrappedConsumer);
        Integer nChunks = (Integer)DKUCompletableFuture.collectResponse(streamFuture);
        logger.infoV("Fully received streamed completion response with %s chunks", new Object[]{nChunks});
    }

    public CompletableFuture<Integer> asyncStreamComplete(LLMClient.SingleCompletionQuery query, LLMClient.CompletionSettings settings, LLMClient.StreamedCompletionResponseConsumer consumer) {
        return this.kernelClient.asyncStreamComplete(query, settings, consumer);
    }

    public CompletableFuture<LLMClient.SimpleEmbeddingResponse> asyncEmbed(LLMClient.EmbeddingQuery query, LLMClient.EmbeddingSettings settings) {
        return this.kernelClient.asyncEmbed(query, settings);
    }

    public CompletableFuture<LLMClient.SingleRerankingResponse> asyncRerank(LLMClient.RerankingQuery query, LLMClient.RerankingSettings settings) {
        return this.kernelClient.asyncRerank(query, settings);
    }

    @Override
    public List<LLMClient.SimpleEmbeddingResponse> embedBatch(List<LLMClient.EmbeddingQuery> queries, LLMClient.EmbeddingSettings settings) throws IOException {
        try {
            logger.infoV("Sending %s embedding queries, settings: %s", new Object[]{queries.size(), JSON.json((Object)settings)});
            List<LLMClient.SimpleEmbeddingResponse> ret = this.kernelClient.embed(queries, settings);
            for (int i = 0; i < ret.size(); ++i) {
                int promptTokens = ret.get((int)i).promptTokens;
                int imageCount = queries.get(i).hasImage() ? 1 : 0;
                ret.get((int)i).estimatedCost = this.model.getEstimatedEmbeddingCost(promptTokens, imageCount);
            }
            logger.infoV("Received %s embedding responses", new Object[]{ret.size()});
            return ret;
        }
        catch (Exception e) {
            for (String s : ExceptionUtils.getRootCauseStackTrace((Throwable)e)) {
                this.smartLogTailBuilder.appendLine(s);
            }
            throw new IOException("Failed to query HuggingFace LLM", e);
        }
    }

    @Override
    public List<LLMClient.SingleRerankingResponse> rerankBatch(List<LLMClient.RerankingQuery> queries, LLMClient.RerankingSettings settings) throws Exception {
        throw new IllegalArgumentException("Reranking not supported on this LLM");
    }

    public CompletableFuture<LLMClient.ImageGenerationResponse> asyncGenerateImages(LLMClient.ImageGenerationQuery query) {
        return this.kernelClient.asyncGenerateImages(query);
    }

    @Override
    public LLMClient.ImageGenerationResponse generateImages(LLMClient.ImageGenerationQuery query) throws Exception {
        try {
            logger.infoV("Sending '%s' image generation query.", new Object[]{JSON.sampleJson((Object)query, (int)120)});
            LLMClient.ImageGenerationResponse response = this.kernelClient.generateImages(query);
            response.estimatedCost = this.model.getEstimatedImageGenerationCost(query);
            logger.infoV("Received %s generated images", new Object[]{response.images.size()});
            return response;
        }
        catch (Exception e) {
            for (String s : ExceptionUtils.getRootCauseStackTrace((Throwable)e)) {
                this.smartLogTailBuilder.appendLine(s);
            }
            throw new IOException("Failed to query HuggingFace LLM", e);
        }
    }

    @Override
    public ComputeResourceUsage getTotalCRU(ComputeResourceUsage.LLMUsageType usageType, LLMStructuredRef llmRef) {
        return null;
    }
}

