/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.llm.cache;

import com.dataiku.dip.DKUApp;
import com.dataiku.dip.DSSMetrics;
import com.dataiku.dip.connections.AbstractLLMConnection;
import com.dataiku.dip.exceptions.DKUSecurityException;
import com.dataiku.dip.llm.LLMStructuredRef;
import com.dataiku.dip.llm.cache.ILLMCacheService;
import com.dataiku.dip.llm.online.LLMClient;
import com.dataiku.dip.llm.online.LLMClientFactory;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.server.services.TransactionService;
import com.dataiku.dip.transactions.ifaces.IsolationLevel;
import com.dataiku.dip.transactions.ifaces.Transaction;
import com.dataiku.dip.utils.DKUEhcacheSerializer;
import com.dataiku.dip.utils.DKUFileUtils;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.EhcacheMetrics;
import com.dataiku.dip.utils.JSON;
import com.dataiku.dss.shadelib.org.apache.commons.codec.binary.Hex;
import com.dataiku.dss.shadelib.org.apache.commons.codec.digest.DigestUtils;
import com.google.common.annotations.VisibleForTesting;
import java.io.File;
import java.io.IOException;
import java.security.MessageDigest;
import java.time.Duration;
import javax.annotation.PostConstruct;
import org.ehcache.Cache;
import org.ehcache.PersistentCacheManager;
import org.ehcache.config.Builder;
import org.ehcache.config.CacheConfiguration;
import org.ehcache.config.builders.CacheConfigurationBuilder;
import org.ehcache.config.builders.CacheManagerBuilder;
import org.ehcache.config.builders.ExpiryPolicyBuilder;
import org.ehcache.config.builders.ResourcePoolsBuilder;
import org.ehcache.config.units.MemoryUnit;
import org.ehcache.core.internal.statistics.DefaultStatisticsService;
import org.ehcache.core.spi.service.StatisticsService;
import org.ehcache.impl.config.store.disk.OffHeapDiskStoreConfiguration;
import org.ehcache.spi.serialization.Serializer;
import org.ehcache.spi.service.Service;
import org.ehcache.spi.service.ServiceConfiguration;
import org.springframework.beans.factory.annotation.Autowired;

