/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.llm.online;

import com.dataiku.dip.DKUApp;
import com.dataiku.dip.dao.GeneralSettingsDAO;
import com.dataiku.dip.exceptions.DKUSecurityException;
import com.dataiku.dip.llm.online.LLMCostLimitingQuotasRepository;
import com.dataiku.dip.llm.online.LLMCostLimitingService;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.transactions.fs.RelFile;
import com.dataiku.dip.transactions.fs.ifaces.ReadWriteFS;
import com.dataiku.dip.transactions.fs.utils.NativeCache;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dss.shadelib.com.google.common.base.Preconditions;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
import java.io.File;
import java.io.IOException;
import java.time.LocalDateTime;
import java.time.format.DateTimeFormatter;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;
import java.util.stream.Collectors;
import javax.annotation.PostConstruct;
import javax.annotation.PreDestroy;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;

@Service
public class LLMCostLimitingCountersRepository {
    @Autowired
    private LLMCostLimitingQuotasRepository quotasRepository;
    private static final RelFile DATA_FILE = new RelFile(RelFile.global((String)"cost-limiting"), new String[]{"counters.json"});
    private static final int DEFAULT_FLUSH_INTERVAL_SEC = 60;
    private static final int EXTRA_BUCKET_RETENTION = 1;
    private ReadWriteFS cache;
    private ScheduledFuture<?> flushExecutor;
    private CostLimitingCountersData data = null;
    private static final DKULogger logger = DKULogger.getLogger((String)"dku.llm.costlimiting");

    @PostConstruct
    public synchronized void init() throws IOException {
        if (!LLMCostLimitingService.isFeatureEnabled() || this.isInitialized()) {
            return;
        }
        logger.info((Object)"Initializing cost limiting counters repository");
        this.cache = NativeCache.build((File)DKUApp.getBaseFolderF());
        try {
            this.data = (CostLimitingCountersData)this.cache.readObjectDefault(DATA_FILE, CostLimitingCountersData.class);
        }
        catch (Exception e) {
            logger.error((Object)"Error while loading counters data", (Throwable)e);
            this.data = new CostLimitingCountersData();
        }
        long flushIntervalSec = DKUApp.getParams().getIntParam("dku.llm.costLimiting.counters.flushIntervalSeconds", Integer.valueOf(60));
        this.flushExecutor = Executors.newSingleThreadScheduledExecutor(new ThreadFactoryBuilder().setNameFormat("llm-costlimiting-counters-flusher-%d").build()).scheduleAtFixedRate(this::flushData, flushIntervalSec, flushIntervalSec, TimeUnit.SECONDS);
    }

    private void initIfNeeded() {
        if (!this.isInitialized()) {
            try {
                this.init();
            }
            catch (IOException e) {
                throw new RuntimeException(e);
            }
        }
    }

    private synchronized boolean isInitialized() {
        return this.data != null;
    }

