/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.llm.utils;

import com.dataiku.dip.llm.online.LLMClient;
import com.dataiku.dip.llm.online.LLMTracingUtils;
import com.dataiku.dip.streaming.endpoints.httpsse.MiniSSEEmitter;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.JSON;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicReference;
import org.apache.commons.lang.mutable.MutableBoolean;

public class StreamingConsumer
implements LLMClient.StreamedCompletionResponseConsumer {
    private static final DKULogger logger = DKULogger.getLogger((String)"dku.llm.utils.StreamingConsumer");
    final StringBuilder completeText = new StringBuilder();
    final List<LLMClient.AbstractToolCall> completeToolCalls = new ArrayList<LLMClient.AbstractToolCall>();
    final AtomicReference<LLMClient.FinishReason> finishReason = new AtomicReference();
    final List<LLMClient.DetailedLogProb> allLogProbs = new ArrayList<LLMClient.DetailedLogProb>();
    final MutableBoolean firstChunkReceived = new MutableBoolean(false);
    final MiniSSEEmitter emitter;
    final LLMClient.LLMMeshTraceSpan callSpan;
    final LLMClient.LLMMeshTraceSpan trace;
    LLMClient.StreamedCompletionResponseFooter footer;

    public StreamingConsumer(MiniSSEEmitter miniSSEEmitter, LLMClient.LLMMeshTraceSpan trace, LLMClient.LLMMeshTraceSpan callSpan) {
        this.emitter = miniSSEEmitter;
        this.callSpan = callSpan;
        this.trace = trace;
    }

    public LLMClient.SimpleCompletionResponseOrError buildSimpleCompletionResponseOrErrorFromStreamingInterrupted() {
        this.callSpan.withChildEvent("DKU_LLM_MESH_LLM_CALL_STREAM_INTERRUPTED");
        LLMClient.SimpleCompletionResponseOrError scre = LLMClient.SimpleCompletionResponseOrError.blank();
        scre.ok = true;
        scre.text = this.completeText.toString();
        scre.logProbs = this.allLogProbs;
        scre.toolCalls = this.completeToolCalls;
        LLMTracingUtils.setCompletionOutput(this.callSpan, scre);
        LLMTracingUtils.setCompletionOutput(this.trace, scre);
        this.callSpan.close();
        this.trace.close();
        scre.trace = this.trace;
        return scre;
    }

    public LLMClient.SimpleCompletionResponseOrError buildSimpleCompletionResponseOrErrorFromSuccess() {
        LLMClient.SimpleCompletionResponseOrError scre = LLMClient.SimpleCompletionResponseOrError.blank();
        scre.ok = true;
        scre.finishReason = this.footer.finishReason;
        scre.promptTokens = this.footer.promptTokens;
        scre.completionTokens = this.footer.completionTokens;
        if (this.footer.promptTokens != null || this.footer.completionTokens != null) {
            scre.totalTokens = (this.footer.promptTokens != null ? this.footer.promptTokens : 0) + (this.footer.completionTokens != null ? this.footer.completionTokens : 0);
        }
        scre.tokenCountsAreEstimated = this.footer.tokenCountsAreEstimated;
        scre.estimatedCost = this.footer.estimatedCost;
        scre.additionalInformation = this.footer.additionalInformation;
        scre.trace = this.footer.trace;
        scre.text = this.completeText.toString();
        scre.logProbs = this.allLogProbs;
        scre.toolCalls = this.completeToolCalls;
        LLMClient.FinishReason reason = this.finishReason.get();
        if (reason != null) {
            scre.finishReason = reason;
        } else {
            logger.debug((Object)String.format("The streamed answer does not include a finish reason, defaulting to '%s'", new Object[]{scre.finishReason}));
        }
        logger.info((Object)("Done streaming answer from LLM: " + JSON.json((Object)scre)));
        return scre;
    }

    @Override
    public void onStreamStarted() throws Exception {
        if (!this.firstChunkReceived.booleanValue()) {
            this.callSpan.withChildEvent("DKU_LLM_MESH_LLM_CALL_STREAMED_FIRST_CHUNK");
            this.firstChunkReceived.setValue(true);
        }
        this.emitter.initSuccess();
    }

    @Override
    public void onStreamChunk(LLMClient.StreamedCompletionResponseChunk chunk) throws Exception {
        if (chunk.toolCalls != null) {
            this.completeToolCalls.addAll(chunk.toolCalls);
        }
        if (chunk.text != null) {
            this.completeText.append(chunk.text);
        }
        if (chunk.logProbs != null && !chunk.logProbs.isEmpty()) {
            this.allLogProbs.addAll(chunk.logProbs);
        }
        this.emitter.sendEventWithData("completion-chunk", JSON.json((Object)chunk), false);
    }

    @Override
    public void onStreamComplete(LLMClient.StreamedCompletionResponseFooter streamedCompletionResponseFooter) throws Exception {
        logger.info((Object)("onStreamComplete: " + JSON.json((Object)streamedCompletionResponseFooter)));
        this.callSpan.withChildEvent("DKU_LLM_MESH_LLM_CALL_STREAMED_STREAM_COMPLETE");
        if (streamedCompletionResponseFooter.finishReason != null) {
            this.finishReason.set(streamedCompletionResponseFooter.finishReason);
        }
        LLMClient.SimpleCompletionResponseOrError screForTrace = LLMClient.SimpleCompletionResponseOrError.blank();
        screForTrace.text = this.completeText.toString();
        screForTrace.toolCalls = this.completeToolCalls;
        screForTrace.ok = true;
        LLMTracingUtils.setCompletionOutput(this.callSpan, screForTrace);
        LLMTracingUtils.setCompletionOutput(this.trace, screForTrace);
        this.callSpan.usageMetadata = new LLMClient.UsageMetadata(streamedCompletionResponseFooter);
        this.callSpan.close();
        this.trace.close();
        if (streamedCompletionResponseFooter.trace != null) {
            this.callSpan.addObservation(streamedCompletionResponseFooter.trace);
        }
        this.footer = streamedCompletionResponseFooter;
        this.footer.trace = this.trace;
    }
}

