/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.lambda.endpoints.predict;

import com.codahale.metrics.Timer;
import com.dataiku.dip.DSSMetrics;
import com.dataiku.dip.analysis.ml.FullModelId;
import com.dataiku.dip.datalayer.memimpl.MemRow;
import com.dataiku.dip.server.SpringUtils;
import com.dataiku.lambda.APINodeMetrics;
import com.dataiku.lambda.ServiceGenContext;
import com.dataiku.lambda.endpoints.predict.JavaPredictionStep;
import com.dataiku.lambda.endpoints.predict.PredictionPipeline;
import com.dataiku.lambda.endpoints.predict.SQLLeftJoinEnrichStep;
import com.dataiku.lambda.endpoints.predictcommon.PipelineMessage;
import com.dataiku.lambda.endpoints.predictcommon.PredictionEndpointHandlerBase;
import com.dataiku.lambda.model.api.PredictionResponse;
import com.dataiku.lambda.model.serverconfig.BundledSMVersion;
import com.dataiku.lambda.model.serverconfig.PredictionEndpointConfig;
import com.dataiku.scoring.builders.Build;
import com.google.gson.JsonElement;
import com.google.gson.JsonObject;
import java.io.File;
import java.io.IOException;
import java.net.URL;
import java.util.List;
import org.apache.log4j.Logger;