    @PreDestroy
    public void destroy() {
        if (this.isInitialized()) {
            this.flushExecutor.cancel(false);
            this.flushData();
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void flushData() {
        assert (this.isInitialized());
        logger.debug((Object)"Flushing LLM counters data");
        List<GeneralSettingsDAO.LLMCostLimitingQuota> quotas = this.quotasRepository.getAllQuotas();
        LLMCostLimitingCountersRepository lLMCostLimitingCountersRepository = this;
        synchronized (lLMCostLimitingCountersRepository) {
            try {
                LLMCostLimitingCountersRepository.pruneDeprecatedCounters(quotas, this.data, LocalDateTime.now());
            }
            catch (Exception e) {
                logger.error((Object)"Error pruning deprecated counters", (Throwable)e);
            }
            try {
                this.cache.writeObject(DATA_FILE, (Object)this.data);
            }
            catch (IOException e) {
                logger.warn((Object)"Failed to flush LLM counters data", (Throwable)e);
            }
        }
    }

    static void pruneDeprecatedCounters(List<GeneralSettingsDAO.LLMCostLimitingQuota> upToDateQuotas, CostLimitingCountersData data, LocalDateTime referenceDate) {
        Set validQuotaIds = upToDateQuotas.stream().map(l -> l.getId()).collect(Collectors.toSet());
        HashSet<String> foundQuotaIds = new HashSet<String>(data.counters.keySet());
        foundQuotaIds.forEach(foundQuotaId -> {
            if (!validQuotaIds.contains(foundQuotaId)) {
                data.counters.remove(foundQuotaId);
            } else {
                BucketCounter bucketCounter = data.counters.get(foundQuotaId);
                Set<String> expectedBucketIds = bucketCounter.type.generateAggregationIds(referenceDate, bucketCounter.rollingPeriods + 1);
                HashSet<String> existingPeriodBucketIds = new HashSet<String>(bucketCounter.countersByPeriod.keySet());
                existingPeriodBucketIds.forEach(existingPeriodBucketId -> {
                    if (!expectedBucketIds.contains(existingPeriodBucketId)) {
                        bucketCounter.countersByPeriod.remove(existingPeriodBucketId);
                    }
                });
            }
        });
    }

    private synchronized BucketCounter getCounter(GeneralSettingsDAO.LLMCostLimitingQuota quota) {
        this.initIfNeeded();
        BucketCounter bucketCounter = this.data.counters.computeIfAbsent(quota.getId(), id -> new BucketCounter(quota.periodicity.aggregationType, quota.rollingPeriods));
        if (bucketCounter.updateSettings(quota)) {
            logger.warn((Object)("Cleared counter data for quota: " + quota.getLogIdentifier()));
        }
        return bucketCounter;
    }

    public synchronized AggregatedData getAggregatedData(GeneralSettingsDAO.LLMCostLimitingQuota quota) {
        return this.getCounter(quota).get();
    }

    public synchronized AggregatedDataWithDetails getAggregatedDataWithDetails(GeneralSettingsDAO.LLMCostLimitingQuota quota, boolean withDetails) {
        BucketCounter counter = this.getCounter(quota);
        AggregatedData counterData = counter.get();
        if (withDetails) {
            return new AggregatedDataWithDetails(quota, counterData, counter);
        }
        return new AggregatedDataWithDetails(quota, counterData, null);
    }

    public synchronized AuthorizedProgressData getProgressData(GeneralSettingsDAO.LLMCostLimitingQuota quota, double costLimit) {
        BucketCounter counter = this.getCounter(quota);
        return new AuthorizedProgressData(counter.get(), quota.getId(), quota.getName(), costLimit, quota.withFullAccess);
    }

    public Map<String, AggregatedDataWithDetails> getAllAggregatedData(boolean withDetails) {
        List<GeneralSettingsDAO.LLMCostLimitingQuota> quotas = this.quotasRepository.getAllQuotas();
        return quotas.stream().collect(Collectors.toMap(quota -> quota.getId(), quota -> this.getAggregatedDataWithDetails((GeneralSettingsDAO.LLMCostLimitingQuota)quota, withDetails)));
    }

    public Map<String, AuthorizedProgressData> getAuthorizedProgressData(AuthCtx authCtx) throws DKUSecurityException {
        List<GeneralSettingsDAO.LLMCostLimitingQuota> quotas = this.quotasRepository.getAuthorizedQuotas(authCtx);
        Map<String, Double> costLimitByQuotaId = this.quotasRepository.getCostLimitByQuotaId();
        return quotas.stream().collect(Collectors.toMap(quota -> quota.getId(), quota -> this.getProgressData((GeneralSettingsDAO.LLMCostLimitingQuota)quota, (Double)costLimitByQuotaId.get(quota.getId()))));
    }

    public List<AggregatedDataWithDetails> getAggregatedDataFromAvailableQuotas() {
        List<GeneralSettingsDAO.LLMCostLimitingQuota> quotas = this.quotasRepository.getAvailableQuotas();
        return quotas.stream().map(quota -> this.getAggregatedDataWithDetails((GeneralSettingsDAO.LLMCostLimitingQuota)quota, false)).toList();
    }

    public void clearAllCounters() {
        for (GeneralSettingsDAO.LLMCostLimitingQuota quota : this.quotasRepository.getAllQuotas()) {
            logger.info((Object)("Clearing data from cost limit quota: " + quota.getLogIdentifier()));
            this.getCounter(quota).clearData();
        }
    }

    public synchronized AggregatedData increment(GeneralSettingsDAO.LLMCostLimitingQuota quota, AggregatedData data) {
        this.initIfNeeded();
        return this.getCounter(quota).increment(data);
    }

    public AggregatedData increment(GeneralSettingsDAO.LLMCostLimitingQuota quota, double estimatedCost, int allowedQueries) {
        return this.increment(quota, estimatedCost, allowedQueries, 0);
    }

    public AggregatedData increment(GeneralSettingsDAO.LLMCostLimitingQuota quota, double estimatedCost, int allowedQueries, int blockedQueries) {
        return this.increment(quota, new AggregatedData(estimatedCost, allowedQueries, blockedQueries));
    }

    public static class CostLimitingCountersData {
        String version = "1";
        Map<String, BucketCounter> counters = new HashMap<String, BucketCounter>();
    }

    public static class BucketCounter {
        public String name;
        public AggregationType type = AggregationType.MINUTE;
        public int rollingPeriods = 1;
        public Map<String, AggregatedData> countersByPeriod = new HashMap<String, AggregatedData>();
        private transient Map<String, AggregatedData> rollingCountersByPeriod = new HashMap<String, AggregatedData>();

        public BucketCounter() {
        }

        public BucketCounter(AggregationType aggregationType) {
            this(aggregationType, 1);
        }

        public BucketCounter(AggregationType aggregationType, int rollingPeriods) {
            this.type = aggregationType;
            this.rollingPeriods = rollingPeriods;
        }

        private synchronized AggregatedData getPeriodCounter(LocalDateTime date) {
            return this.countersByPeriod.computeIfAbsent(this.type.generateAggregationId(date), id -> new AggregatedData());
        }

        private synchronized AggregatedData getOrCreateRollingCounter(LocalDateTime date) {
            String currentPeriodId = this.type.generateAggregationId(date);
            return this.rollingCountersByPeriod.computeIfAbsent(currentPeriodId, id -> {
                Set<String> rollingPeriodCounterIds = this.type.generateAggregationIds(date, this.rollingPeriods);
                AggregatedData aggregatedData = new AggregatedData();
                for (String rollingPeriodCounterId : rollingPeriodCounterIds) {
                    AggregatedData periodData = this.countersByPeriod.get(rollingPeriodCounterId);
                    if (periodData == null) continue;
                    aggregatedData.aggregate(periodData);
                }
                return aggregatedData;
            });
        }

        synchronized AggregatedData get(LocalDateTime dateTime) {
            if (this.rollingPeriods == 1) {
                return this.getPeriodCounter(dateTime).clone();
            }
            return this.getOrCreateRollingCounter(dateTime).clone();
        }

        synchronized AggregatedData get() {
            return this.get(LocalDateTime.now());
        }

        synchronized AggregatedData increment(LocalDateTime dateTime, AggregatedData data) {
            AggregatedData rollingCounter = this.rollingPeriods > 1 ? this.getOrCreateRollingCounter(dateTime) : null;
            AggregatedData counter = this.getPeriodCounter(dateTime);
            counter.aggregate(data);
            if (rollingCounter != null) {
                rollingCounter.aggregate(data);
                return rollingCounter.clone();
            }
            return counter.clone();
        }

        synchronized AggregatedData increment(AggregatedData data) {
            return this.increment(LocalDateTime.now(), data);
        }

        synchronized boolean updateSettings(GeneralSettingsDAO.LLMCostLimitingQuota quota) {
            this.name = quota.getName();
            AggregationType newType = quota.periodicity.aggregationType;
            int newRollingPeriods = quota.rollingPeriods;
            boolean requiresFullClear = false;
            boolean requiresRollingCountersCacheClear = false;
            if (this.type != newType) {
                logger.warn((Object)("Aggregation type has changed from " + String.valueOf((Object)this.type) + " to " + String.valueOf((Object)newType)));
                this.type = newType;
                requiresFullClear = true;
            }
            if (this.rollingPeriods != newRollingPeriods) {
                logger.info((Object)("Rolling periods nb has changed from " + this.rollingPeriods + " to " + newRollingPeriods));
                this.rollingPeriods = newRollingPeriods;
                requiresRollingCountersCacheClear = true;
            }
            if (requiresFullClear) {
                this.clearData();
                return true;
            }
            if (requiresRollingCountersCacheClear) {
                this.clearRollingCountersCache();
                return false;
            }
            return false;
        }

        private synchronized void clearData() {
            this.countersByPeriod.clear();
            this.rollingCountersByPeriod.clear();
        }

        private synchronized void clearRollingCountersCache() {
            this.rollingCountersByPeriod.clear();
        }
    }

    public static class AggregatedData
    implements Cloneable {
        public double accruedCost = 0.0;
        public int allowedQueries = 0;
        public int blockedQueries = 0;

        public AggregatedData() {
        }

        public AggregatedData(double accruedCost, int allowedQueries) {
            this(accruedCost, allowedQueries, 0);
        }

        public AggregatedData(double accruedCost, int allowedQueries, int blockedQueries) {
            this.accruedCost = accruedCost;
            this.allowedQueries = allowedQueries;
            this.blockedQueries = blockedQueries;
        }

        void aggregate(AggregatedData data) {
            this.accruedCost += data.accruedCost;
            this.allowedQueries += data.allowedQueries;
            this.blockedQueries += data.blockedQueries;
        }

        protected AggregatedData clone() {
            try {
                return (AggregatedData)super.clone();
            }
            catch (CloneNotSupportedException e) {
                throw new RuntimeException(e);
            }
        }
    }

    public static class AggregatedDataWithDetails
    extends AggregatedData {
        public String id;
        public String name;
        public BucketCounter counters;

        AggregatedDataWithDetails(GeneralSettingsDAO.LLMCostLimitingQuota quota, AggregatedData aggregatedData, BucketCounter counter) {
            this.id = quota.getId();
            this.name = quota.getName();
            this.aggregate(aggregatedData);
            this.counters = counter;
        }
    }

    public static class AuthorizedProgressData {
        public Double accruedCost;
        public Integer allowedQueries;
        public Integer blockedQueries;
        public Double progress;
        public String id;
        public String name;

        AuthorizedProgressData(AggregatedData data, String id, String name, double costLimit, boolean withFullAccess) {
            this.id = id;
            this.name = name;
            if (costLimit > 0.0) {
                this.progress = data.accruedCost / costLimit;
                if (Double.isInfinite(this.progress)) {
                    this.progress = 0.0;
                }
            } else {
                this.progress = null;
            }
            if (withFullAccess) {
                this.accruedCost = data.accruedCost;
                this.allowedQueries = data.allowedQueries;
                this.blockedQueries = data.blockedQueries;
            }
        }
    }

    public static enum AggregationType {
        YEAR("yyyy", d -> d.minusYears(1L)),
        QUARTER("yyyy-'Q'q", d -> d.minusMonths(3L)),
        MONTH("yyyy-MM", d -> d.minusMonths(1L)),
        DAY("yyyy-MM-dd", d -> d.minusDays(1L)),
        MINUTE("yyyy-MM-dd'T'HH:mm", d -> d.minusMinutes(1L)),
        NONE("'default'", d -> d);

        private final DateTimeFormatter formatter;
        private final Function<LocalDateTime, LocalDateTime> shiftBackFunction;

        private AggregationType(String pattern, Function<LocalDateTime, LocalDateTime> shiftBackFunction) {
            this.formatter = DateTimeFormatter.ofPattern(pattern);
            this.shiftBackFunction = shiftBackFunction;
        }

        public String generateAggregationId(LocalDateTime dateTime) {
            return dateTime.format(this.formatter);
        }

        private LocalDateTime previousPeriod(LocalDateTime dateTime) {
            return this.shiftBackFunction.apply(dateTime);
        }

        public Set<String> generateAggregationIds(LocalDateTime referenceDate, int nbRollingPeriods) {
            Preconditions.checkArgument((nbRollingPeriods >= 1 ? 1 : 0) != 0, (Object)"Expecting nbRollingPeriods to be at least 1");
            LinkedHashSet<String> aggregationIds = new LinkedHashSet<String>();
            LocalDateTime localDateTime = referenceDate;
            aggregationIds.add(this.generateAggregationId(localDateTime));
            for (int i = 1; i < nbRollingPeriods; ++i) {
                localDateTime = this.previousPeriod(localDateTime);
                aggregationIds.add(this.generateAggregationId(localDateTime));
            }
            return aggregationIds;
        }
    }
}

