/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.connections;

import com.dataiku.dip.ProxySettings;
import com.dataiku.dip.activity.UsageSummaryModel;
import com.dataiku.dip.connections.AbstractLLMConnection;
import com.dataiku.dip.connections.AbstractSQLConnection;
import com.dataiku.dip.connections.ConnectionWithAzureAuthCredentials;
import com.dataiku.dip.connections.ConnectionWithBasicCredential;
import com.dataiku.dip.connections.OpenAIConnection;
import com.dataiku.dip.dao.GeneralSettingsDAO;
import com.dataiku.dip.exceptions.DKUSecurityException;
import com.dataiku.dip.llm.LLMModelHandle;
import com.dataiku.dip.llm.LLMStructuredRef;
import com.dataiku.dip.llm.online.openai.OpenAIChatAPI;
import com.dataiku.dip.llm.online.openai.OpenAIImageHandling;
import com.dataiku.dip.llm.online.openai.OpenAIPricing;
import com.dataiku.dip.llm.online.openai.OpenAIRerankingHandlingMode;
import com.dataiku.dip.security.PasswordEncryptionService;
import com.dataiku.dip.utils.ErrorContext;
import com.dataiku.dip.utils.Params;
import com.dataiku.dip.variables.VariablesContext;
import java.io.IOException;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.apache.commons.lang.StringUtils;

