/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.llm.online;

import com.dataiku.dip.connections.CustomLLMConnection;
import com.dataiku.dip.dao.GeneralSettingsDAO;
import com.dataiku.dip.llm.LLMStructuredRef;
import com.dataiku.dip.llm.custom.CustomJavaLLMClientWrapper;
import com.dataiku.dip.plugins.IPluginsRegistryService;
import com.dataiku.dip.plugins.model.InstalledPluginDesc;
import com.dataiku.dip.plugins.model.PluginDesc;
import com.dataiku.dip.server.SpringUtils;
import com.dataiku.dip.server.notifications.DSSEvent;
import com.dataiku.dip.server.notifications.backend.GeneralSettingsChangedEvent;
import com.dataiku.dip.server.notifications.backend.LLMRateLimitingSettingsChanged;
import com.dataiku.dip.server.services.ConnectionsService;
import com.dataiku.dip.server.services.IPubSubService;
import com.dataiku.dip.server.services.TransactionService;
import com.dataiku.dip.transactions.ifaces.Transaction;
import com.dataiku.dip.utils.JSON;
import com.dataiku.dss.shadelib.com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableMap;
import java.io.IOException;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import javax.annotation.PostConstruct;
import org.jetbrains.annotations.NotNull;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;

@Service
public class LLMRateLimitingSettingsService {
    @Autowired
    private TransactionService transactionService;
    @Autowired
    private GeneralSettingsDAO generalSettingsDAO;
    @Autowired
    private IPubSubService pubSubService;
    @Autowired
    private IPluginsRegistryService pluginsService;
    private Map<String, GeneralSettingsDAO.RateLimitingProviderSettings> baselineSettings = new HashMap<String, GeneralSettingsDAO.RateLimitingProviderSettings>();
    private Map<String, GeneralSettingsDAO.RateLimitingProviderSettings> applicableSettings = new HashMap<String, GeneralSettingsDAO.RateLimitingProviderSettings>();
    @VisibleForTesting
    static final Map<LLMStructuredRef.LLMType, String> SUPPORTED_PROVIDER_IDS_BY_LLM_TYPES = ImmutableMap.builder().put((Object)LLMStructuredRef.LLMType.ANTHROPIC, (Object)"Anthropic").put((Object)LLMStructuredRef.LLMType.AZURE_LLM, (Object)"AzureLLM").put((Object)LLMStructuredRef.LLMType.AZURE_OPENAI_DEPLOYMENT, (Object)"AzureOpenAI").put((Object)LLMStructuredRef.LLMType.BEDROCK, (Object)"Bedrock").put((Object)LLMStructuredRef.LLMType.COHERE, (Object)"Cohere").put((Object)LLMStructuredRef.LLMType.MISTRALAI, (Object)"MistralAI").put((Object)LLMStructuredRef.LLMType.OPENAI, (Object)"OpenAI").put((Object)LLMStructuredRef.LLMType.STABILITYAI, (Object)"StabilityAI").put((Object)LLMStructuredRef.LLMType.VERTEX, (Object)"VertexAILLM").build();
    private static final Set<@NotNull String> DEFAULT_SMOOTH_PROVIDERS = Set.of("AzureOpenAI", "StabilityAI", "MistralAI");
    private Set<@NotNull String> smoothProviders = new HashSet<String>();

    @PostConstruct
    public void init() throws IOException {
        GeneralSettingsDAO.RateLimitingSettings userSettings;
        try (Transaction t = this.transactionService.beginRead();){
            userSettings = this.generalSettingsDAO.read().generativeAISettings.rateLimitingSettings;
        }
        this.updateSettings(userSettings);
        this.pubSubService.subscribe("general-settings-changed", evt -> {
            boolean rateLimitingSettingsChanged;
            GeneralSettingsChangedEvent settingsChangedEvent = (GeneralSettingsChangedEvent)evt;
            GeneralSettingsDAO.RateLimitingSettings newUserSettings = settingsChangedEvent.newSettings.generativeAISettings.rateLimitingSettings;
            boolean bl = rateLimitingSettingsChanged = !JSON.jsonEquals((Object)settingsChangedEvent.previousSettings.generativeAISettings.rateLimitingSettings, (Object)newUserSettings);
            if (rateLimitingSettingsChanged) {
                this.updateSettings(newUserSettings);
            }
        });
        this.pubSubService.subscribe("plugin-changed", evt -> {
            GeneralSettingsDAO.RateLimitingSettings currentUserSettings;
            try (Transaction t = this.transactionService.beginRead();){
                currentUserSettings = this.generalSettingsDAO.read().generativeAISettings.rateLimitingSettings;
            }
            this.updateSettings(currentUserSettings);
        });
    }

    public synchronized Map<String, GeneralSettingsDAO.RateLimitingProviderSettings> getBaselineSettings() {
        return this.baselineSettings;
    }

    public synchronized Map<String, GeneralSettingsDAO.RateLimitingProviderSettings> getApplicableSettings() {
        return this.applicableSettings;
    }

