/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.llm.online.openai;

import com.dataiku.dip.connections.AbstractLLMConnection;
import com.dataiku.dip.exceptions.DKUSecurityException;
import com.dataiku.dip.llm.EnrichedLLMStructuredRef;
import com.dataiku.dip.llm.LLMModelHandle;
import com.dataiku.dip.llm.LLMStructuredRef;
import com.dataiku.dip.llm.online.AbstractLLMClient;
import com.dataiku.dip.llm.online.LLMClient;
import com.dataiku.dip.llm.online.LLMQueryRunner;
import com.dataiku.dip.llm.online.marshall.CoreCompletionSettings;
import com.dataiku.dip.llm.online.openai.OpenAIImageHandling;
import com.dataiku.dip.llm.online.openai.OpenAIPricing;
import com.dataiku.dip.llm.online.openai.RawOpenAIClient;
import com.dataiku.dip.llm.utils.ImageGenerationUtils;
import com.dataiku.dip.llm.utils.OnlineLLMUtils;
import com.dataiku.dip.resourceusage.ComputeResourceUsage;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.ExceptionUtils;
import com.dataiku.dip.utils.JSON;
import com.google.common.base.Stopwatch;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;

public abstract class AzureClientBase<M extends AbstractLLMConnection.BaseModel>
extends AbstractLLMClient
implements LLMClient {
    private ComputeResourceUsage.InternalLLMUsageData usageData = new ComputeResourceUsage.LLMUsageData();
    protected M model;
    protected final LLMQueryRunner queryRunner;

    public AzureClientBase(EnrichedLLMStructuredRef enrichedRef, M model, AbstractLLMConnection.HTTPBasedLLMNetworkSettings networkSettings) throws DKUSecurityException, IOException {
        super(enrichedRef);
        this.model = model;
        this.queryRunner = new LLMQueryRunner(this.getProviderId(), enrichedRef, (LLMModelHandle.Model)model, networkSettings, OnlineLLMUtils::isRetryableException);
    }

    public abstract RawOpenAIClient getRaw();

    public abstract boolean isChatModel();

    public abstract String getLogPrefix();

    public abstract String getAzureModelId();

    public abstract OpenAIImageHandling getImageHandlingMode();

    @Override
    public boolean supportNativeBatch() {
        return false;
    }

    @Override
    public List<LLMClient.ChatMessage> getFormattedPrompt(List<LLMClient.ChatMessage> chatMessages) {
        if (this.isChatModel()) {
            return chatMessages;
        }
        ArrayList<LLMClient.ChatMessage> formattedPromptMessages = new ArrayList<LLMClient.ChatMessage>();
        LLMClient.ChatMessage formattedPrompt = new LLMClient.ChatMessage();
        formattedPrompt.setTextOnly(chatMessages.stream().map(m -> m.getText()).collect(Collectors.joining("\n")));
        formattedPrompt.role = "prompt";
        formattedPromptMessages.add(formattedPrompt);
        return formattedPromptMessages;
    }

    @Override
    public List<LLMClient.SimpleCompletionResponse> completeBatch(List<LLMClient.SingleCompletionQuery> queries, LLMClient.CompletionSettings settings) throws Exception {
        ArrayList<LLMClient.SimpleCompletionResponse> ret = new ArrayList<LLMClient.SimpleCompletionResponse>();
        for (LLMClient.SingleCompletionQuery query : queries) {
            this.getLogger().info((Object)(this.getLogPrefix() + " Complete: " + JSON.json((Object)query)));
            long before = System.currentTimeMillis();
            CoreCompletionSettings ccs = this.getCoreCompletionSettings(settings);
            LLMClient.SimpleCompletionResponse scr = this.queryRunner.run(() -> {
                if (this.isChatModel()) {
                    List<LLMClient.ChatMessage> chatMessages = this.getFormattedPrompt(query.messages);
                    try {
                        return this.getRaw().chatComplete(this.getAzureModelId(), chatMessages, ccs, (LLMModelHandle.Model)this.model);
                    }
                    catch (LLMClient.LLMException e) {
                        this.usageData.incrementTotalComputationTimeMS(Long.valueOf(System.currentTimeMillis() - before));
                        this.usageData.incrementTotalPromptTokens(e.promptTokens);
                        this.usageData.incrementTotalCompletionTokens(e.completionTokens);
                        this.usageData.incrementEstimatedCostUSD(e.estimatedCost);
                        throw e;
                    }
                }
                String prompt = query.messages.stream().map(m -> m.getText()).collect(Collectors.joining("\n"));
                return this.getRaw().complete(this.getAzureModelId(), prompt, ccs);
            });
            scr.estimatedCost = this.model.getEstimatedCompletionCost(scr.promptTokens, scr.completionTokens);
            this.usageData.incrementTotalComputationTimeMS(Long.valueOf(System.currentTimeMillis() - before));
            this.usageData.incrementTotalPromptTokens(scr.promptTokens);
            this.usageData.incrementTotalCompletionTokens(scr.completionTokens);
            this.usageData.incrementEstimatedCostUSD(scr.estimatedCost);
            ret.add(scr);
        }
        return ret;
    }

    @Override
    public List<LLMClient.SimpleEmbeddingResponse> embedBatch(List<LLMClient.EmbeddingQuery> queries, LLMClient.EmbeddingSettings settings) throws Exception {
        ArrayList<LLMClient.SimpleEmbeddingResponse> ret = new ArrayList<LLMClient.SimpleEmbeddingResponse>();
        List batchTexts = queries.stream().map(query -> query.text).collect(Collectors.toList());
        this.getLogger().info((Object)(this.getLogPrefix() + " Embed sending the following batch : " + JSON.json(batchTexts)));
        long before = System.currentTimeMillis();
        List EmbeddingResponses = this.queryRunner.run(() -> this.getRaw().embed(this.getAzureModelId(), batchTexts));
        this.usageData.incrementTotalComputationTimeMS(Long.valueOf(System.currentTimeMillis() - before));
        for (LLMClient.SimpleEmbeddingResponse scr : EmbeddingResponses) {
            scr.estimatedCost = this.model.getEstimatedEmbeddingCost(scr.promptTokens, 0);
            this.usageData.incrementTotalPromptTokens(scr.promptTokens);
            this.usageData.incrementEstimatedCostUSD(scr.estimatedCost);
            ret.add(scr);
        }
        return ret;
    }

    @Override
    public LLMClient.ImageGenerationResponse generateImages(LLMClient.ImageGenerationQuery query) throws Exception {
        long before = System.currentTimeMillis();
        LLMClient.ImageGenerationResponse resp = this.queryRunner.run(() -> this.getRaw().generateImage(this.getImageHandlingMode(), this.getAzureModelId(), query));
        this.usageData.incrementTotalComputationTimeMS(Long.valueOf(System.currentTimeMillis() - before));
        String size = ImageGenerationUtils.getWxHString(query, 1024, 1024);
        resp.estimatedCost = OpenAIPricing.getOpenAIEstimatedImageGenerationCost(this.getImageHandlingMode(), query.nbImagesToGenerate, size, query.quality);
        this.usageData.incrementEstimatedCostUSD(Double.valueOf(resp.estimatedCost));
        return resp;
    }

    @Override
    public ComputeResourceUsage getTotalCRU(ComputeResourceUsage.LLMUsageType usageType, LLMStructuredRef llmRef) {
        ComputeResourceUsage cru = new ComputeResourceUsage();
        cru.setupLLMUsage(usageType, llmRef.connection, llmRef.type.toString(), llmRef.id);
        cru.llmUsage.setFromInternal(this.usageData);
        return cru;
    }

    @Override
    public void streamComplete(LLMClient.SingleCompletionQuery query, LLMClient.CompletionSettings settings, LLMClient.StreamedCompletionResponseConsumer consumer) throws Exception {
        CoreCompletionSettings ccs = this.getCoreCompletionSettings(settings);
        this.getLogger().info((Object)(this.getLogPrefix() + " streamed complete: " + JSON.json((Object)query)));
        Stopwatch stopwatch = Stopwatch.createStarted();
        LLMClient.StreamedCompletionResponseConsumerProxy wrappedConsumer = new LLMClient.StreamedCompletionResponseConsumerProxy(consumer, (ExceptionUtils.ThrowingConsumer<LLMClient.StreamedCompletionResponseFooter, Exception>)((ExceptionUtils.ThrowingConsumer)footer -> {
            this.usageData.incrementTotalPromptTokens(footer.promptTokens);
            this.usageData.incrementTotalCompletionTokens(footer.completionTokens);
            this.usageData.incrementTotalComputationTimeMS(Long.valueOf(stopwatch.elapsed(TimeUnit.MILLISECONDS)));
            this.usageData.incrementEstimatedCostUSD(footer.estimatedCost);
        }));
        if (!this.isChatModel()) {
            throw new IllegalArgumentException("Streaming not supported on legacy-completion " + this.getLogPrefix() + " endpoints");
        }
        List<LLMClient.ChatMessage> chatMessages = this.getFormattedPrompt(query.messages);
        this.queryRunner.run(() -> {
            try {
                this.getRaw().streamChatComplete(wrappedConsumer, (LLMModelHandle.Model)this.model, chatMessages, ccs);
            }
            catch (LLMClient.LLMException e) {
                this.usageData.incrementTotalComputationTimeMS(Long.valueOf(stopwatch.elapsed(TimeUnit.MILLISECONDS)));
                this.usageData.incrementTotalPromptTokens(e.promptTokens);
                this.usageData.incrementTotalCompletionTokens(e.completionTokens);
                this.usageData.incrementEstimatedCostUSD(e.estimatedCost);
                throw e;
            }
            return null;
        });
    }

    @Override
    public boolean supportsStream() {
        return true;
    }

    @Override
    public int getBatchSize(AbstractLLMConnection.QueryType queryType, LLMStructuredRef llmRef) {
        switch (queryType) {
            case textEmbedding: {
                if (((AbstractLLMConnection.BaseModel)this.model).getBatchSize().isPresent()) {
                    logger.info((Object)"Using user defined batch size");
                    return ((AbstractLLMConnection.BaseModel)this.model).getBatchSize().getAsInt();
                }
                return 10;
            }
        }
        return 1;
    }

    @Override
    public void close() {
        this.getRaw().close();
    }

    protected abstract DKULogger getLogger();
}

