/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.llm.custom;

import com.dataiku.dip.ApplicationConfigurator;
import com.dataiku.dip.connections.AbstractLLMConnection;
import com.dataiku.dip.connections.ConnectionsDAO;
import com.dataiku.dip.connections.CustomLLMConnection;
import com.dataiku.dip.custom.PluginSettingsResolver;
import com.dataiku.dip.exceptions.DKUSecurityException;
import com.dataiku.dip.io.JavaBlockLink;
import com.dataiku.dip.kernel.KernelPool;
import com.dataiku.dip.kernel.KernelPoolThreadFactory;
import com.dataiku.dip.kernel.KernelScalingStrategyBuilder;
import com.dataiku.dip.llm.EnrichedLLMStructuredRef;
import com.dataiku.dip.llm.LLMStructuredRef;
import com.dataiku.dip.llm.custom.CustomLLMClient;
import com.dataiku.dip.llm.custom.CustomPythonLLMClient;
import com.dataiku.dip.llm.custom.CustomPythonLLMsService;
import com.dataiku.dip.llm.custom.LoadedPythonLLM;
import com.dataiku.dip.llm.online.AbstractLLMClient;
import com.dataiku.dip.llm.online.LLMClient;
import com.dataiku.dip.llm.online.LLMQueryRunner;
import com.dataiku.dip.resourceusage.ComputeResourceUsage;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.utils.DKUCompletableFuture;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.JSON;
import com.dataiku.dss.shadelib.com.google.common.util.concurrent.ThreadFactoryBuilder;
import com.dataiku.dss.shadelib.org.apache.commons.codec.digest.DigestUtils;
import com.dataiku.dss.shadelib.org.apache.commons.lang3.StringUtils;
import com.google.common.base.Stopwatch;
import com.google.gson.JsonObject;
import java.io.IOException;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executor;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import javax.annotation.Nonnull;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;

@Service
public class CustomPythonLLMKernelPool {
    @Autowired
    CustomPythonLLMsService customPythonLLMsService;
    private final KernelPool<CustomPythonLLMClient, KernelDesc, KernelDesc> manager;
    private final ExecutorService executorService = Executors.newCachedThreadPool(new ThreadFactoryBuilder().setNameFormat("custompython-kernel-pool-startstop-%d").build());
    private int skipSettingsResolve = 0;
    private static final Integer GLOBAL_MAX_KERNEL_COUNT = 128;
    private static final Integer KERNEL_TTL_SECONDS = 600;
    private static final Integer MAX_PARALLEL_REQUESTS = 256;
    private static final String RETRYABLE_EXCEPTION_PYTHON_PATH = "dataiku.llm.python.exception.RetryableException";
    private static final DKULogger logger = DKULogger.getLogger((String)"dku.llm.custompythonpool");

    private KernelDesc createKernelDesc(AuthCtx authCtx, CustomLLMConnection connection, String projectKey, CustomLLMConnection.LLMModel customModel, LoadedPythonLLM loaded, EnrichedLLMStructuredRef enrichedRef, CustomLLMConnection.CustomLLMModel customLLMModel) {
        KernelDesc desc = new KernelDesc();
        desc.authCtx = authCtx;
        desc.connection = connection;
        desc.projectKey = projectKey;
        desc.capability = customModel.capability;
        desc.loaded = loaded;
        desc.enrichedRef = enrichedRef;
        try {
            desc.settings = this.customPythonLLMsService.getExpandedPluginSettings(loaded.getType(), authCtx, projectKey, customModel.customConfig);
        }
        catch (DKUSecurityException | IOException e) {
            throw new RuntimeException(e);
        }
        desc.maxParallelism = customLLMModel.maxParallelism;
        desc.retrySettings = customLLMModel.getRetrySettings();
        return desc;
    }

