/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.llm.online;

import com.dataiku.dip.ApplicationConfigurator;
import com.dataiku.dip.DKUApp;
import com.dataiku.dip.dao.GeneralSettingsDAO;
import com.dataiku.dip.llm.online.LLMCostLimitingCountersRepository;
import com.dataiku.dip.llm.online.LLMCostLimitingQuotasRepository;
import com.dataiku.dip.server.notifications.DSSEvent;
import com.dataiku.dip.server.notifications.EmailNotificationsSender;
import com.dataiku.dip.server.notifications.VariableLookup;
import com.dataiku.dip.server.notifications.backend.LlmCostLimitingThresholdReachedEvent;
import com.dataiku.dip.server.notifications.emails.MessageContentBuilder;
import com.dataiku.dip.server.notifications.emails.TemplatedContent;
import com.dataiku.dip.server.services.GeneralSettingsService;
import com.dataiku.dip.server.services.IPubSubService;
import com.dataiku.dip.server.services.TransactionService;
import com.dataiku.dip.server.services.UserSettingsService;
import com.dataiku.dip.transactions.ifaces.Transaction;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.JSON;
import com.dataiku.dss.shadelib.com.google.common.base.Strings;
import com.dataiku.dss.shadelib.com.google.common.cache.CacheBuilder;
import com.dataiku.dss.shadelib.com.google.common.cache.CacheLoader;
import com.dataiku.dss.shadelib.com.google.common.cache.LoadingCache;
import java.io.File;
import java.time.LocalDateTime;
import java.time.Year;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import javax.annotation.PostConstruct;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.lang.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;

@Service
public class LLMCostLimitingService {
    @Autowired
    private LLMCostLimitingCountersRepository countersRepository;
    @Autowired
    private LLMCostLimitingQuotasRepository quotasRepository;
    @Autowired
    private IPubSubService pubSubService;
    @Autowired
    private TransactionService transactionService;
    @Autowired
    private GeneralSettingsService generalSettingsService;
    private final LoadingCache<LLMCostLimitingContext, List<GeneralSettingsDAO.LLMCostLimitingQuota>> applicableQuotasCache = CacheBuilder.newBuilder().expireAfterAccess(1L, TimeUnit.HOURS).build((CacheLoader)new CacheLoader<LLMCostLimitingContext, List<GeneralSettingsDAO.LLMCostLimitingQuota>>(){

        public List<GeneralSettingsDAO.LLMCostLimitingQuota> load(LLMCostLimitingContext context) {
            return LLMCostLimitingService.this.quotasRepository.getApplicableQuotas(context);
        }
    });
    private static final DKULogger logger = DKULogger.getLogger((String)"dku.llm.costlimiting");

    @PostConstruct
    public void init() {
        this.pubSubService.subscribe("general-settings-changed", evt -> {
            boolean costLimitingSettingsChanged;
            boolean bl = costLimitingSettingsChanged = !JSON.jsonEquals((Object)evt.previousSettings.generativeAISettings.costLimitingSettings, (Object)evt.newSettings.generativeAISettings.costLimitingSettings);
            if (costLimitingSettingsChanged) {
                this.applicableQuotasCache.invalidateAll();
            }
        });
        this.pubSubService.subscribe("costlimitingthreshold-reached", evt -> {
            GeneralSettingsDAO.LLMCostLimitingQuota quota = this.quotasRepository.getQuota(evt.quotaId);
            if (quota == null) {
                logger.warn((Object)("Unknown cost limiting quota: " + evt.quotaId));
                return;
            }
            this.sendThresholdReachedNotification((LlmCostLimitingThresholdReachedEvent)evt, quota);
        });
    }

