/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.llm;

import com.dataiku.dip.DKUApp;
import com.dataiku.dip.analysis.ml.llm.LLMSMMgmtService;
import com.dataiku.dip.analysis.model.llm.LLMModelSnippetData;
import com.dataiku.dip.connections.AbstractLLMConnection;
import com.dataiku.dip.connections.ConnectionsDAO;
import com.dataiku.dip.connections.DSSConnection;
import com.dataiku.dip.connections.HuggingFaceLocalConnection;
import com.dataiku.dip.dao.SavedModel;
import com.dataiku.dip.dao.SavedModelsDAO;
import com.dataiku.dip.exceptions.DKUSecurityException;
import com.dataiku.dip.exceptions.UnauthorizedException;
import com.dataiku.dip.llm.EnrichedLLMStructuredRef;
import com.dataiku.dip.llm.LLMModelHandle;
import com.dataiku.dip.llm.LLMStructuredRef;
import com.dataiku.dip.llm.retrieval.RAGLLMSettings;
import com.dataiku.dip.savedmodels.agents.AgentTypesRegistry;
import com.dataiku.dip.savedmodels.agents.LoadedCustomAgent;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.server.services.TransactionService;
import com.dataiku.dip.transactions.ifaces.IsolationLevel;
import com.dataiku.dip.transactions.ifaces.Transaction;
import com.dataiku.dip.util.AnyLoc;
import com.dataiku.dip.utils.DKULogger;
import java.io.IOException;
import java.util.Arrays;
import java.util.Objects;
import java.util.Optional;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;

@Service
public class LLMRefEnricherService {
    @Autowired
    private TransactionService transactionService;
    @Autowired
    private ConnectionsDAO connectionsDAO;
    @Autowired
    private SavedModelsDAO savedModelsDAO;
    private static DKULogger logger = DKULogger.getLogger((String)"dku.llm-ref-enricher");

    private EnrichedLLMStructuredRef getEnrichedLLMRefFromConnection(AuthCtx authCtx, LLMStructuredRef llmRef) throws Exception {
        AbstractLLMConnection llmConnection;
        try (Transaction t = this.transactionService.retrieveOrBeginRead(IsolationLevel.YOLO);){
            DSSConnection conn = this.connectionsDAO.getMandatoryConnection(authCtx, llmRef.connection);
            if (!conn.isFreelyUsableBy(authCtx)) {
                throw new UnauthorizedException("You may not use this connection", "denied");
            }
            if (!(conn instanceof AbstractLLMConnection)) {
                throw new IllegalArgumentException(String.format("Connection %s is not a LLM connection", conn.name));
            }
            llmConnection = (AbstractLLMConnection)conn;
        }
        return llmConnection.getLLMModel(llmRef).getEnrichedRef();
    }

    private static LLMStructuredRef getLLMRefFromAgentSM(SavedModel sm, String contextProjectKey, LLMSMMgmtService.LLMSMVersionHeader versionHeader) {
        if (((LLMModelSnippetData)versionHeader.snippet).llmSMInfo == null || !Objects.equals(sm.activeVersion, versionHeader.versionId)) {
            return null;
        }
        return LLMStructuredRef.forFinetunedSavedModelVersion(((LLMModelSnippetData)versionHeader.snippet).llmSMInfo.llmType, ((LLMModelSnippetData)versionHeader.snippet).llmSMInfo.connection, new AnyLoc(sm.projectKey, sm.id).getSmartName(contextProjectKey), versionHeader.versionId);
    }

    public static String getLLMIdFromLLMGenericSM(SavedModel sm, String contextProjectKey) {
        LLMSMMgmtService.LLMSMStatus status = LLMSMMgmtService.getStatus_NT(sm);
        for (LLMSMMgmtService.LLMSMVersionHeader versionHeader : status.versions) {
            LLMStructuredRef savedModelLLmRef = LLMRefEnricherService.getLLMRefFromAgentSM(sm, contextProjectKey, versionHeader);
            if (savedModelLLmRef == null) continue;
            return savedModelLLmRef.id;
        }
        return null;
    }

