/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.llm.online.huggingface;

import com.dataiku.dip.connections.AbstractLLMConnection;
import com.dataiku.dip.connections.ConnectionUtils;
import com.dataiku.dip.connections.HuggingFaceInferenceAPIConnection;
import com.dataiku.dip.llm.LLMModelHandle;
import com.dataiku.dip.llm.LLMStructuredRef;
import com.dataiku.dip.llm.online.AbstractLLMClient;
import com.dataiku.dip.llm.online.LLMClient;
import com.dataiku.dip.llm.online.huggingface.RawHuggingFaceInferenceAPIClient;
import com.dataiku.dip.llm.online.marshall.CoreCompletionSettings;
import com.dataiku.dip.resourceusage.ComputeResourceUsage;
import com.dataiku.dip.shaker.processors.expr.TokenizedText;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.JSON;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;

public class HuggingFaceInferenceAPIClient
extends AbstractLLMClient
implements LLMClient {
    private HuggingFaceInferenceAPIConnection connection;
    private RawHuggingFaceInferenceAPIClient raw;
    private HuggingFaceInferenceAPIConnection.HFAPIModel model;
    private ComputeResourceUsage.InternalLLMUsageData usageData = new ComputeResourceUsage.LLMUsageData();
    private static DKULogger logger = DKULogger.getLogger((String)"dku.llm.huggingface");

    public HuggingFaceInferenceAPIClient(HuggingFaceInferenceAPIConnection connection, LLMModelHandle<HuggingFaceInferenceAPIConnection.HFAPIModel> modelHandle, String apiKey) {
        super(modelHandle.getEnrichedRef());
        boolean forceContentLength = ConnectionUtils.getParamsFromProperties(connection.getDkuProperties()).getBoolParam("dku.connection.llm.forceContentLength", false);
        this.raw = new RawHuggingFaceInferenceAPIClient(apiKey, connection.getProxySettings(), forceContentLength);
        this.model = modelHandle.getModel();
        this.connection = connection;
    }

    @Override
    public void close() {
        this.raw.close();
    }

    @Override
    public boolean supportNativeBatch() {
        return false;
    }

    @Override
    public boolean requiresCostLimiting() {
        return false;
    }

    @Override
    public String getProviderId() {
        return "HuggingFaceInferenceAPI";
    }

    @Override
    public AbstractLLMConnection getConnection() {
        return this.connection;
    }

    @Override
    public int getMaxParallelism() {
        return 2;
    }

    @Override
    public List<LLMClient.SimpleCompletionResponse> completeBatch(List<LLMClient.SingleCompletionQuery> queries, LLMClient.CompletionSettings settings) throws IOException {
        ArrayList<LLMClient.SimpleCompletionResponse> ret = new ArrayList<LLMClient.SimpleCompletionResponse>();
        CoreCompletionSettings ccs = this.getCoreCompletionSettings(settings);
        for (LLMClient.SingleCompletionQuery query : queries) {
            logger.info((Object)("HuggingFace single complete query: " + JSON.log((Object)query.getSafeForLoggingCopy())));
            long before = System.currentTimeMillis();
            String prompt = query.messages.stream().map(m -> m.getText()).collect(Collectors.joining("\n"));
            LLMClient.SimpleCompletionResponse scr = this.raw.complete(this.model.getId(), prompt, ccs);
            scr.promptTokens = (int)(2.5f * (float)new TokenizedText(prompt).size());
            scr.completionTokens = (int)(2.5f * (float)new TokenizedText(scr.text).size());
            scr.tokenCountsAreEstimated = true;
            scr.estimatedCost = 0.0;
            scr.includeInUsageData(this.usageData, System.currentTimeMillis() - before);
            ret.add(scr);
        }
        return ret;
    }

    @Override
    public ComputeResourceUsage getTotalCRU(ComputeResourceUsage.LLMUsageType usageType, LLMStructuredRef llmRef) {
        ComputeResourceUsage cru = new ComputeResourceUsage();
        cru.setupLLMUsage(usageType, llmRef.connection, llmRef.type.toString(), llmRef.id);
        cru.llmUsage.setFromInternal(this.usageData);
        return cru;
    }

    @Override
    public List<LLMClient.ChatMessage> getFormattedPrompt(List<LLMClient.ChatMessage> chatMessages) {
        throw new IllegalArgumentException("Not implemented");
    }

    @Override
    public List<LLMClient.SimpleEmbeddingResponse> embedBatch(List<LLMClient.EmbeddingQuery> queries, LLMClient.EmbeddingSettings settings) throws IOException {
        throw new IllegalArgumentException("Embeddings not supported on this LLM");
    }

    @Override
    public List<LLMClient.SingleRerankingResponse> rerankBatch(List<LLMClient.RerankingQuery> queries, LLMClient.RerankingSettings settings) throws Exception {
        throw new IllegalArgumentException("Rerankings not supported on this LLM");
    }
}

