/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.aigenerations.recipes;

import com.dataiku.dip.CodedRuntimeException;
import com.dataiku.dip.aigenerations.AIRecipeGenerationService;
import com.dataiku.dip.aigenerations.recipes.AIRecipe;
import com.dataiku.dip.coremodel.Dataset;
import com.dataiku.dip.coremodel.InfoMessage;
import com.dataiku.dip.coremodel.SchemaColumn;
import com.dataiku.dip.dataflow.exec.join.JoinRecipePayloadParams;
import com.dataiku.dip.dataflow.exec.joinlike.ColumnDesc;
import com.dataiku.dip.dataflow.exec.joinlike.ConditionsMode;
import com.dataiku.dip.dataflow.exec.joinlike.JoinInputDescBase;
import com.dataiku.dip.dataflow.exec.joinlike.JoinType;
import com.dataiku.dip.i18n.TranslationService;
import com.dataiku.dip.recipes.consistency.RecipeCodes;
import com.dataiku.dip.server.SpringUtils;
import com.google.gson.Gson;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;

public class AIJoin
extends AIRecipe {
    public static final String type = "join";
    public Params params;
    private List<String> virtualInputs;

    @Override
    public List<AIRecipeGenerationService.CreationMessage> validate(List<Dataset> datasets, String contextProjectKey, String lang) {
        ArrayList<AIRecipeGenerationService.CreationMessage> creationMessages = new ArrayList<AIRecipeGenerationService.CreationMessage>();
        if (this.params.join_conditions == null) {
            throw new CodedRuntimeException((InfoMessage.MessageCode)RecipeCodes.ERR_RECIPE_CREATION_FROM_TEXT_TO_RECIPE, "Join requires at least two distinct input datasets to work");
        }
        this.params.join_conditions = this.params.join_conditions.stream().filter(JoinCondition::isValidCondition).collect(Collectors.toList());
        Set joinConditionDatasetNames = this.params.join_conditions.stream().flatMap(joinCond -> Stream.of(joinCond.dataset_left, joinCond.dataset_right)).collect(Collectors.toSet());
        this.params.input_datasets = new ArrayList(joinConditionDatasetNames);
        List<String> validDatasets = AIRecipe.AIRecipeParams.getValidDatasetNames(datasets, this.params.input_datasets, contextProjectKey);
        if (validDatasets.size() < 2) {
            throw new CodedRuntimeException((InfoMessage.MessageCode)RecipeCodes.ERR_RECIPE_CREATION_FROM_TEXT_TO_RECIPE, "Join requires at least two distinct input datasets to work");
        }
        this.params.checkSingleOutputDataset(type);
        for (JoinCondition joinCondition : this.params.join_conditions) {
            if (joinCondition.join_type != null && (joinCondition.join_type.equals("INNER") || joinCondition.join_type.equals("LEFT") || joinCondition.join_type.equals("RIGHT") || joinCondition.join_type.equals("FULL") || joinCondition.join_type.equals("CROSS"))) continue;
            throw new CodedRuntimeException((InfoMessage.MessageCode)RecipeCodes.ERR_RECIPE_CREATION_FROM_TEXT_TO_RECIPE, String.format("The type of the join suggested %s is not supported by the feature, supported type of join are: INNER, RIGHT, LEFT, FULL & CROSS", joinCondition.join_type));
        }
        List contextDatasetNames = datasets.stream().map(d -> d.getSmartName(contextProjectKey)).collect(Collectors.toList());
        HashSet<String> validDatasetsSet = new HashSet<String>(validDatasets);
        TranslationService ts = (TranslationService)SpringUtils.getBean(TranslationService.class);
        for (String contextDatasetName : contextDatasetNames) {
            if (validDatasetsSet.contains(contextDatasetName)) continue;
            creationMessages.add(new AIRecipeGenerationService.CreationMessage(AIRecipeGenerationService.CreationMessage.Level.WARNING, ts.translate(lang, "RIGHT_PANEL.TABS.GENERATE_RECIPE.JOIN.DATASET_NOT_PART_OF_INPUT_DATASET", String.format("Dataset %s part of the context will not be part of the join inputs datasets", contextDatasetName), "dataset", contextDatasetName)));
        }
        Map<String, Set<String>> datasetToColumnNames = this.getDatasetNameToSchemaColumn(datasets, contextProjectKey);
        HashSet<String> datasetNameNotExisting = new HashSet<String>();
        HashSet<String> leftDatasetJoinPossibility = new HashSet<String>();
        leftDatasetJoinPossibility.add(this.params.join_conditions.get((int)0).dataset_left);
        leftDatasetJoinPossibility.add(this.params.join_conditions.get((int)0).dataset_right);
        for (JoinCondition joinCondition : this.params.join_conditions) {
            if (!datasetToColumnNames.containsKey(joinCondition.dataset_left) && !datasetNameNotExisting.add(joinCondition.dataset_left)) {
                creationMessages.add(new AIRecipeGenerationService.CreationMessage(AIRecipeGenerationService.CreationMessage.Level.WARNING, ts.translate(lang, "RIGHT_PANEL.TABS.GENERATE_RECIPE.JOIN.DATASET_NOT_PART_OF_SELECTED_DATASET", String.format("The dataset %s is not part of the selected datasets", joinCondition.dataset_left), "dataset", joinCondition.dataset_left)));
            }
            if (!datasetToColumnNames.containsKey(joinCondition.dataset_right) && !datasetNameNotExisting.add(joinCondition.dataset_right)) {
                creationMessages.add(new AIRecipeGenerationService.CreationMessage(AIRecipeGenerationService.CreationMessage.Level.WARNING, ts.translate(lang, "RIGHT_PANEL.TABS.GENERATE_RECIPE.JOIN.DATASET_NOT_PART_OF_SELECTED_DATASET", String.format("The dataset %s is not part of the selected datasets", joinCondition.dataset_right), "dataset", joinCondition.dataset_right)));
            }
            if (datasetToColumnNames.containsKey(joinCondition.dataset_left) && datasetToColumnNames.containsKey(joinCondition.dataset_right)) {
                for (JoinKey joinKey : joinCondition.join_keys) {
                    if (datasetToColumnNames.get(joinCondition.dataset_left).contains(joinKey.column_left) && datasetToColumnNames.get(joinCondition.dataset_right).contains(joinKey.column_right)) continue;
                    creationMessages.add(new AIRecipeGenerationService.CreationMessage(AIRecipeGenerationService.CreationMessage.Level.WARNING, ts.translate(lang, "RIGHT_PANEL.TABS.GENERATE_RECIPE.JOIN.COLUMN_DOES_NOT_EXIST", String.format("Specified column does not exist for the join condition on: %s and %s", joinKey.column_left, joinKey.column_right), "joinKeyLeft", joinKey.column_left, "joinKeyRight", joinKey.column_right)));
                }
            }
            if (datasetToColumnNames.containsKey(joinCondition.dataset_left) && !leftDatasetJoinPossibility.contains(joinCondition.dataset_left)) {
                creationMessages.add(new AIRecipeGenerationService.CreationMessage(AIRecipeGenerationService.CreationMessage.Level.WARNING, ts.translate(lang, "RIGHT_PANEL.TABS.GENERATE_RECIPE.JOIN.NOT_LINKED_TO_PREVIOUS_DATASETS", String.format("The join condition on dataset (%s) with dataset (%s) will be skipped because dataset (%s) is not linked with previous join conditions. Consider creating multiple join recipes instead.", joinCondition.dataset_left, joinCondition.dataset_right, joinCondition.dataset_left), "datasetLeft", joinCondition.dataset_left, "datasetRight", joinCondition.dataset_right)));
                continue;
            }
            leftDatasetJoinPossibility.add(joinCondition.dataset_left);
            leftDatasetJoinPossibility.add(joinCondition.dataset_right);
        }
        return creationMessages;
    }

    private Map<String, Set<String>> getDatasetNameToSchemaColumn(List<Dataset> datasets, String contextProjectKey) {
        return datasets.stream().collect(Collectors.toMap(d -> d.getSmartName(contextProjectKey), d -> d.getSchema().columns.stream().map(SchemaColumn::getName).collect(Collectors.toSet())));
    }

    @Override
    public AIRecipeGenerationService.AIMetaCreation generateRecipeMetaPayload(AIRecipeGenerationService.RecipeGenerationContext recipeGenerationContext) {
        AIRecipeGenerationService.AIMetaCreation metaCreation = new AIRecipeGenerationService.AIMetaCreation();
        metaCreation.messages = this.validate(recipeGenerationContext.datasets, recipeGenerationContext.contextProjectKey, recipeGenerationContext.lang);
        JoinRecipePayloadParams joinRecipePayloadParams = new JoinRecipePayloadParams();
        this.enrichParams(joinRecipePayloadParams, recipeGenerationContext.datasets, recipeGenerationContext.contextProjectKey);
        metaCreation.datasetInputNames = this.virtualInputs;
        metaCreation.datasetOutputNames = List.of(this.params.output_dataset);
        metaCreation.payload = new Gson().toJson((Object)joinRecipePayloadParams);
        return metaCreation;
    }

    public void enrichParams(JoinRecipePayloadParams recipeParams, List<Dataset> datasets, String contextProjectKey) {
        Map<String, Set<String>> datasetToColumnNames = this.getDatasetNameToSchemaColumn(datasets, contextProjectKey);
        HashMap<String, Integer> datasetNameToIndex = new HashMap<String, Integer>();
        this.virtualInputs = new ArrayList<String>();
        int lastVirtualInputCreatedIndex = -1;
        for (JoinCondition joinCondition : this.params.join_conditions) {
            if (!datasetToColumnNames.containsKey(joinCondition.dataset_left) || !datasetToColumnNames.containsKey(joinCondition.dataset_right)) continue;
            if (lastVirtualInputCreatedIndex == -1) {
                this.virtualInputs.add(joinCondition.dataset_left);
                this.addVirtualJoinInput(recipeParams, ++lastVirtualInputCreatedIndex);
                datasetNameToIndex.put(joinCondition.dataset_left, lastVirtualInputCreatedIndex);
                this.virtualInputs.add(joinCondition.dataset_right);
                this.addVirtualJoinInput(recipeParams, ++lastVirtualInputCreatedIndex);
                datasetNameToIndex.put(joinCondition.dataset_right, lastVirtualInputCreatedIndex);
                this.addJoinCondition(recipeParams, 0, 1, joinCondition.join_type, joinCondition.join_keys, datasetToColumnNames, joinCondition.dataset_left, joinCondition.dataset_right);
                continue;
            }
            if (!datasetNameToIndex.containsKey(joinCondition.dataset_left)) continue;
            int indexLeft = (Integer)datasetNameToIndex.get(joinCondition.dataset_left);
            if (!datasetNameToIndex.containsKey(joinCondition.dataset_right)) {
                this.virtualInputs.add(joinCondition.dataset_right);
                datasetNameToIndex.put(joinCondition.dataset_right, ++lastVirtualInputCreatedIndex);
            }
            int indexRight = (Integer)datasetNameToIndex.get(joinCondition.dataset_right);
            this.addVirtualJoinInput(recipeParams, indexRight);
            this.addJoinCondition(recipeParams, indexLeft, indexRight, joinCondition.join_type, joinCondition.join_keys, datasetToColumnNames, joinCondition.dataset_left, joinCondition.dataset_right);
        }
    }

    private void addVirtualJoinInput(JoinRecipePayloadParams recipeParams, int index) {
        JoinRecipePayloadParams.InputDesc inputDesc = JoinRecipePayloadParams.InputDesc.ofTable(index);
        inputDesc.outputColumnsSelectionMode = JoinInputDescBase.OutputColumnsSelectionMode.AUTO_NON_CONFLICTING;
        recipeParams.virtualInputs.add(inputDesc);
    }

    private void addJoinCondition(JoinRecipePayloadParams recipeParams, int indexLeft, int indexRight, String joinType, List<JoinKey> joinKeys, Map<String, Set<String>> datasetToColumnNames, String datasetLeft, String datasetRight) {
        JoinRecipePayloadParams.JoinDesc joinDesc = JoinRecipePayloadParams.JoinDesc.of(indexLeft, AIJoin.getJoinType(joinType), indexRight, ConditionsMode.AND);
        for (JoinKey joinKey : joinKeys) {
            JoinRecipePayloadParams.MatchingCondition mc = new JoinRecipePayloadParams.MatchingCondition();
            if (!datasetToColumnNames.get(datasetLeft).contains(joinKey.column_left) || !datasetToColumnNames.get(datasetRight).contains(joinKey.column_right)) continue;
            mc.column1 = new ColumnDesc(indexLeft, joinKey.column_left);
            mc.column2 = new ColumnDesc(indexRight, joinKey.column_right);
            mc.type = JoinRecipePayloadParams.MatchingType.EQ;
            joinDesc.on.add(mc);
        }
        recipeParams.joins.add(joinDesc);
    }

    private static JoinType getJoinType(String joinType) {
        switch (joinType) {
            case "LEFT": {
                return JoinType.LEFT;
            }
            case "RIGHT": {
                return JoinType.RIGHT;
            }
            case "FULL": {
                return JoinType.FULL;
            }
            case "CROSS": {
                return JoinType.CROSS;
            }
        }
        return JoinType.INNER;
    }

    public static class Params
    extends AIRecipe.AIRecipeParams {
        public List<JoinCondition> join_conditions;
    }

    public static class JoinCondition {
        public String dataset_left;
        public String dataset_right;
        public List<JoinKey> join_keys;
        public String join_type;

        public boolean isValidCondition() {
            return this.dataset_left != null && this.dataset_right != null && !this.dataset_right.equals(this.dataset_left);
        }
    }

    public static class JoinKey {
        public String column_left;
        public String column_right;
    }
}