    public static EnrichedLLMStructuredRef getEnrichedLLMRefFromLLMGenericSM(AuthCtx authCtx, SavedModel sm, String contextProjectKey) throws IOException, DKUSecurityException {
        LLMSMMgmtService.LLMSMStatus status = LLMSMMgmtService.getStatus_NT(sm);
        for (LLMSMMgmtService.LLMSMVersionHeader versionHeader : status.versions) {
            LLMStructuredRef savedModelLLmRef = LLMRefEnricherService.getLLMRefFromAgentSM(sm, contextProjectKey, versionHeader);
            if (savedModelLLmRef == null) continue;
            String friendlyName = savedModelLLmRef.getFinetunedModelMetaName() + " - " + sm.name + " - " + versionHeader.versionId;
            EnrichedLLMStructuredRef enrichedRef = new EnrichedLLMStructuredRef(savedModelLLmRef, friendlyName, friendlyName);
            enrichedRef.handlesSystemMessage = ((LLMModelSnippetData)versionHeader.snippet).llmSMInfo.handlesSystemMessage != null ? ((LLMModelSnippetData)versionHeader.snippet).llmSMInfo.handlesSystemMessage : false;
            enrichedRef.embeddingSize = ((LLMModelSnippetData)versionHeader.snippet).llmSMInfo.embeddingSize;
            enrichedRef.maxTokensLimit = ((LLMModelSnippetData)versionHeader.snippet).llmSMInfo.maxTokensLimit;
            switch (((LLMModelSnippetData)versionHeader.snippet).llmSMInfo.llmType) {
                case SAVED_MODEL_FINETUNED_OPENAI: 
                case SAVED_MODEL_FINETUNED_AZURE_OPENAI: 
                case SAVED_MODEL_FINETUNED_BEDROCK: {
                    AbstractLLMConnection connection = ConnectionsDAO.get().getMandatoryConnectionAs(authCtx, enrichedRef.connection, AbstractLLMConnection.class);
                    enrichedRef.connectionDescription = connection.description;
                    LLMModelHandle.Model model = connection.getLLMModel(LLMStructuredRef.decodeId(((LLMModelSnippetData)versionHeader.snippet).llmSMInfo.originalLLMId)).getModel();
                    enrichedRef.loadModelCapabilities(model.getModelCapabilities());
                    if (((LLMModelSnippetData)versionHeader.snippet).deployment != null) {
                        enrichedRef.deployment = ((LLMModelSnippetData)versionHeader.snippet).deployment.deploymentId;
                    } else if (Arrays.asList(LLMStructuredRef.LLMType.SAVED_MODEL_FINETUNED_AZURE_OPENAI, LLMStructuredRef.LLMType.SAVED_MODEL_FINETUNED_BEDROCK).contains((Object)((LLMModelSnippetData)versionHeader.snippet).llmSMInfo.llmType)) {
                        enrichedRef.friendlyName = enrichedRef.friendlyName + " (no attached deployment)";
                    }
                    enrichedRef.canBeFinetuned = connection.getLLMConnectionParams().allowFinetuning;
                    enrichedRef.supportsImageInputs = AbstractLLMConnection.LLMUsagePurpose.IMAGE_INPUT.supportedBySavedModelLLMs && enrichedRef.supportsImageInputs;
                    return enrichedRef;
                }
                case SAVED_MODEL_FINETUNED_HUGGINGFACE_TRANSFORMER: {
                    AbstractLLMConnection connection = ConnectionsDAO.get().getMandatoryConnectionAs(authCtx, enrichedRef.connection, HuggingFaceLocalConnection.class);
                    enrichedRef.connectionDescription = connection.description;
                    enrichedRef.loadModelCapabilities(((HuggingFaceLocalConnection)connection).getLLMModelFromSMInfo(((LLMModelSnippetData)versionHeader.snippet).llmSMInfo).getModelCapabilities());
                    enrichedRef.canBeFinetuned = connection.getLLMConnectionParams().allowFinetuning;
                    enrichedRef.supportsImageInputs = AbstractLLMConnection.LLMUsagePurpose.IMAGE_INPUT.supportedBySavedModelLLMs && enrichedRef.supportsImageInputs;
                    return enrichedRef;
                }
            }
        }
        return null;
    }

    public EnrichedLLMStructuredRef getEnrichedLLMRefFromAgentSM(AuthCtx authCtx, String contextProjectKey, LLMStructuredRef llmRef) throws Exception {
        SavedModel sm;
        AnyLoc loc = AnyLoc.resolveSmart(contextProjectKey, llmRef.savedModelSmartId);
        try (Transaction t = this.transactionService.retrieveOrBeginRead(IsolationLevel.YOLO);){
            sm = (SavedModel)this.savedModelsDAO.getMandatory(loc);
        }
        if (llmRef.savedModelVersionId == null) {
            return this.getEnrichedLLMRefFromAgentSM(authCtx, sm, contextProjectKey);
        }
        return this.getEnrichedLLMRefFromAgentSMVersion(authCtx, sm, llmRef.savedModelVersionId, contextProjectKey);
    }

