/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.server.controllers;

import com.dataiku.dip.analysis.model.preprocessing.SentenceEmbeddingModelMeta;
import com.dataiku.dip.code.CodeEnvSelection;
import com.dataiku.dip.code.CodeEnvSelector;
import com.dataiku.dip.connections.AbstractLLMConnection;
import com.dataiku.dip.connections.ConnectionsDAO;
import com.dataiku.dip.connections.DSSConnection;
import com.dataiku.dip.connections.HuggingFaceLocalConnection;
import com.dataiku.dip.exceptions.DKUSecurityException;
import com.dataiku.dip.llm.EnrichedLLMStructuredRef;
import com.dataiku.dip.llm.PretrainedModelsService;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.security.PermissionsService;
import com.dataiku.dip.security.Privileges;
import com.dataiku.dip.security.auth.UIAuthService;
import com.dataiku.dip.server.controllers.AuditedCall;
import com.dataiku.dip.server.controllers.DIPInternalControllerBase;
import com.dataiku.dip.server.services.ConnectionsService;
import com.dataiku.dip.server.services.ProjectsService;
import com.dataiku.dip.server.services.TransactionService;
import com.dataiku.dip.server.services.licensing.LicenseEnforcementService;
import com.dataiku.dip.transactions.ifaces.Transaction;
import com.dataiku.dip.utils.JSON;
import jakarta.servlet.http.HttpServletRequest;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.commons.lang.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Controller;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestMethod;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.ResponseBody;