    public AbstractLLMClient getClient(AuthCtx authCtx, CustomLLMConnection connection, String projectKey, CustomLLMConnection.LLMModel customModel, final LoadedPythonLLM loaded, EnrichedLLMStructuredRef enrichedRef) {
        final CustomLLMConnection.CustomLLMModel customLLMModel = connection.getModel(customModel.id);
        final KernelDesc desc = this.createKernelDesc(authCtx, connection, projectKey, customModel, loaded, enrichedRef, customLLMModel);
        final ComputeResourceUsage.InternalLLMUsageData usageData = new ComputeResourceUsage.InternalLLMUsageData();
        final LLMQueryRunner queryRunner = new LLMQueryRunner("CustomLLM:" + loaded.getOwnerPluginId(), enrichedRef, customModel, AbstractLLMConnection.HTTPBasedLLMNetworkSettings.of(customLLMModel.getRetrySettings()), throwable -> throwable instanceof CustomLLMClient.RetryableException);
        return new AbstractLLMClient(enrichedRef){
            private final DKUCompletableFuture.FutureCancellationTracker futureCancellationTracker;
            {
                super(enrichedRef);
                this.futureCancellationTracker = new DKUCompletableFuture.FutureCancellationTracker();
            }

            @Override
            public boolean supportNativeBatch() {
                return false;
            }

            @Override
            public boolean requiresCostLimiting() {
                return true;
            }

            @Override
            public String getProviderId() {
                return "CustomLLM:" + loaded.getOwnerPluginId();
            }

            @Override
            public AbstractLLMConnection getConnection() {
                return desc.connection;
            }

            @Override
            public int getMaxParallelism() {
                return customLLMModel.maxParallelism;
            }

            @Override
            public List<LLMClient.SimpleCompletionResponse> completeBatch(List<LLMClient.SingleCompletionQuery> queries, LLMClient.CompletionSettings settings) {
                try {
                    return queryRunner.run(() -> {
                        try {
                            return DKUCompletableFuture.collectResponsesNoException(queries.stream().map(query -> {
                                Stopwatch sw = Stopwatch.createStarted();
                                CompletableFuture future = this.futureCancellationTracker.track(() -> CustomPythonLLMKernelPool.this.manager.handle(kernel -> kernel.asyncComplete((LLMClient.SingleCompletionQuery)query, settings), (Object)desc, desc.toHash(), query));
                                future.thenAccept(response -> {
                                    if (response != null) {
                                        response.includeInUsageData(usageData, sw.elapsed(TimeUnit.MILLISECONDS));
                                    }
                                });
                                return future;
                            }).toList());
                        }
                        catch (Exception e) {
                            if (e instanceof JavaBlockLink.RequestFailedException) {
                                JavaBlockLink.RequestFailedException rfe = (JavaBlockLink.RequestFailedException)e;
                                if (CustomPythonLLMKernelPool.RETRYABLE_EXCEPTION_PYTHON_PATH.equals(rfe.pythonExceptionPath)) {
                                    throw new CustomLLMClient.RetryableException(e);
                                }
                            }
                            throw e;
                        }
                    });
                }
                catch (Exception e) {
                    throw new RuntimeException(e);
                }
            }

            @Override
            public boolean supportsStream() {
                return true;
            }

            @Override
            public void streamComplete(LLMClient.SingleCompletionQuery query, LLMClient.CompletionSettings settings, final LLMClient.StreamedCompletionResponseConsumer consumer) throws Exception {
                queryRunner.run(() -> {
                    final Stopwatch sw = Stopwatch.createStarted();
                    final AtomicBoolean receivedOneChunk = new AtomicBoolean(false);
                    LLMClient.StreamedCompletionResponseConsumer wrappedConsumer = new LLMClient.StreamedCompletionResponseConsumer(){

                        @Override
                        public void onStreamStarted() throws Exception {
                            consumer.onStreamStarted();
                        }

                        @Override
                        public void onStreamChunk(LLMClient.StreamedCompletionResponseChunk chunk) throws Exception {
                            receivedOneChunk.set(true);
                            consumer.onStreamChunk(chunk);
                        }

                        @Override
                        public void onStreamComplete(LLMClient.StreamedCompletionResponseFooter footer) throws Exception {
                            footer.includeInUsageData(usageData, sw.elapsed(TimeUnit.MILLISECONDS));
                            consumer.onStreamComplete(footer);
                        }
                    };
                    CompletableFuture cf = this.futureCancellationTracker.track(() -> CustomPythonLLMKernelPool.this.manager.handle(client -> client.streamCompleteAsync(query, settings, wrappedConsumer), (Object)desc, desc.toHash(), (Object)query));
                    try {
                        Integer nChunks = (Integer)DKUCompletableFuture.collectResponse((CompletableFuture)cf);
                        logger.infoV("Fully received streamed completion response with %s chunks", new Object[]{nChunks});
                    }
                    catch (Exception e) {
                        if (e instanceof JavaBlockLink.RequestFailedException) {
                            JavaBlockLink.RequestFailedException rfe = (JavaBlockLink.RequestFailedException)e;
                            if (CustomPythonLLMKernelPool.RETRYABLE_EXCEPTION_PYTHON_PATH.equals(rfe.pythonExceptionPath)) {
                                if (receivedOneChunk.get()) {
                                    throw new RuntimeException("Throwing a RetryableException is not allowed once stream has started");
                                }
                                throw new CustomLLMClient.RetryableException(e);
                            }
                        }
                        throw e;
                    }
                    return null;
                });
            }

            @Override
            public List<LLMClient.SimpleEmbeddingResponse> embedBatch(List<LLMClient.EmbeddingQuery> queries, LLMClient.EmbeddingSettings settings) {
                try {
                    return queryRunner.run(() -> {
                        try {
                            return DKUCompletableFuture.collectResponsesNoException(queries.stream().map(query -> {
                                Stopwatch sw = Stopwatch.createStarted();
                                CompletableFuture future = this.futureCancellationTracker.track(() -> CustomPythonLLMKernelPool.this.manager.handle(kernel -> kernel.asyncEmbed((LLMClient.EmbeddingQuery)query, settings), (Object)desc, desc.toHash(), query));
                                future.thenAccept(response -> {
                                    if (response != null) {
                                        response.includeInUsageData(usageData);
                                        usageData.incrementTotalComputationTimeMS(Long.valueOf(sw.elapsed(TimeUnit.MILLISECONDS)));
                                    }
                                });
                                return future;
                            }).toList());
                        }
                        catch (Exception e) {
                            if (e instanceof JavaBlockLink.RequestFailedException) {
                                JavaBlockLink.RequestFailedException rfe = (JavaBlockLink.RequestFailedException)e;
                                if (CustomPythonLLMKernelPool.RETRYABLE_EXCEPTION_PYTHON_PATH.equals(rfe.pythonExceptionPath)) {
                                    throw new CustomLLMClient.RetryableException(e);
                                }
                            }
                            throw e;
                        }
                    });
                }
                catch (Exception e) {
                    throw new RuntimeException(e);
                }
            }

            @Override
            public LLMClient.ImageGenerationResponse generateImages(LLMClient.ImageGenerationQuery query) throws Exception {
                return queryRunner.run(() -> {
                    Stopwatch sw = Stopwatch.createStarted();
                    CompletableFuture future = this.futureCancellationTracker.track(() -> CustomPythonLLMKernelPool.this.manager.handle(kernel -> kernel.asyncGenerateImages(query), (Object)desc, desc.toHash(), (Object)desc, desc.toHash(), (Object)query));
                    future.thenAccept(response -> {
                        if (response != null) {
                            usageData.incrementTotalComputationTimeMS(Long.valueOf(sw.elapsed(TimeUnit.MILLISECONDS)));
                            usageData.incrementEstimatedCostUSD(Double.valueOf(response.estimatedCost));
                        }
                    });
                    try {
                        return (LLMClient.ImageGenerationResponse)DKUCompletableFuture.collectResponse((CompletableFuture)future);
                    }
                    catch (Exception e) {
                        if (e instanceof JavaBlockLink.RequestFailedException) {
                            JavaBlockLink.RequestFailedException rfe = (JavaBlockLink.RequestFailedException)e;
                            if (CustomPythonLLMKernelPool.RETRYABLE_EXCEPTION_PYTHON_PATH.equals(rfe.pythonExceptionPath)) {
                                throw new CustomLLMClient.RetryableException(e);
                            }
                        }
                        throw e;
                    }
                });
            }

            @Override
            public ComputeResourceUsage getTotalCRU(ComputeResourceUsage.LLMUsageType usageType, LLMStructuredRef llmRef) {
                ComputeResourceUsage cru = new ComputeResourceUsage();
                cru.setupLLMUsage(usageType, llmRef.connection, llmRef.type.toString(), llmRef.id);
                cru.llmUsage.setFromInternal(usageData);
                return cru;
            }

            @Override
            public List<LLMClient.ChatMessage> getFormattedPrompt(List<LLMClient.ChatMessage> chatMessages) {
                return chatMessages;
            }

            @Override
            public void close() throws Exception {
                this.futureCancellationTracker.cancelAll("LLM client closed");
            }
        };
    }

