/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.apideployer.monitoring;

import com.dataiku.dip.apideployer.datamodel.config.SageMakerAPIDeploymentInfra;
import com.dataiku.dip.apideployer.deploymentinfo.SageMakerDeploymentInfo;
import com.dataiku.dip.apideployer.deployments.APIServiceDeploymentsService;
import com.dataiku.dip.apideployer.monitoring.ActivityMetric;
import com.dataiku.dip.apideployer.monitoring.ActivityMetricsFetchingService;
import com.dataiku.dip.apideployer.monitoring.ApiEndpointActivityMonitoringService;
import com.dataiku.dip.dao.UnifiedMonitoringSettingsDAO;
import com.dataiku.dip.exceptions.DKUSecurityException;
import com.dataiku.dip.externalinfras.sagemaker.CloudWatchUtils;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.server.services.TransactionService;
import com.dataiku.dip.transactions.TransactionContext;
import com.dataiku.dip.unifiedmonitoring.externalendpoint.UnifiedMonitoringExternalEndpointsScope;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.Pair;
import com.dataiku.dip.utils.Tuple3;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.services.cloudwatch.CloudWatchClient;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.services.cloudwatch.model.GetMetricDataRequest;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.services.cloudwatch.model.Metric;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.services.cloudwatch.model.MetricDataQuery;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.services.cloudwatch.model.MetricDataResult;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.services.cloudwatch.model.MetricStat;
import com.google.common.base.Preconditions;
import java.io.IOException;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;