    private static GeneralSettingsDAO.RateLimitingProviderSettings buildPluginBaseline(InstalledPluginDesc ipd) {
        PluginDesc.LLMRateLimitingBaseline baseline = ipd.desc.llmRateLimitingBaseline;
        if (baseline == null) {
            baseline = new PluginDesc.LLMRateLimitingBaseline();
        }
        return GeneralSettingsDAO.RateLimitingProviderSettings.builder().withCompletionDefaultConfig(baseline.completionDefaultRPM).withEmbeddingDefaultConfig(baseline.embeddingDefaultRPM).withImageGenerationDefaultConfig(baseline.imageGenerationDefaultRPM).build();
    }

    private Map<String, GeneralSettingsDAO.RateLimitingProviderSettings> buildCustomPluginBaselines() {
        HashMap<String, GeneralSettingsDAO.RateLimitingProviderSettings> customPluginBaselines = new HashMap<String, GeneralSettingsDAO.RateLimitingProviderSettings>();
        for (InstalledPluginDesc ipd : this.pluginsService.getLoadedPlugins()) {
            Object pluginProviderId;
            if (!ipd.customJavaLLMs.isEmpty()) {
                pluginProviderId = CustomJavaLLMClientWrapper.getCustomProviderId(ipd.desc.id);
                customPluginBaselines.put((String)pluginProviderId, LLMRateLimitingSettingsService.buildPluginBaseline(ipd));
                if (ipd.desc.llmRateLimitingBaseline == null) {
                    ipd.desc.llmRateLimitingBaseline = new PluginDesc.LLMRateLimitingBaseline();
                }
                if (ipd.desc.llmRateLimitingBaseline.isSmooth) {
                    this.smoothProviders.add((String)pluginProviderId);
                }
            }
            if (ipd.customPythonLLMs.isEmpty()) continue;
            pluginProviderId = "CustomLLM:" + ipd.desc.id;
            customPluginBaselines.put((String)pluginProviderId, LLMRateLimitingSettingsService.buildPluginBaseline(ipd));
            if (ipd.desc.llmRateLimitingBaseline == null) {
                ipd.desc.llmRateLimitingBaseline = new PluginDesc.LLMRateLimitingBaseline();
            }
            if (!ipd.desc.llmRateLimitingBaseline.isSmooth) continue;
            this.smoothProviders.add((String)pluginProviderId);
        }
        return customPluginBaselines;
    }

    public synchronized String getSupportedProviderId(LLMStructuredRef llmRef) throws IOException {
        if (llmRef.type == LLMStructuredRef.LLMType.CUSTOM) {
            CustomLLMConnection conn = (CustomLLMConnection)((ConnectionsService)SpringUtils.getBean(ConnectionsService.class)).get(llmRef.connection);
            return CustomJavaLLMClientWrapper.getCustomProviderId(conn.params.pluginID);
        }
        return SUPPORTED_PROVIDER_IDS_BY_LLM_TYPES.get((Object)llmRef.type);
    }

    public synchronized boolean isSmoothProvider(String providerId) {
        return this.smoothProviders.contains(providerId);
    }

    @VisibleForTesting
    static Map<String, GeneralSettingsDAO.RateLimitingProviderSettings> buildManagedBaselineSettings() {
        HashMap<String, GeneralSettingsDAO.RateLimitingProviderSettings> baseline = new HashMap<String, GeneralSettingsDAO.RateLimitingProviderSettings>();
        GeneralSettingsDAO.RateLimitingProviderSettings anthropicSettings = GeneralSettingsDAO.RateLimitingProviderSettings.builder().withCompletionDefaultConfig(1000L).build();
        baseline.put("Anthropic", anthropicSettings);
        GeneralSettingsDAO.RateLimitingProviderSettings azureOpenAISettings = GeneralSettingsDAO.RateLimitingProviderSettings.builder().withCompletionDefaultConfig(500L).withEmbeddingDefaultConfig(500L).withImageGenerationDefaultConfig(500L).build();
        baseline.put("AzureOpenAI", azureOpenAISettings);
        int azureLLMRpmLimit = 600;
        GeneralSettingsDAO.RateLimitingProviderSettings azureLLMSettings = GeneralSettingsDAO.RateLimitingProviderSettings.builder().withCompletionDefaultConfig(azureLLMRpmLimit).withEmbeddingDefaultConfig(azureLLMRpmLimit).withImageGenerationDefaultConfig(azureLLMRpmLimit).build();
        baseline.put("AzureLLM", azureLLMSettings);
        GeneralSettingsDAO.RateLimitingProviderSettings cohereSettings = GeneralSettingsDAO.RateLimitingProviderSettings.builder().withCompletionDefaultConfig(500L).withEmbeddingDefaultConfig(2000L).build();
        baseline.put("Cohere", cohereSettings);
        GeneralSettingsDAO.RateLimitingProviderSettings bedrockSettings = GeneralSettingsDAO.RateLimitingProviderSettings.builder().withCompletionDefaultConfig(1000L).withEmbeddingDefaultConfig(2000L).withImageGenerationDefaultConfig(10L).build();
        baseline.put("Bedrock", bedrockSettings);
        int mistralRpmLimit = 300;
        GeneralSettingsDAO.RateLimitingProviderSettings mistralSettings = GeneralSettingsDAO.RateLimitingProviderSettings.builder().withCompletionDefaultConfig(mistralRpmLimit).withEmbeddingDefaultConfig(mistralRpmLimit).build();
        baseline.put("MistralAI", mistralSettings);
        GeneralSettingsDAO.RateLimitingProviderSettings openAISettings = GeneralSettingsDAO.RateLimitingProviderSettings.builder().withCompletionDefaultConfig(5000L).withEmbeddingDefaultConfig(5000L).withImageGenerationDefaultConfig(2500L).withPerModelConfig("gpt-3.5-turbo", 3500L).build();
        baseline.put("OpenAI", openAISettings);
        GeneralSettingsDAO.RateLimitingProviderSettings vertexSettings = GeneralSettingsDAO.RateLimitingProviderSettings.builder().withCompletionDefaultConfig(200L).withEmbeddingDefaultConfig(1500L).withImageGenerationDefaultConfig(100L).withPerModelConfig("gemini-pro", 300L).withPerModelConfig("gemini-1.5-pro", 60L).withPerModelConfig("gemini-pro-vision", 100L).withPerModelConfig("multimodalembedding@001", 120L).withPerModelConfig("text-embedding-004", 360L).build();
        baseline.put("VertexAILLM", vertexSettings);
        int stabilityRPMLimit = 900;
        GeneralSettingsDAO.RateLimitingProviderSettings stabilitySettings = GeneralSettingsDAO.RateLimitingProviderSettings.builder().withImageGenerationDefaultConfig(stabilityRPMLimit).build();
        baseline.put("StabilityAI", stabilitySettings);
        return baseline;
    }