@Controller
public class PretrainedModelsController
extends DIPInternalControllerBase {
    @Autowired
    private ConnectionsDAO connectionsDAO;
    @Autowired
    private TransactionService transactionService;
    @Autowired
    private UIAuthService authService;
    @Autowired
    private ProjectsService projectsService;
    @Autowired
    private PretrainedModelsService pretrainedModelsService;
    @Autowired
    private ConnectionsService connectionsService;
    @Autowired
    private LicenseEnforcementService licenseEnforcementService;
    @Autowired
    private PermissionsService permissionsService;
    private final CodeEnvSelector codeEnvSelector = new CodeEnvSelector();

    @AuditedCall(value={"msgType", "get-image-embeddings-pretrained-models"})
    @RequestMapping(value={"/api/pretrained-models/image-embedding"}, method={RequestMethod.GET})
    @ResponseBody
    public List<EmbeddingModel> getImageEmbeddingModels(HttpServletRequest req) throws IOException {
        List<EnrichedLLMStructuredRef> structuredRefsFromConnections;
        AuthCtx authCtx;
        try (Transaction t = this.transactionService.beginRead();){
            authCtx = this.authService.getMandatoryUser(req);
            structuredRefsFromConnections = this.connectionsService.listAvailableConnectionLLMs(authCtx, AbstractLLMConnection.LLMUsagePurpose.IMAGE_EMBEDDING_EXTRACTION);
        }
        boolean canUseModelFromConnections = this.licenseEnforcementService.getUserProfileByNameOrFallback((String)authCtx.getUserProfile()).mayAdvancedVisualML;
        return structuredRefsFromConnections.stream().map(ref -> EmbeddingModel.fromStructuredRef(ref, !canUseModelFromConnections)).collect(Collectors.toList());
    }

    @AuditedCall(value={"msgType", "get-sentence-embeddings-pretrained-models"})
    @RequestMapping(value={"/api/pretrained-models/sentence-embedding"}, method={RequestMethod.GET})
    @ResponseBody
    public List<SentenceEmbeddingModel> getSentenceEmbeddingsModels(HttpServletRequest req, @RequestParam CodeEnvSelection envSelection, @RequestParam String projectKey) throws Exception {
        List<EnrichedLLMStructuredRef> structuredRefsFromConnections;
        AuthCtx authCtx;
        try (Transaction t = this.transactionService.beginRead();){
            authCtx = this.authService.getMandatoryUser(req);
            structuredRefsFromConnections = this.connectionsService.listAvailableConnectionLLMs(authCtx, AbstractLLMConnection.LLMUsagePurpose.TEXT_EMBEDDING_EXTRACTION);
        }
        boolean canUseModelFromConnections = this.licenseEnforcementService.getUserProfileByNameOrFallback((String)authCtx.getUserProfile()).mayAdvancedVisualML;
        String envName = this.codeEnvSelector.selectForDoctor(projectKey, envSelection);
        if (envName == null || "__BUILTIN_ENV__".equals(envName)) {
            return structuredRefsFromConnections.stream().map(ref -> SentenceEmbeddingModel.fromStructuredRef(ref, !canUseModelFromConnections)).collect(Collectors.toList());
        }
        Collection<SentenceEmbeddingModelMeta> resourcesModel = this.pretrainedModelsService.getCodeEnvResourcesSentenceEmbeddingModels(authCtx, envName, projectKey).values();
        return Stream.concat(structuredRefsFromConnections.stream().map(ref -> SentenceEmbeddingModel.fromStructuredRef(ref, !canUseModelFromConnections)), resourcesModel.stream().map(modelMeta -> SentenceEmbeddingModel.fromSentenceEmbeddingModelMeta(modelMeta, envName))).collect(Collectors.toList());
    }

    @AuditedCall(value={"msgType", "llm-mesh-llms-list", "projectKey", "${projectKey}", "purpose", "${purpose}"})
    @RequestMapping(value={"/api/llm/list-available-llms"})
    @ResponseBody
    public PretrainedModelsService.ModelsList listAvailableLLMs(HttpServletRequest req, @RequestParam(required=false) String projectKey, @RequestParam String purpose) throws Exception {
        try (Transaction t = this.transactionService.beginRead();){
            AuthCtx authCtx = this.authService.getMandatoryUser(req);
            if (StringUtils.isNotBlank((String)projectKey)) {
                this.projectsService.checkPerm(authCtx, projectKey, Privileges.ProjectLevelPrivilegeType.READ_CONF);
            }
            PretrainedModelsService.ModelsList modelsList = this.pretrainedModelsService.listAvailableLLMs(authCtx, projectKey, AbstractLLMConnection.LLMUsagePurpose.valueOf(purpose));
            return modelsList;
        }
    }

    @AuditedCall(value={"msgType", "connections-list"})
    @RequestMapping(value={"/api/llm/list-available-llm-connections"})
    @ResponseBody
    public List<BasicLLMConnectionInfo> listAvailableLLMConnections(HttpServletRequest req) throws Exception {
        try (Transaction t = this.transactionService.beginRead();){
            AuthCtx authCtx = this.authService.getMandatoryUser(req);
            ArrayList<BasicLLMConnectionInfo> ret = new ArrayList<BasicLLMConnectionInfo>();
            for (DSSConnection connection : this.connectionsDAO.listUnsafe().values()) {
                if (!(connection instanceof AbstractLLMConnection) || !connection.isFreelyUsableBy(authCtx)) continue;
                BasicLLMConnectionInfo bc = new BasicLLMConnectionInfo();
                bc.name = connection.name;
                bc.type = connection.type;
                ret.add(bc);
            }
            ArrayList<BasicLLMConnectionInfo> arrayList = ret;
            return arrayList;
        }
    }

    @AuditedCall(value={"msgType", "llm-mesh-llms-list", "purpose", "${purpose}"})
    @RequestMapping(value={"/api/llm/list-available-connection-llms"})
    @ResponseBody
    public PretrainedModelsService.ModelsList listAvailableConnectionLLMs(HttpServletRequest req, @RequestParam String purpose) throws Exception {
        try (Transaction t = this.transactionService.beginRead();){
            AuthCtx authCtx = this.authService.getMandatoryUser(req);
            PretrainedModelsService.ModelsList modelsList = this.pretrainedModelsService.listAvailableConnectionLLMs(authCtx, AbstractLLMConnection.LLMUsagePurpose.valueOf(purpose));
            return modelsList;
        }
    }

    @AuditedCall(value={"msgType", "connection-get-info", "connectionName", "${connectionName}"})
    @RequestMapping(value={"/api/llm/get-huggingface-connection-non-sensitive-params"})
    @ResponseBody
    public HuggingFaceLocalConnection.HuggingFaceLocalConnectionParams getHFConnectionParams(HttpServletRequest req, @RequestParam String connectionName) throws Exception {
        try (Transaction t = this.transactionService.beginRead();){
            AuthCtx authCtx = this.authService.getMandatoryUser(req);
            DSSConnection connection = this.connectionsDAO.getMandatoryConnectionUnsafeUnexpanded(authCtx, connectionName);
            if (!connection.isFreelyUsableBy(authCtx)) {
                throw new DKUSecurityException("Access to connection denied");
            }
            if (!(connection instanceof HuggingFaceLocalConnection)) {
                throw new DKUSecurityException("Invalid connection");
            }
            HuggingFaceLocalConnection.HuggingFaceLocalConnectionParams hp = (HuggingFaceLocalConnection.HuggingFaceLocalConnectionParams)JSON.deepCopy((Object)((HuggingFaceLocalConnection)connection).params);
            hp.apiKey = null;
            hp.dkuProperties = null;
            hp.guardrailsPipelineSettings = null;
            HuggingFaceLocalConnection.HuggingFaceLocalConnectionParams huggingFaceLocalConnectionParams = hp;
            return huggingFaceLocalConnectionParams;
        }
    }

    static class BasicLLMConnectionInfo {
        String type;
        String name;

        BasicLLMConnectionInfo() {
        }
    }

    private static class SentenceEmbeddingModel
    extends EmbeddingModel {
        public Integer maxTokensLimit;
        public boolean compat;
        public String type;
        public boolean isStructuredRef;

        private SentenceEmbeddingModel() {
        }

        public static SentenceEmbeddingModel fromSentenceEmbeddingModelMeta(SentenceEmbeddingModelMeta modelMeta, String envName) {
            SentenceEmbeddingModel model = new SentenceEmbeddingModel();
            model.modelRefId = modelMeta.name;
            model.modelFriendlyName = modelMeta.name;
            model.modelType = "Resources models (code env: '" + envName + "')";
            model.maxTokensLimit = modelMeta.maxPositionEmbeddings;
            model.compat = modelMeta.compat;
            model.type = modelMeta.type;
            model.isStructuredRef = false;
            model.embeddingSize = null;
            model.disabled = false;
            return model;
        }

        public static SentenceEmbeddingModel fromStructuredRef(EnrichedLLMStructuredRef structuredRef, boolean disabled) {
            SentenceEmbeddingModel model = new SentenceEmbeddingModel();
            model.modelRefId = structuredRef.id;
            model.modelFriendlyName = structuredRef.friendlyName;
            model.modelType = structuredRef.type.toString();
            model.maxTokensLimit = structuredRef.maxTokensLimit;
            model.compat = true;
            model.type = "";
            model.isStructuredRef = true;
            model.embeddingSize = structuredRef.embeddingSize;
            model.disabled = disabled;
            return model;
        }
    }

    public static class EmbeddingModel {
        public String modelRefId;
        public String modelFriendlyName;
        public String modelType;
        public Integer embeddingSize;
        public boolean disabled;

        public static EmbeddingModel fromStructuredRef(EnrichedLLMStructuredRef structuredRef, boolean disabled) {
            EmbeddingModel embeddingModel = new EmbeddingModel();
            embeddingModel.modelRefId = structuredRef.id;
            embeddingModel.modelFriendlyName = structuredRef.friendlyName;
            embeddingModel.modelType = structuredRef.type.toString();
            embeddingModel.embeddingSize = structuredRef.embeddingSize;
            embeddingModel.disabled = disabled;
            return embeddingModel;
        }
    }
}