    private String getLLMIdFromAgentSM(SavedModel sm, String contextProjectKey) {
        String savedModelSmartId = new AnyLoc(sm.projectKey, sm.id).getSmartName(contextProjectKey);
        LLMStructuredRef llmRef = LLMStructuredRef.forAgentSavedModel(savedModelSmartId);
        return llmRef.id;
    }

    public EnrichedLLMStructuredRef getEnrichedLLMRefFromAgentSM(AuthCtx authCtx, SavedModel sm, String contextProjectKey) throws Exception {
        return this.getEnrichedLLMRefFromAgentSMVersion(authCtx, sm, sm.activeVersion, contextProjectKey, true);
    }

    public EnrichedLLMStructuredRef getEnrichedLLMRefFromAgentSMVersion(AuthCtx authCtx, SavedModel sm, String versionId, String contextProjectKey) throws Exception {
        return this.getEnrichedLLMRefFromAgentSMVersion(authCtx, sm, versionId, contextProjectKey, false);
    }

    private EnrichedLLMStructuredRef getEnrichedLLMRefFromAgentSMVersion(AuthCtx authCtx, SavedModel sm, String versionId, String contextProjectKey, boolean activeVersion) throws Exception {
        SavedModel.SavedModelInlineVersion smiv;
        LLMStructuredRef llmRef;
        String versionText = activeVersion ? "active version" : String.format("version %s", versionId);
        String errorMessage = String.format("Couldn't find %s in saved model %s", versionText, sm.id);
        String savedModelSmartId = new AnyLoc(sm.projectKey, sm.id).getSmartName(contextProjectKey);
        String friendlyName = "Agent - " + sm.name;
        if (activeVersion) {
            llmRef = LLMStructuredRef.forAgentSavedModel(savedModelSmartId);
            smiv = sm.getVersion(sm.activeVersion).orElseThrow(() -> new IllegalArgumentException(errorMessage));
        } else {
            llmRef = LLMStructuredRef.forAgentSavedModelVersion(savedModelSmartId, versionId);
            smiv = sm.getVersion(versionId).orElseThrow(() -> new IllegalArgumentException(errorMessage));
            friendlyName = friendlyName + " - " + smiv.versionId;
        }
        EnrichedLLMStructuredRef enrichedRef = new EnrichedLLMStructuredRef(llmRef, friendlyName, friendlyName);
        enrichedRef.promptDriven = true;
        enrichedRef.supportsImageInputs = this.doesAgentSupportImages(sm.savedModelType, smiv, authCtx, contextProjectKey);
        enrichedRef.canBeFinetuned = false;
        return enrichedRef;
    }

    /*
     * Unable to fully structure code
     */
    private boolean doesAgentSupportImages(SavedModel.SavedModelType smType, SavedModel.SavedModelInlineVersion smiv, AuthCtx authCtx, String contextProjectKey) throws Exception {
        block10: {
            block12: {
                block13: {
                    block11: {
                        if (smType != SavedModel.SavedModelType.PYTHON_AGENT) break block11;
                        supportImages = smiv.pythonAgentSettings.supportsImageInputs;
                        break block10;
                    }
                    if (smType != SavedModel.SavedModelType.TOOLS_USING_AGENT) break block12;
                    if (!DKUApp.getParams().getBoolParam("dku.agents.visual.agenticLoopV2", true)) break block13;
                    switch (1.$SwitchMap$com$dataiku$dip$dao$SavedModel$ToolsUsingAgentSettings$ImageSupportMode[smiv.toolsUsingAgentSettings.imageSupportMode.ordinal()]) {
                        case 1: {
                            if (smiv.toolsUsingAgentSettings.mode != SavedModel.ToolsUsingAgentMode.BLOCKS_GRAPH) ** GOTO lbl12
                            supportImages = true;
                            break block10;
lbl12:
                            // 1 sources

                            if (smiv.toolsUsingAgentSettings.llmId != null && !smiv.toolsUsingAgentSettings.allowAgentsAsLLM) {
                                try {
                                    baseLLMRef = this.getEnrichedLLMRef(smiv.toolsUsingAgentSettings.llmId, authCtx, contextProjectKey);
                                    supportImages = baseLLMRef.supportsImageInputs;
                                    break block10;
                                }
                                catch (Exception e) {
                                    LLMRefEnricherService.logger.warn((Object)"Couldn't determine image support from base LLM, defaulting to enabled image support", (Throwable)e);
                                }
                            }
                        }
                        case 2: {
                            supportImages = true;
                            break block10;
                        }
                        case 3: {
                            supportImages = false;
                            break block10;
                        }
                        default: {
                            throw new Error("Unknown image support mode: " + String.valueOf((Object)smiv.toolsUsingAgentSettings.imageSupportMode));
                        }
                    }
                }
                supportImages = false;
                break block10;
            }
            if (smType == SavedModel.SavedModelType.PLUGIN_AGENT) {
                meta = AgentTypesRegistry.getMeta(smiv.pluginAgentType);
                loadedDesc = (LoadedCustomAgent)meta.getLoadedDesc();
                supportImages = loadedDesc.desc.supportsImageInputs;
            } else {
                throw new Error("Unknown agent saved model type: " + String.valueOf((Object)smType));
            }
        }
        return AbstractLLMConnection.LLMUsagePurpose.IMAGE_INPUT.supportedBySavedModelLLMs != false && supportImages != false;
    }

