/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.llm.online.bedrock;

import com.dataiku.dip.connections.AbstractLLMConnection;
import com.dataiku.dip.connections.BedrockConnection;
import com.dataiku.dip.connections.HuggingFaceLocalConnection;
import com.dataiku.dip.exceptions.DKUSecurityException;
import com.dataiku.dip.llm.LLMModelHandle;
import com.dataiku.dip.llm.LLMStructuredRef;
import com.dataiku.dip.llm.local.HuggingFaceLocalClient;
import com.dataiku.dip.llm.online.AbstractLLMClient;
import com.dataiku.dip.llm.online.ISavedModelDeployer;
import com.dataiku.dip.llm.online.LLMChatMessageUtils;
import com.dataiku.dip.llm.online.LLMClient;
import com.dataiku.dip.llm.online.LLMQueryRunner;
import com.dataiku.dip.llm.online.RemoteFineTuningClient;
import com.dataiku.dip.llm.online.anthropic.AnthropicClient;
import com.dataiku.dip.llm.online.bedrock.BedrockFineTuningClient;
import com.dataiku.dip.llm.online.bedrock.BedrockSavedModelDeployer;
import com.dataiku.dip.llm.online.bedrock.RawBedrockClient;
import com.dataiku.dip.llm.online.cohere.CohereClient;
import com.dataiku.dip.llm.online.marshall.CoreCompletionSettings;
import com.dataiku.dip.llm.online.mistralai.MistralAIClient;
import com.dataiku.dip.llm.online.sagemakergeneric.GenericLLMHandling;
import com.dataiku.dip.resourceusage.ComputeResourceUsage;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.ExceptionUtils;
import com.dataiku.dip.utils.JSON;
import com.dataiku.dss.legacy.aws.com.amazonaws.services.s3.AmazonS3Client;
import com.dataiku.dss.shadelibawssk2.org.apache.http.NoHttpResponseException;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.awscore.exception.AwsServiceException;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.core.exception.ApiCallAttemptTimeoutException;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.core.exception.RetryableException;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.core.exception.SdkClientException;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.core.exception.SdkException;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.core.exception.SdkServiceException;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.core.retry.RetryUtils;
import com.google.common.base.Stopwatch;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import org.apache.commons.lang3.StringUtils;

