/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.analysis.model.prediction.algorithms.python;

import com.dataiku.dip.analysis.ml.prediction.CustomPythonPredictionAlgoDesc;
import com.dataiku.dip.analysis.ml.prediction.CustomPythonPredictionAlgoService;
import com.dataiku.dip.analysis.ml.prediction.LoadedCustomPythonPredictionAlgoDesc;
import com.dataiku.dip.analysis.model.ModelTrainInfo;
import com.dataiku.dip.analysis.model.core.WorkSet;
import com.dataiku.dip.analysis.model.prediction.PostTrainPredictionModelingParams;
import com.dataiku.dip.analysis.model.prediction.PreTrainPredictionModelingParams;
import com.dataiku.dip.analysis.model.prediction.PredictionMLTask;
import com.dataiku.dip.analysis.model.prediction.PredictionModelingParams;
import com.dataiku.dip.analysis.model.prediction.PredictionParameterChecks;
import com.dataiku.dip.analysis.model.prediction.algorithms.python.PyMemoryAlgorithmMeta;
import com.dataiku.dip.server.SpringUtils;
import com.google.gson.JsonArray;
import com.google.gson.JsonElement;
import com.google.gson.JsonObject;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import org.apache.commons.lang.StringUtils;

public class PyCustomPluginMeta
extends PyMemoryAlgorithmMeta {
    @Override
    public String generateName(PreTrainPredictionModelingParams rpmp) {
        if (StringUtils.isNotBlank((String)rpmp.plugin_python_grid.name)) {
            return rpmp.plugin_python_grid.name;
        }
        return rpmp.plugin_python_grid.elementId;
    }

    @Override
    public ModelTrainInfo.PreSearchDescription generatePreTrainDescription(PreTrainPredictionModelingParams rpmp) {
        CustomPythonPredictionAlgoService pythonPredictionModelService = this.getPythonPredictionModelService();
        PredictionModelingParams.CustomPythonPluginParams customAlgoParams = rpmp.plugin_python_grid;
        LoadedCustomPythonPredictionAlgoDesc modelDesc = pythonPredictionModelService.get(customAlgoParams.pluginId, customAlgoParams.elementId);
        if (!CustomPythonPredictionAlgoDesc.GridSearchMode.MANAGED.equals((Object)modelDesc.desc.gridSearchMode)) {
            return null;
        }
        ModelTrainInfo.PreSearchDescription ret = new ModelTrainInfo.PreSearchDescription(rpmp).withGridLength(this.getGridLength(customAlgoParams));
        for (CustomPythonPredictionAlgoDesc.MLParamDesc param : modelDesc.desc.params) {
            if (!param.gridParam || !customAlgoParams.params.has(param.name)) continue;
            if (customAlgoParams.params.get(param.name).isJsonArray()) {
                ArrayList<String> values = new ArrayList<String>();
                JsonArray array = customAlgoParams.params.getAsJsonArray(param.name);
                for (JsonElement elem : array) {
                    values.add(elem.toString());
                }
                ret.withMVParam(param.name, values);
                continue;
            }
            ret.withSVParam(param.name, customAlgoParams.params.get(param.name).toString());
        }
        return ret;
    }

    @Override
    public ModelTrainInfo.PostSearchDescription generatePostTrainDescription(ModelTrainInfo.PreSearchDescription descBefore, PreTrainPredictionModelingParams before, PostTrainPredictionModelingParams after) {
        JsonObject optimizedParams = after.plugin_python;
        ModelTrainInfo.PostSearchDescription ret = new ModelTrainInfo.PostSearchDescription();
        PredictionModelingParams.CustomPythonPluginParams customAlgoParams = before.plugin_python_grid;
        LoadedCustomPythonPredictionAlgoDesc modelDesc = this.getPythonPredictionModelService().get(customAlgoParams.pluginId, customAlgoParams.elementId);
        if (!CustomPythonPredictionAlgoDesc.GridSearchMode.MANAGED.equals((Object)modelDesc.desc.gridSearchMode)) {
            return null;
        }
        if (optimizedParams != null) {
            for (CustomPythonPredictionAlgoDesc.MLParamDesc param : modelDesc.desc.params) {
                if (!param.gridParam || !optimizedParams.has(param.name)) continue;
                ret.withSVParam(param.name, optimizedParams.get(param.name).toString());
            }
        }
        return ret;
    }

    @Override
    public void validateParameters(PredictionModelingParams pmp, PredictionMLTask.TabularPredictionMLTask task, PredictionParameterChecks checks) {
    }

    @Override
    public List<WorkSet.ModelingSet> expandModeling(PredictionModelingParams pmp, PredictionMLTask.TabularPredictionMLTask task, int gsFolds) {
        ArrayList<WorkSet.ModelingSet> out = new ArrayList<WorkSet.ModelingSet>();
        for (PredictionModelingParams.CustomPythonPluginParams customAlgoParams : pmp.plugin_python.values()) {
            if (!customAlgoParams.enabled) continue;
            CustomPythonPredictionAlgoService pythonPredictionModelService = this.getPythonPredictionModelService();
            if (!pythonPredictionModelService.exists(customAlgoParams.pluginId, customAlgoParams.elementId)) {
                logger.warn((Object)("Not training model for plugin algorithm [pluginId=" + customAlgoParams.pluginId + ",elementID=" + customAlgoParams.elementId + "]. Plugin element has not been found"));
                continue;
            }
            LoadedCustomPythonPredictionAlgoDesc modelDesc = pythonPredictionModelService.get(customAlgoParams.pluginId, customAlgoParams.elementId);
            PreTrainPredictionModelingParams rcmp = new PreTrainPredictionModelingParams(PreTrainPredictionModelingParams.Algorithm.CUSTOM_PLUGIN, pmp);
            rcmp.plugin_python_grid = customAlgoParams;
            rcmp.plugin_python_grid.acceptsSparseMatrix = modelDesc.desc.acceptsSparseMatrix;
            rcmp.plugin_python_grid.supportsSampleWeights = modelDesc.desc.supportsSampleWeights;
            this.checkAndUpdateSearchStrategy(pmp, rcmp);
            WorkSet.ModelingSet ms = new WorkSet.ModelingSet(rcmp);
            if (CustomPythonPredictionAlgoDesc.GridSearchMode.MANAGED.equals((Object)modelDesc.desc.gridSearchMode)) {
                rcmp.gridLength = this.getGridLength(customAlgoParams);
            } else if (CustomPythonPredictionAlgoDesc.GridSearchMode.CUSTOM.equals((Object)modelDesc.desc.gridSearchMode)) {
                rcmp.pluginAlgoCustomGridSearch = true;
                ms.pluginAlgoCustomGridSearch = true;
            }
            ms.estimatedTrains = rcmp.gridLength > 1 ? rcmp.gridLength * gsFolds + 1 : 1;
            out.add(ms);
        }
        return out;
    }

    @Override
    protected int getGridLength(PredictionModelingParams.HyperparametersSpace space) {
        int gridLength = 1;
        PredictionModelingParams.CustomPythonPluginParams customAlgoParams = (PredictionModelingParams.CustomPythonPluginParams)space;
        LoadedCustomPythonPredictionAlgoDesc modelDesc = this.getPythonPredictionModelService().get(customAlgoParams.pluginId, customAlgoParams.elementId);
        for (CustomPythonPredictionAlgoDesc.MLParamDesc param : modelDesc.desc.params) {
            JsonArray gridPointAsArray;
            if (!param.gridParam || !customAlgoParams.params.has(param.name) || !customAlgoParams.params.get(param.name).isJsonArray() || (gridPointAsArray = customAlgoParams.params.getAsJsonArray(param.name)).size() <= 0) continue;
            gridLength *= gridPointAsArray.size();
        }
        return gridLength;
    }

    @Override
    public PreTrainPredictionModelingParams regridifyToPreTrain(PostTrainPredictionModelingParams optimized, PreTrainPredictionModelingParams usedToTrain) {
        PreTrainPredictionModelingParams ret = this.getCopyWithGridStrategy(usedToTrain);
        PredictionModelingParams.CustomPythonPluginParams customAlgoParams = ret.plugin_python_grid;
        JsonObject optimizedParams = optimized.plugin_python;
        if (optimizedParams != null) {
            for (Map.Entry e : optimizedParams.entrySet()) {
                if (!customAlgoParams.params.has((String)e.getKey())) continue;
                if (customAlgoParams.params.get((String)e.getKey()).isJsonArray()) {
                    JsonArray newArray = new JsonArray();
                    newArray.add((JsonElement)e.getValue());
                    customAlgoParams.params.add((String)e.getKey(), (JsonElement)newArray);
                    continue;
                }
                customAlgoParams.params.add((String)e.getKey(), (JsonElement)e.getValue());
            }
        }
        return ret;
    }

    @Override
    public void regridifyToMLTask(PredictionModelingParams target, PostTrainPredictionModelingParams optimized, PreTrainPredictionModelingParams usedToTrain) {
        PreTrainPredictionModelingParams preTrain = this.regridifyToPreTrain(optimized, usedToTrain);
        PredictionModelingParams.CustomPythonPluginParams customAlgoParams = preTrain.plugin_python_grid;
        customAlgoParams.enabled = true;
        LoadedCustomPythonPredictionAlgoDesc modelDesc = this.getPythonPredictionModelService().get(customAlgoParams.pluginId, customAlgoParams.elementId);
        target.plugin_python.put(modelDesc.getType(), customAlgoParams);
    }

    @Override
    public void refreshMLTask(PredictionModelingParams target, PreTrainPredictionModelingParams usedToTrain) {
        PredictionModelingParams.CustomPythonPluginParams customAlgoParams = usedToTrain.plugin_python_grid;
        customAlgoParams.enabled = true;
        LoadedCustomPythonPredictionAlgoDesc modelDesc = this.getPythonPredictionModelService().get(customAlgoParams.pluginId, customAlgoParams.elementId);
        target.plugin_python.put(modelDesc.getType(), customAlgoParams);
    }

    @Override
    public boolean hasProbabilities(PreTrainPredictionModelingParams rpmp) {
        PredictionModelingParams.CustomPythonPluginParams customAlgoParams = rpmp.plugin_python_grid;
        LoadedCustomPythonPredictionAlgoDesc modelDesc = this.getPythonPredictionModelService().get(customAlgoParams.pluginId, customAlgoParams.elementId);
        return modelDesc.desc.hasProbabilities;
    }

    private CustomPythonPredictionAlgoService getPythonPredictionModelService() {
        return (CustomPythonPredictionAlgoService)SpringUtils.getBean(CustomPythonPredictionAlgoService.class);
    }

    @Override
    public boolean oneModelPerPreprocessingSet() {
        return true;
    }

    @Override
    public PredictionModelingParams.GridSearchParams.Strategy getMaximumSupportedSearchStrategy() {
        return PredictionModelingParams.GridSearchParams.Strategy.GRID;
    }
}

