/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.llm.local;

import com.dataiku.dip.connections.HuggingFaceLocalConnection;
import com.dataiku.dip.io.SimplePythonKernel;
import com.dataiku.dip.llm.io.PythonRequestUtils;
import com.dataiku.dip.llm.io.commands.ProcessSingleEmbeddingCommand;
import com.dataiku.dip.llm.io.commands.ProcessSinglePromptCommand;
import com.dataiku.dip.llm.online.LLMClient;
import com.dataiku.dip.utils.DKUCompletableFuture;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.DKUtils;
import com.dataiku.dip.utils.JSON;
import com.dataiku.dip.utils.polyjson.Mapping;
import com.dataiku.dip.utils.polyjson.PolyJSON;
import com.dataiku.dss.shadelib.org.apache.commons.codec.digest.DigestUtils;
import com.dataiku.j2py.annotations.PyModel;
import com.google.gson.JsonObject;
import java.io.Closeable;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.apache.commons.lang.StringUtils;

public class HuggingFaceKernelClient
implements Closeable {
    static final String VLLM_VERSION = "0.10.1.1";
    private final SimplePythonKernel kernel;
    private static final DKULogger logger = DKULogger.getLogger((String)"dku.llm.huggingface.kernel.client");

    public HuggingFaceKernelClient(SimplePythonKernel kernel) {
        this.kernel = kernel;
    }

    public boolean isAlive() {
        return this.kernel != null && this.kernel.isAlive();
    }

    public String getKernelId() {
        return this.kernel.getId();
    }

    public DKUtils.SmartLogTailBuilder getSmartLogTailBuilder() {
        return this.kernel.getSmartLogTailBuilder();
    }

    @Override
    public void close() throws IOException {
        if (this.kernel != null) {
            try {
                this.kernel.close();
            }
            catch (Throwable e) {
                logger.error((Object)"Failed to kill kernel", e);
            }
        }
    }

    public CompletableFuture<Integer> asyncStreamComplete(LLMClient.SingleCompletionQuery query, LLMClient.CompletionSettings settings, LLMClient.StreamedCompletionResponseConsumer consumer) {
        return PythonRequestUtils.asyncStreamRequest(this.kernel.getAsyncLink(), query, settings, consumer);
    }

    public JsonObject collectTrackingData() {
        CollectTrackingData command = new CollectTrackingData();
        return ((CollectTrackingDataResponse)this.kernel.getAsyncLink().request((Object)command, CollectTrackingDataResponse.class)).trackingData;
    }

    public CompletableFuture<LLMClient.SimpleCompletionResponse> asyncComplete(LLMClient.SingleCompletionQuery query, LLMClient.CompletionSettings settings) {
        return this.kernel.getAsyncLink().asyncSendRequest((Object)new ProcessSinglePromptCommand(query, settings, false), ProcessSinglePromptResponse.class, ProcessSinglePromptResponse::toSimpleCompletionResponse);
    }

    public List<LLMClient.SimpleCompletionResponse> complete(List<LLMClient.SingleCompletionQuery> queries, LLMClient.CompletionSettings settings) throws Exception {
        List futures = queries.stream().map(query -> this.asyncComplete((LLMClient.SingleCompletionQuery)query, settings)).collect(Collectors.toList());
        return DKUCompletableFuture.collectResponses(futures);
    }

    public CompletableFuture<LLMClient.SimpleEmbeddingResponse> asyncEmbed(LLMClient.EmbeddingQuery query, LLMClient.EmbeddingSettings settings) {
        ProcessSingleEmbeddingCommand command = new ProcessSingleEmbeddingCommand(query, settings);
        return this.kernel.getAsyncLink().asyncStreamRequest((Object)command, LLMClient.SimpleEmbeddingResponse.class).last().toFuture();
    }

    public List<LLMClient.SimpleEmbeddingResponse> embed(List<LLMClient.EmbeddingQuery> queries, LLMClient.EmbeddingSettings settings) throws Exception {
        List futures = queries.stream().map(query -> this.asyncEmbed((LLMClient.EmbeddingQuery)query, settings)).collect(Collectors.toList());
        return DKUCompletableFuture.collectResponses(futures);
    }

    public CompletableFuture<LLMClient.ImageGenerationResponse> asyncGenerateImages(LLMClient.ImageGenerationQuery query) {
        ProcessSingleImageGenerationCommand command = new ProcessSingleImageGenerationCommand(query);
        return this.kernel.getAsyncLink().asyncSendRequest((Object)command, ImageGenerationKernelResponse.class, ImageGenerationKernelResponse::toPublicApiResponse);
    }

    public LLMClient.ImageGenerationResponse generateImages(LLMClient.ImageGenerationQuery query) throws Exception {
        return this.asyncGenerateImages(query).get();
    }

    public static class CollectTrackingData
    extends HuggingFaceKernelCommand {
    }

    @PyModel
    public static class CollectTrackingDataResponse {
        public JsonObject trackingData;
    }

    @PyModel
    private static class ProcessSinglePromptResponse {
        @Nullable
        String text;
        @Nullable
        ZeroShotClassificationResponse classification;
        @Nullable
        UsageData usage;
        @Nullable
        List<LLMClient.DetailedLogProb> logProbs;
        @Nullable
        LLMClient.FinishReason finishReason;
        @Nullable
        public List<LLMClient.AbstractToolCall> toolCalls;

        private ProcessSinglePromptResponse() {
        }

        private LLMClient.SimpleCompletionResponse toSimpleCompletionResponse() {
            LLMClient.SimpleCompletionResponse response = new LLMClient.SimpleCompletionResponse();
            if (this.usage != null) {
                response.promptTokens = this.usage.promptTokens;
                response.completionTokens = this.usage.completionTokens;
                response.totalTokens = this.usage.promptTokens + this.usage.completionTokens;
            }
            if (this.finishReason != null) {
                response.finishReason = this.finishReason;
            }
            if (this.text != null) {
                response.text = this.text;
                response.logProbs = this.logProbs;
            }
            if (this.classification != null) {
                if (this.classification.labels != null && this.classification.labels.size() > 0) {
                    response.predictedClass = this.classification.labels.get(0);
                    response.predictedClassProbas = new ArrayList<LLMClient.PredictedClassProba>();
                    for (int i = 0; i < this.classification.labels.size(); ++i) {
                        LLMClient.PredictedClassProba pcp = new LLMClient.PredictedClassProba();
                        pcp.className = this.classification.labels.get(i);
                        if (this.classification.scores != null && this.classification.scores.size() > i) {
                            pcp.proba = this.classification.scores.get(i);
                        }
                        response.predictedClassProbas.add(pcp);
                    }
                } else {
                    response.text = "Missing labels";
                }
            }
            if (this.toolCalls != null) {
                response.toolCalls = this.toolCalls;
            }
            return response;
        }
    }

    private static class ProcessSingleImageGenerationCommand
    extends ProcessSingleCommand {
        List<LLMClient.ImageGenerationPrompt> promptTexts;
        List<LLMClient.ImageGenerationPrompt> negativePromptTexts;
        Integer numInferenceSteps;
        Integer seed;
        Double fidelity;
        Integer numImagesPerPrompt;
        Integer height;
        Integer width;

        private ProcessSingleImageGenerationCommand(LLMClient.ImageGenerationQuery query) {
            this.promptTexts = query.prompts;
            this.negativePromptTexts = query.negativePrompts;
            this.numInferenceSteps = StringUtils.isNumeric((String)query.quality) ? Integer.valueOf(Integer.parseInt(query.quality)) : null;
            this.seed = query.seed;
            this.fidelity = query.fidelity;
            this.numImagesPerPrompt = query.nbImagesToGenerate;
            this.height = query.height;
            this.width = query.width;
        }

        private ProcessSingleImageGenerationCommand() {
        }
    }

    @PyModel
    private static class ImageGenerationKernelResponse {
        List<String> images = new ArrayList<String>();

        private ImageGenerationKernelResponse() {
        }

        private LLMClient.ImageGenerationResponse toPublicApiResponse() {
            LLMClient.ImageGenerationResponse publicResponse = new LLMClient.ImageGenerationResponse();
            this.images.forEach(image -> publicResponse.images.add(new LLMClient.ImageGenerationImage((String)image)));
            return publicResponse;
        }
    }

    private static class ZeroShotClassificationResponse {
        public List<String> labels;
        public List<Double> scores;

        private ZeroShotClassificationResponse() {
        }
    }

    @PyModel
    private static class UsageData {
        int promptTokens;
        int completionTokens;

        private UsageData() {
        }
    }

    public static class StartCommand
    extends HuggingFaceKernelCommand {
        private final String vllmVersion = "0.10.1.1";
        @Nullable
        public String hfApiKey;
        public boolean useDSSModelCache;
        public ModelOrigin modelOrigin;
        public String hfModelName;
        public String hfModelPath;
        public String savedModelVersionPath;
        public String savedModelProjectKey;
        public String savedModelId;
        public String baseModelName;
        public String baseModelPath;
        public HuggingFaceLocalConnection.HuggingFaceHandlingMode hfHandlingMode;
        public HuggingFaceLocalConnection.InferenceSettings modelSettings;
        public boolean supportsImageInputs;
        public Integer batchSize;
        boolean fakeLLMServer;

        public StartCommand(HuggingFaceLocalConnection.HuggingFaceHandlingMode hfHandlingMode, HuggingFaceLocalConnection connection, HuggingFaceLocalConnection.InferenceSettings modelSettings, Integer batchSize, boolean supportsImageInputs, boolean fakeLLMServer) {
            this.hfHandlingMode = hfHandlingMode;
            this.hfApiKey = connection.params.apiKey;
            this.useDSSModelCache = connection.params.useDSSModelCache;
            this.modelSettings = modelSettings;
            this.batchSize = batchSize;
            this.supportsImageInputs = supportsImageInputs;
            this.fakeLLMServer = fakeLLMServer;
        }

        public void setModelPath(String modelPath, String baseModelPath) {
            this.hfModelPath = modelPath;
            this.baseModelPath = baseModelPath;
        }

        public void setSavedModelInfo(String projectKey, String id, String modelVersionPath, String baseModelName) {
            this.modelOrigin = ModelOrigin.SAVED_MODEL_VERSION;
            this.savedModelProjectKey = projectKey;
            this.savedModelId = id;
            this.savedModelVersionPath = modelVersionPath;
            this.baseModelName = baseModelName;
        }

        public void setHuggingFaceModelInfo(String modelName) {
            this.modelOrigin = ModelOrigin.HUGGINGFACE_MODEL;
            this.hfModelName = modelName;
        }

        private StartCommand() {
        }
    }

    public static abstract class ProcessSingleCommand
    extends HuggingFaceKernelCommand {
    }

    @PolyJSON(value={@Mapping(value=StartCommand.class, type="start"), @Mapping(value=CollectTrackingData.class, type="collect-env"), @Mapping(value=ProcessSingleImageGenerationCommand.class, type="process-image-generation-query")})
    @PyModel
    public static abstract class HuggingFaceKernelCommand {
    }

    public static class KernelConfig {
        public final String hfConnectionName;
        public final String codeEnvName;
        public final String containerConfName;
        public final String clusterId;
        @Nullable
        public final Map<String, String> extraEnv;
        public final StartCommand startCommand;

        public KernelConfig(String hfConnectionName, String codeEnvName, String containerConfName, String clusterId, StartCommand startCommand, @Nullable Map<String, String> extraEnv) {
            this.hfConnectionName = hfConnectionName;
            this.codeEnvName = codeEnvName;
            this.containerConfName = containerConfName;
            this.clusterId = clusterId;
            this.extraEnv = extraEnv;
            this.startCommand = startCommand;
        }

        public String toShortHash() {
            return DigestUtils.sha256Hex((String)JSON.json((Object)this)).substring(0, 10) + "/" + this.startCommand.hfModelName;
        }
    }

    public static enum ModelOrigin {
        SAVED_MODEL_VERSION,
        HUGGINGFACE_MODEL;

    }
}