    public String getLLMIdFromSM(SavedModel sm, String contextProjectKey) {
        if (sm.savedModelType == null) {
            return null;
        }
        switch (sm.savedModelType) {
            case PYTHON_AGENT: 
            case PLUGIN_AGENT: 
            case TOOLS_USING_AGENT: {
                return this.getLLMIdFromAgentSM(sm, contextProjectKey);
            }
            case LLM_GENERIC: {
                return LLMRefEnricherService.getLLMIdFromLLMGenericSM(sm, contextProjectKey);
            }
            case RETRIEVAL_AUGMENTED_LLM: {
                return this.getLLMIdForRetrievalAugmentedLLMSM(sm);
            }
        }
        return null;
    }

    public EnrichedLLMStructuredRef getEnrichedLLMRefFromSM(AuthCtx authCtx, SavedModel sm, String contextProjectKey) throws Exception {
        if (sm.savedModelType == null) {
            return null;
        }
        switch (sm.savedModelType) {
            case PYTHON_AGENT: 
            case PLUGIN_AGENT: 
            case TOOLS_USING_AGENT: {
                return this.getEnrichedLLMRefFromAgentSM(authCtx, sm, contextProjectKey);
            }
            case LLM_GENERIC: {
                return LLMRefEnricherService.getEnrichedLLMRefFromLLMGenericSM(authCtx, sm, contextProjectKey);
            }
            case RETRIEVAL_AUGMENTED_LLM: {
                return this.getEnrichedLLMRefFromRetrievalAugmentedLLMSM(authCtx, sm, contextProjectKey);
            }
        }
        return null;
    }

    public EnrichedLLMStructuredRef getEnrichedLLMRefFromSM(AuthCtx authCtx, String contextProjectKey, LLMStructuredRef llmRef) throws Exception {
        SavedModel sm;
        AnyLoc loc = AnyLoc.resolveSmart(contextProjectKey, llmRef.savedModelSmartId);
        try (Transaction t = this.transactionService.retrieveOrBeginRead(IsolationLevel.YOLO);){
            sm = (SavedModel)this.savedModelsDAO.getMandatory(loc);
        }
        return this.getEnrichedLLMRefFromSM(authCtx, sm, contextProjectKey);
    }

    private String getLLMIdForRetrievalAugmentedLLMSM(SavedModel sm) {
        Optional<SavedModel.SavedModelInlineVersion> smivO = sm.getVersion(sm.activeVersion);
        if (smivO.isEmpty()) {
            return null;
        }
        SavedModel.SavedModelInlineVersion smiv = smivO.get();
        if (smiv.ragllmSettings == null) {
            return null;
        }
        return smiv.ragllmSettings.llmId;
    }

    private EnrichedLLMStructuredRef getEnrichedLLMRefFromRetrievalAugmentedLLMSM(AuthCtx authCtx, SavedModel sm, String contextProjectKey) throws Exception {
        String errorMessage = String.format("Couldn't find active version in saved model %s", sm.id);
        SavedModel.SavedModelInlineVersion smiv = sm.getVersion(sm.activeVersion).orElseThrow(() -> new IllegalArgumentException(errorMessage));
        LLMStructuredRef llmRef = LLMStructuredRef.forRetrievalAugmentedLLM(AnyLoc.resolveSmart(sm.projectKey, sm.id).getSmartName(contextProjectKey));
        return this.getEnrichedLLMRefFromRetrievalAugmentedLLM(authCtx, contextProjectKey, sm.name, smiv.ragllmSettings, llmRef);
    }

