/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.analysis.ml.prediction;

import com.dataiku.dip.analysis.ml.prediction.TabularPredictionParamsExpander;
import com.dataiku.dip.analysis.model.core.ResolvedPreprocessingParams;
import com.dataiku.dip.analysis.model.core.WorkSet;
import com.dataiku.dip.analysis.model.prediction.PreTrainPredictionModelingParams;
import com.dataiku.dip.analysis.model.prediction.PredictionMLTask;
import com.dataiku.dip.analysis.model.prediction.PredictionModelingParams;
import com.dataiku.dip.analysis.model.prediction.ResolvedCausalPredictionPreprocessingParams;
import com.dataiku.dip.analysis.model.preprocessing.CausalPredictionPreprocessingParams;
import com.dataiku.dip.analysis.model.preprocessing.TabularPredictionPreprocessingParams;
import com.dataiku.dip.utils.JSON;
import com.google.common.collect.Lists;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import javax.annotation.Nullable;

public class CausalPredictionParamsExpander
extends TabularPredictionParamsExpander {
    public CausalPredictionParamsExpander(PredictionMLTask.CausalPredictionMLTask task, String sessionId) {
        super(task, sessionId, Collections.singleton(PredictionModelingParams.GridSearchCrossValidationMode.KFOLD));
    }

    private void addCausalModelingSets(WorkSet ws, PreTrainPredictionModelingParams.Algorithm algo, PredictionModelingParams.CausalLearningMethod method, @Nullable PredictionModelingParams.CausalMetaLearner metaLearner) {
        List<WorkSet.ModelingSet> modelingSets = algo.meta.expandModeling(this.task.modeling, this.task, this.getGsFolds());
        for (WorkSet.ModelingSet modelingSet : modelingSets) {
            PreTrainPredictionModelingParams predictionModelingParams = (PreTrainPredictionModelingParams)modelingSet.modelingParams;
            predictionModelingParams.setCausalMethodFields(method, metaLearner);
            if (!algo.meta.oneModelPerPreprocessingSet()) continue;
            this.addModelingSets(Lists.newArrayList((Object[])new WorkSet.ModelingSet[]{modelingSet}), ws, algo.meta.autoCompleter());
        }
        if (!algo.meta.oneModelPerPreprocessingSet()) {
            this.addModelingSets(modelingSets, ws, algo.meta.autoCompleter());
        }
    }

    @Override
    protected void addModelingSets(WorkSet ws) {
        for (PredictionModelingParams.CausalMetaLearner metaLearner : this.task.modeling.meta_learners) {
            for (PreTrainPredictionModelingParams.Algorithm baseLearner : PreTrainPredictionModelingParams.CAUSAL_BASE_LEARNERS) {
                this.addCausalModelingSets(ws, baseLearner, PredictionModelingParams.CausalLearningMethod.META_LEARNER, metaLearner);
            }
        }
        this.addCausalModelingSets(ws, PreTrainPredictionModelingParams.Algorithm.CAUSAL_FOREST, PredictionModelingParams.CausalLearningMethod.CAUSAL_FOREST, null);
    }

    @Override
    protected ResolvedCausalPredictionPreprocessingParams expandResolvedPredictionPreprocessingParams(TabularPredictionPreprocessingParams ppp) {
        if (!(ppp instanceof CausalPredictionPreprocessingParams)) {
            throw new IllegalArgumentException("Invalid preprocessing params type: " + ppp.getClass().getName());
        }
        CausalPredictionPreprocessingParams cppp = (CausalPredictionPreprocessingParams)ppp;
        ResolvedCausalPredictionPreprocessingParams rppp = new ResolvedCausalPredictionPreprocessingParams();
        rppp.per_feature = (Map)JSON.deepCopy((Object)cppp.per_feature);
        rppp.treatment_variable = ((PredictionMLTask.CausalPredictionMLTask)this.task).treatmentVariable;
        rppp.drop_missing_treatment_values = cppp.dropMissingTreatmentValues;
        rppp.reduce = new ResolvedPreprocessingParams.ReductionParams();
        rppp.reduce.enabled = false;
        rppp.target_remapping = cppp.target_remapping;
        rppp.feature_generation = cppp.feature_generation;
        rppp.feature_selection_params = cppp.feature_selection_params;
        return rppp;
    }
}

