/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.llm.online.nvidia;

import com.dataiku.dip.connections.ConnectionUtils;
import com.dataiku.dip.connections.NvidiaNimConnection;
import com.dataiku.dip.llm.LLMModelHandle;
import com.dataiku.dip.llm.LLMStructuredRef;
import com.dataiku.dip.llm.online.AbstractLLMClient;
import com.dataiku.dip.llm.online.LLMChatMessageUtils;
import com.dataiku.dip.llm.online.LLMClient;
import com.dataiku.dip.llm.online.LLMQueryRunner;
import com.dataiku.dip.llm.online.marshall.CoreCompletionSettings;
import com.dataiku.dip.llm.online.openai.OpenAIClient;
import com.dataiku.dip.llm.online.openai.RawOpenAIClient;
import com.dataiku.dip.resourceusage.ComputeResourceUsage;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.utils.ExceptionUtils;
import com.dataiku.dip.utils.JSON;
import com.dataiku.dip.utils.NotImplementedException;
import com.dataiku.dip.utils.Params;
import com.google.common.base.Stopwatch;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import javax.annotation.Nullable;

public class NvidiaNimClient
extends AbstractLLMClient
implements LLMClient {
    private final NvidiaNimConnection connection;
    private final LLMQueryRunner queryRunner;
    private final RawOpenAIClient rawOpenAIClient;
    private final NvidiaNimConnection.NvidiaNimModel model;
    private final ComputeResourceUsage.InternalLLMUsageData usageData = new ComputeResourceUsage.LLMUsageData();

    public NvidiaNimClient(NvidiaNimConnection nimConnection, LLMModelHandle<NvidiaNimConnection.NvidiaNimModel> modelHandle, @Nullable String projectKey, AuthCtx authCtx) {
        super(modelHandle.getEnrichedRef());
        this.connection = nimConnection;
        this.queryRunner = new LLMQueryRunner(this.getProviderId(), modelHandle, this.connection.params.networkSettings, OpenAIClient::isRetryableException);
        this.model = modelHandle.getModel();
        Optional<String> invalidityReason = this.model.getInvalidityReason();
        if (invalidityReason.isPresent()) {
            throw new IllegalArgumentException("Model parameters are invalid: " + invalidityReason.get());
        }
        Params connectionProperties = ConnectionUtils.getParamsFromProperties(this.connection.getDkuProperties());
        boolean trustAllSSLCertificates = connectionProperties.getBoolParam("dku.connection.llm.trustAllSSLCertificates", false);
        boolean forceContentLength = connectionProperties.getBoolParam("dku.connection.llm.forceContentLength", false);
        this.rawOpenAIClient = RawOpenAIClient.forNvidiaNim(this.model.url, this.connection.params.apiKey, this.connection.params.customHeaders, projectKey, this.queryRunner.getHttpClientNetworkSettings(), this.connection.getProxySettings(), trustAllSSLCertificates, forceContentLength, authCtx, nimConnection);
    }

    @Override
    public boolean supportsStream() {
        return true;
    }

    @Override
    public boolean supportNativeBatch() {
        return false;
    }

    @Override
    public boolean requiresCostLimiting() {
        return true;
    }

    @Override
    public String getProviderId() {
        return "NVIDIA-NIM";
    }

    @Override
    public NvidiaNimConnection getConnection() {
        return this.connection;
    }

    @Override
    public int getMaxParallelism() {
        return this.connection.params.maxParallelism;
    }

    @Override
    public List<LLMClient.SimpleCompletionResponse> completeBatch(List<LLMClient.SingleCompletionQuery> queries, LLMClient.CompletionSettings settings) throws Exception {
        ArrayList<LLMClient.SimpleCompletionResponse> ret = new ArrayList<LLMClient.SimpleCompletionResponse>();
        CoreCompletionSettings ccs = this.getCoreCompletionSettings(settings);
        for (LLMClient.SingleCompletionQuery query : queries) {
            logger.info((Object)("NVIDIA NIM single complete query: " + JSON.log((Object)query.getSafeForLoggingCopy())));
            long start = System.currentTimeMillis();
            LLMClient.SimpleCompletionResponse scr = this.queryRunner.run(() -> {
                LLMClient.SimpleCompletionResponse response;
                block4: {
                    List<LLMClient.ChatMessage> chatMessages = this.getFormattedPrompt(query.messages);
                    try {
                        if (this.model.api == NvidiaNimConnection.NimApi.OPENAI_V1_CHAT_COMPLETIONS) {
                            response = this.rawOpenAIClient.chatComplete(this.model.getId(), chatMessages, ccs, this.model);
                            break block4;
                        }
                        if (this.model.api == NvidiaNimConnection.NimApi.OPENAI_V1_RESPONSES) {
                            response = this.rawOpenAIClient.completeResponsesAPI(this.model.getId(), chatMessages, ccs, this.model);
                            break block4;
                        }
                        throw new NotImplementedException("API not implemented: " + String.valueOf((Object)this.model.api));
                    }
                    catch (LLMClient.LLMException e) {
                        e.includeInUsageData(this.usageData, System.currentTimeMillis() - start);
                        throw e;
                    }
                }
                response.estimatedCost = this.model.getEstimatedCompletionCost(response.promptTokens, response.completionTokens);
                return response;
            });
            scr.includeInUsageData(this.usageData, System.currentTimeMillis() - start);
            ret.add(scr);
        }
        return ret;
    }

    @Override
    public void streamComplete(LLMClient.SingleCompletionQuery query, LLMClient.CompletionSettings settings, LLMClient.StreamedCompletionResponseConsumer consumer) throws Exception {
        CoreCompletionSettings ccs = this.getCoreCompletionSettings(settings);
        logger.trace(() -> "NVIDIA NIM streamed complete: " + JSON.log((Object)query.getSafeForLoggingCopy()));
        Stopwatch stopwatch = Stopwatch.createStarted();
        LLMClient.StreamedCompletionResponseConsumerProxy wrappedConsumer = new LLMClient.StreamedCompletionResponseConsumerProxy(consumer, (ExceptionUtils.ThrowingConsumer<LLMClient.StreamedCompletionResponseFooter, Exception>)((ExceptionUtils.ThrowingConsumer)footer -> footer.includeInUsageData(this.usageData, stopwatch.elapsed(TimeUnit.MILLISECONDS))));
        List<LLMClient.ChatMessage> chatMessages = this.getFormattedPrompt(query.messages);
        this.queryRunner.run(() -> {
            block4: {
                try {
                    if (this.model.api == NvidiaNimConnection.NimApi.OPENAI_V1_CHAT_COMPLETIONS) {
                        this.rawOpenAIClient.streamChatComplete(wrappedConsumer, this.model, chatMessages, ccs);
                        break block4;
                    }
                    if (this.model.api == NvidiaNimConnection.NimApi.OPENAI_V1_RESPONSES) {
                        this.rawOpenAIClient.streamResponsesAPI(wrappedConsumer, this.model, chatMessages, ccs);
                        break block4;
                    }
                    throw new IllegalArgumentException("Stream API not implemented: " + String.valueOf((Object)this.model.api));
                }
                catch (LLMClient.LLMException e) {
                    e.includeInUsageData(this.usageData, stopwatch.elapsed(TimeUnit.MILLISECONDS));
                    throw e;
                }
            }
            return null;
        });
    }

    @Override
    public List<LLMClient.SimpleEmbeddingResponse> embedBatch(List<LLMClient.EmbeddingQuery> queries, LLMClient.EmbeddingSettings settings) throws Exception {
        if (LLMClient.TextOverflowMode.TRUNCATE.equals((Object)settings.textOverflowMode)) {
            logger.warn((Object)"Truncation for long texts overflow is not supported yet for OpenAI API (under NVIDIA NIM), defaulting to Failure mode");
        }
        ArrayList<LLMClient.SimpleEmbeddingResponse> ret = new ArrayList<LLMClient.SimpleEmbeddingResponse>();
        long start = System.currentTimeMillis();
        List batchTexts = queries.stream().map(query -> query.text).collect(Collectors.toList());
        logger.info((Object)("NVIDIA NIM Embed sending the following batch : " + JSON.log(batchTexts)));
        List embeddingResponses = this.queryRunner.run(() -> this.rawOpenAIClient.embed(this.model.id, batchTexts));
        this.usageData.incrementTotalComputationTimeMS(Long.valueOf(System.currentTimeMillis() - start));
        for (LLMClient.SimpleEmbeddingResponse scr : embeddingResponses) {
            scr.estimatedCost = this.model.getEstimatedEmbeddingCost(scr.promptTokens, 0);
            scr.includeInUsageData(this.usageData);
            ret.add(scr);
        }
        return ret;
    }

    @Override
    public List<LLMClient.SingleRerankingResponse> rerankBatch(List<LLMClient.RerankingQuery> queries, LLMClient.RerankingSettings settings) throws Exception {
        throw new IllegalArgumentException("Rerankings not supported on this LLM");
    }

    @Override
    public ComputeResourceUsage getTotalCRU(ComputeResourceUsage.LLMUsageType usageType, LLMStructuredRef llmRef) {
        ComputeResourceUsage cru = new ComputeResourceUsage();
        cru.setupLLMUsage(usageType, llmRef.connection, llmRef.type.toString(), llmRef.id);
        cru.llmUsage.setFromInternal(this.usageData);
        return cru;
    }

    @Override
    public List<LLMClient.ChatMessage> getFormattedPrompt(List<LLMClient.ChatMessage> chatMessages) {
        LLMChatMessageUtils.throwIfUnsupportedToolOutputParts(chatMessages);
        if (!this.model.supportsSystemPrompts) {
            chatMessages = LLMChatMessageUtils.convertMessageRole(chatMessages, "system", "user");
        }
        if (!this.model.supportsImageInputs && chatMessages.stream().anyMatch(chatMessage -> !chatMessage.isTextOnly())) {
            throw new UnsupportedOperationException(String.format("The model '%s' with id '%s' does not support image inputs.", this.model.getDisplayName(), this.model.getId()));
        }
        return chatMessages;
    }

    @Override
    public void close() throws Exception {
        this.rawOpenAIClient.close();
    }
}

