/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.llm.online;

import com.dataiku.dip.connections.AbstractLLMConnection;
import com.dataiku.dip.exceptions.DKUSecurityException;
import com.dataiku.dip.externalinfras.azureml.AzureMLUtils;
import com.dataiku.dip.llm.EnrichedLLMStructuredRef;
import com.dataiku.dip.llm.LLMStructuredRef;
import com.dataiku.dip.llm.online.ISavedModelDeployer;
import com.dataiku.dip.llm.online.LLMClient;
import com.dataiku.dip.llm.online.LLMCostLimitingService;
import com.dataiku.dip.llm.online.RemoteFineTuningClient;
import com.dataiku.dip.resourceusage.ComputeResourceUsage;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.server.SpringUtils;
import com.dataiku.dip.utils.SmartLogTail;
import com.dataiku.dss.shadelib.com.google.common.base.MoreObjects;
import java.io.IOException;
import java.util.List;
import java.util.concurrent.Callable;
import javax.annotation.Nullable;
import org.springframework.beans.factory.annotation.Autowired;

public class LLMClientCostLimitingWrapper
implements LLMClient {
    @Autowired
    private LLMCostLimitingService costLimitingService;
    private final LLMClient llmClient;
    private final LLMCostLimitingService.LLMCostLimitingContext context;

    public LLMClientCostLimitingWrapper(LLMClient wrappedClient, AuthCtx authCtx, String projectKey, LLMStructuredRef identifier) {
        assert (!(wrappedClient instanceof LLMClientCostLimitingWrapper));
        SpringUtils.getInstance().autowire((Object)this);
        this.llmClient = wrappedClient;
        this.context = new LLMCostLimitingService.LLMCostLimitingContext();
        this.context.projectKey = projectKey;
        this.context.user = authCtx != null ? authCtx.getAssociatedDSSUser() : null;
        this.context.llmId = identifier.id;
        this.context.connectionName = identifier.connection;
        this.context.provider = wrappedClient.getProviderId();
    }

    @Override
    public boolean supportNativeBatch() {
        return this.llmClient.supportNativeBatch();
    }

    @Override
    public boolean requiresCostLimiting() {
        return false;
    }

    @Override
    public String getProviderId() {
        return null;
    }

    @Override
    public AbstractLLMConnection getConnection() {
        return this.llmClient.getConnection();
    }

    @Override
    public int getMaxParallelism() {
        return this.llmClient.getMaxParallelism();
    }

    @Override
    public int getBatchSize(AbstractLLMConnection.QueryType queryType, LLMStructuredRef llmRef) {
        return this.llmClient.getBatchSize(queryType, llmRef);
    }

    @Override
    public List<LLMClient.SimpleCompletionResponse> completeBatch(List<LLMClient.SingleCompletionQuery> queries, LLMClient.CompletionSettings settings) throws Exception {
        this.costLimitingService.checkQuery(this.context);
        return this.handleQueryProcessingError(queries.size(), () -> {
            List<LLMClient.SimpleCompletionResponse> responses = this.llmClient.completeBatch(queries, settings);
            double estimatedCost = responses.stream().mapToDouble(r -> LLMClientCostLimitingWrapper.safeCost(r.estimatedCost)).sum();
            this.costLimitingService.reportCost(this.context, estimatedCost, responses.size());
            return responses;
        });
    }

    @Override
    public List<LLMClient.SimpleEmbeddingResponse> embedBatch(List<LLMClient.EmbeddingQuery> queries, LLMClient.EmbeddingSettings settings) throws Exception {
        this.costLimitingService.checkQuery(this.context);
        return this.handleQueryProcessingError(queries.size(), () -> {
            List<LLMClient.SimpleEmbeddingResponse> responses = this.llmClient.embedBatch(queries, settings);
            double cost = responses.stream().mapToDouble(r -> LLMClientCostLimitingWrapper.safeCost(r.estimatedCost)).sum();
            this.costLimitingService.reportCost(this.context, cost, responses.size());
            return responses;
        });
    }

    @Override
    public boolean supportsStream() {
        return this.llmClient.supportsStream();
    }

    @Override
    public void streamComplete(LLMClient.SingleCompletionQuery query, LLMClient.CompletionSettings settings, final LLMClient.StreamedCompletionResponseConsumer consumer) throws Exception {
        this.costLimitingService.checkQuery(this.context);
        this.handleQueryProcessingError(1, () -> {
            this.llmClient.streamComplete(query, settings, new LLMClient.StreamedCompletionResponseConsumer(){

                @Override
                public void onStreamStarted() throws Exception {
                    consumer.onStreamStarted();
                }

                @Override
                public void onStreamChunk(LLMClient.StreamedCompletionResponseChunk chunk) throws Exception {
                    consumer.onStreamChunk(chunk);
                }

                @Override
                public void onStreamComplete(LLMClient.StreamedCompletionResponseFooter footer) throws Exception {
                    LLMClientCostLimitingWrapper.this.costLimitingService.reportCost(LLMClientCostLimitingWrapper.this.context, LLMClientCostLimitingWrapper.safeCost(footer.estimatedCost), 1);
                    consumer.onStreamComplete(footer);
                }
            });
            return null;
        });
    }

    @Override
    public LLMClient.ImageGenerationResponse generateImages(LLMClient.ImageGenerationQuery query) throws Exception {
        this.costLimitingService.checkQuery(this.context);
        return this.handleQueryProcessingError(1, () -> {
            LLMClient.ImageGenerationResponse response = this.llmClient.generateImages(query);
            this.costLimitingService.reportCost(this.context, LLMClientCostLimitingWrapper.safeCost(response.estimatedCost), 1);
            return response;
        });
    }

    @Override
    public ComputeResourceUsage getTotalCRU(ComputeResourceUsage.LLMUsageType usageType, LLMStructuredRef llmRef) {
        return this.llmClient.getTotalCRU(usageType, llmRef);
    }

    @Override
    public EnrichedLLMStructuredRef getEnrichedRef() throws Exception {
        return this.llmClient.getEnrichedRef();
    }

    @Override
    public RemoteFineTuningClient newFineTuningClient() throws UnsupportedOperationException {
        return this.llmClient.newFineTuningClient();
    }

    @Override
    public ISavedModelDeployer newSavedModelDeployer(AuthCtx authCtx) throws UnsupportedOperationException, AzureMLUtils.AzureAuthenticationException, IOException, DKUSecurityException {
        return this.llmClient.newSavedModelDeployer(authCtx);
    }

    @Override
    public List<LLMClient.ChatMessage> getFormattedPrompt(List<LLMClient.ChatMessage> chatMessages) {
        return this.llmClient.getFormattedPrompt(chatMessages);
    }

    @Override
    public void setDevMode(boolean devMode) {
        this.llmClient.setDevMode(devMode);
    }

    @Override
    public SmartLogTail getKernelLog() throws Exception {
        return this.llmClient.getKernelLog();
    }

    @Override
    public void close() throws Exception {
        this.llmClient.close();
    }

    public String toString() {
        return "LLMClientCostLimitingWrapper{llmClient=" + String.valueOf(this.llmClient) + "}";
    }

    private static double safeCost(@Nullable Double cost) {
        return (Double)MoreObjects.firstNonNull((Object)cost, (Object)0.0);
    }

    private <R> R handleQueryProcessingError(int nbQueries, Callable<R> processingCall) throws Exception {
        try {
            return processingCall.call();
        }
        catch (LLMCostLimitingService.LLMCostLimitingReportingException e) {
            logger.error((Object)"Cost limiting reporting error", (Throwable)e);
            throw e;
        }
        catch (Exception exception) {
            double errorCost = 0.0;
            if (exception instanceof LLMClient.LLMException) {
                LLMClient.LLMException llmException = (LLMClient.LLMException)exception;
                if (llmException.estimatedCost != null && llmException.estimatedCost > 0.0) {
                    errorCost = llmException.estimatedCost;
                }
            }
            this.costLimitingService.reportCost(this.context, errorCost, nbQueries);
            throw exception;
        }
    }
}

