/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.llm.online.vertex;

import com.dataiku.dip.connections.AbstractLLMConnection;
import com.dataiku.dip.connections.VertexAILLMConnection;
import com.dataiku.dip.llm.LLMModelHandle;
import com.dataiku.dip.llm.LLMStructuredRef;
import com.dataiku.dip.llm.online.AbstractLLMClient;
import com.dataiku.dip.llm.online.LLMChatMessageUtils;
import com.dataiku.dip.llm.online.LLMClient;
import com.dataiku.dip.llm.online.LLMQueryRunner;
import com.dataiku.dip.llm.online.marshall.CoreCompletionSettings;
import com.dataiku.dip.llm.online.vertex.RawVertexClient;
import com.dataiku.dip.llm.utils.OnlineLLMUtils;
import com.dataiku.dip.resourceusage.ComputeResourceUsage;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.ExceptionUtils;
import com.dataiku.dip.utils.JSON;
import com.dataiku.dss.shadelibgcp.com.google.api.client.googleapis.auth.oauth2.GoogleCredential;
import com.google.common.base.Stopwatch;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.TimeUnit;

public class VertexClient
extends AbstractLLMClient
implements LLMClient {
    private final VertexAILLMConnection connection;
    private final LLMQueryRunner queryRunner;
    private final RawVertexClient raw;
    private final VertexAILLMConnection.VertexModel model;
    private final AuthCtx authCtx;
    private ComputeResourceUsage.InternalLLMUsageData usageData = new ComputeResourceUsage.LLMUsageData();
    private static DKULogger logger = DKULogger.getLogger((String)"dku.llm.vertex");

    public VertexClient(AuthCtx authCtx, VertexAILLMConnection connection, GoogleCredential credential, String gcpProject, String region, LLMModelHandle<VertexAILLMConnection.VertexModel> modelHandle, boolean useGlobalEndpoint) {
        super(modelHandle.getEnrichedRef());
        this.authCtx = authCtx;
        this.queryRunner = new LLMQueryRunner(this.getProviderId(), modelHandle, connection.params.networkSettings, VertexClient::isRetryableException);
        this.raw = new RawVertexClient(credential, gcpProject, region, connection.getProxySettings(), this.queryRunner.getHttpClientNetworkSettings(), useGlobalEndpoint);
        this.model = modelHandle.getModel();
        this.connection = connection;
    }

    public static boolean isRetryableException(Throwable t) {
        if (t instanceof RawVertexClient.VertexResponseException) {
            RawVertexClient.VertexResponseException exception = (RawVertexClient.VertexResponseException)t;
            return OnlineLLMUtils.retryRequired(exception.statusCode);
        }
        return t instanceof IOException;
    }

    @Override
    public void close() {
        this.raw.close();
    }

    @Override
    public boolean supportNativeBatch() {
        return false;
    }

    @Override
    public boolean requiresCostLimiting() {
        return true;
    }

    @Override
    public String getProviderId() {
        return "VertexAILLM";
    }

    @Override
    public AbstractLLMConnection getConnection() {
        return this.connection;
    }

    @Override
    public int getMaxParallelism() {
        return this.connection.params.maxParallelism;
    }

    @Override
    public List<LLMClient.ChatMessage> getFormattedPrompt(List<LLMClient.ChatMessage> chatMessages) {
        chatMessages = LLMChatMessageUtils.addToolOutputPartsAsMultiPartUserMessages(this.authCtx, chatMessages, this.model.getModelCapabilities().supportsImageInputs);
        chatMessages = LLMChatMessageUtils.convertExtraSystemMessageToUser(chatMessages);
        return LLMChatMessageUtils.collapseAdjacentSameRoleMessages(chatMessages);
    }

    @Override
    public List<LLMClient.SimpleCompletionResponse> completeBatch(List<LLMClient.SingleCompletionQuery> queries, LLMClient.CompletionSettings settings) throws Exception {
        ArrayList<LLMClient.SimpleCompletionResponse> ret = new ArrayList<LLMClient.SimpleCompletionResponse>();
        CoreCompletionSettings ccs = this.getCoreCompletionSettings(settings);
        for (LLMClient.SingleCompletionQuery query : queries) {
            logger.info((Object)("Vertex single complete query: " + JSON.log((Object)query.getSafeForLoggingCopy())));
            long before = System.currentTimeMillis();
            List<LLMClient.ChatMessage> chatMessages = this.getFormattedPrompt(query.messages);
            LLMClient.SimpleCompletionResponse scr = this.queryRunner.run(() -> this.raw.chatComplete(this.model.getId(), chatMessages, ccs));
            scr.estimatedCost = this.model.getEstimatedCompletionCost(scr.promptTokens, scr.completionTokens);
            scr.includeInUsageData(this.usageData, System.currentTimeMillis() - before);
            ret.add(scr);
        }
        return ret;
    }

    @Override
    public boolean supportsStream() {
        return true;
    }

    @Override
    public void streamComplete(LLMClient.SingleCompletionQuery query, LLMClient.CompletionSettings settings, LLMClient.StreamedCompletionResponseConsumer consumer) throws Exception {
        CoreCompletionSettings ccs = this.getCoreCompletionSettings(settings);
        logger.trace(() -> "Vertex stream complete: " + JSON.log((Object)query.getSafeForLoggingCopy()));
        Stopwatch stopwatch = Stopwatch.createStarted();
        LLMClient.StreamedCompletionResponseConsumerProxy wrappedConsumer = new LLMClient.StreamedCompletionResponseConsumerProxy(consumer, (ExceptionUtils.ThrowingConsumer<LLMClient.StreamedCompletionResponseFooter, Exception>)((ExceptionUtils.ThrowingConsumer)footer -> footer.includeInUsageData(this.usageData, stopwatch.elapsed(TimeUnit.MILLISECONDS))));
        List<LLMClient.ChatMessage> chatMessages = this.getFormattedPrompt(query.messages);
        this.queryRunner.run(() -> {
            try {
                this.raw.streamChatComplete(wrappedConsumer, this.model, chatMessages, ccs);
            }
            catch (LLMClient.LLMException e) {
                e.includeInUsageData(this.usageData, stopwatch.elapsed(TimeUnit.MILLISECONDS));
                throw e;
            }
            return null;
        });
    }

    @Override
    public List<LLMClient.SimpleEmbeddingResponse> embedBatch(List<LLMClient.EmbeddingQuery> queries, LLMClient.EmbeddingSettings settings) throws Exception {
        if (LLMClient.TextOverflowMode.TRUNCATE.equals((Object)settings.textOverflowMode)) {
            logger.warn((Object)"Truncation for long texts overflow is not supported for VertexAI, defaulting to Failure mode");
        }
        ArrayList<LLMClient.SimpleEmbeddingResponse> ret = new ArrayList<LLMClient.SimpleEmbeddingResponse>();
        for (LLMClient.EmbeddingQuery query : queries) {
            Stopwatch stopwatch = Stopwatch.createStarted();
            logger.info((Object)("Vertex Embed: " + JSON.log((Object)query.getSafeForLoggingCopy())));
            LLMClient.SimpleEmbeddingResponse scr = this.queryRunner.run(() -> this.raw.embed(this.model, query));
            scr.estimatedCost = this.model.getVertexEstimatedEmbeddingCost(query);
            this.usageData.incrementTotalComputationTimeMS(Long.valueOf(stopwatch.elapsed(TimeUnit.MILLISECONDS)));
            scr.includeInUsageData(this.usageData);
            ret.add(scr);
        }
        return ret;
    }

    @Override
    public List<LLMClient.SingleRerankingResponse> rerankBatch(List<LLMClient.RerankingQuery> queries, LLMClient.RerankingSettings settings) throws Exception {
        throw new IllegalArgumentException("Reranking not supported on this LLM");
    }

    @Override
    public LLMClient.ImageGenerationResponse generateImages(LLMClient.ImageGenerationQuery query) throws Exception {
        long before = System.currentTimeMillis();
        LLMClient.ImageGenerationResponse resp = this.queryRunner.run(() -> this.raw.imagenGenerateImage(this.model.getId(), query));
        long computationTimeMS = System.currentTimeMillis() - before;
        resp.estimatedCost = this.model.getEstimatedImageGenerationCost(query);
        this.usageData.incrementTotalComputationTimeMS(Long.valueOf(computationTimeMS));
        this.usageData.incrementEstimatedCostUSD(Double.valueOf(resp.estimatedCost));
        return resp;
    }

    @Override
    public ComputeResourceUsage getTotalCRU(ComputeResourceUsage.LLMUsageType usageType, LLMStructuredRef llmRef) {
        ComputeResourceUsage cru = new ComputeResourceUsage();
        cru.setupLLMUsage(usageType, llmRef.connection, llmRef.type.toString(), llmRef.id);
        cru.llmUsage.setFromInternal(this.usageData);
        return cru;
    }
}

