/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.llm.online.sagemakergeneric;

import com.dataiku.dip.ProxySettings;
import com.dataiku.dip.connections.SageMakerGenericLLMConnection;
import com.dataiku.dip.coremodel.SimpleKeyValue;
import com.dataiku.dip.llm.online.LLMClient;
import com.dataiku.dip.llm.online.marshall.CoreCompletionSettings;
import com.dataiku.dip.llm.online.sagemakergeneric.GenericEmbeddingLLMMarshall;
import com.dataiku.dip.llm.online.sagemakergeneric.GenericTextCompletionLLMMarshall;
import com.dataiku.dip.security.aws.AWSClientBrokerService;
import com.dataiku.dip.server.SpringUtils;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.JSON;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.core.SdkBytes;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.core.client.config.ClientOverrideConfiguration;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.http.SdkHttpClient;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.regions.Region;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.services.sagemaker.SageMakerClient;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.services.sagemaker.SageMakerClientBuilder;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.services.sagemakerruntime.SageMakerRuntimeClient;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.services.sagemakerruntime.SageMakerRuntimeClientBuilder;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointRequest;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointResponse;
import com.google.gson.JsonElement;
import com.google.gson.JsonObject;
import java.io.IOException;
import java.util.List;
import java.util.Map;
import org.apache.commons.lang.StringUtils;

public class RawSageMakerClient
implements AutoCloseable {
    private final SageMakerClient awsClient;
    private final SageMakerRuntimeClient awsRuntimeClient;
    private static DKULogger logger = DKULogger.getLogger((String)"dku.llm.sagemaker.client");

    public RawSageMakerClient(SageMakerGenericLLMConnection.SageMakerGenericLLMConnectionParams connectionParams, AwsCredentialsProvider credentialsProvider, ProxySettings proxySettings) {
        Region awsRegion = Region.of((String)connectionParams.region);
        AWSClientBrokerService awsClientBrokerService = (AWSClientBrokerService)SpringUtils.getBean(AWSClientBrokerService.class);
        int queryTimeout = connectionParams.networkSettings.queryTimeoutMS;
        SdkHttpClient.Builder httpClientBuilder = awsClientBrokerService.getHttpClientBuilder(proxySettings, queryTimeout, queryTimeout);
        ClientOverrideConfiguration clientOverrideConfiguration = this.createOverrideConfig(connectionParams);
        this.awsClient = (SageMakerClient)((SageMakerClientBuilder)awsClientBrokerService.createSageMakerClientBuilder(httpClientBuilder, clientOverrideConfiguration, awsRegion).credentialsProvider(credentialsProvider)).build();
        httpClientBuilder = awsClientBrokerService.getHttpClientBuilder(proxySettings, queryTimeout, queryTimeout);
        this.awsRuntimeClient = (SageMakerRuntimeClient)((SageMakerRuntimeClientBuilder)awsClientBrokerService.createSageMakerRuntimeClientBuilder(httpClientBuilder, clientOverrideConfiguration, awsRegion).credentialsProvider(credentialsProvider)).build();
    }

    @Override
    public void close() {
        this.awsClient.close();
        this.awsRuntimeClient.close();
    }

    private ClientOverrideConfiguration createOverrideConfig(SageMakerGenericLLMConnection.SageMakerGenericLLMConnectionParams connectionParams) {
        List<SimpleKeyValue> customHeaders = connectionParams.sageMakerModel.customHeaders;
        if (connectionParams.acceptEULA) {
            String attributesHeader = "X-Amzn-SageMaker-Custom-Attributes";
            String eula_value = "accept_eula=true";
            SimpleKeyValue existingHeader = customHeaders.stream().filter(h -> h.key.equals(attributesHeader)).findFirst().orElse(null);
            if (existingHeader == null) {
                customHeaders.add(new SimpleKeyValue(attributesHeader, eula_value));
            } else {
                existingHeader.value = existingHeader.value + "," + eula_value;
            }
        }
        ClientOverrideConfiguration clientOverrideConfiguration = connectionParams.networkSettings.createRetryStrategy();
        if (!customHeaders.isEmpty()) {
            ClientOverrideConfiguration.Builder builder = clientOverrideConfiguration.toBuilder();
            for (SimpleKeyValue skv : customHeaders) {
                if (StringUtils.isBlank((String)skv.key)) {
                    logger.debug((Object)"Ignoring header with empty key");
                    continue;
                }
                if (StringUtils.isBlank((String)skv.value)) {
                    logger.debug((Object)("Ignoring header with empty value for key: " + skv.key));
                    continue;
                }
                Map existingHeaders = builder.headers();
                if (existingHeaders.containsKey(skv.key)) {
                    logger.debug((Object)("Header " + skv.key + " already defined with value " + String.valueOf(existingHeaders.get(skv.key)) + ". Not setting other value " + skv.value));
                    continue;
                }
                builder.putHeader(skv.key, skv.value);
                logger.infoV("Added SageMaker custom header: %s=%s", new Object[]{skv.key, skv.value});
            }
            clientOverrideConfiguration = (ClientOverrideConfiguration)builder.build();
        }
        return clientOverrideConfiguration;
    }

    private InvokeEndpointResponse executeInvokeRequest(JsonObject requestBody, String endpointName) throws IOException {
        InvokeEndpointRequest.Builder requestBuilder = InvokeEndpointRequest.builder().endpointName(endpointName).contentType("application/json").body(SdkBytes.fromUtf8String((String)JSON.json((Object)requestBody)));
        return this.awsRuntimeClient.invokeEndpoint((InvokeEndpointRequest)requestBuilder.build());
    }

    private static JsonObject getBodyFromInvokeResponse(InvokeEndpointResponse response) {
        return (JsonObject)JSON.parse((String)response.body().asUtf8String(), JsonObject.class);
    }

    public LLMClient.SimpleCompletionResponse complete(SageMakerGenericLLMConnection connection, String prompt, CoreCompletionSettings ccs) throws IOException {
        GenericTextCompletionLLMMarshall marshall = GenericTextCompletionLLMMarshall.get(connection.params.sageMakerModel.handling, connection.params);
        JsonObject requestBody = marshall.prepareTextCompletionQuery(prompt, ccs);
        logger.infoV("SageMaker LLM raw completion request: %s", new Object[]{JSON.prettyLog((Object)requestBody)});
        InvokeEndpointResponse response = this.executeInvokeRequest(requestBody, connection.params.endpointName);
        JsonObject responseBody = RawSageMakerClient.getBodyFromInvokeResponse(response);
        logger.trace(() -> String.format("SageMaker LLM raw chat response: %s", JSON.prettyLog((Object)responseBody)));
        return marshall.parseTextCompletionResponse((JsonElement)responseBody, prompt);
    }

    public LLMClient.SimpleEmbeddingResponse embed(SageMakerGenericLLMConnection connection, LLMClient.EmbeddingQuery query) throws IOException, IllegalArgumentException {
        GenericEmbeddingLLMMarshall marshall = GenericEmbeddingLLMMarshall.get(connection.params.sageMakerModel.handling, connection.params);
        JsonObject requestBody = marshall.prepareInputsEmbedding(query);
        InvokeEndpointResponse response = this.executeInvokeRequest(requestBody, connection.params.endpointName);
        JsonObject responseBody = RawSageMakerClient.getBodyFromInvokeResponse(response);
        return marshall.parseEmbeddingResponse((JsonElement)responseBody);
    }
}