public class BedrockClient
extends AbstractLLMClient {
    private final BedrockConnection connection;
    private final RawBedrockClient raw;
    private final BedrockConnection.BedrockModel model;
    private final GenericLLMHandling handling;
    private final AuthCtx authCtx;
    private final LLMQueryRunner queryRunner;
    private ComputeResourceUsage.InternalLLMUsageData usageData = new ComputeResourceUsage.LLMUsageData();
    private static final Set<String> RETRYABLE_ERROR_CODES = Set.of("PriorRequestNotComplete", "RequestTimeout", "RequestTimeoutException", "InternalError");
    private static final Set<Integer> RETRYABLE_STATUS_CODES = Set.of(Integer.valueOf(500), Integer.valueOf(502), Integer.valueOf(503), Integer.valueOf(504));
    private static final Set<Class<? extends Exception>> RETRYABLE_EXCEPTIONS = Set.of(RetryableException.class, IOException.class, UncheckedIOException.class, ApiCallAttemptTimeoutException.class);
    private static DKULogger logger = DKULogger.getLogger((String)"dku.llm.bedrock");

    public BedrockClient(AuthCtx authCtx, BedrockConnection connection, LLMModelHandle<BedrockConnection.BedrockModel> modelHandle) throws IOException, DKUSecurityException {
        super(modelHandle.getEnrichedRef());
        this.connection = connection;
        this.authCtx = authCtx;
        this.queryRunner = new LLMQueryRunner(this.getProviderId(), modelHandle, connection.params.networkSettings, BedrockClient::isRetryableAmazonClientException);
        if (StringUtils.isBlank((CharSequence)connection.params.region)) {
            throw new IllegalArgumentException("Region is required in connection settings");
        }
        if (StringUtils.isBlank((CharSequence)connection.params.s3Connection)) {
            throw new IllegalArgumentException("S3 connection is required in connection settings");
        }
        this.raw = new RawBedrockClient(connection.params.region, connection.getCredentialsProvider(authCtx), this.queryRunner.getHttpClientNetworkSettings(), connection.getProxySettings(), connection.params.fineTuningSettings);
        this.model = modelHandle.getModel();
        this.handling = this.model.handlingMode;
        if (this.model.inferenceProfile != null) {
            logger.info((Object)("Adding cross-region inference profile as prefix. Full model id: " + this.model.getInferenceModelId()));
        }
    }

    public static boolean isRetryableAmazonClientException(Throwable t) {
        SdkServiceException sdkServiceException;
        AwsServiceException awsServiceException;
        String errorCode;
        if (t instanceof AwsServiceException && (errorCode = (awsServiceException = (AwsServiceException)t).awsErrorDetails().errorCode()) != null && RETRYABLE_ERROR_CODES.contains(errorCode)) {
            return true;
        }
        if (t instanceof SdkException) {
            SdkException sdkException = (SdkException)t;
            if (RetryUtils.isRetryableException((SdkException)sdkException)) {
                return true;
            }
            if (RetryUtils.isClockSkewException((SdkException)sdkException)) {
                return true;
            }
            if (RetryUtils.isThrottlingException((SdkException)sdkException)) {
                return true;
            }
        }
        if (t instanceof SdkServiceException && RETRYABLE_STATUS_CODES.contains((sdkServiceException = (SdkServiceException)t).statusCode())) {
            return true;
        }
        if (t instanceof SdkClientException) {
            SdkClientException sdkClientException = (SdkClientException)t;
            if (sdkClientException.getCause() instanceof NoHttpResponseException) {
                return true;
            }
            if (sdkClientException.getCause() instanceof IOException) {
                return true;
            }
        }
        Class<?> classOfThrowable = t.getClass();
        for (Class<? extends Exception> retryableException : RETRYABLE_EXCEPTIONS) {
            if (!retryableException.isAssignableFrom(classOfThrowable)) continue;
            return true;
        }
        return false;
    }

    @Override
    public void close() {
        this.raw.close();
    }

    public RawBedrockClient getRaw() {
        return this.raw;
    }

    public String getModelId() {
        return this.model.id;
    }

    @Override
    public boolean supportNativeBatch() {
        return false;
    }

    @Override
    public boolean requiresCostLimiting() {
        return true;
    }

    @Override
    public String getProviderId() {
        return "Bedrock";
    }

    @Override
    public AbstractLLMConnection getConnection() {
        return this.connection;
    }

    @Override
    public int getMaxParallelism() {
        return this.connection.params.maxParallelism;
    }

    @Override
    public List<LLMClient.ChatMessage> getFormattedPrompt(List<LLMClient.ChatMessage> chatMessages) {
        if (this.model.supportsConverseAPI) {
            chatMessages = LLMChatMessageUtils.convertMessageRole(chatMessages, "tool", "user");
            if (this.model.supportsSystemPrompt) {
                chatMessages = LLMChatMessageUtils.collapseAdjacentSameRoleMessages(chatMessages);
                chatMessages = LLMChatMessageUtils.convertExtraSystemMessageToUser(chatMessages);
            } else {
                chatMessages = LLMChatMessageUtils.convertMessageRole(chatMessages, "system", "user");
            }
            return LLMChatMessageUtils.collapseAdjacentSameRoleMessages(chatMessages);
        }
        if (this.handling.isChatModel) {
            if (this.handling == GenericLLMHandling.ANTHROPIC_CLAUDE_CHAT) {
                return AnthropicClient.formatChatMessages(chatMessages);
            }
            if (this.handling == GenericLLMHandling.MISTRAL_AI_CHAT) {
                return MistralAIClient.formatChatMessages(chatMessages);
            }
            if (this.handling == GenericLLMHandling.COHERE_COMMAND_CHAT) {
                return CohereClient.formatChatMessages(chatMessages);
            }
            return chatMessages;
        }
        ArrayList<LLMClient.ChatMessage> formattedPromptMessages = new ArrayList<LLMClient.ChatMessage>();
        LLMClient.ChatMessage formattedPrompt = new LLMClient.ChatMessage();
        formattedPrompt.setTextOnly(BedrockClient.getFormattedPromptContent(chatMessages, this.handling));
        formattedPrompt.role = this.handling == GenericLLMHandling.AMAZON_TITAN ? "inputText" : "prompt";
        formattedPromptMessages.add(formattedPrompt);
        return formattedPromptMessages;
    }

    private static String getFormattedPromptContent(List<LLMClient.ChatMessage> chatMessages, GenericLLMHandling handling) {
        if (handling == GenericLLMHandling.ANTHROPIC_CLAUDE) {
            return AnthropicClient.getFormattedPromptContent(chatMessages);
        }
        if (handling == GenericLLMHandling.META_LLAMA_2_BEDROCK) {
            return HuggingFaceLocalClient.getFormattedPromptContent(chatMessages, HuggingFaceLocalConnection.HuggingFaceHandlingMode.TEXT_GENERATION_LLAMA_2);
        }
        if (handling == GenericLLMHandling.META_LLAMA_3_BEDROCK) {
            return HuggingFaceLocalClient.getFormattedPromptContent(chatMessages, HuggingFaceLocalConnection.HuggingFaceHandlingMode.TEXT_GENERATION_LLAMA_3);
        }
        if (handling == GenericLLMHandling.MISTRAL_AI) {
            return HuggingFaceLocalClient.getFormattedPromptContent(chatMessages, HuggingFaceLocalConnection.HuggingFaceHandlingMode.TEXT_GENERATION_MISTRAL);
        }
        return chatMessages.stream().map(m -> m.getText()).collect(Collectors.joining("\n\n"));
    }

    /*
     * Enabled force condition propagation
     * Lifted jumps to return sites
     */
    @Override
    public List<LLMClient.SimpleCompletionResponse> completeBatch(List<LLMClient.SingleCompletionQuery> queries, LLMClient.CompletionSettings settings) throws Exception {
        ArrayList<LLMClient.SimpleCompletionResponse> ret = new ArrayList<LLMClient.SimpleCompletionResponse>();
        for (LLMClient.SingleCompletionQuery query : queries) {
            String modelId;
            Stopwatch stopwatch = Stopwatch.createStarted();
            CoreCompletionSettings ccs = this.getCoreCompletionSettings(settings);
            if (this.model.isDKUFineTuned && this.model.deploymentId != null) {
                Optional<RawBedrockClient.BedrockProvisionedThroughput> provisionedThroughput = this.raw.getProvisionedThroughput_NT(this.model.deploymentId);
                if (!provisionedThroughput.isPresent()) throw new IllegalArgumentException("Deployment " + this.model.deploymentId + " not found on AWS Bedrock");
                modelId = provisionedThroughput.get().provisionedModelArn;
            } else {
                modelId = this.model.getInferenceModelId();
            }
            RawBedrockClient.Settings bedrockSettings = new RawBedrockClient.Settings(modelId, this.handling, this.connection.params.useBedrockGuardrail, this.connection.params.guardrailIdentifier, this.connection.params.guardrailVersion);
            LLMClient.SimpleCompletionResponse scr = this.queryRunner.run(() -> {
                if (this.model.supportsConverseAPI) {
                    List<LLMClient.ChatMessage> chatMessages = this.getFormattedPrompt(query.messages);
                    return this.raw.chatCompleteConverseAPI(chatMessages, bedrockSettings, ccs);
                }
                if (this.handling.isChatModel) {
                    List<LLMClient.ChatMessage> chatMessages = this.getFormattedPrompt(query.messages);
                    return this.raw.chatCompleteInvokeAPI(chatMessages, bedrockSettings, ccs);
                }
                String prompt = BedrockClient.getFormattedPromptContent(query.messages, this.handling);
                return this.raw.completeInvokeAPI(prompt, bedrockSettings, ccs);
            });
            scr.estimatedCost = this.model.getEstimatedCompletionCost(scr.promptTokens, scr.completionTokens);
            scr.includeInUsageData(this.usageData, stopwatch.elapsed(TimeUnit.MILLISECONDS));
            ret.add(scr);
        }
        return ret;
    }

    @Override
    public void streamComplete(LLMClient.SingleCompletionQuery query, LLMClient.CompletionSettings settings, LLMClient.StreamedCompletionResponseConsumer consumer) throws Exception {
        List<LLMClient.ChatMessage> chatMessages;
        String prompt;
        if (this.model.supportsConverseAPI) {
            prompt = null;
            chatMessages = this.getFormattedPrompt(query.messages);
        } else if (this.handling.isChatModel) {
            prompt = null;
            chatMessages = this.getFormattedPrompt(query.messages);
        } else {
            prompt = BedrockClient.getFormattedPromptContent(query.messages, this.handling);
            chatMessages = null;
        }
        Stopwatch stopwatch = Stopwatch.createStarted();
        LLMClient.StreamedCompletionResponseConsumerProxy wrappedConsumer = new LLMClient.StreamedCompletionResponseConsumerProxy(consumer, (ExceptionUtils.ThrowingConsumer<LLMClient.StreamedCompletionResponseFooter, Exception>)((ExceptionUtils.ThrowingConsumer)footer -> footer.includeInUsageData(this.usageData, stopwatch.elapsed(TimeUnit.MILLISECONDS))));
        CoreCompletionSettings ccs = this.getCoreCompletionSettings(settings);
        RawBedrockClient.Settings bedrockSettings = new RawBedrockClient.Settings(this.model.getInferenceModelId(), this.handling, this.connection.params.useBedrockGuardrail, this.connection.params.guardrailIdentifier, this.connection.params.guardrailVersion);
        this.queryRunner.run(() -> {
            if (this.model.supportsConverseAPI) {
                this.raw.streamCompleteConverseAPI(wrappedConsumer, this.model, chatMessages, bedrockSettings, ccs);
            } else {
                this.raw.streamCompleteInvokeAPI(wrappedConsumer, this.model, chatMessages, prompt, bedrockSettings, ccs);
            }
            return null;
        });
    }

    @Override
    public List<LLMClient.SimpleEmbeddingResponse> embedBatch(List<LLMClient.EmbeddingQuery> queries, LLMClient.EmbeddingSettings settings) throws Exception {
        if (LLMClient.TextOverflowMode.TRUNCATE.equals((Object)settings.textOverflowMode)) {
            logger.warn((Object)"Truncation for long texts overflow is not supported yet for Bedrock, defaulting to Failure mode");
        }
        RawBedrockClient.Settings bedrockSettings = new RawBedrockClient.Settings(this.model.getInferenceModelId(), this.handling, this.connection.params.useBedrockGuardrail, this.connection.params.guardrailIdentifier, this.connection.params.guardrailVersion);
        ArrayList<LLMClient.SimpleEmbeddingResponse> ret = new ArrayList<LLMClient.SimpleEmbeddingResponse>();
        for (LLMClient.EmbeddingQuery query : queries) {
            Stopwatch stopwatch = Stopwatch.createStarted();
            logger.info((Object)("Bedrock Embed: " + JSON.json((Object)query)));
            LLMClient.SimpleEmbeddingResponse scr = this.queryRunner.run(() -> this.raw.embed(query, bedrockSettings));
            scr.estimatedCost = this.model.getEstimatedEmbeddingCost(scr.promptTokens, query.hasImage() ? 1 : 0);
            this.usageData.incrementTotalComputationTimeMS(Long.valueOf(stopwatch.elapsed(TimeUnit.MILLISECONDS)));
            scr.includeInUsageData(this.usageData);
            ret.add(scr);
        }
        return ret;
    }

    @Override
    public List<LLMClient.SimpleRerankingResponse> rerankBatch(List<LLMClient.RerankingQuery> queries) throws Exception {
        throw new IllegalArgumentException("Rerankings not supported on this LLM");
    }

    @Override
    public LLMClient.ImageGenerationResponse generateImages(LLMClient.ImageGenerationQuery query) throws Exception {
        Stopwatch stopwatch = Stopwatch.createStarted();
        RawBedrockClient.Settings bedrockSettings = new RawBedrockClient.Settings(this.model.getInferenceModelId(), this.handling, this.connection.params.useBedrockGuardrail, this.connection.params.guardrailIdentifier, this.connection.params.guardrailVersion);
        LLMClient.ImageGenerationResponse ret = this.queryRunner.run(() -> this.raw.generateImage(query, bedrockSettings));
        ret.estimatedCost = this.model.getEstimatedImageGenerationCost(query);
        this.usageData.incrementTotalComputationTimeMS(Long.valueOf(stopwatch.elapsed(TimeUnit.MILLISECONDS)));
        this.usageData.incrementEstimatedCostUSD(Double.valueOf(ret.estimatedCost));
        return ret;
    }

    @Override
    public ComputeResourceUsage getTotalCRU(ComputeResourceUsage.LLMUsageType usageType, LLMStructuredRef llmRef) {
        ComputeResourceUsage cru = new ComputeResourceUsage();
        cru.setupLLMUsage(usageType, llmRef.connection, llmRef.type.toString(), llmRef.id);
        cru.llmUsage.setFromInternal(this.usageData);
        return cru;
    }

    @Override
    public boolean supportsStream() {
        return true;
    }

    @Override
    public RemoteFineTuningClient newFineTuningClient() {
        return new BedrockFineTuningClient(this);
    }

    @Override
    public ISavedModelDeployer newSavedModelDeployer(AuthCtx authCtx) {
        return new BedrockSavedModelDeployer(this.getRaw());
    }

    public AmazonS3Client getS3Client() throws Exception {
        return this.connection.getS3Client(this.authCtx);
    }

    public BedrockConnection.BedrockModel getModel() {
        return this.model;
    }
}