public class PredictionEndpointHandler
extends PredictionEndpointHandlerBase<PredictionEndpointConfig, PredictionPipeline> {
    private Timer prepareTimer;
    private PredictionPipeline.Factory pipelineFactory;
    private BundledSMVersion smVersion;
    private static Logger logger = Logger.getLogger((String)"dku.lambda.prediction.handler");

    public PredictionEndpointHandler(PredictionEndpointConfig config) {
        super(config);
        SpringUtils.getInstance().autowire((Object)this);
    }

    @Override
    public void _init(ServiceGenContext context) throws Exception {
        this.prepareTimer = APINodeMetrics.endpointTimer(context.getServiceId(), ((PredictionEndpointConfig)this.config).id, "prepare");
        this.smVersion = context.getModelConfig(((PredictionEndpointConfig)this.config).modelId);
        File modelFolder = context.getModelFolder(((PredictionEndpointConfig)this.config).modelId);
        switch (((PredictionEndpointConfig)this.config).savedModelType.savedModelHandlingType) {
            case INTERNAL: {
                if (((PredictionEndpointConfig)this.config).useJava) {
                    if (this.checkJavaCompatibility(modelFolder)) {
                        logger.info((Object)"Found java compatible model, loading it up.");
                        JavaPredictionStep predict = new JavaPredictionStep(modelFolder);
                        this.pipelineFactory = new PredictionPipeline.JavaFactory((PredictionEndpointConfig)this.config, context, predict);
                        return;
                    }
                    logger.error((Object)"Failed to read a java pipeline, falling back to a python pipeline.");
                } else {
                    logger.info((Object)"Creating a python kernel factory.");
                }
                this.pipelineFactory = new PredictionPipeline.PythonFactory((PredictionEndpointConfig)this.config, context);
                break;
            }
            case EXTERNAL_MLFLOW: {
                this.pipelineFactory = new PredictionPipeline.MLFlowPyfuncFactory((PredictionEndpointConfig)this.config, context);
                break;
            }
            case PYTHON_AGENT: 
            case PLUGIN_AGENT: 
            case TOOLS_USING_AGENT: {
                throw new IllegalArgumentException("Agents are not supported in API node");
            }
            case LLM_GENERIC: {
                throw new IllegalArgumentException("Fine tuned LLMs are not supported in API node");
            }
            case RETRIEVAL_AUGMENTED_LLM: {
                throw new IllegalArgumentException("Augmented LLMs are not supported in API node");
            }
        }
    }

    private boolean checkJavaCompatibility(File modelFolder) {
        if (new File(modelFolder, "dss_pipeline_meta.json").exists() && new File(modelFolder, "dss_pipeline_model.gz").exists()) {
            try {
                Build.DssPipelineMeta meta = Build.pipelineMeta((URL)modelFolder.toURI().toURL());
                if (meta.isValid) {
                    return true;
                }
            }
            catch (IOException e) {
                logger.warn((Object)"Failed to read java pipeline", (Throwable)e);
                return false;
            }
        }
        return false;
    }

    @Override
    public PredictionPipeline instantiatePipeline() throws Exception {
        return this.pipelineFactory.build();
    }

    @Override
    protected PredictionEndpointHandlerBase.EnrichedPredictionResponse predict(long startTimeN, PipelineMessage pm) throws Exception {
        this.requestsMeter.mark();
        long preprocessedN = System.nanoTime();
        PredictionPipeline pipeline = (PredictionPipeline)this.pool.acquire();
        try {
            PredictionResponse ret;
            long acquiredN = System.nanoTime();
            try (DSSMetrics.TimeCtx tctx = DSSMetrics.timeCtx((Timer)this.enrichTimer);){
                for (SQLLeftJoinEnrichStep step : pipeline.sqlEnrich) {
                    step.process(pm);
                }
            }
            List<JsonObject> postEnrich = this.collectPostEnrichDataIfNeeded(pm, ((PredictionEndpointConfig)this.config).auditPostEnrichData, ((PredictionEndpointConfig)this.config).returnPostEnrichData);
            long enrichedN = System.nanoTime();
            if (pipeline.prepare != null) {
                try (DSSMetrics.TimeCtx tctx = DSSMetrics.timeCtx((Timer)this.prepareTimer);){
                    pipeline.prepare.process(pm.table);
                }
            }
            long preparedN = System.nanoTime();
            int i = 0;
            for (MemRow row : pm.table.rows) {
                if (row.isDeleted()) {
                    pm.prePredictIgnoreReasons.set(i, PredictionResponse.IgnoreReason.IGNORED_BY_SCRIPT);
                }
                ++i;
            }
            try (DSSMetrics.TimeCtx tctx = DSSMetrics.timeCtx((Timer)this.predictTimer);){
                ret = pipeline.predict.process(pm);
            }
            long predictedN = System.nanoTime();
            this.pool.release(pipeline);
            long finishedN = System.nanoTime();
            ret.timing.preProcessing = (preprocessedN - startTimeN) / 1000L;
            ret.timing.wait = (acquiredN - preprocessedN) / 1000L;
            ret.timing.enrich = (enrichedN - acquiredN) / 1000L;
            ret.timing.preparation = (preparedN - enrichedN) / 1000L;
            ret.timing.prediction = (predictedN - preparedN) / 1000L;
            ret.timing.postProcessing = (finishedN - predictedN) / 1000L;
            this.enrichPredictionResponseWithPostEnrichData(ret, postEnrich, ((PredictionEndpointConfig)this.config).auditPostEnrichData, ((PredictionEndpointConfig)this.config).returnPostEnrichData);
            PredictionEndpointHandlerBase.EnrichedPredictionResponse epr = new PredictionEndpointHandlerBase.EnrichedPredictionResponse();
            if (this.smVersion != null) {
                JsonObject savedModel = new JsonObject();
                if (this.smVersion.originalProjectKey != null) {
                    savedModel.addProperty("projectKey", this.smVersion.originalProjectKey);
                    savedModel.addProperty("savedModelId", this.smVersion.originalSavedModelId);
                    savedModel.addProperty("modelVersion", this.smVersion.originalSavedModelVersion);
                    FullModelId fmi = new FullModelId(this.smVersion.originalProjectKey, this.smVersion.originalSavedModelId, this.smVersion.originalSavedModelVersion);
                    savedModel.addProperty("fullModelId", fmi.toString());
                }
                epr.globalAdditionalAuditElements.add("savedModel", (JsonElement)savedModel);
            }
            this.successRequestMeter.mark();
            epr.response = ret;
            return epr;
        }
        catch (Throwable t) {
            this.pool.destroyOnError(pipeline);
            throw t;
        }
    }
}