    private EnrichedLLMStructuredRef getEnrichedLLMRefFromRetrievalAugmentedLLM(AuthCtx authCtx, String projectKey, String modelName, RAGLLMSettings settings, LLMStructuredRef llmRef) throws Exception {
        if (Objects.equals(settings.llmId, llmRef.id)) {
            throw new IllegalArgumentException("The underlying LLM cannot be the same as the RAG model");
        }
        EnrichedLLMStructuredRef originalLLMRef = this.getEnrichedLLMRef(settings.llmId, authCtx, projectKey);
        if (originalLLMRef.type == LLMStructuredRef.LLMType.RETRIEVAL_AUGMENTED) {
            throw new IllegalArgumentException("Double-RAG LLMs are not supported.");
        }
        String friendlyName = modelName + ", using " + originalLLMRef.friendlyNameShort;
        EnrichedLLMStructuredRef enrichedLLMRef = new EnrichedLLMStructuredRef(llmRef, friendlyName, friendlyName);
        switch (originalLLMRef.type) {
            case SAVED_MODEL_FINETUNED_OPENAI: 
            case SAVED_MODEL_FINETUNED_AZURE_OPENAI: 
            case SAVED_MODEL_FINETUNED_BEDROCK: 
            case SAVED_MODEL_FINETUNED_HUGGINGFACE_TRANSFORMER: 
            case SAVED_MODEL_AGENT: {
                enrichedLLMRef.loadModelCapabilities(originalLLMRef.getModelCapabilities());
                break;
            }
            default: {
                DSSConnection connection = this.connectionsDAO.getConnection(authCtx, originalLLMRef.connection);
                if (!(connection instanceof AbstractLLMConnection)) {
                    throw new IllegalArgumentException("Connection is not a LLM connection");
                }
                LLMModelHandle.Model originalModel = ((AbstractLLMConnection)connection).getLLMModel(originalLLMRef).getModel();
                enrichedLLMRef.loadModelCapabilities(originalModel.getModelCapabilities());
            }
        }
        enrichedLLMRef.supportsImageInputs = AbstractLLMConnection.LLMUsagePurpose.IMAGE_INPUT.supportedByKbAugmentedModels && enrichedLLMRef.supportsImageInputs;
        enrichedLLMRef.canBeFinetuned = false;
        return enrichedLLMRef;
    }

    public EnrichedLLMStructuredRef getEnrichedLLMRefFromRetrievalAugmentedLLM(AuthCtx authCtx, String projectKey, LLMStructuredRef llmRef) throws Exception {
        SavedModel sm;
        AnyLoc locSM = AnyLoc.resolveSmart(projectKey, llmRef.savedModelSmartId);
        try (Transaction ignored = this.transactionService.retrieveOrBeginRead(IsolationLevel.YOLO);){
            sm = (SavedModel)this.savedModelsDAO.getMandatory(locSM);
        }
        SavedModel.SavedModelInlineVersion smiv = sm.getVersion(sm.activeVersion).orElseThrow();
        return this.getEnrichedLLMRefFromRetrievalAugmentedLLM(authCtx, projectKey, sm.name, smiv.ragllmSettings, llmRef);
    }

    public EnrichedLLMStructuredRef getEnrichedLLMRef(String llmId, AuthCtx authCtx, String projectKey) throws Exception {
        LLMStructuredRef llmRef = LLMStructuredRef.decodeId(llmId);
        switch (llmRef.type) {
            case ANTHROPIC: 
            case AZURE_OPENAI_DEPLOYMENT: 
            case AZURE_OPENAI_MODEL: 
            case AZURE_AI_FOUNDRY_DEPLOYMENT: 
            case BEDROCK: 
            case COHERE: 
            case HUGGINGFACE_API: 
            case HUGGINGFACE_TRANSFORMER_LOCAL: 
            case MOSAICML: 
            case OPENAI: 
            case NVIDIA_NIM: 
            case MISTRALAI: 
            case VERTEX: 
            case DATABRICKS: 
            case SNOWFLAKE_CORTEX: 
            case CUSTOM: 
            case SAGEMAKER_GENERICLLM: 
            case AZURE_LLM: {
                return this.getEnrichedLLMRefFromConnection(authCtx, llmRef);
            }
            case SAVED_MODEL_FINETUNED_OPENAI: 
            case SAVED_MODEL_FINETUNED_AZURE_OPENAI: 
            case SAVED_MODEL_FINETUNED_BEDROCK: 
            case SAVED_MODEL_FINETUNED_HUGGINGFACE_TRANSFORMER: 
            case RETRIEVAL_AUGMENTED: {
                return this.getEnrichedLLMRefFromSM(authCtx, projectKey, llmRef);
            }
            case SAVED_MODEL_AGENT: {
                return this.getEnrichedLLMRefFromAgentSM(authCtx, projectKey, llmRef);
            }
            case SAVED_MODEL_FINETUNED_VERTEX: {
                throw new IllegalArgumentException("Not implemented");
            }
        }
        throw new IllegalArgumentException("Unreachable");
    }
}

