/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.llm.online;

import com.dataiku.dip.connections.AbstractLLMConnection;
import com.dataiku.dip.dataflow.jobrunner.JobContext;
import com.dataiku.dip.llm.EnrichedLLMStructuredRef;
import com.dataiku.dip.llm.LLMStructuredRef;
import com.dataiku.dip.llm.governance.GuardrailsPipelineSettings;
import com.dataiku.dip.llm.online.CompletionRecipeLLMMeshClient;
import com.dataiku.dip.llm.online.LLMClient;
import com.dataiku.dip.llm.online.LLMClientFactory;
import com.dataiku.dip.llm.online.ParallelLLMClient;
import com.dataiku.dip.llm.online.utils.BatchBufferProcessor;
import com.dataiku.dip.resourceusage.ComputeResourceUsage;
import com.dataiku.dip.rpc.TicketBasedIntercomAPIClient;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.security.tickets.APITicketService;
import com.dataiku.dip.server.SpringUtils;
import com.dataiku.dip.util.AnyLoc;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.JSON;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.stream.Collectors;

public class RedirectCompletionRecipeLLMMeshClient
implements CompletionRecipeLLMMeshClient {
    private final int parallelism;
    private final int batchSize;
    private final AbstractLLMConnection connection;
    private final EnrichedLLMStructuredRef llmRef;
    private final ExecutorService llmPool;
    private final CompletionRedirectClient redirectClient;
    private static final DKULogger logger = DKULogger.getLogger((String)"dku.llm.online.parallel.recipe");

    public RedirectCompletionRecipeLLMMeshClient(AuthCtx authCtx, LLMStructuredRef llmRef, String contextProjectKey, AnyLoc usedDataset, JobContext jobContext, GuardrailsPipelineSettings usageTimeGuardrails) throws Exception {
        try (LLMClient dummyLlmClient = LLMClientFactory.get(authCtx, contextProjectKey, llmRef);){
            this.parallelism = Math.max(1, dummyLlmClient.getMaxParallelism());
            this.batchSize = Math.max(1, dummyLlmClient.getBatchSize(AbstractLLMConnection.QueryType.completion, llmRef));
            this.connection = dummyLlmClient.getConnection();
            this.llmRef = dummyLlmClient.getEnrichedRef();
        }
        logger.info((Object)String.format("Using parallelism=%s and batchSize=%s", this.parallelism, this.batchSize));
        this.llmPool = Executors.newFixedThreadPool(this.parallelism);
        this.redirectClient = new CompletionRedirectClient(contextProjectKey, llmRef, usedDataset, usageTimeGuardrails, this.parallelism, jobContext);
    }

    @Override
    public CompletionRecipeLLMMeshClient.CompletionsStreamer completeQueriesAsyncStream(LLMClient.CompletionSettings settings) {
        logger.info((Object)String.format("Using parallelism=%s and batchSize=%s", this.parallelism, this.batchSize));
        final BatchBufferProcessor<ParallelLLMClient.SingleCompletionQueryWithTrace, LLMClient.SimpleCompletionResponseOrError> queryProcessor = this.completeQueriesAsyncBatcher(settings, this.batchSize);
        int queueSize = this.batchSize * (this.parallelism + 1);
        return new CompletionRecipeLLMMeshClient.CompletionsStreamer(queryProcessor, queueSize, queueSize){

            @Override
            public CompletableFuture<Void> done() {
                queryProcessor.flush();
                return super.done();
            }
        };
    }

    private BatchBufferProcessor<ParallelLLMClient.SingleCompletionQueryWithTrace, LLMClient.SimpleCompletionResponseOrError> completeQueriesAsyncBatcher(LLMClient.CompletionSettings settings, int batchSize) {
        return new BatchBufferProcessor<ParallelLLMClient.SingleCompletionQueryWithTrace, LLMClient.SimpleCompletionResponseOrError>(batchSize, this.llmPool, queriesWithTrace -> {
            List<LLMClient.SingleCompletionQuery> queries = queriesWithTrace.stream().map(q -> q.query).collect(Collectors.toList());
            try {
                return this.redirectClient.completeBatch(queries, settings);
            }
            catch (Exception e) {
                logger.warn((Object)"Got fatal error from backend while retrieving batch completion", (Throwable)e);
                return Collections.nCopies(queries.size(), LLMClient.SimpleCompletionResponseOrError.fromError(e));
            }
        });
    }

    @Override
    public void close() throws Exception {
        this.redirectClient.close();
    }

    @Override
    public ComputeResourceUsage getTotalCRU(ComputeResourceUsage.LLMUsageType usageType) {
        return this.redirectClient.getTotalCRU(usageType);
    }

    @Override
    public EnrichedLLMStructuredRef getEnrichedRef() {
        return this.llmRef;
    }

    @Override
    public AbstractLLMConnection getConnection() {
        return this.connection;
    }

    private static class CompletionRedirectClient
    implements AutoCloseable {
        private final APITicketService.TicketUsage tu;
        private final TicketBasedIntercomAPIClient apiClient;
        private final ComputeResourceUsage.LLMUsageData usageData = new ComputeResourceUsage.LLMUsageData();
        private final String projectKey;
        private final LLMStructuredRef enrichedRef;
        private final AnyLoc usedDataset;
        private final JobContext jobContext;
        private final GuardrailsPipelineSettings usageTimeGuardrails;

        public CompletionRedirectClient(String projectKey, LLMStructuredRef enrichedRef, AnyLoc usedDataset, GuardrailsPipelineSettings usageTimeGuardrails, int parallelism, JobContext jobContext) {
            APITicketService ticketService = (APITicketService)SpringUtils.getBean(APITicketService.class);
            this.tu = ticketService.getAndUseSingleTicket();
            this.enrichedRef = enrichedRef;
            this.projectKey = projectKey;
            this.usedDataset = usedDataset;
            this.usageTimeGuardrails = usageTimeGuardrails;
            this.apiClient = TicketBasedIntercomAPIClient.forLocalHost(this.tu.getTicket().getSecret());
            this.jobContext = jobContext;
            this.apiClient.setMaxTotalConnections(parallelism);
            this.apiClient.setDefaultMaxConnectionsPerRoute(parallelism);
        }

        @Override
        public void close() throws Exception {
            this.tu.close();
            this.apiClient.close();
        }

        public List<LLMClient.SimpleCompletionResponseOrError> completeBatch(List<LLMClient.SingleCompletionQuery> queries, LLMClient.CompletionSettings settings) throws IOException {
            long before = System.currentTimeMillis();
            CompletionsResponse completionResponse = (CompletionsResponse)this.apiClient.postFormToJSON("/dip/api/tintercom/llms/completions", CompletionsResponse.class, new Object[]{"projectKey", this.projectKey, "jobId", this.jobContext.jobId, "activity", this.jobContext.activity, "llmId", this.enrichedRef.id, "usedDatasetSmartName", this.usedDataset.getSmartName(this.projectKey), "queries", JSON.json(queries), "settings", JSON.json((Object)settings), "usageTimeGuardrails", JSON.json((Object)this.usageTimeGuardrails)});
            long computationTime = System.currentTimeMillis() - before;
            this.usageData.incrementTotalComputationTimeMS(Long.valueOf(computationTime));
            for (LLMClient.SimpleCompletionResponseOrError scroe : completionResponse.responses) {
                this.usageData.incrementTotalPromptTokens(scroe.promptTokens);
                this.usageData.incrementTotalCompletionTokens(scroe.completionTokens);
                this.usageData.incrementEstimatedCostUSD(scroe.estimatedCost);
                this.usageData.incrementTotalQueries();
                if (scroe.fromCache) {
                    this.usageData.incrementCacheHitQueries();
                    continue;
                }
                this.usageData.incrementCacheMissQueries();
            }
            return new ArrayList<LLMClient.SimpleCompletionResponseOrError>(completionResponse.responses);
        }

        public ComputeResourceUsage getTotalCRU(ComputeResourceUsage.LLMUsageType usageType) {
            ComputeResourceUsage cru = new ComputeResourceUsage();
            cru.setupLLMUsage(usageType, this.enrichedRef.connection, this.enrichedRef.type.toString(), this.enrichedRef.id);
            cru.llmUsage.setFromInternal((ComputeResourceUsage.InternalLLMUsageData)this.usageData);
            cru.llmUsage.cacheMissQueries = this.usageData.getCacheMissQueries();
            cru.llmUsage.cacheHitQueries = this.usageData.getCacheHitQueries();
            cru.llmUsage.totalQueries = this.usageData.getTotalQueries();
            return cru;
        }

        private static class CompletionsResponse {
            List<LLMClient.SimpleCompletionResponseOrError> responses;

            private CompletionsResponse() {
            }
        }
    }
}

