/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.agents.tools.ml;

import com.dataiku.dip.agents.tools.AgentTool;
import com.dataiku.dip.agents.tools.AgentToolMeta;
import com.dataiku.dip.agents.tools.AgentToolParams;
import com.dataiku.dip.agents.tools.AgentToolRunner;
import com.dataiku.dip.agents.tools.utils.JsonSchema;
import com.dataiku.dip.agents.tools.utils.JsonSchemaElement;
import com.dataiku.dip.analysis.ml.FullModelId;
import com.dataiku.dip.analysis.ml.prediction.PredictionResultsReader;
import com.dataiku.dip.analysis.model.CompatibilityWithReason;
import com.dataiku.dip.analysis.model.preprocessing.FeaturePreprocessingParams;
import com.dataiku.dip.dao.DatasetsDAO;
import com.dataiku.dip.dao.SavedModel;
import com.dataiku.dip.dao.SavedModelsDAO;
import com.dataiku.dip.exceptions.CodedException;
import com.dataiku.dip.exceptions.ForbiddenObjectException;
import com.dataiku.dip.llm.online.LLMClient;
import com.dataiku.dip.scoring.Predictor;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.server.SpringUtils;
import com.dataiku.dip.server.services.ITaggingService;
import com.dataiku.dip.server.services.ProjectsService;
import com.dataiku.dip.server.services.TransactionService;
import com.dataiku.dip.transactions.TransactionContext;
import com.dataiku.dip.transactions.ifaces.Transaction;
import com.dataiku.dip.util.AnyLoc;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.JF;
import com.dataiku.lambda.model.api.PredictionResponse;
import com.dataiku.scoring.builders.ObservationBuilder;
import com.dataiku.scoring.util.RawObservation;
import com.google.common.collect.Lists;
import com.google.gson.JsonElement;
import com.google.gson.JsonObject;
import com.google.gson.JsonPrimitive;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.stream.Stream;
import org.apache.commons.lang.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;

