/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.connections;

import com.dataiku.dip.connections.AbstractLLMConnection;
import com.dataiku.dip.connections.ConnectionWithBasicCredential;
import com.dataiku.dip.connections.ConnectionWithDatabricksCredentials;
import com.dataiku.dip.connections.ConnectionsDAO;
import com.dataiku.dip.connections.DatabricksModelDeploymentConnection;
import com.dataiku.dip.dao.GeneralSettingsDAO;
import com.dataiku.dip.exceptions.DKUSecurityException;
import com.dataiku.dip.llm.EnrichedLLMStructuredRef;
import com.dataiku.dip.llm.LLMStructuredRef;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.security.PasswordEncryptionService;
import com.dataiku.dip.server.SpringUtils;
import com.dataiku.dip.server.services.ConnectionsTestService;
import com.dataiku.dip.server.services.TransactionService;
import com.dataiku.dip.transactions.ifaces.IsolationLevel;
import com.dataiku.dip.transactions.ifaces.Transaction;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.variables.VariablesContext;
import java.io.IOException;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.EnumSet;
import java.util.List;
import java.util.Set;
import java.util.function.Function;
import javax.annotation.Nonnull;

public class DatabricksLLMConnection
extends AbstractLLMConnection<DatabricksLLMModel, HardcodedDatabricksLLMModel, CustomDatabricksLLMModel> {
    private static final EnrichedLLMStructuredRef.FieldRange TEMPERATURE_RANGE = new EnrichedLLMStructuredRef.FieldRange(0.0, 2.0, 0.01);
    private static final EnrichedLLMStructuredRef.FieldRange TOP_K_RANGE = new EnrichedLLMStructuredRef.FieldRange(0.0, 1.0E7, 1.0);
    public DatabricksLLMConnectionParams params = new DatabricksLLMConnectionParams();
    public static final String connectionType = "DatabricksLLM";
    private static final DKULogger logger = DKULogger.getLogger((String)"dip.connections.databricks-llm");

    @Override
    public void expandParametersInPlaceAtDAOLevelUsingGlobalContextOnly(VariablesContext vc) {
    }

    @Override
    protected <T> T getFullyResolvedCredentials_internal(ConnectionWithBasicCredential.CredentialResolutionContext ctx, Class<T> clazz) throws DKUSecurityException, IOException, SQLException {
        DatabricksModelDeploymentConnection dbDeplConn;
        assert (clazz.isAssignableFrom(ConnectionWithDatabricksCredentials.SerializableDatabricksCredentials.class));
        try (Transaction t = ((TransactionService)SpringUtils.getBean(TransactionService.class)).retrieveOrBeginRead(IsolationLevel.YOLO);){
            dbDeplConn = ConnectionsDAO.get().getMandatoryConnectionAs(ctx.authCtx, this.params.modelDeploymentConnection, DatabricksModelDeploymentConnection.class);
        }
        ConnectionWithDatabricksCredentials.SerializableDatabricksCredentials creds = dbDeplConn.getFullyResolvedCredentials_fsLike(new ConnectionWithBasicCredential.CredentialResolutionContext(ctx.authCtx, null), ConnectionWithDatabricksCredentials.SerializableDatabricksCredentials.class);
        return clazz.cast(creds);
    }

    @Override
    public String getType() {
        return connectionType;
    }

    @Override
    public AbstractLLMConnection.AbstractLLMConnectionParams getLLMConnectionParams() {
        return this.params;
    }

    @Override
    public void encryptFields(PasswordEncryptionService cryptoService, GeneralSettingsDAO.SecuritySettings securitySettings) {
    }

    @Override
    public void decryptFields(PasswordEncryptionService cryptoService) {
    }

    @Override
    protected boolean isHardcodedModelEnabled(HardcodedDatabricksLLMModel databricksLLMModel) {
        return databricksLLMModel.allowedModel.apply(this.params);
    }

    @Override
    protected List<HardcodedDatabricksLLMModel> listRawHardcodedModels() {
        return Arrays.asList(HardcodedDatabricksLLMModel.values());
    }

    @Override
    protected List<CustomDatabricksLLMModel> listRawCustomModels() {
        return this.params.customModels;
    }

    @Override
    protected DatabricksLLMModel loadRawHardcodedModel(HardcodedDatabricksLLMModel hardcodedModel) {
        DatabricksLLMModel model = hardcodedModel.toModel();
        this.loadDefaultHardcodedModelSettings(hardcodedModel, model);
        return model;
    }

    @Override
    protected DatabricksLLMModel loadRawCustomModel(CustomDatabricksLLMModel rawCustomModel) {
        DatabricksLLMModel model = rawCustomModel.toModel();
        this.loadDefaultCustomModelSettings(model);
        return model;
    }

    @Override
    protected DKULogger getLogger() {
        return logger;
    }

    @Override
    public ConnectionsTestService.ConnectionTestResult testConnection(AuthCtx authCtx, ConnectionsTestService connectionsTestService) throws Exception {
        return connectionsTestService.testDatabricksLLM(this, authCtx);
    }

    public static class DatabricksLLMConnectionParams
    extends AbstractLLMConnection.AbstractLLMConnectionParams {
        public AbstractLLMConnection.HTTPBasedLLMNetworkSettings networkSettings = new AbstractLLMConnection.HTTPBasedLLMNetworkSettings();
        public int maxParallelism = 1;
        public String modelDeploymentConnection;
        public boolean allowClaude3_7_Sonnet = false;
        public boolean allowLlama4_Maverick = false;
        public boolean allowLlama3_3_70BChat = false;
        public boolean allowLlama3_1_405BChat = false;
        public boolean allowBGELargeEn = false;
        public List<CustomDatabricksLLMModel> customModels = new ArrayList<CustomDatabricksLLMModel>();
    }

    public static enum HardcodedDatabricksLLMModel implements AbstractLLMConnection.IHardcodedConnectionModel<DatabricksLLMModel>
    {
        CLAUDE3_7_SONNET("databricks-claude-3-7-sonnet", "Claude 3.7 Sonnet", DatabricksLLMModelType.CHAT, p -> p.allowClaude3_7_Sonnet),
        LLAMA_4_MAVERICK("databricks-llama-4-maverick", "Llama 4 Maverick", DatabricksLLMModelType.CHAT, p -> p.allowLlama4_Maverick),
        LLAMA3_3_70B_CHAT("databricks-meta-llama-3-3-70b-instruct", "Llama 3.3 70B Chat", DatabricksLLMModelType.CHAT, p -> p.allowLlama3_3_70BChat),
        LLAMA3_1_405B_CHAT("databricks-meta-llama-3-1-405b-instruct", "Llama 3.1 405B Chat", DatabricksLLMModelType.CHAT, p -> p.allowLlama3_1_405BChat),
        BGE_LARGE_EN("databricks-bge-large-en", "BGE Large (En)", DatabricksLLMModelType.TEXT_EMBEDDING, 1024, 8192, p -> p.allowBGELargeEn);

        public final String id;
        public final String displayName;
        public final DatabricksLLMModelType modelType;
        public final Integer embeddingSize;
        public final Integer maxTokensLimit;
        public final Function<DatabricksLLMConnectionParams, Boolean> allowedModel;

        private HardcodedDatabricksLLMModel(String id, String displayName, DatabricksLLMModelType modelType, Function<DatabricksLLMConnectionParams, Boolean> allowedModel) {
            this(id, displayName, modelType, null, null, allowedModel);
        }

        private HardcodedDatabricksLLMModel(String id, String displayName, DatabricksLLMModelType modelType, Integer embeddingSize, Integer maxTokensLimit, Function<DatabricksLLMConnectionParams, Boolean> allowedModel) {
            this.id = id;
            this.displayName = displayName;
            this.modelType = modelType;
            this.embeddingSize = embeddingSize;
            this.maxTokensLimit = maxTokensLimit;
            this.allowedModel = allowedModel;
        }

        @Override
        public DatabricksLLMModel toModel() {
            DatabricksLLMModel model = new DatabricksLLMModel();
            model.id = this.id;
            model.displayName = this.displayName;
            model.embeddingSize = this.embeddingSize;
            model.maxTokensLimit = this.maxTokensLimit;
            model.modelType = this.modelType;
            return model;
        }
    }

    public static class DatabricksLLMModel
    extends AbstractLLMConnection.BaseModel {
        public DatabricksLLMModelType modelType;

        @Override
        public AbstractLLMConnection.ModelCapabilities getModelCapabilities() {
            AbstractLLMConnection.ModelCapabilities capabilities = new AbstractLLMConnection.ModelCapabilities();
            capabilities.handlesSystemMessage = this.modelType == DatabricksLLMModelType.CHAT;
            capabilities.canGenerateCrossLanguageOutput = true;
            capabilities.temperatureRange = TEMPERATURE_RANGE;
            capabilities.topKRange = TOP_K_RANGE;
            return capabilities;
        }

        @Override
        public LLMStructuredRef asStructuredRef(String connection) {
            return LLMStructuredRef.forDatabricksLLMConnection(connection, this.getId());
        }

        @Override
        public boolean canBeUsedForPurpose(@Nonnull AbstractLLMConnection.LLMUsagePurpose purpose) {
            return this.modelType.matchesPurpose(purpose);
        }
    }

    public static class CustomDatabricksLLMModel
    extends AbstractLLMConnection.CustomModel<DatabricksLLMModel> {
        public String id;
        public String displayName;
        public DatabricksLLMModelType modelType;

        @Override
        public DatabricksLLMModel toModel() {
            DatabricksLLMModel model = new DatabricksLLMModel();
            model.loadFromCustomModel(this);
            model.id = this.id;
            model.displayName = this.displayName;
            model.modelType = this.modelType;
            return model;
        }
    }

    public static enum DatabricksLLMModelType {
        CHAT(AbstractLLMConnection.PROMPT_DRIVEN_PURPOSE_SET),
        TEXT_EMBEDDING(EnumSet.of(AbstractLLMConnection.LLMUsagePurpose.TEXT_EMBEDDING_EXTRACTION));

        public final Set<AbstractLLMConnection.LLMUsagePurpose> matchingPurposes;

        private DatabricksLLMModelType(Set<AbstractLLMConnection.LLMUsagePurpose> purposes) {
            this.matchingPurposes = purposes;
        }

        public boolean matchesPurpose(AbstractLLMConnection.LLMUsagePurpose purpose) {
            return this.matchingPurposes.contains((Object)purpose);
        }
    }
}

