/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.llm.online;

import com.dataiku.dip.analysis.ml.prediction.overrides.ReadOnlyColumnFactory;
import com.dataiku.dip.analysis.ml.prediction.overrides.ReadOnlyRowObservation;
import com.dataiku.dip.dao.GeneralSettingsDAO;
import com.dataiku.dip.dataflow.exec.filter.FilterDescUtils;
import com.dataiku.dip.datalayer.ColumnFactory;
import com.dataiku.dip.expressions.Expression;
import com.dataiku.dip.llm.online.LLMCostLimitingService;
import com.dataiku.dip.server.notifications.backend.GeneralSettingsChangedEvent;
import com.dataiku.dip.server.services.IPubSubService;
import com.dataiku.dip.server.services.TransactionService;
import com.dataiku.dip.server.services.licensing.LicenseEnforcementService;
import com.dataiku.dip.transactions.ifaces.Transaction;
import com.dataiku.dip.utils.JSON;
import com.dataiku.dss.shadelib.com.google.common.annotations.VisibleForTesting;
import com.dataiku.dss.shadelib.com.google.common.base.MoreObjects;
import com.dataiku.scoring.util.RawObservation;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import javax.annotation.PostConstruct;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;

@Service
public class LLMCostLimitingQuotasRepository {
    @Autowired
    private TransactionService transactionService;
    @Autowired
    private GeneralSettingsDAO generalSettingsDAO;
    @Autowired
    private IPubSubService pubSubService;
    private LicenseEnforcementService licenseEnforcementService;
    private List<GeneralSettingsDAO.CustomLLMCostLimitingQuota> customQuotas;
    private GeneralSettingsDAO.FallbackLLMCostLimitingQuota fallbackQuota;

    public LLMCostLimitingQuotasRepository(@Autowired LicenseEnforcementService licenseEnforcementService) {
        this.licenseEnforcementService = licenseEnforcementService;
    }

    @PostConstruct
    public void init() throws IOException {
        try (Transaction t = this.transactionService.beginRead();){
            this.updateSettings(this.generalSettingsDAO.read().generativeAISettings.costLimitingSettings);
        }
        this.pubSubService.subscribe("general-settings-changed", evt -> {
            boolean costLimitingSettingsChanged;
            GeneralSettingsChangedEvent settingsChangedEvent = (GeneralSettingsChangedEvent)evt;
            boolean bl = costLimitingSettingsChanged = !JSON.jsonEquals((Object)settingsChangedEvent.previousSettings.generativeAISettings.costLimitingSettings, (Object)settingsChangedEvent.newSettings.generativeAISettings.costLimitingSettings);
            if (costLimitingSettingsChanged) {
                this.updateSettings(settingsChangedEvent.newSettings.generativeAISettings.costLimitingSettings);
            }
        });
    }

    @VisibleForTesting
    synchronized void updateSettings(GeneralSettingsDAO.LLMCostLimitingSettings newSettings) {
        this.customQuotas = newSettings.quotas;
        this.fallbackQuota = newSettings.fallbackQuota;
    }

    @Nullable
    public synchronized GeneralSettingsDAO.LLMCostLimitingQuota getQuota(String quotaId) {
        if ("DKU-FALLBACK-QUOTA".equals(quotaId)) {
            return this.fallbackQuota;
        }
        return this.customQuotas.stream().filter(q -> Objects.equals(quotaId, q.getId())).findFirst().orElse(null);
    }

    public synchronized List<GeneralSettingsDAO.CustomLLMCostLimitingQuota> getCustomQuotas() {
        return this.customQuotas;
    }

    public synchronized List<GeneralSettingsDAO.LLMCostLimitingQuota> getAllQuotas() {
        ArrayList<GeneralSettingsDAO.LLMCostLimitingQuota> allQuotas = new ArrayList<GeneralSettingsDAO.LLMCostLimitingQuota>(this.customQuotas);
        allQuotas.add((GeneralSettingsDAO.CustomLLMCostLimitingQuota)((Object)this.fallbackQuota));
        return allQuotas;
    }

    private List<GeneralSettingsDAO.CustomLLMCostLimitingQuota> getMatchingCustomQuotas(LLMCostLimitingService.LLMCostLimitingContext context) {
        return this.getCustomQuotas().stream().filter(quota -> LLMCostLimitingQuotasRepository.matches(context, quota)).collect(Collectors.toList());
    }

    private boolean areCustomQuotaEnabled() {
        return this.licenseEnforcementService.getFeaturesStatus().advancedLLMMeshAllowed;
    }

    public List<GeneralSettingsDAO.LLMCostLimitingQuota> getAvailableQuotas() {
        ArrayList<GeneralSettingsDAO.LLMCostLimitingQuota> allQuotas = new ArrayList<GeneralSettingsDAO.LLMCostLimitingQuota>();
        if (this.areCustomQuotaEnabled()) {
            allQuotas.addAll(this.customQuotas);
        }
        allQuotas.add(this.fallbackQuota);
        return allQuotas;
    }

    public List<GeneralSettingsDAO.LLMCostLimitingQuota> getApplicableQuotas(LLMCostLimitingService.LLMCostLimitingContext context) {
        List<GeneralSettingsDAO.CustomLLMCostLimitingQuota> matchingQuotas = null;
        if (this.areCustomQuotaEnabled()) {
            matchingQuotas = this.getMatchingCustomQuotas(context);
        }
        if (matchingQuotas != null && !matchingQuotas.isEmpty()) {
            return new ArrayList<GeneralSettingsDAO.LLMCostLimitingQuota>(matchingQuotas);
        }
        return Arrays.asList(this.fallbackQuota);
    }

    private static boolean matches(LLMCostLimitingService.LLMCostLimitingContext context, GeneralSettingsDAO.CustomLLMCostLimitingQuota quota) {
        if (quota.filter != null && quota.filter.enabled) {
            Expression expression;
            RawObservation observation = new RawObservation(Map.of("projectKey", MoreObjects.firstNonNull((Object)context.projectKey, (Object)""), "project", MoreObjects.firstNonNull((Object)context.projectKey, (Object)""), "user", MoreObjects.firstNonNull((Object)context.user, (Object)""), "userLogin", MoreObjects.firstNonNull((Object)context.user, (Object)""), "provider", MoreObjects.firstNonNull((Object)context.provider, (Object)""), "connection", MoreObjects.firstNonNull((Object)context.connectionName, (Object)""), "connectionName", MoreObjects.firstNonNull((Object)context.connectionName, (Object)""), "llmId", MoreObjects.firstNonNull((Object)context.llmId, (Object)"")));
            try {
                expression = new Expression(FilterDescUtils.getGrelExpression(quota.filter));
                expression.setColumnFactory((ColumnFactory)new ReadOnlyColumnFactory(observation.keys().toArray(new String[0])));
            }
            catch (Exception e) {
                throw new RuntimeException(e);
            }
            return expression.isTrueish(new ReadOnlyRowObservation(observation));
        }
        return true;
    }
}

