/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.connections;

import com.dataiku.dip.ProxySettings;
import com.dataiku.dip.connections.AbstractLLMConnection;
import com.dataiku.dip.connections.ConnectionWithBasicCredential;
import com.dataiku.dip.connections.ConnectionWithGoogleAuthCredentials;
import com.dataiku.dip.connections.DSSConnection;
import com.dataiku.dip.dao.GeneralSettingsDAO;
import com.dataiku.dip.exceptions.DKUSecurityException;
import com.dataiku.dip.externalinfras.ExternalInfrasUtils;
import com.dataiku.dip.llm.EnrichedLLMStructuredRef;
import com.dataiku.dip.llm.LLMStructuredRef;
import com.dataiku.dip.llm.online.LLMClient;
import com.dataiku.dip.llm.online.vertex.VertexPricing;
import com.dataiku.dip.security.PasswordEncryptionService;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.variables.VariablesContext;
import java.io.IOException;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.EnumSet;
import java.util.List;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import javax.annotation.Nonnull;
import org.apache.commons.lang.StringUtils;

public class VertexAILLMConnection
extends AbstractLLMConnection<VertexModel, HardcodedVertexModel, CustomVertexModel>
implements ConnectionWithGoogleAuthCredentials {
    private static final EnrichedLLMStructuredRef.FieldRange TEMPERATURE_RANGE = new EnrichedLLMStructuredRef.FieldRange(0.0, 1.0, 0.01);
    private static final EnrichedLLMStructuredRef.FieldRange TOP_K_RANGE = new EnrichedLLMStructuredRef.FieldRange(1.0, 40.0, 1.0);
    public VertexConnectionParams params = new VertexConnectionParams();
    public static final String connectionType = "VertexAILLM";
    private static final DKULogger logger = DKULogger.getLogger((String)"dku.connections.vertex");

    @Override
    public void expandParametersInPlaceAtDAOLevelUsingGlobalContextOnly(VariablesContext vc) {
    }

    @Override
    public boolean mustResolveOnDSSHost() {
        return this.mustResolveGoogleAuthOnDSSHost();
    }

    @Override
    public boolean mustResolveOnBackend() {
        return this.hasRefreshTokenRotation() || super.mustResolveOnBackend();
    }

    @Override
    protected <T> T getFullyResolvedCredentials_internal(ConnectionWithBasicCredential.CredentialResolutionContext ctx, Class<T> clazz) throws DKUSecurityException, IOException, SQLException {
        assert (clazz.isAssignableFrom(ConnectionWithGoogleAuthCredentials.SerializableGoogleAuthCredentials.class));
        ConnectionWithGoogleAuthCredentials.SerializableGoogleAuthCredentials creds = this.getResolvedCredential(ctx.authCtx);
        return clazz.cast(creds);
    }

    @Override
    public String getType() {
        return connectionType;
    }

    @Override
    public AbstractLLMConnection.AbstractLLMConnectionParams getLLMConnectionParams() {
        return this.params;
    }

    @Override
    @Nonnull
    public ConnectionWithGoogleAuthCredentials.GoogleAuth2Params getGoogleAuth2NonResolvedParams() {
        return new ConnectionWithGoogleAuthCredentials.SimpleGoogleAuth2Params(this.params);
    }

    @Override
    public ProxySettings getProxySettingsFromConnection() {
        return ExternalInfrasUtils.getProxy(this);
    }

    @Override
    protected boolean isHardcodedModelEnabled(HardcodedVertexModel vtxModel) {
        return vtxModel.allowedModel.apply(this.params);
    }

    @Override
    protected DKULogger getLogger() {
        return logger;
    }

    @Override
    protected List<HardcodedVertexModel> listRawHardcodedModels() {
        return Arrays.asList(HardcodedVertexModel.values());
    }

    @Override
    protected List<CustomVertexModel> listRawCustomModels() {
        return this.params.customModels;
    }

    @Override
    protected VertexModel loadRawHardcodedModel(HardcodedVertexModel hardcodedModel) {
        VertexModel model = hardcodedModel.toModel();
        this.loadDefaultHardcodedModelSettings(hardcodedModel, model);
        return model;
    }

    @Override
    protected VertexModel loadRawCustomModel(CustomVertexModel rawCustomModel) {
        VertexModel model = rawCustomModel.toModel();
        this.loadDefaultCustomModelSettings(rawCustomModel, model);
        return model;
    }

    @Override
    public boolean actuallyHasPerUserOAuth2Credential() {
        return this.credentialsMode == DSSConnection.CredentialsMode.PER_USER && this.params.authType == ConnectionWithGoogleAuthCredentials.AuthType.OAUTH;
    }

    @Override
    public void decryptFields(PasswordEncryptionService cryptoService) {
        this.params.oauth2ClientSecret = cryptoService.decryptIfEncrypted(this.params.oauth2ClientSecret);
        this.params.appSecretContent = cryptoService.decryptIfEncrypted(this.params.appSecretContent);
    }

    @Override
    public void encryptFields(PasswordEncryptionService cryptoService, GeneralSettingsDAO.SecuritySettings unused) {
        this.params.oauth2ClientSecret = cryptoService.encryptIfNotEncryptedOrEmpty(this.params.oauth2ClientSecret);
        if (StringUtils.isNotBlank((String)this.params.appSecretContent) && ConnectionWithGoogleAuthCredentials.isJson(this.params.appSecretContent)) {
            this.params.appSecretContent = cryptoService.encryptIfNotEncryptedOrEmpty(this.params.appSecretContent);
        }
    }

    public static class VertexConnectionParams
    extends AbstractLLMConnection.AbstractLLMConnectionParams
    implements ConnectionWithGoogleAuthCredentials.IGoogleAuth2Params {
        public int maxParallelism = 8;
        public String project;
        public String region = "us-central1";
        public boolean allowGeminiFlash20 = false;
        public boolean allowGeminiFlashLite20 = false;
        public boolean allowGeminiFlash25 = false;
        public boolean allowGeminiPro25 = false;
        public boolean allowGeminiFlash15 = false;
        public boolean allowGeminiPro15 = false;
        public boolean allowGeminiFlash20Exp = false;
        public boolean allowGeminiFlash20ThinkingExp = false;
        public boolean allowImagen3 = false;
        public boolean allowImagen3Fast = false;
        public boolean allowGeminiTextEmb = false;
        public boolean allowTextEmb = true;
        public boolean allowTextEmb005 = false;
        public boolean allowTextMultilangEmb = true;
        public boolean allowMultimodalEmb = false;
        public List<String> otherModelIds = new ArrayList<String>();
        public List<CustomVertexModel> customModels = new ArrayList<CustomVertexModel>();
        public ConnectionWithGoogleAuthCredentials.AuthType authType = ConnectionWithGoogleAuthCredentials.AuthType.KEYPAIR;
        public String oauth2ClientId;
        public String oauth2ClientSecret;
        public String oauth2Scope;
        public String oauth2AuthorizationEndpoint;
        public String oauth2tokenEndpoint;
        public boolean refreshTokenRotation;
        public String appSecretContent;
        public String serviceAccountEmail;
        public AbstractLLMConnection.HTTPBasedLLMNetworkSettings networkSettings = new AbstractLLMConnection.HTTPBasedLLMNetworkSettings();

        @Override
        public ConnectionWithGoogleAuthCredentials.AuthType getAuthType() {
            return this.authType;
        }

        @Override
        public String getOauth2ClientId() {
            return this.oauth2ClientId;
        }

        @Override
        public String getOauth2ClientSecret() {
            return this.oauth2ClientSecret;
        }

        @Override
        public String getOauth2Scope() {
            return this.oauth2Scope;
        }

        @Override
        public String getOauth2AuthorizationEndpoint() {
            return this.oauth2AuthorizationEndpoint;
        }

        @Override
        public String getOauth2TokenEndpoint() {
            return this.oauth2tokenEndpoint;
        }

        @Override
        public String getAppSecretContent() {
            return this.appSecretContent;
        }

        @Override
        public String getServiceAccountEmail() {
            return this.serviceAccountEmail;
        }

        @Override
        public String getDefaultOauth2Scope() {
            return "https://www.googleapis.com/auth/cloud-platform";
        }

        @Override
        public boolean getRefreshTokenRotation() {
            return this.refreshTokenRotation;
        }
    }

    public static enum HardcodedVertexModel implements AbstractLLMConnection.IHardcodedConnectionModel<VertexModel>
    {
        GEMINI_FLASH_20("gemini-2.0-flash", "Gemini 2.0 Flash", VertexModelType.GEMINI_CHAT, p -> p.allowGeminiFlash20),
        GEMINI_FLASH_LITE_20("gemini-2.0-flash-lite", "Gemini 2.0 Flash-Lite", VertexModelType.GEMINI_CHAT, p -> p.allowGeminiFlashLite20),
        GEMINI_FLASH_25("gemini-2.5-flash", "Gemini 2.5 Flash", VertexModelType.GEMINI_CHAT, p -> p.allowGeminiFlash25),
        GEMINI_PRO_25("gemini-2.5-pro", "Gemini 2.5 Pro", VertexModelType.GEMINI_CHAT, p -> p.allowGeminiPro25),
        GEMINI_FLASH_15("gemini-1.5-flash", "Gemini 1.5 Flash (deprecated / scheduled for retirement on Sept 24th, 2025)", VertexModelType.GEMINI_CHAT, p -> p.allowGeminiFlash15),
        GEMINI_PRO_15("gemini-1.5-pro", "Gemini 1.5 Pro (deprecated / scheduled for retirement on Sept 24th, 2025)", VertexModelType.GEMINI_CHAT, p -> p.allowGeminiPro15),
        GEMINI_FLASH_20_THINKING_EXP("gemini-2.0-flash-thinking-exp-01-21", "Gemini 2.0 Flash Thinking Mode Experimental", VertexModelType.GEMINI_CHAT, p -> p.allowGeminiFlash20ThinkingExp),
        IMAGEN_3_001("imagen-3.0-generate-001", "Imagen 3", VertexModelType.IMAGE_GENERATION, p -> p.allowImagen3),
        IMAGEN_3_FAST_001("imagen-3.0-fast-generate-001", "Imagen 3 Fast", VertexModelType.IMAGE_GENERATION, p -> p.allowImagen3Fast),
        GEMINI_EMBEDDING("gemini-embedding-001", "Gemini Text embedding", VertexModelType.TEXT_EMBEDDING_EXTRACTION, p -> p.allowGeminiTextEmb, 2048),
        TEXT_EMBEDDING("text-embedding-004", "Text embedding (update 004)", VertexModelType.TEXT_EMBEDDING_EXTRACTION, p -> p.allowTextEmb, 2048),
        TEXT_EMBEDDING_005("text-embedding-005", "Text embedding (update 005)", VertexModelType.TEXT_EMBEDDING_EXTRACTION, p -> p.allowTextEmb005, 2048),
        TEXT_MULTILANG_EMBEDDING("text-multilingual-embedding-002", "Text embedding (multilingual)", VertexModelType.TEXT_EMBEDDING_EXTRACTION, p -> p.allowTextMultilangEmb, 2048),
        MULTIMODAL_EMBEDDING("multimodalembedding@001", "Multimodal embedding", VertexModelType.TEXT_IMAGE_EMBEDDING_EXTRACTION, p -> p.allowMultimodalEmb),
        GEMINI_FLASH_20_EXP("gemini-2.0-flash-exp", "Gemini 2.0 Flash Experimental (deprecated)", VertexModelType.GEMINI_CHAT, p -> p.allowGeminiFlash20Exp);

        public final String id;
        public final String displayName;
        public final VertexModelType modelType;
        public final Integer maxTokensLimit;
        public final Function<VertexConnectionParams, Boolean> allowedModel;

        private HardcodedVertexModel(String id, String displayName, VertexModelType modelType, Function<VertexConnectionParams, Boolean> allowedModel) {
            this(id, displayName, modelType, allowedModel, null);
        }

        private HardcodedVertexModel(String id, String displayName, VertexModelType modelType, Function<VertexConnectionParams, Boolean> allowedModel, Integer maxTokensLimit) {
            this.id = id;
            this.displayName = displayName;
            this.modelType = modelType;
            this.allowedModel = allowedModel;
            this.maxTokensLimit = maxTokensLimit;
        }

        @Override
        public VertexModel toModel() {
            VertexModel model = new VertexModel();
            model.id = this.id;
            model.displayName = this.displayName;
            model.modelType = this.modelType;
            model.maxTokensLimit = this.maxTokensLimit;
            return model;
        }
    }

    public static class VertexModel
    extends AbstractLLMConnection.BaseModel {
        public VertexModelType modelType;

        @Override
        public double getEstimatedCompletionCost(Integer promptCount, Integer completionCount) {
            throw new UnsupportedOperationException("Vertex Model does not support tokens cost estimation " + String.valueOf(this.getClass()));
        }

        public double getVertexEstimatedCompletionCost(List<LLMClient.ChatMessage> chatMessages, LLMClient.SimpleCompletionResponse scr) {
            Double promptCost = VertexPricing.getVertexPromptCostPer1kTokens(this.id);
            Double completionCost = VertexPricing.getVertexAICompletionCostPer1kTokens(this.id);
            if (promptCost != null && completionCost != null) {
                double totalCost = 0.0;
                if (scr.promptTokens != null) {
                    totalCost += promptCost * (double)scr.promptTokens.intValue() / 1000.0;
                }
                if (scr.completionTokens != null) {
                    totalCost += completionCost * (double)scr.completionTokens.intValue() / 1000.0;
                }
                return totalCost;
            }
            int inputChars = chatMessages.stream().map(c2 -> c2.getTextEvenIfNotTextOnly()).collect(Collectors.joining("\n")).replaceAll(" ", "").length();
            int outputChars = scr.text.length();
            promptCost = VertexPricing.getVertexPromptCostPer1kChars(this.id);
            completionCost = VertexPricing.getVertexAICompletionCostPer1kChars(this.id);
            if (promptCost == null || completionCost == null) {
                return 0.0;
            }
            return (promptCost * (double)inputChars + completionCost * (double)outputChars) / 1000.0;
        }

        @Override
        public double getEstimatedEmbeddingCost(Integer promptTokens, int nb_images) {
            throw new UnsupportedOperationException("Vertex Model does not support tokens cost estimation " + String.valueOf(this.getClass()));
        }

        public double getVertexEstimatedEmbeddingCost(LLMClient.EmbeddingQuery query) {
            int textChars = query.hasText() ? query.text.length() : 0;
            double textEmbeddingCost = VertexPricing.getVertexAIEmbeddingCostPer1KChars(this.id);
            boolean nbImages = query.hasImage();
            double imageEmbeddingCost = VertexPricing.getVertexAIEmbeddingCostPerImage(this.id);
            return (double)textChars * textEmbeddingCost / 1000.0 + (double)nbImages * imageEmbeddingCost;
        }

        @Override
        public double getEstimatedImageGenerationCost(LLMClient.ImageGenerationQuery query) {
            if (StringUtils.isBlank((String)this.id)) {
                return 0.0;
            }
            if (HardcodedVertexModel.IMAGEN_3_001.id.equals(this.id)) {
                return 0.04;
            }
            if (HardcodedVertexModel.IMAGEN_3_FAST_001.id.equals(this.id)) {
                return 0.02;
            }
            logger.warn((Object)("Unknown pricing for VertexAI model: " + this.id));
            return 0.0;
        }

        @Override
        public AbstractLLMConnection.ModelCapabilities getModelCapabilities() {
            AbstractLLMConnection.ModelCapabilities capabilities = new AbstractLLMConnection.ModelCapabilities();
            capabilities.handlesSystemMessage = this.modelType.handleSystemPrompts;
            capabilities.supportsImageInputs = this.canBeUsedForPurpose(AbstractLLMConnection.LLMUsagePurpose.IMAGE_INPUT);
            capabilities.temperatureRange = TEMPERATURE_RANGE;
            capabilities.topKRange = TOP_K_RANGE;
            return capabilities;
        }

        @Override
        public LLMStructuredRef asStructuredRef(String connection) {
            return LLMStructuredRef.forVertexConnection(connection, this.getId());
        }

        @Override
        public boolean canBeUsedForPurpose(@Nonnull AbstractLLMConnection.LLMUsagePurpose purpose) {
            return this.modelType.matchesPurpose(purpose);
        }
    }

    public static class CustomVertexModel
    extends AbstractLLMConnection.CustomModel<VertexModel> {
        public String id;
        public String displayName;
        public VertexModelType modelType;

        @Override
        public VertexModel toModel() {
            VertexModel model = new VertexModel();
            model.loadFromCustomModel(this);
            model.id = this.id;
            model.displayName = this.displayName;
            model.modelType = this.modelType;
            return model;
        }
    }

    public static enum VertexModelType {
        GEMINI_CHAT(AbstractLLMConnection.PROMPT_DRIVEN_PURPOSE_SET_WITH_VISION, true),
        TEXT_EMBEDDING_EXTRACTION(EnumSet.of(AbstractLLMConnection.LLMUsagePurpose.TEXT_EMBEDDING_EXTRACTION), false),
        TEXT_IMAGE_EMBEDDING_EXTRACTION(EnumSet.of(AbstractLLMConnection.LLMUsagePurpose.TEXT_EMBEDDING_EXTRACTION, AbstractLLMConnection.LLMUsagePurpose.IMAGE_EMBEDDING_EXTRACTION), false),
        IMAGE_GENERATION(EnumSet.of(AbstractLLMConnection.LLMUsagePurpose.IMAGE_GENERATION), false);

        public final Set<AbstractLLMConnection.LLMUsagePurpose> matchingPurposes;
        public final boolean handleSystemPrompts;

        private VertexModelType(Set<AbstractLLMConnection.LLMUsagePurpose> purposes, boolean handleSystemPrompts) {
            this.matchingPurposes = purposes;
            this.handleSystemPrompts = handleSystemPrompts;
        }

        public boolean matchesPurpose(AbstractLLMConnection.LLMUsagePurpose purpose) {
            return this.matchingPurposes.contains((Object)purpose);
        }
    }
}