    @VisibleForTesting
    protected synchronized void updateSettings(GeneralSettingsDAO.RateLimitingSettings newUserSettings) {
        this.smoothProviders = new HashSet<String>(DEFAULT_SMOOTH_PROVIDERS);
        Map<String, GeneralSettingsDAO.RateLimitingProviderSettings> customPluginBaselines = this.buildCustomPluginBaselines();
        this.baselineSettings = LLMRateLimitingSettingsService.buildManagedBaselineSettings();
        this.baselineSettings.putAll(customPluginBaselines);
        this.applicableSettings = this.buildApplicableProvidersSettings(newUserSettings);
        this.pubSubService.publish((DSSEvent)new LLMRateLimitingSettingsChanged());
    }

    private Map<String, GeneralSettingsDAO.RateLimitingProviderSettings> copySettings(Map<String, GeneralSettingsDAO.RateLimitingProviderSettings> settingsMap) {
        HashMap<String, GeneralSettingsDAO.RateLimitingProviderSettings> mapCopy = new HashMap<String, GeneralSettingsDAO.RateLimitingProviderSettings>();
        for (Map.Entry<String, GeneralSettingsDAO.RateLimitingProviderSettings> entry : settingsMap.entrySet()) {
            mapCopy.put(entry.getKey(), new GeneralSettingsDAO.RateLimitingProviderSettings(entry.getValue()));
        }
        return mapCopy;
    }

    private Map<String, GeneralSettingsDAO.RateLimitingProviderSettings> buildApplicableProvidersSettings(GeneralSettingsDAO.RateLimitingSettings userSettings) {
        Map<String, GeneralSettingsDAO.RateLimitingProviderSettings> newApplicableSettings = this.copySettings(this.baselineSettings);
        for (String providerId : newApplicableSettings.keySet()) {
            if (!userSettings.perProviderSettings.containsKey(providerId)) continue;
            GeneralSettingsDAO.RateLimitingProviderSettings userProviderSettings = userSettings.perProviderSettings.get(providerId);
            GeneralSettingsDAO.RateLimitingProviderSettings applicableProviderSettings = newApplicableSettings.get(providerId);
            if (userProviderSettings.completionDefault != null && !userProviderSettings.completionDefault.managed) {
                applicableProviderSettings.completionDefault = userProviderSettings.completionDefault;
            }
            if (userProviderSettings.embeddingDefault != null && !userProviderSettings.embeddingDefault.managed) {
                applicableProviderSettings.embeddingDefault = userProviderSettings.embeddingDefault;
            }
            if (userProviderSettings.imageGenerationDefault != null && !userProviderSettings.imageGenerationDefault.managed) {
                applicableProviderSettings.imageGenerationDefault = userProviderSettings.imageGenerationDefault;
            }
            for (Map.Entry<String, GeneralSettingsDAO.RateLimitingConfig> modelConfig : userProviderSettings.perModelConfigs.entrySet()) {
                if (modelConfig.getValue().managed) continue;
                applicableProviderSettings.perModelConfigs.put(modelConfig.getKey(), modelConfig.getValue());
            }
        }
        return newApplicableSettings;
    }
}

