/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.connections;

import com.dataiku.dip.connections.AbstractLLMConnection;
import com.dataiku.dip.connections.ConnectionWithBasicCredential;
import com.dataiku.dip.dao.GeneralSettingsDAO;
import com.dataiku.dip.exceptions.DKUSecurityException;
import com.dataiku.dip.llm.EnrichedLLMStructuredRef;
import com.dataiku.dip.llm.LLMStructuredRef;
import com.dataiku.dip.llm.online.cohere.CoherePricing;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.security.PasswordEncryptionService;
import com.dataiku.dip.security.model.ICredentialsService;
import com.dataiku.dip.server.services.ConnectionsTestService;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.variables.VariablesContext;
import java.io.IOException;
import java.sql.SQLException;
import java.util.Arrays;
import java.util.List;
import java.util.function.Function;

public class CohereConnection
extends AbstractLLMConnection<CohereModel, HardcodedCohereModel, AbstractLLMConnection.CustomModel<CohereModel>> {
    private static final EnrichedLLMStructuredRef.FieldRange TEMPERATURE_RANGE = new EnrichedLLMStructuredRef.FieldRange(0.0, 5.0, 0.05);
    private static final EnrichedLLMStructuredRef.FieldRange TOP_K_RANGE = new EnrichedLLMStructuredRef.FieldRange(0.0, 500.0, 1.0);
    public CohereConnectionParams params = new CohereConnectionParams();
    public static final String connectionType = "Cohere";
    private static final DKULogger logger = DKULogger.getLogger((String)"dip.connections.cohere");

    @Override
    public AbstractLLMConnection.AbstractLLMConnectionParams getLLMConnectionParams() {
        return this.params;
    }

    @Override
    protected DKULogger getLogger() {
        return logger;
    }

    @Override
    protected List<HardcodedCohereModel> listRawHardcodedModels() {
        return Arrays.asList(HardcodedCohereModel.values());
    }

    @Override
    protected CohereModel loadRawHardcodedModel(HardcodedCohereModel hardcodedModel) {
        CohereModel model = hardcodedModel.toModel();
        this.loadDefaultHardcodedModelSettings(hardcodedModel, model);
        return model;
    }

    @Override
    public void expandParametersInPlaceAtDAOLevelUsingGlobalContextOnly(VariablesContext vc) {
    }

    @Override
    protected <T> T getFullyResolvedCredentials_internal(ConnectionWithBasicCredential.CredentialResolutionContext ctx, Class<T> clazz) throws DKUSecurityException, IOException, SQLException {
        assert (clazz.isAssignableFrom(ICredentialsService.BasicCredential.class));
        ICredentialsService.BasicCredential creds = new ICredentialsService.BasicCredential("", this.params.apiKey);
        return clazz.cast(creds);
    }

    @Override
    public String getType() {
        return connectionType;
    }

    @Override
    protected void encryptLocalFields(PasswordEncryptionService cryptoService, GeneralSettingsDAO.SecuritySettings securitySettings) {
        this.params.apiKey = cryptoService.encryptIfNotEncryptedOrEmpty(this.params.apiKey);
    }

    @Override
    protected void decryptLocalFields(PasswordEncryptionService cryptoService) {
        this.params.apiKey = cryptoService.decryptIfEncrypted(this.params.apiKey);
    }

    @Override
    protected boolean isHardcodedModelEnabled(HardcodedCohereModel cohereModel) {
        return cohereModel.allowedModel.apply(this.params);
    }

    @Override
    public ConnectionsTestService.ConnectionTestResult testConnection(AuthCtx authCtx, ConnectionsTestService connectionsTestService) throws Exception {
        return connectionsTestService.testCohere(this, authCtx);
    }

    public static class CohereConnectionParams
    extends AbstractLLMConnection.AbstractLLMConnectionParams {
        public String apiKey;
        public int maxParallelism = 1;
        public AbstractLLMConnection.HTTPBasedLLMNetworkSettings networkSettings = new AbstractLLMConnection.HTTPBasedLLMNetworkSettings();
        public boolean allowCohereCommand = false;
        public boolean allowCohereCommandLight = false;
        public boolean allowCohereCommandR = false;
        public boolean allowCohereCommandRPlus = false;
    }

    public static enum HardcodedCohereModel implements AbstractLLMConnection.IHardcodedConnectionModel<CohereModel>
    {
        COMMAND("command", "Cohere Command (standard model) (deprecated)", p -> p.allowCohereCommand),
        COMMAND_LIGHT("command-light", "Cohere Command Light (faster) (deprecated)", p -> p.allowCohereCommandLight),
        COMMAND_R("command-r", "Cohere Command R", p -> p.allowCohereCommandR),
        COMMAND_R_PLUS("command-r-plus", "Cohere Command R+", p -> p.allowCohereCommandRPlus);

        public final String id;
        public final String displayName;
        public final Function<CohereConnectionParams, Boolean> allowedModel;

        private HardcodedCohereModel(String id, String displayName, Function<CohereConnectionParams, Boolean> allowedModel) {
            this.id = id;
            this.displayName = displayName;
            this.allowedModel = allowedModel;
        }

        @Override
        public CohereModel toModel() {
            CohereModel model = new CohereModel();
            model.id = this.id;
            model.displayName = this.displayName;
            model.completionCost = CoherePricing.getCohereCompletionCostPer1KTokens(this.id);
            model.promptCost = CoherePricing.getCoherePromptCostPer1KTokens(this.id);
            return model;
        }
    }

    public static class CohereModel
    extends AbstractLLMConnection.BaseModel {
        @Override
        public AbstractLLMConnection.ModelCapabilities getModelCapabilities() {
            AbstractLLMConnection.ModelCapabilities capabilities = new AbstractLLMConnection.ModelCapabilities();
            capabilities.temperatureRange = TEMPERATURE_RANGE;
            capabilities.topKRange = TOP_K_RANGE;
            return capabilities;
        }

        @Override
        public LLMStructuredRef asStructuredRef(String connection) {
            return LLMStructuredRef.forCohereConnection(connection, this.getId());
        }
    }
}