@Service
public class SageMakerActivityMetricsFetchingService
extends ActivityMetricsFetchingService {
    private static final int CLOUDWATCH_DELAY_MINUTES = 3;
    private static final int QUERY_PERIOD_SECONDS = 60;
    private static final int MAX_METRIC_DATA_QUERIES_PER_REQUEST = 500;
    private static final int MAX_ENDPOINT_VARIANTS = 100;
    private static final List<Tuple3<CloudWatchUtils.CloudWatchEndpointMetrics, String, ActivityMetric.Type>> metricsAndAggregationsAndTypes = List.of(new Tuple3((Object)CloudWatchUtils.CloudWatchEndpointMetrics.INVOCATIONS, (Object)"SUM", (Object)ActivityMetric.Type.ALL_REQUESTS_IN_COUNT_PER_S), new Tuple3((Object)CloudWatchUtils.CloudWatchEndpointMetrics.INVOCATION_MODEL_ERRORS, (Object)"SUM", (Object)ActivityMetric.Type.ERROR_REQUESTS_IN_COUNT_PER_S), new Tuple3((Object)CloudWatchUtils.CloudWatchEndpointMetrics.MODEL_LATENCY, (Object)"AVG", (Object)ActivityMetric.Type.AVG_PROCESSING_TIME_IN_MS_PER_REQUEST), new Tuple3((Object)CloudWatchUtils.CloudWatchEndpointMetrics.OVERHEAD_LATENCY, (Object)"AVG", (Object)ActivityMetric.Type.AVG_PROCESSING_TIME_IN_MS_PER_REQUEST));
    private static final DKULogger logger = DKULogger.getLogger((String)"dku.deployer.monitoring.sageMakerActivityMetricsFetchingService");

    @Autowired
    public SageMakerActivityMetricsFetchingService(TransactionService transactionService, UnifiedMonitoringSettingsDAO unifiedMonitoringSettingsDAO, APIServiceDeploymentsService apiServiceDeploymentsService) {
        super(transactionService, unifiedMonitoringSettingsDAO, apiServiceDeploymentsService);
    }

    public ActivityMetric.TimeAndMetricsByDeploymentAndEndpoint getInfraActivityMetrics_NT(AuthCtx authCtx, SageMakerAPIDeploymentInfra infra, int connectTimeout, int socketTimeout, String overridingConnectionName) {
        TransactionContext.assertNoAttachedTransaction();
        String usedConnection = overridingConnectionName != null ? overridingConnectionName : infra.authConnection;
        String displayedUsedConnection = overridingConnectionName != null ? overridingConnectionName : infra.getDisplayedConnectionName();
        logger.infoV("Getting activity metrics from CloudWatch for infrastructure %s using connection %s", new Object[]{infra.id, displayedUsedConnection});
        try {
            List deploymentAndEndpoints = this.getDeploymentAndEndpoints(infra.id, authCtx);
            CloudWatchClient cloudWatchClient = CloudWatchUtils.loginCloudWatch(authCtx, usedConnection, infra.awsRegion, connectTimeout, socketTimeout);
            logger.infoV("Will fetch activity metrics from CloudWatch for at least %d dss endpoints for infrastructure %s using connection %s", new Object[]{deploymentAndEndpoints.size(), infra.id, displayedUsedConnection});
            Pair<Instant, Instant> queryRange = this.getMetricsRequestRange(3);
            ActivityMetric.TimeAndMetricsByDeploymentAndEndpoint timeAndMetricsByDeploymentAndEndpoint = new ActivityMetric.TimeAndMetricsByDeploymentAndEndpoint();
            for (Tuple3<CloudWatchUtils.CloudWatchEndpointMetrics, String, ActivityMetric.Type> metric : metricsAndAggregationsAndTypes) {
                String metricName = ((CloudWatchUtils.CloudWatchEndpointMetrics)((Object)metric._1)).getMetricName();
                String aggregation = (String)metric._2;
                ActivityMetric.Type type = (ActivityMetric.Type)metric._3;
                List<MetricDataResult> invocationsMetricData = this.fetchMetricData(metricName, aggregation, cloudWatchClient, (Instant)queryRange.first, (Instant)queryRange.second);
                SageMakerActivityMetricsFetchingService.fillDssEndpointTimeAndMetrics(timeAndMetricsByDeploymentAndEndpoint, deploymentAndEndpoints, metricName, invocationsMetricData, type);
            }
            String p95LatencyMetric = "Latency";
            List<String> endpointNames = deploymentAndEndpoints.stream().map(de -> ((SageMakerDeploymentInfo)de.deploymentInfo).sageMakerEndpointName).collect(Collectors.toList());
            List<MetricDataResult> p95LatencyData = this.fetchP95LatencyMetricData(cloudWatchClient, endpointNames, (Instant)queryRange.first, (Instant)queryRange.second);
            SageMakerActivityMetricsFetchingService.fillDssEndpointTimeAndMetrics(timeAndMetricsByDeploymentAndEndpoint, deploymentAndEndpoints, p95LatencyMetric, p95LatencyData, ActivityMetric.Type.P95_PROCESSING_TIME_IN_MS_PER_REQUEST);
            Pair<Long, Long> stepAlignedMetricsRange = this.getStepAlignedMetricsRange((Instant)queryRange.first, (Instant)queryRange.second);
            SageMakerActivityMetricsFetchingService.fillSparseActivityMetrics(timeAndMetricsByDeploymentAndEndpoint, deploymentAndEndpoints, (long)((Long)stepAlignedMetricsRange.first), (long)((Long)stepAlignedMetricsRange.second));
            ActivityMetric.TimeAndMetricsByDeploymentAndEndpoint withTotalMetrics = ApiEndpointActivityMonitoringService.mapAverageToTotal(timeAndMetricsByDeploymentAndEndpoint);
            return withTotalMetrics;
        }
        catch (DKUSecurityException | IOException e) {
            logger.errorV(e, "Failed to fetch activity metrics from CloudWatch for infrastructure %s using connection %s", new Object[]{infra.id, displayedUsedConnection});
            return new ActivityMetric.TimeAndMetricsByDeploymentAndEndpoint.None();
        }
    }

    public ActivityMetric.TimeAndMetricsByEndpoint getScopeActivityMetrics_NT(AuthCtx authCtx, UnifiedMonitoringExternalEndpointsScope.Sagemaker externalEndpointsScope, int connectTimeout, int socketTimeout) {
        TransactionContext.assertNoAttachedTransaction();
        logger.infoV("Getting activity metrics from CloudWatch for scope %s using connection %s", new Object[]{externalEndpointsScope.name, externalEndpointsScope.connectionName});
        try {
            CloudWatchClient cloudWatchClient = CloudWatchUtils.loginCloudWatch(authCtx, externalEndpointsScope.connectionName, externalEndpointsScope.region, connectTimeout, socketTimeout);
            List<String> endpointNames = externalEndpointsScope.listEndpoints(authCtx).stream().map(monitoring -> monitoring.endpointName).collect(Collectors.toList());
            logger.infoV("Will fetch activity metrics from CloudWatch for at least %d external endpoints for scope %s using connection %s", new Object[]{endpointNames.size(), externalEndpointsScope.name, externalEndpointsScope.connectionName});
            Pair<Instant, Instant> queryRange = this.getMetricsRequestRange(3);
            ActivityMetric.TimeAndMetricsByEndpoint timeAndMetricsByEndpoint = new ActivityMetric.TimeAndMetricsByEndpoint();
            for (Tuple3<CloudWatchUtils.CloudWatchEndpointMetrics, String, ActivityMetric.Type> metric : metricsAndAggregationsAndTypes) {
                String metricName = ((CloudWatchUtils.CloudWatchEndpointMetrics)((Object)metric._1)).getMetricName();
                String aggregation = (String)metric._2;
                ActivityMetric.Type type = (ActivityMetric.Type)metric._3;
                List<MetricDataResult> invocationsMetricData = this.fetchMetricData(metricName, aggregation, cloudWatchClient, (Instant)queryRange.first, (Instant)queryRange.second);
                this.fillExternalEndpointTimeAndMetrics(timeAndMetricsByEndpoint, endpointNames, metricName, invocationsMetricData, type);
            }
            String p95LatencyMetric = "Latency";
            List<MetricDataResult> latencyData = this.fetchP95LatencyMetricData(cloudWatchClient, endpointNames, (Instant)queryRange.first, (Instant)queryRange.second);
            this.fillExternalEndpointTimeAndMetrics(timeAndMetricsByEndpoint, endpointNames, p95LatencyMetric, latencyData, ActivityMetric.Type.P95_PROCESSING_TIME_IN_MS_PER_REQUEST);
            Pair<Long, Long> stepAlignedMetricsRange = this.getStepAlignedMetricsRange((Instant)queryRange.first, (Instant)queryRange.second);
            SageMakerActivityMetricsFetchingService.fillSparseActivityMetrics(timeAndMetricsByEndpoint, endpointNames, (long)((Long)stepAlignedMetricsRange.first), (long)((Long)stepAlignedMetricsRange.second));
            ActivityMetric.TimeAndMetricsByEndpoint withTotalMetrics = ApiEndpointActivityMonitoringService.mapAverageToTotal(timeAndMetricsByEndpoint);
            return withTotalMetrics;
        }
        catch (Exception e) {
            logger.errorV((Throwable)e, "Failed to get activity metrics from CloudWatch for scope %s using connection %s", new Object[]{externalEndpointsScope.name, externalEndpointsScope.connectionName});
            return new ActivityMetric.TimeAndMetricsByEndpoint.None();
        }
    }

    private static void fillDssEndpointTimeAndMetrics(ActivityMetric.TimeAndMetricsByDeploymentAndEndpoint timeAndMetricsByDeploymentAndEndpoint, List<ActivityMetricsFetchingService.DeploymentAndEndpoint<SageMakerDeploymentInfo>> deploymentAndEndpoints, String metricName, List<MetricDataResult> metricResults, ActivityMetric.Type metricType) {
        int countDssEndpoints = 0;
        int countExternalEndpoints = 0;
        for (MetricDataResult resultData : metricResults) {
            logger.infoV(String.format("Processing fetched '%s' activity metric data", metricName), new Object[0]);
            String remoteSageMakerEndpointName = resultData.label();
            ActivityMetricsFetchingService.DeploymentAndEndpoint deploymentAndEndpoint = deploymentAndEndpoints.stream().filter(de -> ((SageMakerDeploymentInfo)de.deploymentInfo).sageMakerEndpointName.equals(remoteSageMakerEndpointName)).findFirst().orElse(null);
            if (deploymentAndEndpoint != null) {
                ++countDssEndpoints;
                ActivityMetric.TimeAndMetricsForDeployment timeAndMetricsForDeployment = timeAndMetricsByDeploymentAndEndpoint.computeIfAbsent(deploymentAndEndpoint.deploymentId, ActivityMetric.TimeAndMetricsForDeployment::new);
                ActivityMetric.TimeAndMetricsForEndpoint timeAndMetricsForEndpoint = timeAndMetricsForDeployment.timeAndMetricsForEndpoint.computeIfAbsent(deploymentAndEndpoint.endpointId, ActivityMetric.TimeAndMetricsForEndpoint::new);
                SageMakerActivityMetricsFetchingService.fillTimeAndMetrics(resultData, timeAndMetricsForEndpoint.timeAndMetricsOrderedByTime, metricType);
                logger.debugV("Processed fetched '%s' activity metric data for endpoint %s on deployment %s", new Object[]{metricName, deploymentAndEndpoint.endpointId, deploymentAndEndpoint.deploymentId});
                continue;
            }
            ++countExternalEndpoints;
        }
        logger.infoV("Processed fetched '%s' activity metric for %d DSS endpoints. %d external endpoints were detected and skipped.", new Object[]{metricName, countDssEndpoints, countExternalEndpoints});
    }

    private void fillExternalEndpointTimeAndMetrics(ActivityMetric.TimeAndMetricsByEndpoint timeAndMetricsByEndpoint, List<String> endpointNames, String metricName, List<MetricDataResult> metricResults, ActivityMetric.Type metricType) {
        int countExternalEndpoints = 0;
        for (MetricDataResult resultData : metricResults) {
            logger.infoV(String.format("Processing fetched '%s' activity metric data", metricName), new Object[0]);
            String remoteSageMakerEndpointName = resultData.label();
            if (!endpointNames.contains(remoteSageMakerEndpointName)) {
                logger.debugV("Could not find endpoint '%s' in the endpoint listing", new Object[]{remoteSageMakerEndpointName});
                continue;
            }
            ++countExternalEndpoints;
            ActivityMetric.TimeAndMetricsForEndpoint timeAndMetricsForEndpoint = timeAndMetricsByEndpoint.computeIfAbsent(remoteSageMakerEndpointName, ActivityMetric.TimeAndMetricsForEndpoint::new);
            SageMakerActivityMetricsFetchingService.fillTimeAndMetrics(resultData, timeAndMetricsForEndpoint.timeAndMetricsOrderedByTime, metricType);
        }
        logger.infoV("Processed fetched '%s' activity metric data for %d external endpoints.", new Object[]{metricName, countExternalEndpoints});
    }

    private List<MetricDataResult> fetchMetricData(String metricName, String aggregation, CloudWatchClient cloudWatchClient, Instant start, Instant end) throws IOException {
        GetMetricDataRequest dataRequest = this.buildMetricInsightsDataRequest(metricName, aggregation, start, end);
        logger.infoV("Querying CloudWatch Metric Insights for activity metric '%s'.", new Object[]{metricName});
        return CloudWatchUtils.fetchMetricDataResults(dataRequest, cloudWatchClient);
    }

    private List<MetricDataResult> fetchP95LatencyMetricData(CloudWatchClient cloudWatchClient, List<String> endpointNames, Instant start, Instant end) throws IOException {
        List<GetMetricDataRequest> latencyDataRequests = this.buildLatencyMetricRequests(cloudWatchClient, endpointNames, "p95", start, end);
        logger.infoV("Querying CloudWatch for activity metrics '%s' and '%s'.", new Object[]{CloudWatchUtils.CloudWatchEndpointMetrics.MODEL_LATENCY.getMetricName(), CloudWatchUtils.CloudWatchEndpointMetrics.OVERHEAD_LATENCY.getMetricName()});
        ArrayList<MetricDataResult> results = new ArrayList<MetricDataResult>();
        for (GetMetricDataRequest latencyRequestBatch : latencyDataRequests) {
            results.addAll(CloudWatchUtils.fetchMetricDataResults(latencyRequestBatch, cloudWatchClient));
        }
        return results;
    }

    private static void fillTimeAndMetrics(MetricDataResult data, Map<Long, ActivityMetric.TimeAndMetrics> timeAndMetricsOrderedByTime, ActivityMetric.Type metricType) {
        List values = data.values();
        List timestamps = data.timestamps();
        for (int i = 0; i < timestamps.size(); ++i) {
            ActivityMetric.TimeAndMetrics metrics = timeAndMetricsOrderedByTime.computeIfAbsent(((Instant)timestamps.get(i)).getEpochSecond(), ActivityMetric.TimeAndMetrics::new);
            ActivityMetric metric = metrics.activityMetricsByType.computeIfAbsent(metricType, t -> new ActivityMetric((ActivityMetric.Type)t, 0.0));
            double value = (Double)values.get(i);
            value = metricType != ActivityMetric.Type.AVG_PROCESSING_TIME_IN_MS_PER_REQUEST ? (value /= 60.0) : (value /= 1000.0);
            metric.setValue(metric.getValue() + value);
        }
    }

    private GetMetricDataRequest buildMetricInsightsDataRequest(String metricName, String aggregation, Instant start, Instant end) {
        String query = "SELECT " + aggregation + "(" + metricName + ")\nFROM SCHEMA(\"AWS/SageMaker\", EndpointName,VariantName) \nGROUP BY EndpointName";
        MetricDataQuery dataQuery = (MetricDataQuery)MetricDataQuery.builder().id("dku_" + metricName).returnData(Boolean.valueOf(true)).period(Integer.valueOf(60)).expression(query).build();
        return (GetMetricDataRequest)GetMetricDataRequest.builder().metricDataQueries(new MetricDataQuery[]{dataQuery}).startTime(start).endTime(end).build();
    }

    private List<GetMetricDataRequest> buildLatencyMetricRequests(CloudWatchClient cloudWatchClient, List<String> endpointNames, String aggregation, Instant start, Instant end) throws IOException {
        Map<String, Set<Metric>> endpointMetricDefinitions = SageMakerActivityMetricsFetchingService.fetchLatencyMetricDefinitions(cloudWatchClient, endpointNames);
        ArrayList<GetMetricDataRequest> requests = new ArrayList<GetMetricDataRequest>();
        ArrayList<MetricDataQuery> metricDataQueries = new ArrayList<MetricDataQuery>();
        for (Map.Entry<String, Set<Metric>> entry : endpointMetricDefinitions.entrySet()) {
            String endpointName = entry.getKey();
            Set<Metric> endpointVariantMetrics = entry.getValue();
            try {
                if (endpointVariantMetrics.size() / 2 > 100) {
                    logger.warnV(String.format("Due to CloudWatch limitations, we cannot compute activity metrics for endpoints with more than %d variants", 100), new Object[0]);
                    continue;
                }
                if (metricDataQueries.size() + endpointVariantMetrics.size() + 1 > 500) {
                    requests.add((GetMetricDataRequest)GetMetricDataRequest.builder().metricDataQueries(metricDataQueries).startTime(start).endTime(end).build());
                    metricDataQueries.clear();
                }
                List<MetricDataQuery> endpointDataQueries = this.buildEndpointMetricDataQueries(endpointName, endpointVariantMetrics, aggregation, metricDataQueries.size());
                metricDataQueries.addAll(endpointDataQueries);
            }
            catch (Exception e) {
                logger.warnV((Throwable)e, "Failed to build latency metric data query for endpoint %s", new Object[]{endpointName});
            }
        }
        if (!metricDataQueries.isEmpty()) {
            requests.add((GetMetricDataRequest)GetMetricDataRequest.builder().metricDataQueries(metricDataQueries).startTime(start).endTime(end).build());
        }
        return requests;
    }

    private List<MetricDataQuery> buildEndpointMetricDataQueries(String endpointName, Collection<Metric> endpointVariantMetrics, String aggregation, int startingCounter) {
        Preconditions.checkArgument((!aggregation.equals("Average") ? 1 : 0) != 0, (Object)"buildEndpointMetricDataQueries() should not be used with Average aggregation, use buildMetricInsightsGetMetricDataRequest() instead");
        ArrayList<MetricDataQuery> metricDataQueries = new ArrayList<MetricDataQuery>();
        for (Metric metric : endpointVariantMetrics) {
            MetricStat metricStat = (MetricStat)MetricStat.builder().stat(aggregation).period(Integer.valueOf(60)).metric(metric).build();
            metricDataQueries.add((MetricDataQuery)MetricDataQuery.builder().id(String.format("i%d", startingCounter++)).metricStat(metricStat).returnData(Boolean.valueOf(false)).build());
        }
        if (!metricDataQueries.isEmpty()) {
            List queryIds = metricDataQueries.stream().map(MetricDataQuery::id).collect(Collectors.toList());
            String listExpression = String.format("[%s]", String.join((CharSequence)",", queryIds));
            metricDataQueries.add((MetricDataQuery)MetricDataQuery.builder().id(String.format("%s__latency__%s", endpointName.replace("-", "_"), aggregation)).expression(String.format("SUM(%s) / DATAPOINT_COUNT(%s) * 2", listExpression, listExpression)).returnData(Boolean.valueOf(true)).label(endpointName).build());
        }
        return metricDataQueries;
    }

    private static Map<String, Set<Metric>> fetchLatencyMetricDefinitions(CloudWatchClient cloudWatchClient, List<String> endpointNames) throws IOException {
        Map<String, Set<Metric>> endpointToMetricDefinitions = CloudWatchUtils.fetchMetricDefinitions(cloudWatchClient, endpointNames, Arrays.asList(CloudWatchUtils.CloudWatchEndpointMetrics.MODEL_LATENCY.getMetricName(), CloudWatchUtils.CloudWatchEndpointMetrics.OVERHEAD_LATENCY.getMetricName()), true);
        endpointToMetricDefinitions.values().removeIf(Collection::isEmpty);
        for (Set<Metric> definitions : endpointToMetricDefinitions.values()) {
            definitions.removeIf(def -> CloudWatchUtils.getDimensionNames(def).contains("EndpointConfigName"));
        }
        return endpointToMetricDefinitions;
    }

    protected Pair<Long, Long> getStepAlignedMetricsRange(Instant start, Instant end) {
        long firstDataTimestampSeconds = start.getEpochSecond() - start.getEpochSecond() % 60L;
        long lastDataTimestampSeconds = end.getEpochSecond() - end.getEpochSecond() % 60L - 60L;
        return new Pair((Object)firstDataTimestampSeconds, (Object)lastDataTimestampSeconds);
    }
}

