/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.llm.online.sagemakergeneric;

import com.dataiku.dip.connections.AbstractLLMConnection;
import com.dataiku.dip.connections.HuggingFaceLocalConnection;
import com.dataiku.dip.connections.SageMakerGenericLLMConnection;
import com.dataiku.dip.connections.aws.ShadedS3ConnectionAWSSessionCredentialsProviderV2;
import com.dataiku.dip.exceptions.DKUSecurityException;
import com.dataiku.dip.llm.LLMModelHandle;
import com.dataiku.dip.llm.LLMStructuredRef;
import com.dataiku.dip.llm.local.HuggingFaceLocalClient;
import com.dataiku.dip.llm.online.AbstractLLMClient;
import com.dataiku.dip.llm.online.LLMClient;
import com.dataiku.dip.llm.online.marshall.CoreCompletionSettings;
import com.dataiku.dip.llm.online.sagemakergeneric.GenericLLMHandling;
import com.dataiku.dip.llm.online.sagemakergeneric.RawSageMakerClient;
import com.dataiku.dip.resourceusage.ComputeResourceUsage;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.JSON;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.commons.lang3.StringUtils;

public class SageMakerLLMClient
extends AbstractLLMClient
implements LLMClient {
    private final SageMakerGenericLLMConnection connection;
    private final RawSageMakerClient raw;
    private final SageMakerGenericLLMConnection.SageMakerModel model;
    private static final DKULogger logger = DKULogger.getLogger((String)"dku.llm.sagemaker");

    public SageMakerLLMClient(AuthCtx authCtx, SageMakerGenericLLMConnection connection, LLMModelHandle<SageMakerGenericLLMConnection.SageMakerModel> modelHandle) throws IOException, DKUSecurityException {
        super(modelHandle.getEnrichedRef());
        this.connection = connection;
        this.model = modelHandle.getModel();
        if (StringUtils.isBlank((CharSequence)connection.params.region)) {
            throw new IllegalArgumentException("Region is required in connection settings");
        }
        if (StringUtils.isBlank((CharSequence)connection.params.sageMakerConnection)) {
            throw new IllegalArgumentException("SageMaker connection is required in connection settings");
        }
        this.raw = new RawSageMakerClient(connection.params, new ShadedS3ConnectionAWSSessionCredentialsProviderV2(authCtx, connection), connection.getProxySettings());
    }

    @Override
    public void close() {
        this.raw.close();
    }

    @Override
    public boolean supportNativeBatch() {
        return false;
    }

    @Override
    public boolean requiresCostLimiting() {
        return true;
    }

    @Override
    public String getProviderId() {
        return "SageMaker-GenericLLM";
    }

    @Override
    public AbstractLLMConnection getConnection() {
        return this.connection;
    }

    @Override
    public int getMaxParallelism() {
        return this.connection.params.maxParallelism;
    }

    @Override
    public List<LLMClient.ChatMessage> getFormattedPrompt(List<LLMClient.ChatMessage> chatMessages) {
        ArrayList<LLMClient.ChatMessage> formattedPromptMessages = new ArrayList<LLMClient.ChatMessage>();
        LLMClient.ChatMessage formattedPrompt = new LLMClient.ChatMessage();
        formattedPrompt.setTextOnly(SageMakerLLMClient.getFormattedPromptContent(chatMessages, this.model.handling));
        formattedPrompt.role = GenericLLMHandling.AMAZON_TITAN == this.model.handling ? "inputText" : "prompt";
        formattedPromptMessages.add(formattedPrompt);
        return formattedPromptMessages;
    }

    private static String getFormattedPromptContent(List<LLMClient.ChatMessage> chatMessages, GenericLLMHandling handling) {
        if (GenericLLMHandling.META_LLAMA_2_SAGEMAKER == handling) {
            return HuggingFaceLocalClient.getFormattedPromptContent(chatMessages, HuggingFaceLocalConnection.HuggingFaceHandlingMode.TEXT_GENERATION_LLAMA_2);
        }
        return chatMessages.stream().map(m -> m.getText()).collect(Collectors.joining("\n\n"));
    }

    @Override
    public List<LLMClient.SimpleCompletionResponse> completeBatch(List<LLMClient.SingleCompletionQuery> queries, LLMClient.CompletionSettings settings) throws IOException {
        ArrayList<LLMClient.SimpleCompletionResponse> ret = new ArrayList<LLMClient.SimpleCompletionResponse>();
        CoreCompletionSettings ccs = this.getCoreCompletionSettings(settings);
        for (LLMClient.SingleCompletionQuery query : queries) {
            String prompt = SageMakerLLMClient.getFormattedPromptContent(query.messages, this.model.handling);
            LLMClient.SimpleCompletionResponse scr = this.raw.complete(this.connection, prompt, ccs);
            ret.add(scr);
        }
        return ret;
    }

    @Override
    public List<LLMClient.SimpleEmbeddingResponse> embedBatch(List<LLMClient.EmbeddingQuery> queries, LLMClient.EmbeddingSettings settings) throws IOException {
        if (LLMClient.TextOverflowMode.TRUNCATE.equals((Object)settings.textOverflowMode)) {
            logger.warn((Object)"Truncation for long texts overflow is not supported yet for Bedrock, defaulting to Failure mode");
        }
        ArrayList<LLMClient.SimpleEmbeddingResponse> ret = new ArrayList<LLMClient.SimpleEmbeddingResponse>();
        for (LLMClient.EmbeddingQuery query : queries) {
            logger.debugV("SageMaker Embed: %s", new Object[]{JSON.json((Object)query.getSafeForLoggingCopy())});
            LLMClient.SimpleEmbeddingResponse ser = this.raw.embed(this.connection, query);
            ret.add(ser);
        }
        return ret;
    }

    @Override
    public List<LLMClient.SingleRerankingResponse> rerankBatch(List<LLMClient.RerankingQuery> queries, LLMClient.RerankingSettings settings) throws Exception {
        throw new IllegalArgumentException("Reranking not supported on this LLM");
    }

    @Override
    public ComputeResourceUsage getTotalCRU(ComputeResourceUsage.LLMUsageType usageType, LLMStructuredRef llmRef) {
        return null;
    }

    @Override
    public CoreCompletionSettings getCoreCompletionSettings(LLMClient.CompletionSettings cs2) {
        SageMakerGenericLLMConnection.CustomSageMakerModel model = this.connection.params.sageMakerModel;
        Integer maxOutputTokens = null;
        if (cs2.maxOutputTokens == null) {
            if (model.maxTokensLimit != null && model.maxTokensLimit != 0) {
                maxOutputTokens = model.maxTokensLimit;
            }
        } else if (model.maxTokensLimit != null && model.maxTokensLimit != 0) {
            maxOutputTokens = Math.min(cs2.maxOutputTokens, model.maxTokensLimit);
        }
        CoreCompletionSettings ccs = super.getCoreCompletionSettings(cs2);
        ccs.maxTokens = maxOutputTokens;
        return ccs;
    }
}