public class ClassicalPredictionModelPredictTool {
    public static final AgentToolMeta META = new AgentToolMeta(true){

        @Override
        public String getType() {
            return "ClassicalPredictionModelPredict";
        }

        @Override
        public Class<? extends AgentToolParams> paramsClass() {
            return Params.class;
        }

        @Override
        public List<SavedModel.AgentDependency> getDependencies(AgentTool tool) {
            Params p = tool.getParamsCopyAs(Params.class);
            if (p.smRef != null) {
                return Lists.newArrayList((Object[])new SavedModel.AgentDependency[]{new SavedModel.AgentDependency(ITaggingService.TaggableType.SAVED_MODEL, p.smRef)});
            }
            return new ArrayList<SavedModel.AgentDependency>();
        }

        @Override
        public void checkAccessDependency(AuthCtx authCtx, AgentTool tool) throws IOException, ForbiddenObjectException {
            Params p = tool.getParamsCopyAs(Params.class);
            if (p.smRef == null) {
                logger.warn((Object)"No model selected. Skipping access check to dependency.");
            } else {
                AnyLoc smLoc = AnyLoc.resolveSmart(tool.projectKey, p.smRef);
                ((ProjectsService)SpringUtils.getBean(ProjectsService.class)).failIfLocNotAvailableInProject(ITaggingService.TaggableType.SAVED_MODEL, smLoc, tool.projectKey);
            }
        }

        @Override
        public AgentToolMeta.ToolDescriptor getResultingDescriptor(AuthCtx authCtx, String projectKey, AgentTool tool) throws IOException {
            TransactionContext.assertNoAttachedTransaction();
            Params p = tool.getParamsCopyAs(Params.class);
            AgentToolMeta.ToolDescriptor td = new AgentToolMeta.ToolDescriptor(tool.name);
            td.description = "Predicts a record using a ML prediction model";
            if (p.smRef == null) {
                td.description = td.description + " (error: no model selected).";
                return td;
            }
            SavedModel sm = this.getSavedModel(tool.projectKey, p.smRef);
            td.description = td.description + ": " + sm.name + "\n\n";
            td.description = td.description + "Provide the record to predict as a single JSON dictionary called \"record\", with one key per column.\n";
            JsonSchemaElement record = JsonSchemaElement.object("The record to predict");
            this.getInputFeaturesStream(sm).forEach(e -> {
                if (((FeaturePreprocessingParams)e.getValue()).type == FeaturePreprocessingParams.FeatureType.NUMERIC) {
                    record.properties.put((String)e.getKey(), JsonSchemaElement.number("Value for " + (String)e.getKey() + ". Optional. "));
                } else {
                    record.properties.put((String)e.getKey(), JsonSchemaElement.string("Value for " + (String)e.getKey() + ". Optional. "));
                }
            });
            if (StringUtils.isNotBlank((String)tool.additionalDescriptionForLLM)) {
                td.description = td.description + "\n\n" + tool.additionalDescriptionForLLM;
            }
            td.inputSchema = JsonSchema.newObject("https://dataiku.com/agents/tools/ml/predict/input", "Input for record prediction tool");
            td.inputSchema.properties.put("record", record);
            return td;
        }

        @Override
        public AgentToolMeta.ToolCallDescription getToolCallDescription_NT(AuthCtx authCtx, String projectKey, AgentTool tool, LLMClient.FunctionTool descriptor, AgentToolRunner.AgentToolInput input) throws Exception {
            TransactionContext.assertNoAttachedTransaction();
            Params p = tool.getParamsCopyAs(Params.class);
            SavedModel sm = this.getSavedModel(tool.projectKey, p.smRef);
            Object description = String.format("I'm about to use model <b>%s</b> to predict the outcome for the following record.%n", sm.name);
            description = (String)description + "\n";
            description = (String)description + "Do you want to proceed?";
            return new AgentToolMeta.ToolCallDescription((String)description);
        }

        @Override
        public JsonObject loadSampleQuery(AuthCtx authCtx, String projectKey, AgentTool tool) throws Exception {
            Params p = tool.getParamsCopyAs(Params.class);
            JsonObject quickTestRecord = new JsonObject();
            if (p.smRef == null) {
                throw new IllegalArgumentException("No saved model selected.");
            }
            SavedModel sm = this.getSavedModel(tool.projectKey, p.smRef);
            this.getInputFeaturesStream(sm).forEach(e -> {
                if (((FeaturePreprocessingParams)e.getValue()).type == FeaturePreprocessingParams.FeatureType.NUMERIC) {
                    quickTestRecord.addProperty((String)e.getKey(), (Number)1);
                } else {
                    quickTestRecord.addProperty((String)e.getKey(), "<Your value here>");
                }
            });
            JsonObject sampleInput = new JsonObject();
            sampleInput.add("record", (JsonElement)quickTestRecord);
            return sampleInput;
        }

        @Override
        public AgentToolRunner buildRunner(AuthCtx authCtx, String projectKey, AgentTool tool, boolean devKernel) throws CodedException {
            return new Runner(authCtx, tool.projectKey, tool.getParamsCopyAs(Params.class));
        }

        private SavedModel getSavedModel(String sourceProjectKey, String smRef) throws IOException {
            SavedModel sm;
            try (Transaction t = ((TransactionService)SpringUtils.getBean(TransactionService.class)).retrieveOrBeginRead();){
                AnyLoc smLoc = AnyLoc.resolveSmart(sourceProjectKey, smRef);
                sm = (SavedModel)((SavedModelsDAO)SpringUtils.getBean(SavedModelsDAO.class)).getMandatory(smLoc);
            }
            return sm;
        }

        private Stream<Map.Entry<String, FeaturePreprocessingParams>> getInputFeaturesStream(SavedModel sm) throws IOException {
            FullModelId fmi = new FullModelId(sm.projectKey, sm.id, sm.getActiveVersion());
            Map<String, FeaturePreprocessingParams> featureParamsMap = PredictionResultsReader.makeModelDetails((FullModelId)fmi).getPreprocessing().per_feature;
            return featureParamsMap.entrySet().stream().filter(e -> ((FeaturePreprocessingParams)e.getValue()).role.equals((Object)FeaturePreprocessingParams.Role.INPUT));
        }
    };
    private static final DKULogger logger = DKULogger.getLogger((String)"dku.agents.tools.ml.prediction");

