/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.resourceusage;

import com.dataiku.common.stereotype.PartOfPublicAPI;
import com.dataiku.common.stereotype.RoutinelyUsedInExtensionCode;
import com.dataiku.dip.resourceusage.ComputeResourceUsageContext;
import com.dataiku.dip.resourceusage.ComputeResourceUsageReportingService;
import com.dataiku.dip.resourceusage.CurrentComputeResourceUsageContext;
import com.dataiku.dip.resourceusage.k8s.IK8SContainerLimits;
import com.dataiku.dip.server.SpringUtils;
import com.dataiku.dip.util.SecretKeyGenerator;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.JSON;
import com.google.gson.JsonElement;
import com.google.gson.JsonObject;
import com.google.gson.JsonPrimitive;

@PartOfPublicAPI
public class ComputeResourceUsage
implements AutoCloseable {
    public ComputeResourceUsageContext context;
    public ComputeResourceUsageType type;
    public final String id = SecretKeyGenerator.generate(16);
    public Long startTime = System.currentTimeMillis();
    public Long endTime;
    public Long totalTime;
    public LocalProcessResourceUsageData localProcess;
    public SingleK8SJobResourceUsageData singleK8SJob;
    public SparkK8SJobResourceUsageData sparkK8SJob;
    public SQLConnectionUsageData sqlConnection;
    public SQLQueryUsageData sqlQuery;
    public LLMUsageData llmUsage;
    private final transient boolean doNotReport;
    private UpdateChecking checking;
    private JsonObject allocationTags;
    private static final DKULogger logger = DKULogger.getLogger((String)"dku.resourceusage");

    public ComputeResourceUsage() {
        this(false);
    }

    public ComputeResourceUsage(boolean doNotReport) {
        this.context = CurrentComputeResourceUsageContext.get();
        this.doNotReport = doNotReport;
    }

    public ComputeResourceUsage reportStartNoFail() {
        if (this.doNotReport) {
            return this;
        }
        if (logger.isDebugEnabled()) {
            logger.debug((Object)("Reporting start of CRU:" + String.valueOf(this)));
        }
        try {
            ComputeResourceUsageReportingService cruReporting = SpringUtils.getBean(ComputeResourceUsageReportingService.class);
            cruReporting.reportStart(this);
        }
        catch (Exception e) {
            logger.warn((Object)("Failed to report compute resource usage start: " + String.valueOf(this)), (Throwable)e);
        }
        return this;
    }

    public void setUpdateCheckingIfMissing(UpdateChecking checking) {
        if (this.checking == null) {
            this.checking = checking;
        }
    }

    public void reportUpdateNoFailIfNeeded() {
        if (this.checking == null) {
            logger.error((Object)"reportUpdateNoFailIfNeeded has been called but no checking interval defined.");
            return;
        }
        long now = System.nanoTime();
        if (now - this.checking.lastUpdateInNanoSeconds > this.checking.intervalInNanoSeconds) {
            this.checking.lastUpdateInNanoSeconds = now;
            this.reportUpdateNoFail();
        }
    }

    public void updateInsertedRowCountAndCheck(long rowCount) {
        this.sqlQuery.insertedRowCount = rowCount;
        this.reportUpdateNoFailIfNeeded();
    }

    public void reportCompleteWithRowCount(long rowCount) {
        this.sqlQuery.insertedRowCount = rowCount;
        this.reportCompleteNoFail();
    }

    public void reportUpdateNoFail() {
        if (this.doNotReport) {
            return;
        }
        try {
            ComputeResourceUsageReportingService cruReporting = SpringUtils.getBean(ComputeResourceUsageReportingService.class);
            cruReporting.reportUpdate(this);
        }
        catch (Exception e) {
            logger.warn((Object)("Failed to report compute resource usage update: " + String.valueOf(this)), (Throwable)e);
        }
    }

    public void markComplete() {
        this.endTime = System.currentTimeMillis();
        if (this.startTime != null) {
            this.totalTime = this.endTime - this.startTime;
        }
    }

    @Override
    public void close() {
        this.reportCompleteNoFail();
    }

    public void reportCompleteNoFail() {
        if (this.doNotReport) {
            return;
        }
        if (logger.isDebugEnabled()) {
            logger.debug((Object)("Reporting completion of CRU:" + String.valueOf(this)));
        }
        this.checking = null;
        try {
            this.markComplete();
            ComputeResourceUsageReportingService cruReporting = SpringUtils.getBean(ComputeResourceUsageReportingService.class);
            cruReporting.reportComplete(this);
        }
        catch (Exception e) {
            logger.warn((Object)("Failed to report compute resource usage completion: " + String.valueOf(this)), (Throwable)e);
        }
    }

    public void setupLocalProcess() {
        this.type = ComputeResourceUsageType.LOCAL_PROCESS;
        this.localProcess = new LocalProcessResourceUsageData();
    }

    public void setupSparkK8SJob(String k8sClusterId, String executionId, String sparkConfigName) {
        this.type = ComputeResourceUsageType.SPARK_K8S_JOB;
        this.sparkK8SJob = new SparkK8SJobResourceUsageData();
        this.sparkK8SJob.k8sClusterId = k8sClusterId;
        this.sparkK8SJob.executionId = executionId;
        this.sparkK8SJob.sparkConfigName = sparkConfigName;
    }

    public void setupSingleK8SJob(String k8sClusterId, String executionId, Double cpuRequest, Double cpuLimit, Integer memoryRequestMB, Integer memoryLimitMB, String containerConfigName) {
        this.type = ComputeResourceUsageType.SINGLE_K8S_JOB;
        this.singleK8SJob = new SingleK8SJobResourceUsageData();
        this.singleK8SJob.k8sClusterId = k8sClusterId;
        this.singleK8SJob.executionId = executionId;
        this.singleK8SJob.cpuRequest = cpuRequest;
        this.singleK8SJob.cpuLimit = cpuLimit;
        this.singleK8SJob.memoryRequestMB = memoryRequestMB;
        this.singleK8SJob.memoryLimitMB = memoryLimitMB;
        this.singleK8SJob.containerConfigName = containerConfigName;
    }

    public void setupSQLConnection(String connectionName, String sparkSQLConnectionUsageId) {
        this.type = ComputeResourceUsageType.SQL_CONNECTION;
        this.sqlConnection = new SQLConnectionUsageData();
        this.sqlConnection.connection = connectionName;
        this.sqlConnection.sparkSQLConnectionUsageId = sparkSQLConnectionUsageId;
    }

    public void setupSQLQuery(String connectionUsageId, String connectionName, String sparkSQLConnectionUsageId, String query) {
        this.type = ComputeResourceUsageType.SQL_QUERY;
        this.sqlQuery = new SQLQueryUsageData();
        this.sqlQuery.connectionUsageId = connectionUsageId;
        this.sqlQuery.connection = connectionName;
        this.sqlQuery.sparkSQLConnectionUsageId = sparkSQLConnectionUsageId;
        this.sqlQuery.query = query;
    }

    @Deprecated(forRemoval=true)
    public void setupLLMUsage(LLMUsageType usageType, String connectionName, String llmType) {
        this.setupLLMUsage(usageType, connectionName, llmType, null);
    }

    public void setupLLMUsage(LLMUsageType usageType, String connectionName, String llmType, String llmId) {
        this.type = ComputeResourceUsageType.LLM_USAGE;
        this.llmUsage = new LLMUsageData();
        this.llmUsage.usageType = usageType;
        this.llmUsage.connection = connectionName;
        this.llmUsage.llmType = llmType;
        this.llmUsage.llmId = llmId;
    }

    public ComputeResourceUsage withAllocationTag(String key, JsonElement value) {
        if (this.allocationTags == null) {
            this.allocationTags = new JsonObject();
        }
        this.allocationTags.add(key, value);
        return this;
    }

    public ComputeResourceUsage withAllocationTag(String key, String value) {
        return this.withAllocationTag(key, (JsonElement)new JsonPrimitive(value));
    }

    public JsonObject getAllocationTags() {
        return this.allocationTags;
    }

    public static ComputeResourceUsage forLocalProcess() {
        ComputeResourceUsage cru = new ComputeResourceUsage();
        cru.setupLocalProcess();
        return cru;
    }

    public static ComputeResourceUsage forSingleK8SJob(String k8sClusterId, String executionId, IK8SContainerLimits limits, String containerConfigName) {
        ComputeResourceUsage cru = new ComputeResourceUsage();
        cru.setupSingleK8SJob(k8sClusterId, executionId, limits.getCpuRequest(), limits.getCpuLimit(), limits.getMemRequestMB(), limits.getMemLimitMB(), containerConfigName);
        return cru;
    }

    public String toString() {
        try {
            return JSON.json((Object)this);
        }
        catch (Exception e) {
            logger.errorV((Throwable)e, "Could not serialize compute resource usage to JSON format", new Object[0]);
            return "<compute resource usage serialization failed>";
        }
    }

    public static class UpdateChecking {
        long lastUpdateInNanoSeconds = System.nanoTime();
        long intervalInNanoSeconds;

        public UpdateChecking(long intervalInMilliSeconds) {
            this.intervalInNanoSeconds = intervalInMilliSeconds * 1000000L;
        }
    }

    @PartOfPublicAPI
    public static class SQLQueryUsageData {
        String connectionUsageId;
        String connection;
        String query;
        String sparkSQLConnectionUsageId;
        public Long fetchedRowCount;
        Long insertedRowCount;
        public Long statementExecutionTime;
    }

    @PartOfPublicAPI
    public static enum ComputeResourceUsageType {
        SINGLE_K8S_JOB,
        SPARK_K8S_JOB,
        LOCAL_PROCESS,
        SQL_CONNECTION,
        SQL_QUERY,
        LLM_USAGE,
        CONTAINERIZED_PROCESS;

    }

    @PartOfPublicAPI
    public static class LocalProcessResourceUsageData {
        Integer pid;
        String commandName;
        Long cpuUserTimeMS;
        Long cpuSystemTimeMS;
        Long cpuChildrenUserTimeMS;
        Long cpuChildrenSystemTimeMS;
        Long cpuTotalMS;
        double cpuCurrent;
        double cpuAverageOverPast60Seconds;
        Long vmSizeMB;
        Long vmRSSMB;
        Long vmHWMMB;
        Long vmRSSAnonMB;
        Long vmDataMB;
        Long vmSizePeakMB;
        Long vmRSSPeakMB;
        Long vmRSSTotalMBS;
        Long majorFaults;
        Long childrenMajorFaults;
    }

    @PartOfPublicAPI
    public static class SparkK8SJobResourceUsageData {
        String k8sClusterId;
        String executionId;
        String sparkConfigName;
    }

    @PartOfPublicAPI
    public static class SingleK8SJobResourceUsageData {
        String k8sClusterId;
        String executionId;
        Double cpuRequest;
        Double cpuLimit;
        Integer memoryRequestMB;
        Integer memoryLimitMB;
        String containerConfigName;
    }

    @PartOfPublicAPI
    public static class SQLConnectionUsageData {
        String connection;
        String sparkSQLConnectionUsageId;
    }

    @PartOfPublicAPI
    public static enum LLMUsageType {
        COMPLETION,
        TEXT_EMBEDDING_EXTRACTION,
        IMAGE_EMBEDDING_EXTRACTION,
        MULTIMODAL_EMBEDDING_EXTRACTION,
        IMAGE_GENERATION,
        RERANKING;

    }

    @RoutinelyUsedInExtensionCode
    @PartOfPublicAPI
    public static class LLMUsageData
    extends InternalLLMUsageData {
        public String connection;
        public String llmId;
        public String llmType;
        public LLMUsageType usageType;
        public long totalQueries;
        public long cacheHitQueries;
        public long cacheMissQueries;

        public synchronized void setFromInternal(InternalLLMUsageData i) {
            this.totalPromptTokens = i.getTotalPromptTokens();
            this.totalCompletionTokens = i.getTotalCompletionTokens();
            this.totalComputationTimeMS = i.getTotalComputationTimeMS();
            this.estimatedCostUSD = i.getEstimatedCostUSD();
        }

        public synchronized void incrementTotalQueries() {
            ++this.totalQueries;
        }

        public synchronized void incrementCacheHitQueries() {
            ++this.cacheHitQueries;
        }

        public synchronized void incrementCacheMissQueries() {
            ++this.cacheMissQueries;
        }

        public synchronized long getTotalQueries() {
            return this.totalQueries;
        }

        public synchronized long getCacheHitQueries() {
            return this.cacheHitQueries;
        }

        public synchronized long getCacheMissQueries() {
            return this.cacheMissQueries;
        }
    }

    @PartOfPublicAPI
    public static class InternalLLMUsageData {
        public long totalPromptTokens;
        public long totalCompletionTokens;
        public long totalComputationTimeMS;
        public double estimatedCostUSD;

        public synchronized void incrementTotalComputationTimeMS(Long computationTimeMS) {
            if (computationTimeMS != null) {
                this.totalComputationTimeMS += computationTimeMS.longValue();
            }
        }

        public synchronized void incrementTotalPromptTokens(Integer promptTokens) {
            if (promptTokens != null) {
                this.totalPromptTokens += (long)promptTokens.intValue();
            }
        }

        public synchronized void incrementTotalCompletionTokens(Integer completionTokens) {
            if (completionTokens != null) {
                this.totalCompletionTokens += (long)completionTokens.intValue();
            }
        }

        public synchronized void incrementEstimatedCostUSD(Double costUSD) {
            if (costUSD != null) {
                this.estimatedCostUSD += costUSD.doubleValue();
            }
        }

        public synchronized long getTotalPromptTokens() {
            return this.totalPromptTokens;
        }

        public synchronized long getTotalCompletionTokens() {
            return this.totalCompletionTokens;
        }

        public synchronized long getTotalComputationTimeMS() {
            return this.totalComputationTimeMS;
        }

        public synchronized double getEstimatedCostUSD() {
            return this.estimatedCostUSD;
        }
    }
}