public abstract class AbstractAzureAIConnection<D extends AbstractAzureAIDeployment, HM extends AbstractLLMConnection.IHardcodedConnectionModel<AbstractAzureAIModel>, CM extends AbstractLLMConnection.CustomModel<AbstractAzureAIModel>>
extends AbstractLLMConnection<AbstractAzureAIModel, HM, CM>
implements ConnectionWithAzureAuthCredentials {
    private static final String DEFAULT_OAUTH_SCOPE = "https://cognitiveservices.azure.com/.default";
    private static final String DEFAULT_OAUTH_USER_IMPERSONATION_SCOPE = "https://cognitiveservices.azure.com/user_impersonation offline_access";
    public BaseAzureAIConnectionParams<D> params = new BaseAzureAIConnectionParams();

    @Override
    public AbstractLLMConnection.AbstractLLMConnectionParams getLLMConnectionParams() {
        return this.params;
    }

    @Override
    public ConnectionWithAzureAuthCredentials.IAzureAuthParams getAzureAuth2NonResolvedParams() {
        return this.params;
    }

    @Override
    protected List<CM> listRawCustomModels() {
        ArrayList<AbstractLLMConnection.CustomModel> customModels = new ArrayList<AbstractLLMConnection.CustomModel>();
        for (AbstractAzureAIDeployment d : this.params.availableDeployments) {
            customModels.add(d);
        }
        for (AbstractCustomAzureAIModel m : this.params.customModels) {
            customModels.add(m);
        }
        return customModels;
    }

    @Override
    protected AbstractAzureAIModel loadRawCustomModel(CM rawCustomModel) {
        if (!(rawCustomModel instanceof AbstractAzureAIDeployment) && !(rawCustomModel instanceof AbstractCustomAzureAIModel)) {
            throw new IllegalStateException("Unexpected custom model: " + String.valueOf(rawCustomModel));
        }
        AbstractAzureAIModel model = (AbstractAzureAIModel)((AbstractLLMConnection.CustomModel)rawCustomModel).toModel();
        this.loadDefaultCustomModelSettings(model);
        return model;
    }

    @Override
    public ProxySettings getProxySettingsFromConnection() {
        return this.getProxySettings();
    }

    @Override
    public String getDefaultAuthScope() {
        return DEFAULT_OAUTH_SCOPE;
    }

    @Override
    public String getDefaultAuthUserImpersonationScope() {
        return DEFAULT_OAUTH_USER_IMPERSONATION_SCOPE;
    }

    @Override
    public Params getDkuPropertiesAsParams() {
        return AbstractSQLConnection.CustomDatabaseProperty.toParams(this.getDkuProperties());
    }

    @Override
    public boolean mustResolveOnBackend() {
        return this.hasRefreshTokenRotation() || super.mustResolveOnBackend();
    }

    @Override
    public void expandParametersInPlaceAtDAOLevelUsingGlobalContextOnly(VariablesContext vc) {
    }

    @Override
    protected <T> T getFullyResolvedCredentials_internal(ConnectionWithBasicCredential.CredentialResolutionContext ctx, Class<T> clazz) throws DKUSecurityException, IOException, SQLException {
        assert (clazz.isAssignableFrom(ConnectionWithAzureAuthCredentials.SerializableAzureAuthCredentials.class));
        ConnectionWithAzureAuthCredentials.SerializableAzureAuthCredentials creds = this.getFullyResolvedAzureAuthCredentials(ctx);
        return clazz.cast(creds);
    }

    @Override
    public boolean ignoreConnectionTest(LLMStructuredRef llmRef) {
        return llmRef.deployment == null;
    }

    @Override
    protected String generateUniqueModelIdentifier(AbstractAzureAIModel model) {
        return (model.isDeployment() ? "deployment:" : "model:") + model.getId();
    }

    @Override
    public abstract String getType();

    @Override
    protected void encryptLocalFields(PasswordEncryptionService cryptoService, GeneralSettingsDAO.SecuritySettings securitySettings) {
        if (securitySettings.secureSecretKeys) {
            this.params.apiKey = cryptoService.encryptIfNotEncryptedOrEmpty(this.params.apiKey);
            this.params.appSecret = cryptoService.encryptIfNotEncryptedOrEmpty(this.params.appSecret);
            this.params.encryptProperties(this.params.customHeaders, cryptoService);
        }
    }

    @Override
    protected void decryptLocalFields(PasswordEncryptionService cryptoService) {
        this.params.apiKey = cryptoService.decryptIfEncrypted(this.params.apiKey);
        this.params.appSecret = cryptoService.decryptIfEncrypted(this.params.appSecret);
        this.params.decryptProperties(this.params.customHeaders, cryptoService);
    }

    @Override
    public void fillModelsForGlobalSummaryReport(UsageSummaryModel.LLMConnectionSummary lcs) {
        for (AbstractAzureAIDeployment depl : this.params.availableDeployments) {
            if (StringUtils.isBlank((String)depl.underlyingModelName)) {
                if (depl.deploymentType == OpenAIConnection.OpenAIModelType.TEXT_EMBEDDING_EXTRACTION) {
                    ++lcs.enabledTextEmbeddingCustomModels;
                    continue;
                }
                if (depl.deploymentType == OpenAIConnection.OpenAIModelType.IMAGE_GENERATION) continue;
                ++lcs.enabledCompletionCustomModels;
                continue;
            }
            if (depl.deploymentType == OpenAIConnection.OpenAIModelType.TEXT_EMBEDDING_EXTRACTION) {
                lcs.enabledTextEmbeddingStandardModels.add(depl.underlyingModelName);
                continue;
            }
            if (depl.deploymentType == OpenAIConnection.OpenAIModelType.IMAGE_GENERATION) continue;
            lcs.enabledCompletionStandardModels.add(depl.underlyingModelName);
        }
    }

    @Override
    public Map<String, Object> getConsistencyCheckables() {
        Map<String, Object> consistencyCheckables = super.getConsistencyCheckables();
        consistencyCheckables.put("Resource name / URL", this.params.resourceName);
        return consistencyCheckables;
    }

    public static class BaseAzureAIConnectionParams<D extends AbstractAzureAIDeployment>
    extends AbstractLLMConnection.AbstractLLMConnectionParams
    implements ConnectionWithAzureAuthCredentials.IAzureAuthParams {
        public AbstractLLMConnection.HTTPBasedLLMNetworkSettings networkSettings = new AbstractLLMConnection.HTTPBasedLLMNetworkSettings();
        public String resourceName;
        public List<D> availableDeployments = new ArrayList<D>();
        public List<AbstractCustomAzureAIModel> customModels = new ArrayList<AbstractCustomAzureAIModel>();
        public String apiKey;
        public AuthType authType = AuthType.API_KEY;
        public String tenantId;
        public String appId;
        public String appSecret;
        public String tokenEndpoint;
        public String authorizationEndpoint;
        public boolean refreshTokenRotation;
        public int maxParallelism = 8;
        public List<AbstractSQLConnection.CustomDatabaseProperty> customHeaders = new ArrayList<AbstractSQLConnection.CustomDatabaseProperty>();
        public boolean allowDavinciFinetuning;
        public boolean allowBabbageFinetuning;
        public boolean allowGPT35Turbo0125Finetuning;
        public boolean allowGPT35Turbo0613Finetuning;
        public boolean allowGPT35Turbo1106Finetuning;
        public boolean allowGPT4OMiniFinetuning;
        public boolean allowGPT4OFinetuning;
        public boolean allowGPT41MiniFinetuning;
        public boolean allowGPT41Finetuning;
        public boolean allowGPT41NanoFinetuning;
        public boolean allowGPTO4Mini;
        public String subscriptionId;
        public String resourceGroup;
        public String azureMLConnection;

        @Override
        public ConnectionWithAzureAuthCredentials.AuthType getAuthType() {
            return switch (this.authType) {
                default -> throw new IncompatibleClassChangeError();
                case AuthType.API_KEY -> ConnectionWithAzureAuthCredentials.AuthType.KEY;
                case AuthType.OAUTH2_APP -> ConnectionWithAzureAuthCredentials.AuthType.OAUTH2_APP;
            };
        }

        @Override
        public String getKey() {
            return this.apiKey;
        }

        @Override
        public String getOauth2TenantId() {
            return this.tenantId;
        }

        @Override
        public String getOauth2AppId() {
            return this.appId;
        }

        @Override
        public String getOauth2AppSecret() {
            return this.appSecret;
        }

        @Override
        public String getOauth2AuthorizationEndpoint() {
            return this.authorizationEndpoint;
        }

        @Override
        public String getOauth2TokenEndpoint() {
            return this.tokenEndpoint;
        }

        @Override
        public boolean getRefreshTokenRotation() {
            return this.refreshTokenRotation;
        }
    }

    public static abstract class AbstractAzureAIDeployment
    extends AbstractLLMConnection.CustomModel<AbstractAzureAIModel> {
        public String name;
        public OpenAIConnection.OpenAIModelType deploymentType;
        public String underlyingModelName;
        @Nullable
        public OpenAIImageHandling imageHandlingMode = this.getDefaultImageHandlingMode();
        @Nullable
        public OpenAIRerankingHandlingMode rerankingHandlingMode;
        @Nullable
        public OpenAIChatAPI api = OpenAIChatAPI.CHAT_COMPLETIONS;
        public boolean isReasoning = false;

        protected abstract String getType();

        protected abstract OpenAIImageHandling getDefaultImageHandlingMode();

        abstract AzureOpenAIMaxTokensAPIMode getMaxTokensAPIMode();

        protected abstract AbstractAzureAIModel initModel();

        @Override
        public AbstractAzureAIModel toModel() {
            if (this.deploymentType == null) {
                throw ErrorContext.isef((String)"Undefined deployment type for %s deployment %s", (Object)this.getType(), (Object[])new Object[]{this.name});
            }
            AbstractAzureAIModel model = this.initModel();
            model.loadFromCustomModel(this);
            model.id = this.name;
            model.deploymentId = this.name;
            model.displayName = this.name;
            model.deploymentType = this.deploymentType;
            model.maxTokensAPIMode = this.getMaxTokensAPIMode();
            model.useChatApi = this.deploymentType.useChatAPI;
            model.canBeFineTuned = false;
            model.underlyingModelName = this.underlyingModelName;
            if (this.deploymentType == OpenAIConnection.OpenAIModelType.TEXT_EMBEDDING_EXTRACTION) {
                model.embeddingCost = this.embeddingCost != null ? this.embeddingCost : OpenAIPricing.getAzureOpenAIEmbeddingCostPer1KTokens(this.underlyingModelName);
            } else {
                model.promptCost = this.promptCost != null ? this.promptCost : OpenAIPricing.getAzureOpenAIPromptCostPer1KTokens(this.underlyingModelName);
                model.completionCost = this.completionCost != null ? this.completionCost : OpenAIPricing.getAzureOpenAICompletionCostPer1KTokens(this.underlyingModelName);
            }
            if (this.deploymentType == OpenAIConnection.OpenAIModelType.IMAGE_GENERATION && this.imageHandlingMode != null) {
                model.imageHandlingMode = this.imageHandlingMode;
            }
            if (this.deploymentType == OpenAIConnection.OpenAIModelType.RERANKING && this.rerankingHandlingMode != null) {
                model.rerankingHandlingMode = this.rerankingHandlingMode;
                model.rerankingRequestCost = this.rerankingRequestCost;
                model.rerankingTokenCost = 0.0;
                model.rerankingDocumentCost = 0.0;
            }
            if (this.deploymentType == OpenAIConnection.OpenAIModelType.COMPLETION_CHAT || this.deploymentType == OpenAIConnection.OpenAIModelType.COMPLETION_CHAT_MULTIMODAL || this.deploymentType == OpenAIConnection.OpenAIModelType.COMPLETION_CHAT_NO_SYSTEM_PROMPT) {
                model.api = this.api;
                model.isReasoning = this.isReasoning;
            }
            return model;
        }
    }

    public static abstract class AbstractCustomAzureAIModel
    extends AbstractLLMConnection.CustomModel<AbstractAzureAIModel> {
        public String id;
        public String displayName;
        public boolean useChatAPI;

        protected abstract AbstractAzureAIModel initModel();

        @Override
        public AbstractAzureAIModel toModel() {
            AbstractAzureAIModel model = this.initModel();
            model.loadFromCustomModel(this);
            model.id = this.id;
            model.underlyingModelName = this.id;
            model.displayName = this.displayName;
            model.canBeFineTuned = this.canBeFineTuned;
            model.useChatApi = this.useChatAPI;
            model.deploymentType = null;
            return model;
        }
    }

    public static abstract class AbstractAzureAIModel
    extends AbstractLLMConnection.BaseModel
    implements LLMModelHandle.FineTuneableModel<AbstractAzureAIModel> {
        @Nullable
        public String baseModelId;
        @Nullable
        public OpenAIConnection.OpenAIModelType deploymentType;
        @Nullable
        public AzureOpenAIMaxTokensAPIMode maxTokensAPIMode;
        @Nullable
        public String deploymentId;
        public String underlyingModelName;
        public boolean useChatApi;
        public boolean isDKUFineTuned = false;
        @Nullable
        public OpenAIImageHandling imageHandlingMode;
        @Nullable
        public OpenAIRerankingHandlingMode rerankingHandlingMode;
        @Nullable
        public OpenAIChatAPI api;
        public boolean isReasoning;

        public boolean isDeployment() {
            return this.deploymentType != null;
        }

        @Override
        public AbstractLLMConnection.ModelCapabilities getModelCapabilities() {
            AbstractLLMConnection.ModelCapabilities capabilities = new AbstractLLMConnection.ModelCapabilities();
            capabilities.canGenerateCrossLanguageOutput = true;
            capabilities.handlesSystemMessage = this.useChatApi;
            capabilities.supportsImageInputs = this.canBeUsedForPurpose(AbstractLLMConnection.LLMUsagePurpose.IMAGE_INPUT);
            capabilities.temperatureRange = OpenAIConnection.TEMPERATURE_RANGE;
            capabilities.topKRange = OpenAIConnection.TOP_K_RANGE;
            capabilities.isReasoning = this.isReasoning;
            return capabilities;
        }

        @Override
        public Optional<String> getInvalidityReason() {
            if (StringUtils.isBlank((String)this.id)) {
                return Optional.of("Empty model/deployment name");
            }
            return Optional.empty();
        }

        @Override
        public boolean canBeUsedForPurpose(@Nonnull AbstractLLMConnection.LLMUsagePurpose purpose) {
            if (this.isDeployment()) {
                return this.deploymentType.matchesPurpose(purpose);
            }
            return this.canBeFineTuned() && purpose.equals((Object)AbstractLLMConnection.LLMUsagePurpose.FINE_TUNING);
        }

        @Override
        public String getBaseModelId() {
            return this.isDKUFineTuned ? this.baseModelId : null;
        }
    }

    public static enum AuthType {
        OAUTH2_APP,
        API_KEY;

    }

    public static enum AzureOpenAIMaxTokensAPIMode {
        MODERN,
        LEGACY;

    }
}