    public static class Runner
    implements AgentToolRunner {
        private final Params params;
        private final String sourceProjectKey;
        private final AuthCtx authCtx;
        private FullModelId fmi;
        private Predictor predictor;
        @Autowired
        private DatasetsDAO datasetsDAO;
        @Autowired
        private SavedModelsDAO smDAO;
        @Autowired
        private TransactionService transactionService;

        public Runner(AuthCtx authCtx, String sourceProjectKey, Params p) {
            this.authCtx = authCtx;
            this.sourceProjectKey = sourceProjectKey;
            this.params = p;
        }

        @Override
        public void init() throws IOException {
            SavedModel sm;
            SpringUtils.getInstance().autowire((Object)this);
            if (StringUtils.isBlank((String)this.params.smRef)) {
                throw new IllegalArgumentException("Model to use is not specified in tool");
            }
            try (Transaction t = this.transactionService.beginRead();){
                sm = (SavedModel)this.smDAO.getMandatory(AnyLoc.resolveSmart(this.sourceProjectKey, this.params.smRef));
            }
            this.fmi = new FullModelId(sm.projectKey, sm.id, sm.getActiveVersion());
            File modelFolder = this.fmi.getModelFolder();
            CompatibilityWithReason javaCompatibility = PredictionResultsReader.makeModelDetails((FullModelId)this.fmi).javaScoreCompatibility;
            if (!javaCompatibility.compatible) {
                throw new IllegalArgumentException("Model can't be used with this tool: " + javaCompatibility.reason);
            }
            this.predictor = new Predictor(modelFolder);
        }

        @Override
        public AgentToolRunner.AgentToolOutput run(AgentToolRunner.AgentToolInput input) throws Exception {
            JsonObject record = this.safeReadObjectArgument(input, "record");
            ObservationBuilder.NormalizingCoercingBuilder builder = this.predictor.getNormalizingCoercingBuilder();
            for (Map.Entry entry : record.entrySet()) {
                if (!((JsonElement)entry.getValue()).isJsonPrimitive()) continue;
                JsonPrimitive prim = ((JsonElement)entry.getValue()).getAsJsonPrimitive();
                if (prim.isNumber()) {
                    try {
                        double v = prim.getAsDouble();
                        builder.with((String)entry.getKey(), (Number)v);
                    }
                    catch (Exception e) {
                        long v = prim.getAsLong();
                        builder.with((String)entry.getKey(), (Number)v);
                    }
                    continue;
                }
                builder.with((String)entry.getKey(), prim.getAsString());
            }
            RawObservation s = builder.build();
            PredictionResponse.PredictionResponseItem item = this.predictor.predict(s);
            JF.ObjectBuilder ob = JF.obj();
            if (item instanceof PredictionResponse.RegressionResponseItem) {
                ob.with("prediction", (Number)((PredictionResponse.RegressionResponseItem)item).prediction);
            } else if (item instanceof PredictionResponse.ClassificationResponseItem) {
                ob.with("prediction", ((PredictionResponse.ClassificationResponseItem)item).prediction);
            }
            AgentToolRunner.AgentToolOutput ret = new AgentToolRunner.AgentToolOutput();
            ret.output = ob.get();
            return ret;
        }

        @Override
        public void close() throws Exception {
        }
    }

    public static class Params
    implements AgentToolParams {
        public String smRef;
    }
}