@org.springframework.stereotype.Service
public class BackendLLMCacheService
implements ILLMCacheService {
    @Autowired
    private TransactionService transactionService;
    private LlmResponseCache<LLMClient.SingleCompletionQuery, LLMClient.CompletionSettings, LLMClient.SimpleCompletionResponseOrError> llmCompletionCache;
    private LlmResponseCache<LLMClient.EmbeddingQuery, LLMClient.EmbeddingSettings, LLMClient.SimpleEmbeddingResponseOrError> embeddingsExtractionCache;
    private LlmResponseCache<LLMClient.RerankingQuery, LLMClient.RerankingSettings, LLMClient.SingleRerankingResponseOrError> rerankingCache;
    private DKUApp.DSSVersion dssVersion;
    private PersistentCacheManager completionCacheManager;
    private PersistentCacheManager embeddingsCacheManager;
    private PersistentCacheManager rerankingCacheManager;
    private File completionCacheDir;
    private File embeddingsCacheDir;
    private File rerankingCacheDir;
    private StatisticsService statisticsService;
    private static final DKULogger logger = DKULogger.getLogger((String)"dku.llm.cache");
    private static final String llmCompletionEhcacheCacheName = "llmCompletion";
    private static final String embeddingsExtractionEhcacheCacheName = "embeddingsExtraction";
    private static final String rerankingEhcacheCacheName = "reranking";

    @VisibleForTesting
    void initCaches() throws IOException {
        logger.info((Object)("Initializing completion LLM cache in " + String.valueOf(this.completionCacheDir)));
        logger.info((Object)("Initializing embeddings LLM cache in " + String.valueOf(this.embeddingsCacheDir)));
        logger.info((Object)("Initializing reranking LLM cache in " + String.valueOf(this.rerankingCacheDir)));
        this.completionCacheManager = null;
        this.embeddingsCacheManager = null;
        this.rerankingCacheManager = null;
        this.statisticsService = new DefaultStatisticsService();
        try {
            this.completionCacheManager = (PersistentCacheManager)CacheManagerBuilder.newCacheManagerBuilder().using((Service)this.statisticsService).with(CacheManagerBuilder.persistence((File)DKUApp.getFile((String[])new String[]{"caches", "llms", CacheType.COMPLETION.name}))).build();
            this.embeddingsCacheManager = (PersistentCacheManager)CacheManagerBuilder.newCacheManagerBuilder().using((Service)this.statisticsService).with(CacheManagerBuilder.persistence((File)DKUApp.getFile((String[])new String[]{"caches", "llms", CacheType.EMBEDDINGS.name}))).build();
            this.rerankingCacheManager = (PersistentCacheManager)CacheManagerBuilder.newCacheManagerBuilder().using((Service)this.statisticsService).with(CacheManagerBuilder.persistence((File)DKUApp.getFile((String[])new String[]{"caches", "llms", CacheType.RERANKING.name}))).build();
            this.completionCacheManager.init();
            this.embeddingsCacheManager.init();
            this.rerankingCacheManager.init();
            this.llmCompletionCache = new LlmResponseCache(this.completionCacheManager.createCache(llmCompletionEhcacheCacheName, BackendLLMCacheService.buildCacheConfiguration(LLMClient.SimpleCompletionResponseOrError.class, new DKUEhcacheSerializer<LLMClient.SimpleCompletionResponseOrError>(LLMClient.SimpleCompletionResponseOrError.class), CacheType.COMPLETION)));
            this.embeddingsExtractionCache = new LlmResponseCache(this.embeddingsCacheManager.createCache(embeddingsExtractionEhcacheCacheName, BackendLLMCacheService.buildCacheConfiguration(LLMClient.SimpleEmbeddingResponseOrError.class, new DKUEhcacheSerializer<LLMClient.SimpleEmbeddingResponseOrError>(LLMClient.SimpleEmbeddingResponseOrError.class), CacheType.EMBEDDINGS)));
            this.rerankingCache = new LlmResponseCache(this.rerankingCacheManager.createCache(rerankingEhcacheCacheName, BackendLLMCacheService.buildCacheConfiguration(LLMClient.SingleRerankingResponseOrError.class, new DKUEhcacheSerializer<LLMClient.SingleRerankingResponseOrError>(LLMClient.SingleRerankingResponseOrError.class), CacheType.RERANKING)));
        }
        catch (Exception e) {
            logger.error((Object)"Failed to initialize LLM caches", (Throwable)e);
            this.closeCacheManager();
            throw e;
        }
    }

    @PostConstruct
    public void init() throws IOException {
        this.dssVersion = DKUApp.getDSSVersion();
        this.completionCacheDir = DKUApp.getFile((String[])new String[]{"caches", "llms", CacheType.COMPLETION.name});
        this.embeddingsCacheDir = DKUApp.getFile((String[])new String[]{"caches", "llms", CacheType.EMBEDDINGS.name});
        this.rerankingCacheDir = DKUApp.getFile((String[])new String[]{"caches", "llms", CacheType.RERANKING.name});
        try {
            this.initCaches();
        }
        catch (Exception e) {
            logger.error((Object)"Clearing all cache data", (Throwable)e);
            if (this.completionCacheDir.exists()) {
                DKUFileUtils.forceDelete((File)this.completionCacheDir);
            }
            if (this.embeddingsCacheDir.exists()) {
                DKUFileUtils.forceDelete((File)this.embeddingsCacheDir);
            }
            if (this.rerankingCacheDir.exists()) {
                DKUFileUtils.forceDelete((File)this.rerankingCacheDir);
            }
            this.initCaches();
        }
        DSSMetrics.registry().registerAll(EhcacheMetrics.cacheMetricsSet("llmCompletions", this.statisticsService, llmCompletionEhcacheCacheName));
        DSSMetrics.registry().registerAll(EhcacheMetrics.cacheMetricsSet("embeddingsExtractions", this.statisticsService, embeddingsExtractionEhcacheCacheName));
        DSSMetrics.registry().registerAll(EhcacheMetrics.cacheMetricsSet(rerankingEhcacheCacheName, this.statisticsService, rerankingEhcacheCacheName));
        Runtime.getRuntime().addShutdownHook(new Thread(this::closeCacheManager));
    }

    @VisibleForTesting
    void closeCacheManager() {
        if (this.completionCacheManager != null) {
            logger.info((Object)"Shutting down persistent completion LLM cache");
            this.completionCacheManager.close();
            this.completionCacheManager = null;
        }
        if (this.embeddingsCacheManager != null) {
            logger.info((Object)"Shutting down persistent embeddings LLM cache");
            this.embeddingsCacheManager.close();
            this.embeddingsCacheManager = null;
        }
        if (this.rerankingCacheManager != null) {
            logger.info((Object)"Shutting down persistent reranking LLM cache");
            this.rerankingCacheManager.close();
            this.rerankingCacheManager = null;
        }
    }

    private static <TResult> CacheConfiguration<String, TResult> buildCacheConfiguration(Class<TResult> valueClass, Serializer<TResult> valueSerializer, CacheType cacheType) {
        int onHeapCachedEntries = DKUApp.getParams().getIntParam("dku.llm.cache." + cacheType.name + ".onHeapCachedEntries", Integer.valueOf(cacheType.defaultOnHeapCachedEntries));
        int offHeapCachedMB = DKUApp.getParams().getIntParam("dku.llm.cache." + cacheType.name + ".offHeapCachedMB", Integer.valueOf(cacheType.defaultOffHeapCachedMB));
        int diskCachedMB = DKUApp.getParams().getIntParam("dku.llm.cache." + cacheType.name + ".diskCachedMB", Integer.valueOf(cacheType.defaultDiskCachedMB));
        int diskCacheSegments = DKUApp.getParams().getIntParam("dku.llm.cache." + cacheType.name + ".diskCacheSegments", Integer.valueOf(cacheType.defaultDiskCacheSegments));
        int cacheExpirationMinutes = DKUApp.getParams().getIntParam("dku.llm.cache." + cacheType.name + ".cacheExpirationMinutes", Integer.valueOf(cacheType.defaultCacheExpirationMinutes));
        CacheConfigurationBuilder cacheConfiguration = CacheConfigurationBuilder.newCacheConfigurationBuilder(String.class, valueClass, (Builder)ResourcePoolsBuilder.heap((long)onHeapCachedEntries).offheap((long)offHeapCachedMB, MemoryUnit.MB).disk((long)diskCachedMB, MemoryUnit.MB, true)).withExpiry(ExpiryPolicyBuilder.timeToLiveExpiration((Duration)Duration.ofMinutes(cacheExpirationMinutes))).withService((ServiceConfiguration)new OffHeapDiskStoreConfiguration(diskCacheSegments)).withValueSerializer(valueSerializer);
        return cacheConfiguration.build();
    }

    private <TQuery, TSettings> String queryCacheKey(ResolvedCacheLocationInfo locationInfo, TQuery query, TSettings settings) {
        MessageDigest digester = DigestUtils.getSha256Digest();
        JSON.updateDigest((MessageDigest)digester, (Object)this.dssVersion);
        JSON.updateDigest((MessageDigest)digester, (Object)locationInfo.getCacheKeyRepresentation());
        JSON.updateDigest((MessageDigest)digester, query);
        JSON.updateDigest((MessageDigest)digester, settings);
        return Hex.encodeHexString((byte[])digester.digest());
    }

    private String getCacheUsageIneligibilityReason(AbstractLLMConnection<?, ?, ?> connection, AuthCtx authCtx) {
        if (connection == null) {
            return "No connection associated to LLM";
        }
        if (!connection.isFreelyUsableBy(authCtx)) {
            return "Connection usage denied";
        }
        return null;
    }

    private ResolvedCacheLocationInfo resolveCacheLocationInfo(AuthCtx authCtx, String llmId, String contextProjectKey) throws IOException, DKUSecurityException {
        ResolvedCacheLocationInfo ret = new ResolvedCacheLocationInfo();
        ret.contextProjectKey = contextProjectKey;
        ret.enrichedLLMRef = LLMStructuredRef.decodeId(llmId);
        try (Transaction t = this.transactionService.retrieveOrBeginRead(IsolationLevel.YOLO);){
            ret.connection = LLMClientFactory.getFinalConnectionForGovernance(authCtx, contextProjectKey, ret.enrichedLLMRef);
        }
        return ret;
    }

    @Override
    public ILLMCacheService.QueryCacheResult<LLMClient.SimpleCompletionResponseOrError> get(AuthCtx authCtx, String llmId, String contextProjectKey, LLMClient.SingleCompletionQuery query, LLMClient.CompletionSettings settings) {
        return this.llmCompletionCache.get(authCtx, llmId, contextProjectKey, query, settings);
    }

    @Override
    public void put(AuthCtx authCtx, String llmId, String contextProjectKey, LLMClient.SingleCompletionQuery query, LLMClient.CompletionSettings settings, LLMClient.SimpleCompletionResponseOrError result) {
        this.llmCompletionCache.put(authCtx, llmId, contextProjectKey, query, settings, result);
    }

    @Override
    public ILLMCacheService.QueryCacheResult<LLMClient.SimpleEmbeddingResponseOrError> get(AuthCtx authCtx, String llmId, String contextProjectKey, LLMClient.EmbeddingQuery query, LLMClient.EmbeddingSettings settings) {
        return this.embeddingsExtractionCache.get(authCtx, llmId, contextProjectKey, query, settings);
    }

    @Override
    public void put(AuthCtx authCtx, String llmId, String contextProjectKey, LLMClient.EmbeddingQuery query, LLMClient.EmbeddingSettings settings, LLMClient.SimpleEmbeddingResponseOrError result) {
        this.embeddingsExtractionCache.put(authCtx, llmId, contextProjectKey, query, settings, result);
    }

    @Override
    public ILLMCacheService.QueryCacheResult<LLMClient.SingleRerankingResponseOrError> get(AuthCtx authCtx, String llmId, String contextProjectKey, LLMClient.RerankingQuery query, LLMClient.RerankingSettings settings) {
        return this.rerankingCache.get(authCtx, llmId, contextProjectKey, query, settings);
    }

    @Override
    public void put(AuthCtx authCtx, String llmId, String contextProjectKey, LLMClient.RerankingQuery query, LLMClient.RerankingSettings settings, LLMClient.SingleRerankingResponseOrError result) {
        this.rerankingCache.put(authCtx, llmId, contextProjectKey, query, settings, result);
    }

    private static enum CacheType {
        COMPLETION("completion", 1000, 1440),
        EMBEDDINGS("embeddings", 2000, 44640),
        RERANKING("reranking", 2000, 44640);

        private final String name;
        private final int defaultOnHeapCachedEntries = 10;
        private final int defaultOffHeapCachedMB = 20;
        private final int defaultDiskCachedMB;
        private final int defaultDiskCacheSegments = 2;
        private final int defaultCacheExpirationMinutes;

        private CacheType(String name, int defaultDiskCachedMB, int cacheExpirationMinutes) {
            this.name = name;
            this.defaultDiskCachedMB = defaultDiskCachedMB;
            this.defaultCacheExpirationMinutes = cacheExpirationMinutes;
        }
    }

    private class LlmResponseCache<TQuery, TSettings, TResult> {
        private final Cache<String, TResult> cache;

        LlmResponseCache(Cache<String, TResult> cache) {
            this.cache = cache;
        }

        private ILLMCacheService.QueryCacheResult<TResult> get(AuthCtx authCtx, String llmId, String contextProjectKey, TQuery query, TSettings settings) {
            try {
                ResolvedCacheLocationInfo locationInfo = BackendLLMCacheService.this.resolveCacheLocationInfo(authCtx, llmId, contextProjectKey);
                ILLMCacheService.QueryCacheResult ret = new ILLMCacheService.QueryCacheResult();
                String cacheIneligibilityReason = BackendLLMCacheService.this.getCacheUsageIneligibilityReason(locationInfo.connection, authCtx);
                if (cacheIneligibilityReason != null) {
                    ret.cacheIneligibilityReason = cacheIneligibilityReason;
                    ret.cacheHit = false;
                    return ret;
                }
                Object cachedResponse = this.cache.get((Object)BackendLLMCacheService.this.queryCacheKey(locationInfo, query, settings));
                if (cachedResponse != null) {
                    ret.cacheHit = true;
                    ret.result = cachedResponse;
                }
                return ret;
            }
            catch (Exception e) {
                logger.warn((Object)"Failed to lookup LLM response in cache", (Throwable)e);
                return new ILLMCacheService.QueryCacheResult();
            }
        }

        private void put(AuthCtx authCtx, String llmId, String contextProjectKey, TQuery query, TSettings settings, TResult result) {
            try {
                ResolvedCacheLocationInfo locationInfo = BackendLLMCacheService.this.resolveCacheLocationInfo(authCtx, llmId, contextProjectKey);
                String cacheIneligibilityReason = BackendLLMCacheService.this.getCacheUsageIneligibilityReason(locationInfo.connection, authCtx);
                if (cacheIneligibilityReason != null) {
                    logger.debug((Object)("Not caching, not eligible: " + cacheIneligibilityReason));
                    return;
                }
                this.cache.put((Object)BackendLLMCacheService.this.queryCacheKey(locationInfo, query, settings), result);
            }
            catch (Exception e) {
                logger.warn((Object)"Failed to put LLM response in cache", (Throwable)e);
            }
        }
    }

    static class ResolvedCacheLocationInfo {
        String contextProjectKey;
        LLMStructuredRef enrichedLLMRef;
        AbstractLLMConnection<?, ?, ?> connection;

        ResolvedCacheLocationInfo() {
        }

        private Object getCacheKeyRepresentation() {
            Object[] objectArray;
            if (this.enrichedLLMRef.isProjectBound()) {
                Object[] objectArray2 = new Object[3];
                objectArray2[0] = this.contextProjectKey;
                objectArray2[1] = this.enrichedLLMRef;
                objectArray = objectArray2;
                objectArray2[2] = this.connection;
            } else {
                Object[] objectArray3 = new Object[2];
                objectArray3[0] = this.enrichedLLMRef;
                objectArray = objectArray3;
                objectArray3[1] = this.connection;
            }
            return objectArray;
        }
    }
}

