/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.connections;

import com.dataiku.dip.connections.AbstractLLMConnection;
import com.dataiku.dip.connections.AbstractSQLConnection;
import com.dataiku.dip.connections.ConnectionWithBasicCredential;
import com.dataiku.dip.dao.GeneralSettingsDAO;
import com.dataiku.dip.exceptions.DKUSecurityException;
import com.dataiku.dip.llm.LLMStructuredRef;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.security.PasswordEncryptionService;
import com.dataiku.dip.security.model.ICredentialsService;
import com.dataiku.dip.server.services.ConnectionsTestService;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.variables.VariablesContext;
import java.io.IOException;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Optional;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.apache.commons.lang3.StringUtils;

public class NvidiaNimConnection
extends AbstractLLMConnection<NvidiaNimModel, AbstractLLMConnection.IHardcodedConnectionModel<AbstractLLMConnection.BaseModel>, NvidiaNimDeployment> {
    public static final String CONNECTION_TYPE = "NVIDIA-NIM";
    public NvidiaNimConnectionParams params = new NvidiaNimConnectionParams();
    private static final DKULogger logger = DKULogger.getLogger((String)"dip.connections.nvidia-nim");

    @Override
    protected List<NvidiaNimDeployment> listRawCustomModels() {
        return this.params.availableDeployments;
    }

    @Override
    protected NvidiaNimModel loadRawCustomModel(NvidiaNimDeployment rawCustomModel) {
        NvidiaNimModel model = rawCustomModel.toModel();
        this.loadDefaultCustomModelSettings(model);
        return model;
    }

    @Override
    protected DKULogger getLogger() {
        return logger;
    }

    @Override
    public AbstractLLMConnection.AbstractLLMConnectionParams getLLMConnectionParams() {
        return this.params;
    }

    @Override
    protected void encryptLocalFields(PasswordEncryptionService cryptoService, GeneralSettingsDAO.SecuritySettings securitySettings) {
        this.params.apiKey = cryptoService.encryptIfNotEncryptedOrEmpty(this.params.apiKey);
        this.params.encryptProperties(this.params.customHeaders, cryptoService);
    }

    @Override
    protected void decryptLocalFields(PasswordEncryptionService cryptoService) {
        this.params.apiKey = cryptoService.decryptIfEncrypted(this.params.apiKey);
        this.params.decryptProperties(this.params.customHeaders, cryptoService);
    }

    @Override
    protected <T> T getFullyResolvedCredentials_internal(ConnectionWithBasicCredential.CredentialResolutionContext ctx, Class<T> clazz) throws DKUSecurityException, IOException, SQLException {
        assert (clazz.isAssignableFrom(ICredentialsService.BasicCredential.class));
        ICredentialsService.BasicCredential creds = new ICredentialsService.BasicCredential("", this.params.apiKey);
        return clazz.cast(creds);
    }

    @Override
    public String getType() {
        return CONNECTION_TYPE;
    }

    @Override
    public void expandParametersInPlaceAtDAOLevelUsingGlobalContextOnly(VariablesContext vc) {
    }

    @Override
    public ConnectionsTestService.ConnectionTestResult testConnection(AuthCtx authCtx, ConnectionsTestService connectionsTestService) throws Exception {
        return connectionsTestService.testNvidiaNim(this, authCtx);
    }

    public static class NvidiaNimConnectionParams
    extends AbstractLLMConnection.AbstractLLMConnectionParams {
        public String apiKey;
        public AbstractLLMConnection.HTTPBasedLLMNetworkSettings networkSettings = new AbstractLLMConnection.HTTPBasedLLMNetworkSettings();
        public int maxParallelism = 8;
        public List<NvidiaNimDeployment> availableDeployments = new ArrayList<NvidiaNimDeployment>();
        public List<AbstractSQLConnection.CustomDatabaseProperty> customHeaders = new ArrayList<AbstractSQLConnection.CustomDatabaseProperty>();
    }

    public static class NvidiaNimDeployment
    extends AbstractLLMConnection.CustomModel<NvidiaNimModel> {
        @Nullable
        public String displayName;
        @Nullable
        public String id;
        @Nullable
        public String url;
        @Nullable
        public NvidiaNimModelType modelType;
        @Nullable
        public NimApi api;
        public boolean supportsImageInputs = false;
        public boolean supportsSystemPrompts = true;

        @Override
        NvidiaNimModel toModel() {
            NvidiaNimModel model = new NvidiaNimModel();
            model.loadFromCustomModel(this);
            model.id = this.id;
            model.displayName = (String)StringUtils.firstNonBlank((CharSequence[])new String[]{this.displayName, model.id});
            model.url = this.url;
            model.modelType = this.modelType;
            model.api = this.api;
            model.supportsImageInputs = this.supportsImageInputs;
            model.supportsSystemPrompts = this.supportsSystemPrompts;
            return model;
        }
    }

    public static class NvidiaNimModel
    extends AbstractLLMConnection.BaseModel {
        @Nullable
        public String url;
        @Nullable
        public NvidiaNimModelType modelType;
        @Nullable
        public NimApi api;
        public boolean supportsImageInputs = false;
        public boolean supportsSystemPrompts = true;

        @Override
        public Optional<String> getInvalidityReason() {
            if (StringUtils.isBlank((CharSequence)this.id)) {
                return Optional.of("Empty model id");
            }
            if (this.api == null) {
                return Optional.of("No api");
            }
            if (this.url == null) {
                return Optional.of("No url");
            }
            if (this.modelType == null) {
                return Optional.of("No modelType");
            }
            return super.getInvalidityReason();
        }

        @Override
        LLMStructuredRef asStructuredRef(String connection) {
            return LLMStructuredRef.forNvidiaNimConnection(connection, this.id);
        }

        @Override
        public AbstractLLMConnection.ModelCapabilities getModelCapabilities() {
            AbstractLLMConnection.ModelCapabilities capabilities = new AbstractLLMConnection.ModelCapabilities();
            capabilities.supportsImageInputs = this.canBeUsedForPurpose(AbstractLLMConnection.LLMUsagePurpose.IMAGE_INPUT);
            capabilities.handlesSystemMessage = this.supportsSystemPrompts;
            return capabilities;
        }

        @Override
        public boolean canBeUsedForPurpose(@Nonnull AbstractLLMConnection.LLMUsagePurpose purpose) {
            return purpose == AbstractLLMConnection.LLMUsagePurpose.IMAGE_INPUT && this.supportsImageInputs || this.modelType != null && this.modelType.matchingPurposes.contains((Object)purpose);
        }
    }

    public static enum NimApi {
        OPENAI_V1_CHAT_COMPLETIONS,
        OPENAI_V1_EMBEDDINGS,
        OPENAI_V1_RESPONSES;

    }

    public static enum NvidiaNimModelType {
        CHAT_COMPLETIONS(AbstractLLMConnection.LLMUsagePurpose.SUMMARIZATION, AbstractLLMConnection.LLMUsagePurpose.CLASSIFICATION_WITH_USER_PROVIDED_CLASSES, AbstractLLMConnection.LLMUsagePurpose.EMOTION_ANALYSIS, AbstractLLMConnection.LLMUsagePurpose.SENTIMENT_ANALYSIS, AbstractLLMConnection.LLMUsagePurpose.GENERIC_COMPLETION),
        EMBEDDINGS(AbstractLLMConnection.LLMUsagePurpose.TEXT_EMBEDDING_EXTRACTION);

        public final List<AbstractLLMConnection.LLMUsagePurpose> matchingPurposes;

        private NvidiaNimModelType(AbstractLLMConnection.LLMUsagePurpose ... purposes) {
            this.matchingPurposes = Arrays.stream(purposes).toList();
        }
    }
}

