/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.externalinfras.sagemaker;

import com.dataiku.dip.ProxySettings;
import com.dataiku.dip.connections.SageMakerConnection;
import com.dataiku.dip.exceptions.DKUSecurityException;
import com.dataiku.dip.externalinfras.ExternalInfrasUtils;
import com.dataiku.dip.externalinfras.sagemaker.SageMakerUtils;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.security.aws.AWSClientBrokerService;
import com.dataiku.dip.server.SpringUtils;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.http.SdkHttpClient;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.regions.Region;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.services.cloudwatch.CloudWatchClient;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.services.cloudwatch.CloudWatchClientBuilder;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.services.cloudwatch.model.Dimension;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.services.cloudwatch.model.GetMetricDataRequest;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.services.cloudwatch.model.GetMetricDataResponse;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.services.cloudwatch.model.ListMetricsRequest;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.services.cloudwatch.model.ListMetricsResponse;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.services.cloudwatch.model.MessageData;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.services.cloudwatch.model.Metric;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.services.cloudwatch.model.MetricDataResult;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.services.cloudwatch.model.RecentlyActive;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;

public class CloudWatchUtils {
    private static final String SAGEMAKER_NAMESPACE = "AWS/SageMaker";
    private static final DKULogger logger = DKULogger.getLogger((String)"dku.externalInfras.cloudWatch.utils");

    public static CloudWatchClient loginCloudWatch(AuthCtx authCtx, String sageMakerConnectionName, String region, int connectTimeout, int socketTimeout) throws DKUSecurityException, IOException {
        SageMakerConnection connection = (SageMakerConnection)ExternalInfrasUtils.getAndCheckConnection(authCtx, sageMakerConnectionName);
        logger.debugV("Trying to login into CloudWatch on region %s ...", new Object[]{region});
        ProxySettings proxySettings = ExternalInfrasUtils.getProxy(connection);
        AWSClientBrokerService awsClientBrokerService = (AWSClientBrokerService)SpringUtils.getBean(AWSClientBrokerService.class);
        SdkHttpClient.Builder httpClientBuilder = awsClientBrokerService.getHttpClientBuilder(proxySettings, connectTimeout, socketTimeout);
        CloudWatchClientBuilder builder = awsClientBrokerService.createCloudWatchClientBuilder(httpClientBuilder, Region.of((String)region));
        SageMakerUtils.addCredentialsFromConnection(builder, authCtx, connection);
        return (CloudWatchClient)builder.build();
    }

    public static List<MetricDataResult> fetchMetricDataResults(GetMetricDataRequest initialDataRequest, CloudWatchClient cloudWatchClient) throws IOException {
        ArrayList<MetricDataResult> results = new ArrayList<MetricDataResult>();
        GetMetricDataRequest.Builder dataRequestBuilder = initialDataRequest.toBuilder();
        do {
            GetMetricDataResponse response;
            if ((response = cloudWatchClient.getMetricData((GetMetricDataRequest)dataRequestBuilder.build())).sdkHttpResponse().statusCode() != 200) {
                MessageData message = !response.messages().isEmpty() ? (MessageData)response.messages().get(0) : null;
                String messageCode = message != null ? message.code() : "";
                String messageValue = message != null ? message.value() : "";
                throw new IOException(String.format("Error %d %s %s", response.sdkHttpResponse().statusCode(), messageCode, messageValue));
            }
            if (response.metricDataResults() != null && !response.metricDataResults().isEmpty()) {
                results.addAll(response.metricDataResults());
            }
            dataRequestBuilder.nextToken(response.nextToken());
            if (response.nextToken() == null) continue;
            logger.debug((Object)"More endpoint activity metrics data available from CloudWatch, fetching next page.");
        } while (((GetMetricDataRequest)dataRequestBuilder.build()).nextToken() != null);
        return results;
    }

    public static Map<String, Set<Metric>> fetchMetricDefinitions(CloudWatchClient cloudWatchClient, List<String> endpointNames, List<String> metricNames, boolean onlyRecentlyActive) throws IOException {
        List<Metric> metricDefinitions = CloudWatchUtils.fetchAllMetricDefinitions(cloudWatchClient, onlyRecentlyActive);
        HashMap<String, Set<Metric>> result = new HashMap<String, Set<Metric>>();
        endpointNames.forEach(endpointName -> {
            Set endpointMetrics = metricDefinitions.stream().filter(metricDefinition -> metricNames.contains(metricDefinition.metricName())).filter(metricDefinition -> {
                Optional<String> endpointNameValue = CloudWatchUtils.getMetricDimensionValue(metricDefinition, "EndpointName");
                return endpointNameValue.filter(endpointName::equals).isPresent();
            }).collect(Collectors.toSet());
            result.put((String)endpointName, endpointMetrics);
        });
        return result;
    }

    public static List<Metric> fetchAllMetricDefinitions(CloudWatchClient cloudWatchClient, boolean onlyRecentlyActive) throws IOException {
        ArrayList<Metric> results = new ArrayList<Metric>();
        ListMetricsRequest.Builder listMetricsRequestBuilder = ListMetricsRequest.builder().namespace(SAGEMAKER_NAMESPACE);
        if (onlyRecentlyActive) {
            listMetricsRequestBuilder.recentlyActive(RecentlyActive.PT3_H);
        }
        do {
            ListMetricsResponse response;
            int responseStatusCode;
            if ((responseStatusCode = (response = cloudWatchClient.listMetrics((ListMetricsRequest)listMetricsRequestBuilder.build())).sdkHttpResponse().statusCode()) != 200) {
                throw new IOException(String.format("Listing available CloudWatch endpoint metrics failed with code %d", responseStatusCode));
            }
            if (response.metrics() != null && !response.metrics().isEmpty()) {
                results.addAll(response.metrics());
            }
            listMetricsRequestBuilder.nextToken(response.nextToken());
            if (response.nextToken() == null) continue;
            logger.debug((Object)"More queryable endpoint metrics available from CloudWatch, fetching next page.");
        } while (((ListMetricsRequest)listMetricsRequestBuilder.build()).nextToken() != null);
        return results;
    }

    public static Optional<String> getMetricDimensionValue(Metric metric, String dimensionName) {
        return metric.dimensions().stream().filter(dimension -> dimension.name().equals(dimensionName)).map(Dimension::value).findFirst();
    }

    public static List<String> getDimensionNames(Metric metric) {
        return metric.dimensions().stream().map(Dimension::name).collect(Collectors.toList());
    }

    public static enum CloudWatchEndpointMetrics {
        INVOCATIONS("Invocations"),
        INVOCATION_MODEL_ERRORS("InvocationModelErrors"),
        MODEL_LATENCY("ModelLatency"),
        OVERHEAD_LATENCY("OverheadLatency");

        public final String metricName;

        private CloudWatchEndpointMetrics(String metricName) {
            this.metricName = metricName;
        }

        public String getMetricName() {
            return this.metricName;
        }
    }
}

