/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.llm.online.openai;

import com.dataiku.dip.connections.AbstractLLMConnection;
import com.dataiku.dip.connections.AzureOpenAIConnection;
import com.dataiku.dip.connections.ConnectionWithAzureAuthCredentials;
import com.dataiku.dip.connections.ConnectionWithBasicCredential;
import com.dataiku.dip.connections.OpenAIConnection;
import com.dataiku.dip.exceptions.DKUSecurityException;
import com.dataiku.dip.externalinfras.azureml.AzureMLUtils;
import com.dataiku.dip.externalinfras.azureml.http.AzureMLHttpClient;
import com.dataiku.dip.llm.LLMModelHandle;
import com.dataiku.dip.llm.online.ISavedModelDeployer;
import com.dataiku.dip.llm.online.LLMChatMessageUtils;
import com.dataiku.dip.llm.online.LLMClient;
import com.dataiku.dip.llm.online.RemoteFineTuningClient;
import com.dataiku.dip.llm.online.openai.AzureClientBase;
import com.dataiku.dip.llm.online.openai.AzureOpenAIFineTuningClient;
import com.dataiku.dip.llm.online.openai.AzureOpenAISavedModelDeployer;
import com.dataiku.dip.llm.online.openai.OpenAIImageHandling;
import com.dataiku.dip.llm.online.openai.RawOpenAIClient;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.utils.DKULogger;
import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
import javax.annotation.Nullable;

public class AzureOpenAIClient
extends AzureClientBase<AzureOpenAIConnection.AzureOpenAIModel> {
    private final AzureOpenAIConnection connection;
    private RawOpenAIClient raw;
    private AzureMLHttpClient azureMLHttpClient;
    private static DKULogger logger = DKULogger.getLogger((String)"dku.llm.azureopenai");

    public AzureOpenAIClient(AzureOpenAIConnection connection, LLMModelHandle<AzureOpenAIConnection.AzureOpenAIModel> modelHandle, AuthCtx authCtx, @Nullable String projectKey) throws DKUSecurityException, IOException, AzureMLUtils.AzureAuthenticationException {
        super(modelHandle.getEnrichedRef(), modelHandle.getModel(), connection.params.networkSettings);
        this.connection = connection;
        boolean trustAllSSLCertificates = connection.getDkuPropertiesAsParams().getBoolParam("dku.connection.llm.trustAllSSLCertificates", false);
        ConnectionWithAzureAuthCredentials.SerializableAzureAuthCredentials creds = connection.getFullyResolvedCredentials_fsLike(new ConnectionWithBasicCredential.CredentialResolutionContext(authCtx, null), ConnectionWithAzureAuthCredentials.SerializableAzureAuthCredentials.class);
        AbstractLLMConnection.HTTPBasedLLMNetworkSettings networkSettings = this.queryRunner.getHttpClientNetworkSettings();
        switch (creds.authType) {
            case OAUTH2_APP: {
                this.raw = RawOpenAIClient.forAzureWithOAuthToken(connection.params.resourceName, creds.oauth2AccessToken, connection.params.customHeaders, projectKey, networkSettings, connection.getProxySettings(), trustAllSSLCertificates);
                break;
            }
            case KEY: {
                this.raw = RawOpenAIClient.forAzureWithAPIKey(connection.params.resourceName, creds.key, connection.params.customHeaders, projectKey, networkSettings, connection.getProxySettings(), trustAllSSLCertificates);
            }
        }
    }

    @Override
    public List<LLMClient.ChatMessage> getFormattedPrompt(List<LLMClient.ChatMessage> chatMessages) {
        if (((AzureOpenAIConnection.AzureOpenAIModel)this.model).deploymentType == OpenAIConnection.OpenAIModelType.COMPLETION_CHAT_NO_SYSTEM_PROMPT) {
            chatMessages = LLMChatMessageUtils.convertMessageRole(chatMessages, "system", "user");
        }
        return super.getFormattedPrompt(chatMessages);
    }

    public AzureOpenAIConnection.AzureOpenAIModel getModel() {
        return (AzureOpenAIConnection.AzureOpenAIModel)this.model;
    }

    @Override
    public RawOpenAIClient getRaw() {
        return this.raw;
    }

    @Override
    public boolean isChatModel() {
        return ((AzureOpenAIConnection.AzureOpenAIModel)this.model).useChatApi;
    }

    @Override
    public String getLogPrefix() {
        return "AzureOpenAI";
    }

    @Override
    public String getAzureModelId() {
        return ((AzureOpenAIConnection.AzureOpenAIModel)this.model).deploymentId;
    }

    @Override
    public OpenAIImageHandling getImageHandlingMode() {
        return ((AzureOpenAIConnection.AzureOpenAIModel)this.model).imageHandlingMode;
    }

    @Override
    public boolean requiresCostLimiting() {
        return true;
    }

    @Override
    public String getProviderId() {
        return "AzureOpenAI";
    }

    @Override
    public AbstractLLMConnection getConnection() {
        return this.connection;
    }

    @Override
    public int getMaxParallelism() {
        return this.connection.params.maxParallelism;
    }

    @Override
    public RemoteFineTuningClient newFineTuningClient() {
        return new AzureOpenAIFineTuningClient(this);
    }

    @Override
    public ISavedModelDeployer newSavedModelDeployer(AuthCtx authCtx) throws AzureMLUtils.AzureAuthenticationException, IOException, DKUSecurityException {
        if (this.connection.params.azureMLConnection != null && this.azureMLHttpClient == null) {
            int defaultConnectTimeout = 120000;
            int defaultSocketTimeout = 60000;
            this.azureMLHttpClient = AzureMLUtils.getAzureMLClient_NT(authCtx, this.connection.params.azureMLConnection, defaultConnectTimeout, defaultSocketTimeout);
        }
        String resourceName = this.connection.params.resourceName;
        if (this.connection.params.resourceName.startsWith("http://") || this.connection.params.resourceName.startsWith("https://")) {
            resourceName = Arrays.stream(((String)Arrays.stream(this.connection.params.resourceName.split("://")).collect(Collectors.toList()).get(1)).split(".openai.azure.com/openai")).findFirst().get();
        }
        if (this.azureMLHttpClient == null) {
            throw new IllegalArgumentException("Error creating the Azure ML Http client. Make sure you have an Azure ML connection in your Azure Open AI Connection");
        }
        if (this.connection.params.subscriptionId == null || this.connection.params.subscriptionId.isEmpty()) {
            throw new IllegalArgumentException("Please define a subscription id in your Azure Open AI connection '" + this.connection.name + "' to access your model deployments.");
        }
        if (this.connection.params.resourceGroup == null || this.connection.params.resourceGroup.isEmpty()) {
            throw new IllegalArgumentException("Please define a resource group in your Azure Open AI connection '" + this.connection.name + "' to access your model deployments.");
        }
        return new AzureOpenAISavedModelDeployer(this.azureMLHttpClient, this.connection.params.subscriptionId, this.connection.params.resourceGroup, resourceName);
    }

    @Override
    protected DKULogger getLogger() {
        return logger;
    }
}

