/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.lambda.client;

import com.dataiku.common.server.SerializedError;
import com.dataiku.dip.ApplicationConfigurator;
import com.dataiku.dip.DKUApp;
import com.dataiku.dip.apideployer.datamodel.actual.APIServiceDeploymentHeavyStatus;
import com.dataiku.dip.apideployer.datamodel.config.AbstractFullyManagedAPIDeploymentInfra;
import com.dataiku.dip.connections.VertexAIModelDeploymentConnection;
import com.dataiku.dip.exceptions.DKUSecurityException;
import com.dataiku.dip.externalinfras.ExternalInfrasUtils;
import com.dataiku.dip.externalinfras.vertexai.VertexAIUtils;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.utils.JSON;
import com.dataiku.dss.shadelib.com.google.auth.oauth2.GoogleCredentials;
import com.dataiku.dss.shadelib.com.google.cloud.aiplatform.v1.PredictResponse;
import com.dataiku.dss.shadelib.com.google.protobuf.MessageOrBuilder;
import com.dataiku.dss.shadelib.com.google.protobuf.util.JsonFormat;
import com.dataiku.lambda.client.BaseLambdaAPIClient;
import com.dataiku.lambda.model.serverconfig.LambdaEndpointConfig;
import com.dataiku.lambda.model.studioconfig.ApiEndpointQuery;
import com.google.gson.JsonArray;
import com.google.gson.JsonElement;
import com.google.gson.JsonObject;
import java.io.IOException;
import java.net.URISyntaxException;
import java.util.Collection;
import org.apache.log4j.Logger;

public class VertexAILambdaAPIClient
implements BaseLambdaAPIClient {
    private static final String FEATURES_PROPERTY = "features";
    private static final String INSTANCES_PROPERTY = "instances";
    private static final Logger logger = Logger.getLogger((String)"dku.lambda.test.vertex_ai");

    private VertexAILambdaAPIClient() {
    }

    public static BaseLambdaAPIClient.ApiEndpointResponses runQueries_NT(AuthCtx authCtx, AbstractFullyManagedAPIDeploymentInfra infra, String gcpProject, String gcpRegion, String vertexAIEndpointId, APIServiceDeploymentHeavyStatus.EndpointSummary endpoint, Collection<ApiEndpointQuery> testQueries, boolean forTest) throws IOException, URISyntaxException, DKUSecurityException {
        VertexAIModelDeploymentConnection connection = (VertexAIModelDeploymentConnection)ExternalInfrasUtils.getAndCheckConnection(authCtx, infra.authConnection);
        GoogleCredentials credentials = VertexAIUtils.getGoogleCredentials_NT(authCtx, connection);
        BaseLambdaAPIClient.ApiEndpointResponses apiEndpointResponses = new BaseLambdaAPIClient.ApiEndpointResponses();
        for (ApiEndpointQuery tq : testQueries) {
            BaseLambdaAPIClient.ResponseOrError roe = new BaseLambdaAPIClient.ResponseOrError();
            roe.query = tq;
            try {
                JsonArray vertexQuery = VertexAILambdaAPIClient.createQuery(tq, endpoint.type);
                String vertexQueryJSON = JSON.json((Object)vertexQuery);
                logger.debug((Object)("Submit query with name `" + tq.name + "` and body `" + vertexQueryJSON + "` to Vertex AI endpoint " + vertexAIEndpointId));
                VertexAIUtils utils = VertexAIUtils.from(gcpRegion, gcpProject, credentials, connection);
                PredictResponse response = utils.predict_NT(vertexAIEndpointId, vertexQuery);
                String responseString = JsonFormat.printer().print((MessageOrBuilder)response);
                roe.response = JSON.parse((String)responseString, JsonObject.class);
            }
            catch (Exception e) {
                logger.warn((Object)("Failure while trying to send Query " + JSON.json((Object)tq.q) + " to Vertex AI endpoint " + vertexAIEndpointId + "."), (Throwable)e);
                roe.error = new SerializedError((Throwable)e, !ApplicationConfigurator.hideErrorStacks(), !DKUApp.hideErrorStacks(), !ApplicationConfigurator.hideLogTails());
            }
            apiEndpointResponses.responses.add(roe);
        }
        return apiEndpointResponses;
    }

    public static ApiEndpointQuery updateQueryToVertexAIExpectedFormat(ApiEndpointQuery query, LambdaEndpointConfig.EndpointType endpointType) {
        switch (endpointType) {
            case STD_PREDICTION: 
            case STD_CAUSAL_PREDICTION: 
            case STD_CLUSTERING: 
            case CUSTOM_PREDICTION: {
                JsonObject extractedQuery = VertexAILambdaAPIClient.extractFeaturesContent(query.q);
                JsonObject updatedQuery = VertexAILambdaAPIClient.embedIntoInstancesDict(extractedQuery);
                return new ApiEndpointQuery(query.name, updatedQuery);
            }
            case PY_FUNCTION: {
                JsonObject updatedQuery = VertexAILambdaAPIClient.embedIntoInstancesDict(query.q);
                return new ApiEndpointQuery(query.name, updatedQuery);
            }
        }
        return query;
    }

    private static JsonObject extractFeaturesContent(JsonObject query) {
        if (query.has(FEATURES_PROPERTY) && query.get(FEATURES_PROPERTY).isJsonObject()) {
            return query.get(FEATURES_PROPERTY).getAsJsonObject();
        }
        return query;
    }

    private static JsonObject embedIntoInstancesDict(JsonObject query) {
        JsonArray instances = new JsonArray();
        instances.add((JsonElement)query);
        JsonObject result = new JsonObject();
        result.add(INSTANCES_PROPERTY, (JsonElement)instances);
        return result;
    }

    private static JsonArray createQuery(ApiEndpointQuery tq, LambdaEndpointConfig.EndpointType endpointType) {
        switch (endpointType) {
            case STD_PREDICTION: 
            case STD_CAUSAL_PREDICTION: 
            case STD_CLUSTERING: 
            case CUSTOM_PREDICTION: 
            case PY_FUNCTION: {
                return VertexAILambdaAPIClient.createVertexPredictionQuery(tq.q);
            }
            case CUSTOM_R_PREDICTION: 
            case STD_FORECAST: 
            case DATASETS_LOOKUP: 
            case R_FUNCTION: 
            case SQL_QUERY: {
                throw new IllegalArgumentException("Unsupported endpoint type " + String.valueOf((Object)endpointType));
            }
        }
        throw new Error("Unreachable");
    }

    private static JsonArray createVertexPredictionQuery(JsonObject content) {
        if (content.has(INSTANCES_PROPERTY)) {
            return content.get(INSTANCES_PROPERTY).getAsJsonArray();
        }
        throw new IllegalArgumentException("'instances' field is required");
    }
}

