/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.connections;

import com.dataiku.dip.connections.AbstractLLMConnection;
import com.dataiku.dip.connections.AbstractSQLConnection;
import com.dataiku.dip.connections.ConnectionWithBasicCredential;
import com.dataiku.dip.connections.OpenAIConnection;
import com.dataiku.dip.dao.GeneralSettingsDAO;
import com.dataiku.dip.exceptions.DKUSecurityException;
import com.dataiku.dip.llm.LLMStructuredRef;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.security.PasswordEncryptionService;
import com.dataiku.dip.server.services.ConnectionsTestService;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.variables.VariablesContext;
import java.io.IOException;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;
import javax.annotation.Nonnull;
import org.apache.commons.lang.StringUtils;

public class AzureLLMConnection
extends AbstractLLMConnection<AzureMLModel, AbstractLLMConnection.IHardcodedConnectionModel<AzureMLModel>, CustomAzureMLModel> {
    public static final String AZUREML_LLM_CONNECTION_TYPE = "AzureLLM";
    public AzureMLGenericLLMConnectionParams params = new AzureMLGenericLLMConnectionParams();
    private static final DKULogger logger = DKULogger.getLogger((String)"dip.connections.azurellm");

    @Override
    protected DKULogger getLogger() {
        return logger;
    }

    @Override
    public AbstractLLMConnection.AbstractLLMConnectionParams getLLMConnectionParams() {
        return this.params;
    }

    @Override
    protected List<CustomAzureMLModel> listRawCustomModels() {
        return this.params.customModels;
    }

    @Override
    protected AzureMLModel loadRawCustomModel(CustomAzureMLModel rawCustomModel) {
        AzureMLModel model = rawCustomModel.toModel();
        this.loadDefaultCustomModelSettings(rawCustomModel, model);
        return model;
    }

    @Override
    public void encryptFields(PasswordEncryptionService cryptoService, GeneralSettingsDAO.SecuritySettings securitySettings) {
        this.params.defaultKey = cryptoService.encryptIfNotEncryptedOrEmpty(this.params.defaultKey);
        for (CustomAzureMLModel model : this.params.customModels) {
            model.key = cryptoService.encryptIfNotEncryptedOrEmpty(model.key);
            this.params.encryptProperties(model.customHeaders, cryptoService);
        }
    }

    @Override
    public void decryptFields(PasswordEncryptionService cryptoService) {
        this.params.defaultKey = cryptoService.decryptIfEncrypted(this.params.defaultKey);
        for (CustomAzureMLModel model : this.params.customModels) {
            model.key = cryptoService.decryptIfEncrypted(model.key);
            this.params.decryptProperties(model.customHeaders, cryptoService);
        }
    }

    @Override
    public String getType() {
        return AZUREML_LLM_CONNECTION_TYPE;
    }

    @Override
    public void expandParametersInPlaceAtDAOLevelUsingGlobalContextOnly(VariablesContext vc) {
    }

    @Override
    protected <T> T getFullyResolvedCredentials_internal(ConnectionWithBasicCredential.CredentialResolutionContext ctx, Class<T> clazz) throws DKUSecurityException, IOException, SQLException {
        return null;
    }

    @Override
    public ConnectionsTestService.ConnectionTestResult testConnection(AuthCtx authCtx, ConnectionsTestService connectionsTestService) throws Exception {
        return connectionsTestService.testAzureLLM(this, authCtx);
    }

    public static class AzureMLGenericLLMConnectionParams
    extends AbstractLLMConnection.AbstractLLMConnectionParams {
        public AbstractLLMConnection.HTTPBasedLLMNetworkSettings networkSettings = new AbstractLLMConnection.HTTPBasedLLMNetworkSettings();
        public String defaultKey;
        public int maxParallelism = 8;
        @Nonnull
        public List<CustomAzureMLModel> customModels = new ArrayList<CustomAzureMLModel>();
    }

    public static class CustomAzureMLModel
    extends AbstractLLMConnection.CustomModel<AzureMLModel> {
        public String id;
        public String displayName;
        public String targetURI;
        public String key;
        public CustomAzureMLModelType modelType = CustomAzureMLModelType.COMPLETION_CHAT;
        public List<AbstractSQLConnection.CustomDatabaseProperty> customHeaders = new ArrayList<AbstractSQLConnection.CustomDatabaseProperty>();

        @Override
        public AzureMLModel toModel() {
            AzureMLModel model = new AzureMLModel();
            model.loadFromCustomModel(this);
            model.id = this.id;
            model.displayName = this.displayName;
            model.key = this.key;
            model.targetURI = this.targetURI;
            model.modelType = this.modelType;
            model.customHeaders = this.customHeaders.stream().map(header -> new AbstractSQLConnection.CustomDatabaseProperty(header.name, header.value, header.secret)).collect(Collectors.toList());
            return model;
        }
    }

    public static class AzureMLModel
    extends AbstractLLMConnection.BaseModel {
        public String targetURI;
        public String key;
        public CustomAzureMLModelType modelType;
        public List<AbstractSQLConnection.CustomDatabaseProperty> customHeaders = new ArrayList<AbstractSQLConnection.CustomDatabaseProperty>();

        @Override
        public LLMStructuredRef asStructuredRef(String connection) {
            return LLMStructuredRef.forAzureLLMConnection(connection, this.id);
        }

        @Override
        public Optional<String> getInvalidityReason() {
            if (StringUtils.isBlank((String)this.getId())) {
                return Optional.of("Missing model id");
            }
            if (null == this.modelType) {
                return Optional.of("Missing model type");
            }
            if (StringUtils.isBlank((String)this.targetURI)) {
                return Optional.of("Missing targetURI");
            }
            return Optional.empty();
        }

        @Override
        public boolean canBeUsedForPurpose(@Nonnull AbstractLLMConnection.LLMUsagePurpose purpose) {
            if (this.modelType == null) {
                return false;
            }
            return this.modelType.matchesPurpose(purpose);
        }

        @Override
        public AbstractLLMConnection.ModelCapabilities getModelCapabilities() {
            AbstractLLMConnection.ModelCapabilities capabilities = new AbstractLLMConnection.ModelCapabilities();
            capabilities.canGenerateCrossLanguageOutput = true;
            capabilities.handlesSystemMessage = true;
            capabilities.supportsImageInputs = this.canBeUsedForPurpose(AbstractLLMConnection.LLMUsagePurpose.IMAGE_INPUT);
            capabilities.temperatureRange = OpenAIConnection.TEMPERATURE_RANGE;
            capabilities.topKRange = OpenAIConnection.TOP_K_RANGE;
            return capabilities;
        }
    }

    public static enum CustomAzureMLModelType {
        COMPLETION_SIMPLE(Arrays.asList(AbstractLLMConnection.LLMUsagePurpose.SUMMARIZATION, AbstractLLMConnection.LLMUsagePurpose.CLASSIFICATION_WITH_USER_PROVIDED_CLASSES, AbstractLLMConnection.LLMUsagePurpose.EMOTION_ANALYSIS, AbstractLLMConnection.LLMUsagePurpose.SENTIMENT_ANALYSIS, AbstractLLMConnection.LLMUsagePurpose.GENERIC_COMPLETION), false),
        COMPLETION_CHAT(Arrays.asList(AbstractLLMConnection.LLMUsagePurpose.SUMMARIZATION, AbstractLLMConnection.LLMUsagePurpose.CLASSIFICATION_WITH_USER_PROVIDED_CLASSES, AbstractLLMConnection.LLMUsagePurpose.EMOTION_ANALYSIS, AbstractLLMConnection.LLMUsagePurpose.SENTIMENT_ANALYSIS, AbstractLLMConnection.LLMUsagePurpose.GENERIC_COMPLETION), true),
        COMPLETION_CHAT_MULTIMODAL(Arrays.asList(AbstractLLMConnection.LLMUsagePurpose.SUMMARIZATION, AbstractLLMConnection.LLMUsagePurpose.CLASSIFICATION_WITH_USER_PROVIDED_CLASSES, AbstractLLMConnection.LLMUsagePurpose.EMOTION_ANALYSIS, AbstractLLMConnection.LLMUsagePurpose.SENTIMENT_ANALYSIS, AbstractLLMConnection.LLMUsagePurpose.GENERIC_COMPLETION, AbstractLLMConnection.LLMUsagePurpose.IMAGE_INPUT), true),
        TEXT_EMBEDDING_EXTRACTION(Collections.singletonList(AbstractLLMConnection.LLMUsagePurpose.TEXT_EMBEDDING_EXTRACTION), false);

        public final List<AbstractLLMConnection.LLMUsagePurpose> matchingPurposes;
        public final boolean useChatAPI;

        private CustomAzureMLModelType(List<AbstractLLMConnection.LLMUsagePurpose> purposes, boolean useChatAPI) {
            this.matchingPurposes = purposes;
            this.useChatAPI = useChatAPI;
        }

        public boolean matchesPurpose(AbstractLLMConnection.LLMUsagePurpose purpose) {
            return this.matchingPurposes.contains((Object)purpose);
        }
    }
}