    public CustomPythonLLMKernelPool() {
        this.manager = new KernelPool((KernelPool.KernelController)new KernelPool.KernelController<CustomPythonLLMClient, KernelDesc, KernelDesc>(){

            @Nonnull
            public CustomPythonLLMClient createKernel(KernelDesc kernelDesc) {
                return new CustomPythonLLMClient(kernelDesc.authCtx, kernelDesc.projectKey, kernelDesc.capability, kernelDesc.settings, kernelDesc.loaded);
            }

            @Nonnull
            public CompletableFuture<Void> startKernel(CustomPythonLLMClient kernel, KernelDesc kernelDesc) {
                return DKUCompletableFuture.runAsync(kernel::init, (Executor)CustomPythonLLMKernelPool.this.executorService);
            }

            public int getGlobalMaxKernelCount() {
                return ApplicationConfigurator.getParams().getIntParam("dku.llm.custompython.globalMaxKernelCount", GLOBAL_MAX_KERNEL_COUNT);
            }

            public int getAutoscaleTimeWindowSeconds() {
                return ApplicationConfigurator.getParams().getIntParam("dku.llm.custompython.kernelTTLSeconds", KERNEL_TTL_SECONDS);
            }

            public int getHardMaxParallelRequests(KernelDesc kernelDesc) {
                return ApplicationConfigurator.getParams().getIntParam("dku.llm.custompython.maxParallelRequests", MAX_PARALLEL_REQUESTS);
            }

            public int getSoftMaxParallelRequests(KernelDesc kernelDesc) {
                return this.getHardMaxParallelRequests(kernelDesc);
            }

            @Nonnull
            public CompletableFuture<Void> killKernel(CustomPythonLLMClient kernel) {
                return DKUCompletableFuture.runAsync(kernel::close, (Executor)CustomPythonLLMKernelPool.this.executorService);
            }

            public boolean isAlive(CustomPythonLLMClient kernel) {
                return kernel.getKernel() != null && kernel.getKernel().isAlive();
            }

            public String getKernelId(CustomPythonLLMClient kernel) {
                return kernel.getKernelId();
            }

            public Integer getMaxKernelCount(KernelDesc kernelGroup) {
                return 1;
            }

            public boolean isOutdated(KernelDesc kernelDesc) {
                try {
                    String instanceCurrentlyLoaded = ((LoadedPythonLLM)CustomPythonLLMKernelPool.this.customPythonLLMsService.getLoadedDescByElementType(kernelDesc.loaded.getType())).getInstanceKey();
                    CustomLLMConnection newConnection = ConnectionsDAO.get().getMandatoryConnectionAs(kernelDesc.authCtx, kernelDesc.connection.name, CustomLLMConnection.class);
                    CustomLLMConnection.CustomLLMModel model = newConnection.getModel(kernelDesc.enrichedRef.model);
                    if (!kernelDesc.retrySettings.equals(model.getRetrySettings())) {
                        return true;
                    }
                    if (kernelDesc.maxParallelism != model.maxParallelism) {
                        return true;
                    }
                    if (!instanceCurrentlyLoaded.equals(kernelDesc.loaded.getInstanceKey())) {
                        return true;
                    }
                    int interval = ApplicationConfigurator.getParams().getIntParam("dku.llm.custompython.resolveSettingsInterval", Integer.valueOf(60));
                    if (CustomPythonLLMKernelPool.this.skipSettingsResolve % interval == 0) {
                        JsonObject modelConfig = model.customConfig;
                        PluginSettingsResolver.ResolvedSettings settings = CustomPythonLLMKernelPool.this.customPythonLLMsService.getExpandedPluginSettings(kernelDesc.loaded.getType(), kernelDesc.authCtx, kernelDesc.projectKey, modelConfig);
                        if (!settings.equals(kernelDesc.settings)) {
                            return true;
                        }
                    }
                    CustomPythonLLMKernelPool.this.skipSettingsResolve += 1 % interval;
                }
                catch (Exception e) {
                    logger.info((Object)String.format("[%s] Failed to assess outdatedness of kernel", kernelDesc.toHash()), (Throwable)e);
                    return true;
                }
                return false;
            }
        }, new KernelScalingStrategyBuilder(), new KernelPoolThreadFactory("custompython"), logger);
    }

    public KernelPool.PoolDump dump(boolean full) {
        return this.manager.dump(full);
    }

    public void killAllKernels(KernelPool.DeathReason deathReason) {
        this.manager.killAllKernels(deathReason);
    }

    public void killAllRequests() {
        this.manager.killAllRequests();
    }

    public static class KernelDesc {
        AuthCtx authCtx;
        CustomLLMConnection connection;
        String projectKey;
        CustomLLMConnection.Capability capability;
        PluginSettingsResolver.ResolvedSettings settings;
        LoadedPythonLLM loaded;
        EnrichedLLMStructuredRef enrichedRef;
        CustomLLMClient.RateLimitingRetrySettings retrySettings;
        int maxParallelism;

        public String toHash() {
            return DigestUtils.sha256Hex((String)StringUtils.join((Object[])new String[]{JSON.json((Object)this.enrichedRef), JSON.json((Object)((Object)this.capability)), JSON.json((Object)this.settings), JSON.json((Object)this.retrySettings), JSON.json((Object)this.maxParallelism), this.loaded.getInstanceKey()})).substring(0, 10);
        }
    }
}