    public void checkQuery(LLMCostLimitingContext context) throws LLMCostLimitingBlockedException {
        List applicableQuotas = (List)this.applicableQuotasCache.getUnchecked((Object)context);
        ArrayList<GeneralSettingsDAO.LLMCostLimitingQuota> exceededQuotas = new ArrayList<GeneralSettingsDAO.LLMCostLimitingQuota>();
        for (GeneralSettingsDAO.LLMCostLimitingQuota applicableQuota : applicableQuotas) {
            if (!applicableQuota.blockingLimit) continue;
            LLMCostLimitingCountersRepository.AggregatedData aggregatedData = this.countersRepository.getAggregatedData(applicableQuota);
            if (!(aggregatedData.accruedCost >= applicableQuota.costLimit)) continue;
            this.countersRepository.increment(applicableQuota, 0.0, 0, 1);
            exceededQuotas.add(applicableQuota);
        }
        if (!exceededQuotas.isEmpty()) {
            String exceededQuotaIds = exceededQuotas.stream().map(l -> l.getLogIdentifier()).collect(Collectors.joining(", "));
            throw new LLMCostLimitingBlockedException("Request blocked by cost limiter because of exceeded quotas: " + exceededQuotaIds);
        }
    }

    public void reportCost(LLMCostLimitingContext context, double estimatedCost, int nbQueries) {
        try {
            if (nbQueries == 0) {
                return;
            }
            if (estimatedCost < 0.0) {
                logger.warn((Object)"Negative cost provided, reporting 0");
                estimatedCost = 0.0;
            }
            List applicableQuotas = (List)this.applicableQuotasCache.getUnchecked((Object)context);
            for (GeneralSettingsDAO.LLMCostLimitingQuota applicableQuota : applicableQuotas) {
                String currentBucket = applicableQuota.periodicity.aggregationType.generateAggregationId(LocalDateTime.now());
                LLMCostLimitingCountersRepository.AggregatedData aggregatedData = this.countersRepository.increment(applicableQuota, estimatedCost, nbQueries);
                logger.info((Object)("Bucket " + currentBucket + " incremented for quota " + applicableQuota.getLogIdentifier() + ", new value: " + aggregatedData.accruedCost));
                double previousCost = aggregatedData.accruedCost - estimatedCost;
                ArrayList<GeneralSettingsDAO.ReportingAction> thresholdsToCheck = new ArrayList<GeneralSettingsDAO.ReportingAction>(applicableQuota.reportingActions);
                if (this.requiresAdditionalBlockedQueriesEmailAlert(applicableQuota)) {
                    thresholdsToCheck.add(new GeneralSettingsDAO.ReportingAction(100.0));
                }
                for (GeneralSettingsDAO.ReportingAction reportingAction : thresholdsToCheck) {
                    if (reportingAction.threshold < 0.0) {
                        logger.warn((Object)("Ignoring reporting alert for quota " + applicableQuota.getLogIdentifier() + " because of negative threshold of " + reportingAction.threshold));
                        continue;
                    }
                    double thresholdCostLimitUsd = applicableQuota.costLimit * reportingAction.threshold / 100.0;
                    boolean thresholdJustCrossed = aggregatedData.accruedCost > thresholdCostLimitUsd && previousCost <= thresholdCostLimitUsd;
                    if (!thresholdJustCrossed) continue;
                    logger.warn((Object)("Reporting threshold exceeded: quotaId=" + applicableQuota.getLogIdentifier() + ", threshold=" + reportingAction.threshold));
                    this.pubSubService.publish((DSSEvent)new LlmCostLimitingThresholdReachedEvent(applicableQuota.getId(), applicableQuota.costLimit, reportingAction.threshold, aggregatedData.accruedCost));
                }
            }
        }
        catch (Exception e) {
            throw new LLMCostLimitingReportingException(e);
        }
    }

    private boolean requiresAdditionalBlockedQueriesEmailAlert(GeneralSettingsDAO.LLMCostLimitingQuota quota) {
        if (!quota.blockingLimit) {
            return false;
        }
        boolean alreadyHas100PctAlert = quota.reportingActions.stream().anyMatch(a -> a.threshold == 100.0);
        return !alreadyHas100PctAlert;
    }

