/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.connections;

import com.dataiku.dip.connections.AbstractLLMConnection;
import com.dataiku.dip.connections.AbstractSQLConnection;
import com.dataiku.dip.connections.ConnectionWithBasicCredential;
import com.dataiku.dip.dao.GeneralSettingsDAO;
import com.dataiku.dip.exceptions.DKUSecurityException;
import com.dataiku.dip.llm.EnrichedLLMStructuredRef;
import com.dataiku.dip.llm.LLMStructuredRef;
import com.dataiku.dip.llm.custom.CustomLLMClient;
import com.dataiku.dip.security.PasswordEncryptionService;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.variables.VariablesContext;
import com.dataiku.j2py.annotations.PyModel;
import com.google.gson.JsonObject;
import java.io.IOException;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;
import javax.annotation.Nonnull;
import org.apache.commons.lang3.StringUtils;

public class CustomLLMConnection
extends AbstractLLMConnection<LLMModel, AbstractLLMConnection.IHardcodedConnectionModel<LLMModel>, CustomLLMModel> {
    public CustomLLMConnectionParams params = new CustomLLMConnectionParams();
    public static final String connectionType = "CustomLLM";
    private static final DKULogger logger = DKULogger.getLogger((String)"dip.connections.customllm");

    @Override
    protected DKULogger getLogger() {
        return logger;
    }

    @Override
    public AbstractLLMConnection.AbstractLLMConnectionParams getLLMConnectionParams() {
        return this.params;
    }

    @Override
    public List<CustomLLMModel> listRawCustomModels() {
        return this.params.models;
    }

    @Override
    protected LLMModel loadRawCustomModel(CustomLLMModel rawCustomModel) {
        LLMModel model = rawCustomModel.toModel();
        this.loadDefaultCustomModelSettings(rawCustomModel, model);
        return model;
    }

    public CustomLLMModel getModel(String customModelId) {
        if (StringUtils.isEmpty((CharSequence)customModelId)) {
            throw new IllegalArgumentException("Empty model name provided");
        }
        List matchingModels = this.params.models.stream().filter(model -> customModelId.equals(model.id)).collect(Collectors.toList());
        if (matchingModels.isEmpty()) {
            throw new IllegalArgumentException(String.format("No custom LLM found for ID \"%s\" in custom LLM connection \"%s\"", customModelId, this.name));
        }
        if (matchingModels.size() > 1) {
            logger.warn((Object)String.format("More than one custom LLM found for ID \"%s\" in custom LLM connection \"%s\"", customModelId, this.name));
        }
        return (CustomLLMModel)matchingModels.get(0);
    }

    @Override
    public void expandParametersInPlaceAtDAOLevelUsingGlobalContextOnly(VariablesContext vc) {
    }

    @Override
    public String getType() {
        return connectionType;
    }

    @Override
    public void encryptFields(PasswordEncryptionService cryptoService, GeneralSettingsDAO.SecuritySettings securitySettings) {
    }

    @Override
    public void decryptFields(PasswordEncryptionService cryptoService) {
    }

    @Override
    protected <T> T getFullyResolvedCredentials_internal(ConnectionWithBasicCredential.CredentialResolutionContext ctx, Class<T> clazz) throws DKUSecurityException, IOException, SQLException {
        throw new UnsupportedOperationException("Not applicable");
    }

    @Override
    public List<AbstractSQLConnection.CustomDatabaseProperty> getDkuProperties() {
        return new ArrayList<AbstractSQLConnection.CustomDatabaseProperty>();
    }

    public static class CustomLLMConnectionParams
    extends AbstractLLMConnection.AbstractLLMConnectionParams {
        public String pluginID;
        public List<CustomLLMModel> models = new ArrayList<CustomLLMModel>();
    }

    public static class CustomLLMModel
    extends AbstractLLMConnection.CustomModel<LLMModel> {
        public String type;
        public Capability capability;
        public String id;
        public JsonObject customConfig = new JsonObject();
        public int maxParallelism = 64;
        public boolean customRateLimitingEnabled = false;
        public CustomLLMClient.RateLimitingRetrySettings retrySettings = new CustomLLMClient.RateLimitingRetrySettings();

        @Override
        public LLMModel toModel() {
            LLMModel model = new LLMModel();
            model.loadFromCustomModel(this);
            model.id = this.id;
            model.displayName = "Custom LLM " + this.id;
            model.type = this.type;
            model.capability = this.capability;
            model.customConfig = this.customConfig;
            return model;
        }

        public CustomLLMClient.RateLimitingRetrySettings getRetrySettings() {
            if (!this.customRateLimitingEnabled) {
                return new CustomLLMClient.RateLimitingRetrySettings();
            }
            return this.retrySettings;
        }
    }

    public static class LLMModel
    extends AbstractLLMConnection.BaseModel {
        public String type;
        public JsonObject customConfig;
        public Capability capability;

        @Override
        public Optional<String> getInvalidityReason() {
            if (StringUtils.isBlank((CharSequence)this.getId())) {
                return Optional.of("Empty model id");
            }
            if (this.capability == null) {
                return Optional.of("Missing model capability");
            }
            if (this.type == null) {
                return Optional.of("Missing model type");
            }
            return Optional.empty();
        }

        @Override
        public boolean canBeUsedForPurpose(@Nonnull AbstractLLMConnection.LLMUsagePurpose purpose) {
            if (this.capability == null) {
                return false;
            }
            return this.capability.matchesPurpose(purpose);
        }

        @Override
        public LLMStructuredRef asStructuredRef(String connection) {
            return LLMStructuredRef.forCustomConnection(connection, this.id);
        }

        @Override
        public AbstractLLMConnection.ModelCapabilities getModelCapabilities() {
            AbstractLLMConnection.ModelCapabilities capabilities = new AbstractLLMConnection.ModelCapabilities();
            capabilities.temperatureRange = new EnrichedLLMStructuredRef.FieldRange();
            capabilities.topKRange = new EnrichedLLMStructuredRef.FieldRange();
            capabilities.handlesSystemMessage = true;
            capabilities.supportsImageInputs = this.capability == Capability.TEXT_COMPLETION_MULTIMODAL;
            return capabilities;
        }
    }

    @PyModel
    public static enum Capability {
        TEXT_COMPLETION(Arrays.asList(AbstractLLMConnection.LLMUsagePurpose.GENERIC_COMPLETION, AbstractLLMConnection.LLMUsagePurpose.CLASSIFICATION_WITH_USER_PROVIDED_CLASSES, AbstractLLMConnection.LLMUsagePurpose.SENTIMENT_ANALYSIS, AbstractLLMConnection.LLMUsagePurpose.EMOTION_ANALYSIS, AbstractLLMConnection.LLMUsagePurpose.SUMMARIZATION)),
        TEXT_COMPLETION_MULTIMODAL(Arrays.asList(AbstractLLMConnection.LLMUsagePurpose.GENERIC_COMPLETION, AbstractLLMConnection.LLMUsagePurpose.CLASSIFICATION_WITH_USER_PROVIDED_CLASSES, AbstractLLMConnection.LLMUsagePurpose.SENTIMENT_ANALYSIS, AbstractLLMConnection.LLMUsagePurpose.EMOTION_ANALYSIS, AbstractLLMConnection.LLMUsagePurpose.SUMMARIZATION, AbstractLLMConnection.LLMUsagePurpose.IMAGE_INPUT)),
        TEXT_EMBEDDING(Collections.singletonList(AbstractLLMConnection.LLMUsagePurpose.TEXT_EMBEDDING_EXTRACTION)),
        IMAGE_GENERATION(Collections.singletonList(AbstractLLMConnection.LLMUsagePurpose.IMAGE_GENERATION)),
        TEXT_IMAGE_EMBEDDING_EXTRACTION(Arrays.asList(AbstractLLMConnection.LLMUsagePurpose.TEXT_EMBEDDING_EXTRACTION, AbstractLLMConnection.LLMUsagePurpose.IMAGE_EMBEDDING_EXTRACTION));

        public final List<AbstractLLMConnection.LLMUsagePurpose> matchingPurposes;

        private Capability(List<AbstractLLMConnection.LLMUsagePurpose> purposes) {
            this.matchingPurposes = purposes;
        }

        public boolean matchesPurpose(AbstractLLMConnection.LLMUsagePurpose purpose) {
            return this.matchingPurposes.contains((Object)purpose);
        }
    }
}

