/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.lambda.endpoints.predict;

import com.dataiku.dip.ApplicationConfigurator;
import com.dataiku.dip.dao.GeneralSettingsDAO;
import com.dataiku.dip.externalml.mlflow.MLFlowInputOutput;
import com.dataiku.dip.io.SocketBlockLinkException;
import com.dataiku.dip.kernels.IDSSKernelBase;
import com.dataiku.dip.utils.JSON;
import com.dataiku.kernels.AbstractLambdaPythonKernel;
import com.dataiku.lambda.ServiceGenContext;
import com.dataiku.lambda.endpoints.predict.MLFlowPyfuncPredictionKernel;
import com.dataiku.lambda.endpoints.predict.PredictionModelPredictionStep;
import com.dataiku.lambda.endpoints.predictcommon.AbstractPythonPredictionStep;
import com.dataiku.lambda.endpoints.predictcommon.PipelineMessage;
import com.dataiku.lambda.model.api.PredictionResponse;
import com.dataiku.lambda.model.serverconfig.PredictionEndpointConfig;
import com.google.gson.JsonArray;
import com.google.gson.JsonElement;
import com.google.gson.JsonNull;
import com.google.gson.JsonObject;
import java.io.File;
import java.io.IOException;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;

public class MLFlowPyfuncPredictionStep
extends PredictionModelPredictionStep {
    private final PredictionEndpointConfig endpointConfig;

    public MLFlowPyfuncPredictionStep(PredictionEndpointConfig ep, ServiceGenContext ctx, File modelFolder) throws IOException {
        super(ep, ctx, modelFolder);
        this.endpointConfig = ep;
    }

    @Override
    public String getKernelNamePrefix() {
        return "mlflowpyfuncpred-";
    }

    @Override
    public AbstractLambdaPythonKernel initializeKernel() {
        return new MLFlowPyfuncPredictionKernel(this.link, this.ctx, this.endpointConfig.id, (List<String>)this.codePaths, this.codeEnvFolder, this.pluginResourcesEnv, (File)this.kernelWorkDir, ApplicationConfigurator.isDevLambdaServer(), (GeneralSettingsDAO.CGrouppableProcessType)(ApplicationConfigurator.isDevLambdaServer() ? GeneralSettingsDAO.CGrouppableProcessType.LAMBDA_DEV_SERVER : null));
    }

    @Override
    public PredictionResponse process(PipelineMessage message) throws Exception {
        AbstractPythonPredictionStep.InternalPredictionResponse iret;
        MLFlowPredictQuery query = new MLFlowPredictQuery();
        query.mlFlowOutputStyle = this.endpointConfig.mlFlowOutputStyle;
        if (message.explanations != null) {
            query.explanations = JSON.toJsonObject((Object)message.explanations, (String[])new String[0]);
        }
        for (int i = 0; i < message.itemsToPredict.size(); ++i) {
            JsonObject features = message.itemsToPredict.get((int)i).features;
            for (String string : features.keySet()) {
                JsonArray arr = query.columns.get(string);
                if (arr != null) continue;
                arr = new JsonArray();
                for (int j = 0; j < i; ++j) {
                    arr.add((JsonElement)new JsonNull());
                }
                query.columns.put(string, arr);
            }
            for (Map.Entry entry : query.columns.entrySet()) {
                JsonElement itemVal = features.get((String)entry.getKey());
                if (itemVal == null) {
                    ((JsonArray)entry.getValue()).add((JsonElement)new JsonNull());
                    continue;
                }
                ((JsonArray)entry.getValue()).add(itemVal);
            }
        }
        logger.trace(() -> "Sending query to prediction handler: " + JSON.log((Object)query));
        try {
            iret = (AbstractPythonPredictionStep.InternalPredictionResponse)this.link.execute((Object)query, AbstractPythonPredictionStep.InternalPredictionResponse.class, "Failed to get prediction");
        }
        catch (SocketBlockLinkException e) {
            e.withLogTail((IDSSKernelBase)this.kernel);
            throw this.kernel.maybeRethrowAsProcessDied((IOException)((Object)e));
        }
        if (logger.isDebugEnabled()) {
            logger.debug((Object)("Got prediction result: " + JSON.log((Object)iret)));
        }
        PredictionResponse ret = this.toResponse(message, iret);
        logger.trace(() -> "RET:" + JSON.log((Object)ret));
        return ret;
    }

    static class MLFlowPredictQuery {
        public JsonObject explanations;
        Map<String, JsonArray> columns = new LinkedHashMap<String, JsonArray>();
        MLFlowInputOutput.MLFlowPredictionOutputStyle mlFlowOutputStyle;

        MLFlowPredictQuery() {
        }
    }
}