    private void sendThresholdReachedNotification(LlmCostLimitingThresholdReachedEvent event, GeneralSettingsDAO.LLMCostLimitingQuota quota) throws Exception {
        if (CollectionUtils.isEmpty(quota.reportingTargets)) {
            logger.warn((Object)("Bypassing cost limiting alerting because no recipients defined for quota:" + quota.getLogIdentifier()));
            return;
        }
        if (Strings.isNullOrEmpty((String)quota.emailChannelId)) {
            logger.warn((Object)"Bypassing cost limiting alerting because notification emails channel is not configured");
            return;
        }
        UserSettingsService.EmailNotificationsSettings params = new UserSettingsService.EmailNotificationsSettings();
        params.enabled = true;
        params.subject = event.reachedValue > quota.costLimit && quota.blockingLimit ? "Dataiku \u2022 LLM quota \"" + quota.getName() + "\" has been blocked!" : "Dataiku \u2022 LLM quota \"" + quota.getName() + "\" has hit " + event.reportingThreshold.intValue() + "% consumption!";
        String body = this.makeBody(event);
        EmailNotificationsSender sender = new EmailNotificationsSender(quota.emailChannelId);
        for (GeneralSettingsDAO.ReportingTarget target : quota.reportingTargets) {
            if (Strings.isNullOrEmpty((String)target.email)) continue;
            logger.info((Object)("Sending quota reached email to " + target.email));
            sender.sendToUser(target.email, params, body);
        }
    }

    private String makeBody(LlmCostLimitingThresholdReachedEvent event) throws Exception {
        VariableLookup lookup = new VariableLookup();
        try (Transaction ignored = this.transactionService.beginRead();){
            GeneralSettingsDAO.GeneralSettings settings = this.generalSettingsService.read();
            GeneralSettingsDAO.LLMCostLimitingQuota quota = this.quotasRepository.getQuota(event.quotaId);
            lookup.addVariable("studioExternalUrl", StringUtils.isBlank((String)settings.studioExternalUrl) ? null : settings.studioExternalUrl);
            lookup.addVariable("quotaAdminLink", settings.studioExternalUrl + "/admin/general/genai/#cost-control");
            lookup.addVariable("currentYear", String.valueOf(Year.now().getValue()));
            lookup.addVariable("threshold", event.reportingThreshold);
            lookup.addVariable("consumedAmount", event.reachedValue);
            lookup.addVariable("consumedPercent", event.quotaAmount != 0.0 ? Double.valueOf(event.reachedValue * 100.0 / event.quotaAmount) : null);
            lookup.addVariable("quotaId", quota.getId());
            lookup.addVariable("quotaName", quota.getName());
            lookup.addVariable("quotaResetPeriod", quota.periodicity.toDisplayLabel(quota.rollingPeriods));
            lookup.addVariable("quotaThresholds", quota.reportingActions.stream().map(a -> (int)a.threshold + "%").collect(Collectors.joining(", ")));
            lookup.addVariable("quotaBlocking", quota.blockingLimit);
            lookup.addVariable("quotaAmount", quota.costLimit);
            lookup.addVariable("queriesBlocked", quota.blockingLimit && event.reachedValue >= quota.costLimit);
        }
        File templateFile = ApplicationConfigurator.getResourceFile((String[])new String[]{"notifications", "costlimiting-quota-exceeded-email.ftl"});
        MessageContentBuilder.ExpandedTemplate expandedTemplate = new MessageContentBuilder(lookup).buildMessage(new TemplatedContent(), templateFile);
        return expandedTemplate.message;
    }

    public static boolean isFeatureEnabled() {
        return DKUApp.getParams().getBoolParam("dku.llm.costLimiting.enabled", true);
    }

    public static class LLMCostLimitingBlockedException
    extends RuntimeException {
        public LLMCostLimitingBlockedException(String message) {
            super(message);
        }
    }

    public static class LLMCostLimitingReportingException
    extends RuntimeException {
        public LLMCostLimitingReportingException(Throwable cause) {
            super("Error while reporting llm cost: " + cause.getMessage(), cause);
        }
    }

    public static class LLMCostLimitingContext {
        public String projectKey;
        public String user;
        public String provider;
        public String connectionName;
        public String llmId;

        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (o == null || this.getClass() != o.getClass()) {
                return false;
            }
            LLMCostLimitingContext that = (LLMCostLimitingContext)o;
            return Objects.equals(this.projectKey, that.projectKey) && Objects.equals(this.user, that.user) && Objects.equals(this.provider, that.provider) && Objects.equals(this.connectionName, that.connectionName) && Objects.equals(this.llmId, that.llmId);
        }

        public int hashCode() {
            return Objects.hash(this.projectKey, this.user, this.provider, this.connectionName, this.llmId);
        }
    }
}

