/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.connections;

import com.dataiku.dip.connections.AbstractLLMConnection;
import com.dataiku.dip.connections.ConnectionWithBasicCredential;
import com.dataiku.dip.dao.GeneralSettingsDAO;
import com.dataiku.dip.exceptions.DKUSecurityException;
import com.dataiku.dip.llm.EnrichedLLMStructuredRef;
import com.dataiku.dip.llm.LLMStructuredRef;
import com.dataiku.dip.llm.online.mistralai.MistralAIPricing;
import com.dataiku.dip.security.PasswordEncryptionService;
import com.dataiku.dip.security.model.ICredentialsService;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.variables.VariablesContext;
import java.io.IOException;
import java.sql.SQLException;
import java.util.Arrays;
import java.util.EnumSet;
import java.util.List;
import java.util.Set;
import java.util.function.Function;
import javax.annotation.Nonnull;

public class MistralAIConnection
extends AbstractLLMConnection<MistralAIModel, HardcodedMistralAIModel, AbstractLLMConnection.CustomModel> {
    public static final EnrichedLLMStructuredRef.FieldRange TEMPERATURE_RANGE = new EnrichedLLMStructuredRef.FieldRange(0.0, 2.0, 0.02);
    public static final EnrichedLLMStructuredRef.FieldRange TOP_K_RANGE = null;
    public MistralAIConnectionParams params = new MistralAIConnectionParams();
    public static final String connectionType = "MistralAI";
    private static final DKULogger logger = DKULogger.getLogger((String)"dip.connections.mistralai");

    @Override
    public void expandParametersInPlaceAtDAOLevelUsingGlobalContextOnly(VariablesContext vc) {
    }

    @Override
    protected <T> T getFullyResolvedCredentials_internal(ConnectionWithBasicCredential.CredentialResolutionContext ctx, Class<T> clazz) throws DKUSecurityException, IOException, SQLException {
        assert (clazz.isAssignableFrom(ICredentialsService.BasicCredential.class));
        ICredentialsService.BasicCredential creds = new ICredentialsService.BasicCredential("", this.params.apiKey);
        return clazz.cast(creds);
    }

    @Override
    public String getType() {
        return connectionType;
    }

    @Override
    public AbstractLLMConnection.AbstractLLMConnectionParams getLLMConnectionParams() {
        return this.params;
    }

    @Override
    public void encryptFields(PasswordEncryptionService cryptoService, GeneralSettingsDAO.SecuritySettings securitySettings) {
        this.params.apiKey = cryptoService.encryptIfNotEncryptedOrEmpty(this.params.apiKey);
    }

    @Override
    public void decryptFields(PasswordEncryptionService cryptoService) {
        this.params.apiKey = cryptoService.decryptIfEncrypted(this.params.apiKey);
    }

    @Override
    protected boolean isHardcodedModelEnabled(HardcodedMistralAIModel mistralAIModel) {
        return mistralAIModel.allowedModel.apply(this.params);
    }

    @Override
    protected List<HardcodedMistralAIModel> listRawHardcodedModels() {
        return Arrays.asList(HardcodedMistralAIModel.values());
    }

    @Override
    protected MistralAIModel loadRawHardcodedModel(HardcodedMistralAIModel hardcodedModel) {
        MistralAIModel model = hardcodedModel.toModel();
        this.loadDefaultHardcodedModelSettings(hardcodedModel, model);
        return model;
    }

    @Override
    protected DKULogger getLogger() {
        return logger;
    }

    public static class MistralAIConnectionParams
    extends AbstractLLMConnection.AbstractLLMConnectionParams {
        public AbstractLLMConnection.HTTPBasedLLMNetworkSettings networkSettings = new AbstractLLMConnection.HTTPBasedLLMNetworkSettings();
        public int maxParallelism = 4;
        public boolean allowMistralSmall = false;
        public boolean allowMistralMedium = false;
        public boolean allowMistralLarge = false;
        public boolean allowMistralEmbed = false;
        public String apiKey;
    }

    public static enum HardcodedMistralAIModel implements AbstractLLMConnection.IHardcodedConnectionModel
    {
        MISTRAL_LARGE("mistral-large-latest", "Mistral large", p -> p.allowMistralLarge),
        MISTRAL_MEDIUM("mistral-medium-latest", "Mistral medium", p -> p.allowMistralMedium),
        MISTRAL_SMALL("mistral-small-latest", "Mistral small", p -> p.allowMistralSmall),
        MISTRAL_EMBED("mistral-embed", "Mistral embed", EnumSet.of(AbstractLLMConnection.LLMUsagePurpose.TEXT_EMBEDDING_EXTRACTION), 1024, 8192, p -> p.allowMistralEmbed);

        public final String id;
        public final String displayName;
        public final Integer embeddingSize;
        public final Integer maxTokensLimit;
        public final Set<AbstractLLMConnection.LLMUsagePurpose> purposeSet;
        public final Function<MistralAIConnectionParams, Boolean> allowedModel;

        private HardcodedMistralAIModel(String id, String displayName, Function<MistralAIConnectionParams, Boolean> allowedModel) {
            this(id, displayName, AbstractLLMConnection.PROMPT_DRIVEN_PURPOSE_SET, allowedModel);
        }

        private HardcodedMistralAIModel(String id, String displayName, Set<AbstractLLMConnection.LLMUsagePurpose> purposeSet, Function<MistralAIConnectionParams, Boolean> allowedModel) {
            this(id, displayName, purposeSet, null, null, allowedModel);
        }

        private HardcodedMistralAIModel(String id, String displayName, Set<AbstractLLMConnection.LLMUsagePurpose> purposeSet, Integer embeddingSize, Integer maxTokensLimit, Function<MistralAIConnectionParams, Boolean> allowedModel) {
            this.id = id;
            this.displayName = displayName;
            this.purposeSet = purposeSet;
            this.embeddingSize = embeddingSize;
            this.maxTokensLimit = maxTokensLimit;
            this.allowedModel = allowedModel;
        }

        public MistralAIModel toModel() {
            MistralAIModel model = new MistralAIModel();
            model.id = this.id;
            model.purposeSet = this.purposeSet;
            model.displayName = this.displayName;
            model.embeddingSize = this.embeddingSize;
            model.maxTokensLimit = this.maxTokensLimit;
            if (this.purposeSet.contains((Object)AbstractLLMConnection.LLMUsagePurpose.TEXT_EMBEDDING_EXTRACTION)) {
                model.embeddingCost = MistralAIPricing.getMistralAIEmbeddingCostPer1kTokens(this.id);
            } else {
                model.completionCost = MistralAIPricing.getMistralAICompletionCostPer1kTokens(this.id);
                model.promptCost = MistralAIPricing.getMistralAIPromptCostPer1kTokens(this.id);
            }
            return model;
        }
    }

    public static class MistralAIModel
    extends AbstractLLMConnection.BaseModel {
        private Set<AbstractLLMConnection.LLMUsagePurpose> purposeSet;

        @Override
        public LLMStructuredRef asStructuredRef(String connection) {
            return LLMStructuredRef.forMistralAIConnection(connection, this.getId());
        }

        @Override
        public boolean canBeUsedForPurpose(@Nonnull AbstractLLMConnection.LLMUsagePurpose purpose) {
            return this.purposeSet.contains((Object)purpose);
        }

        @Override
        public AbstractLLMConnection.ModelCapabilities getModelCapabilities() {
            AbstractLLMConnection.ModelCapabilities capabilities = new AbstractLLMConnection.ModelCapabilities();
            capabilities.canGenerateCrossLanguageOutput = true;
            capabilities.handlesSystemMessage = true;
            capabilities.temperatureRange = TEMPERATURE_RANGE;
            capabilities.topKRange = TOP_K_RANGE;
            return capabilities;
        }
    }
}

