/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.savedmodels;

import com.dataiku.dip.analysis.coreservices.flow.SavedModelsCRUDService;
import com.dataiku.dip.analysis.ml.MLPaths;
import com.dataiku.dip.coremodel.VersionTag;
import com.dataiku.dip.dao.SavedModel;
import com.dataiku.dip.dao.SavedModelsDAO;
import com.dataiku.dip.llm.LLMStructuredRef;
import com.dataiku.dip.llm.retrieval.RAGLLMSettings;
import com.dataiku.dip.llm.retrieval.RetrievableKnowledge;
import com.dataiku.dip.llm.retrieval.RetrievableKnowledgeDAO;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.security.impersonation.FilesystemACLUtils;
import com.dataiku.dip.server.services.FlowZonesService;
import com.dataiku.dip.server.services.TransactionService;
import com.dataiku.dip.transactions.ifaces.RWTransaction;
import com.dataiku.dip.transactions.ifaces.Transaction;
import com.dataiku.dip.util.AnyLoc;
import com.dataiku.dip.util.SecretKeyGenerator;
import com.dataiku.dip.utils.DKUFileUtils;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.JSON;
import java.io.File;
import java.util.List;
import java.util.Objects;
import org.apache.commons.lang.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;

@Service
public class SavedModelsRetrievalAugmentedLlmService {
    @Autowired
    private TransactionService transactionService;
    @Autowired
    private SavedModelsDAO savedModelsDAO;
    @Autowired
    private SavedModelsCRUDService savedModelsCRUDService;
    @Autowired
    private RetrievableKnowledgeDAO retrievableKnowledgeDAO;
    @Autowired
    private FlowZonesService flowZonesService;
    static DKULogger logger = DKULogger.getLogger((String)"dku.llm.rag.service");

    public SavedModel createRetrievalAugmentedLlmAndVersion(AuthCtx authCtx, String projectKey, String name, String knowledgeBankRef, String llmId) throws Exception {
        String zoneId;
        RetrievableKnowledge rk;
        AnyLoc rkSmartLoc = AnyLoc.resolveSmart(projectKey, knowledgeBankRef);
        try (Transaction ignored = this.transactionService.beginRead();){
            rk = (RetrievableKnowledge)this.retrievableKnowledgeDAO.getOrNull(rkSmartLoc);
            if (rk == null) {
                throw new IllegalArgumentException(String.format("The requested RetrievableKnowledge %s doesn't exist", knowledgeBankRef));
            }
            zoneId = this.flowZonesService.retrieveZone(projectKey, rk);
        }
        SavedModel sm = new SavedModel();
        sm.projectKey = projectKey;
        sm.id = SecretKeyGenerator.generate((int)8);
        sm.name = name;
        sm.savedModelType = SavedModel.SavedModelType.RETRIEVAL_AUGMENTED_LLM;
        File targetSMFolder = MLPaths.savedModelBaseFolder(sm.projectKey, sm.id);
        DKUFileUtils.mkdirs((File)targetSMFolder);
        FilesystemACLUtils.restrictRwxToDSSIfImpersonationEnabled(targetSMFolder);
        FilesystemACLUtils.grantFSReadACLs(authCtx, sm.projectKey, targetSMFolder);
        SavedModel.SavedModelInlineVersion smiv = new SavedModel.SavedModelInlineVersion();
        smiv.versionId = "v1";
        smiv.creationTag = new VersionTag(authCtx.getIdentifier());
        smiv.versionTag = new VersionTag(authCtx.getIdentifier());
        sm.inlineVersions.add(smiv);
        sm.activeVersion = smiv.versionId;
        smiv.ragllmSettings.kbRef = knowledgeBankRef;
        smiv.ragllmSettings.llmId = llmId;
        smiv.ragllmSettings.retrievalSource = StringUtils.isNotEmpty((String)rk.multimodalColumn) ? RAGLLMSettings.RetrievalSource.MULTIMODAL : RAGLLMSettings.RetrievalSource.EMBEDDING;
        this.checkRAGModelValidity(smiv.ragllmSettings);
        try (RWTransaction t = this.transactionService.beginWriteAsLoggedInUser(authCtx);){
            if (this.savedModelsDAO.getOrNull(sm.projectKey, sm.id) != null) {
                throw new IllegalArgumentException("Retrieval Augmented LLM " + sm.id + " already exists");
            }
            this.savedModelsCRUDService.save(sm, true, false);
            if (!zoneId.isEmpty()) {
                this.flowZonesService.attachObjectToZone(zoneId, projectKey, sm, true);
            }
            t.commit("Created Retrieval Augmented LLM: " + name + ": " + sm.id);
            SavedModel savedModel = sm;
            return savedModel;
        }
    }

    public SavedModel save(AuthCtx authCtx, SavedModel sm) throws Exception {
        this.checkRAGModelValidity(sm);
        SavedModel preExisting = this.savedModelsCRUDService.getMandatory(sm.projectKey, sm.id);
        SavedModel.SavedModelInlineVersion preExistingInlineVersion = preExisting.getActiveSaveModelInlineVersion();
        SavedModel.SavedModelInlineVersion currentInlineVersion = sm.getActiveSaveModelInlineVersion();
        currentInlineVersion.versionTag = VersionTag.increment(preExistingInlineVersion.versionTag, authCtx.getIdentifier());
        currentInlineVersion.creationTag = (VersionTag)JSON.deepCopy((Object)preExistingInlineVersion.creationTag);
        return this.savedModelsCRUDService.save(sm, false, false);
    }

    private void checkRAGModelValidity(SavedModel sm) {
        for (SavedModel.SavedModelInlineVersion smiv : sm.inlineVersions) {
            this.checkRAGModelValidity(smiv.ragllmSettings);
        }
    }

    public List<SavedModel> listRetrievalAugmentedLLMFromKbRef(String projectKey, String kbRef) throws Exception {
        List savedModelList = this.savedModelsDAO.listUnsafe(projectKey);
        return savedModelList.stream().filter(sm -> {
            if (sm.savedModelType != SavedModel.SavedModelType.RETRIEVAL_AUGMENTED_LLM) {
                return false;
            }
            SavedModel.SavedModelInlineVersion smiv = sm.getVersion(sm.activeVersion).orElseThrow(() -> new IllegalArgumentException(String.format("Couldn't find active version in saved model %s", sm.id)));
            return kbRef.equals(smiv.ragllmSettings.kbRef);
        }).toList();
    }

    private void checkRAGModelValidity(RAGLLMSettings settings) {
        if (StringUtils.isBlank((String)settings.llmId)) {
            throw new IllegalArgumentException("Missing LLM to use in the settings of the RAG model");
        }
        if (Objects.equals((Object)LLMStructuredRef.decodeId((String)settings.llmId).type, (Object)LLMStructuredRef.LLMType.RETRIEVAL_AUGMENTED)) {
            throw new IllegalArgumentException(String.format("Double RAG is not supported, please chose another underlying model instead of: %s", settings.llmId));
        }
    }
}

