/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.dataflow.exec.autofeaturegeneration;

import com.dataiku.dip.coremodel.Dataset;
import com.dataiku.dip.coremodel.Schema;
import com.dataiku.dip.coremodel.SchemaColumn;
import com.dataiku.dip.dataflow.exec.autofeaturegeneration.AutoFeatureGenerationRecipePayloadParams;
import com.dataiku.dip.dataflow.exec.autofeaturegeneration.ColumnForComputation;
import com.dataiku.dip.dataflow.exec.autofeaturegeneration.TimeWindow;
import com.dataiku.dip.dataflow.exec.autofeaturegeneration.TimeWindowQueryVisitor;
import com.dataiku.dip.dataflow.exec.autofeaturegeneration.VariableType;
import com.dataiku.dip.dataflow.exec.autofeaturegeneration.Visitable;
import com.dataiku.dip.dataflow.exec.autofeaturegeneration.Visitor;
import com.dataiku.dip.dataflow.exec.grouping.GroupingRecipePayloadParams;
import com.dataiku.dip.dataflow.exec.join.JoinRecipePayloadParams;
import com.dataiku.dip.datasets.Type;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.ErrorContext;
import com.google.common.base.Preconditions;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.commons.lang.StringUtils;

public class RelationshipGraph {
    AutoFeatureGenerationRecipePayloadParams params;
    HashMap<Integer, Node> nodes = new HashMap();
    Map<Integer, List<Edge>> adjacentEdges = new HashMap<Integer, List<Edge>>();
    SchemaColumn cutoffColumn;
    private static final DKULogger logger = DKULogger.getLogger((String)"dku.recipes.autofeaturegeneration.relationshipgraph");

    public RelationshipGraph(AutoFeatureGenerationRecipePayloadParams params, Map<String, Dataset> datasetMap) {
        this.params = params;
        Preconditions.checkNotNull((Object)params, (Object)"No params");
        Preconditions.checkNotNull((Object)params.virtualInputs, (Object)"No inputs");
        this.build(datasetMap);
    }

    private void build(Map<String, Dataset> datasetsMap) {
        this.addRoot(this.params.getPrimaryDataset(), datasetsMap);
        int virtualInputIndex = 1;
        for (AutoFeatureGenerationRecipePayloadParams.InputDesc input : this.params.getSecondaryDatasets()) {
            this.addNode(input, false, virtualInputIndex, datasetsMap);
            ++virtualInputIndex;
        }
        for (AutoFeatureGenerationRecipePayloadParams.RelationshipDesc relationship : this.params.relationships) {
            if (relationship.type == null) continue;
            this.addEdge(relationship);
            for (JoinRecipePayloadParams.MatchingCondition matchingCondition : relationship.on) {
                this.nodes.get(relationship.table1).addJoinColumnToSchema(matchingCondition.column1.name, datasetsMap);
                this.nodes.get(relationship.table2).addJoinColumnToSchema(matchingCondition.column2.name, datasetsMap);
            }
        }
    }

    public Node getRoot() {
        return this.nodes.get(0);
    }

    public void makeTimeWindowFilterJoins(TimeWindowQueryVisitor visitor) {
        HashSet<Integer> seenInputs = new HashSet<Integer>();
        this.makeTimeWindowFilterJoinsRecursive(visitor, this.getRoot(), seenInputs);
    }

    public void makeTimeWindowFilterJoinsRecursive(TimeWindowQueryVisitor visitor, Node current, Set<Integer> seenInputs) {
        int currentId = current.getId();
        seenInputs.add(currentId);
        for (Edge edge : this.adjacentEdges.get(currentId)) {
            if (seenInputs.contains(edge.relatedNode)) continue;
            logger.infoV("Joining %s to %s for potential time window filtering", new Object[]{edge.getCurrentNodeName(), edge.getRelatedNodeName()});
            edge.accept(visitor);
            this.makeTimeWindowFilterJoinsRecursive(visitor, this.nodes.get(edge.relatedNode), seenInputs);
        }
    }

    public void makeFeatures(Visitor visitor) {
        HashSet<Integer> seenInputs = new HashSet<Integer>();
        this.makeFeaturesRecursive(visitor, this.getRoot(), seenInputs);
    }

    private void makeFeaturesRecursive(Visitor visitor, Node input, Set<Integer> seenInputs) {
        int inputId = input.getId();
        seenInputs.add(inputId);
        for (Edge edge : this.adjacentEdges.get(inputId)) {
            if (seenInputs.contains(edge.relatedNode)) continue;
            this.makeFeaturesRecursive(visitor, this.nodes.get(edge.relatedNode), seenInputs);
            edge.accept(visitor);
        }
        input.accept(visitor);
    }

    public Schema getRootSchema() {
        List computedColumnNames = this.getRoot().getColumnsForComputation().stream().map(column -> column.name).collect(Collectors.toList());
        Schema finalSchema = new Schema();
        for (SchemaColumn column2 : this.getRoot().getSchema().getColumns()) {
            if (!computedColumnNames.contains(column2.getName())) continue;
            finalSchema.addColumn(column2);
        }
        return finalSchema;
    }

    private void addNode(AutoFeatureGenerationRecipePayloadParams.InputDesc inputDataset, boolean isRoot, int nodeId, Map<String, Dataset> datasetMap) {
        this.nodes.put(nodeId, new Node(inputDataset, isRoot, nodeId, datasetMap));
        this.adjacentEdges.put(nodeId, new ArrayList());
    }

    private void addRoot(AutoFeatureGenerationRecipePayloadParams.InputDesc inputDataset, Map<String, Dataset> datasetMap) {
        this.addNode(inputDataset, true, 0, datasetMap);
        Node root = this.getRoot();
        if (root.hasTimeIndexColumn()) {
            this.cutoffColumn = root.findSchemaColumnInDatasetMap(root.getTimeIndexColumn(), datasetMap);
            root.setCutoffTimeColumn(this.cutoffColumn);
        }
    }

    private void addEdge(AutoFeatureGenerationRecipePayloadParams.RelationshipDesc relationship) {
        int node1 = relationship.table1;
        int node2 = relationship.table2;
        Edge edgeFromNode1 = null;
        Edge edgeFromNode2 = null;
        List<JoinRecipePayloadParams.MatchingCondition> reversedConditions = relationship.reverseConditions();
        if (relationship.type == null) {
            throw ErrorContext.iae((String)"The relationship type is missing");
        }
        switch (relationship.type) {
            case MANY_TO_ONE: {
                edgeFromNode1 = new ForwardEdge(node1, node2, relationship.on);
                edgeFromNode2 = new BackwardEdge(node2, node1, reversedConditions);
                break;
            }
            case ONE_TO_ONE: {
                edgeFromNode1 = new ForwardEdge(node1, node2, relationship.on);
                edgeFromNode2 = new ForwardEdge(node2, node1, reversedConditions);
                break;
            }
            case ONE_TO_MANY: {
                edgeFromNode1 = new BackwardEdge(node1, node2, relationship.on);
                edgeFromNode2 = new ForwardEdge(node2, node1, reversedConditions);
            }
        }
        this.adjacentEdges.get(node1).add(edgeFromNode1);
        this.adjacentEdges.get(node2).add(edgeFromNode2);
    }

    public static class Node
    implements Visitable {
        private final AutoFeatureGenerationRecipePayloadParams.InputDesc dataset;
        private final List<ColumnForComputation> columnsForComputations;
        private final List<ColumnForComputation> columnsForEntityComputation;
        private final boolean isRoot;
        private final int id;
        private final List<TimeWindow> timeWindows;
        private String cutoffTimeColumn;
        private final Schema schema;

        public Node(AutoFeatureGenerationRecipePayloadParams.InputDesc inputDataset, boolean isRoot, int nodeId, Map<String, Dataset> datasetMap) {
            this.dataset = inputDataset;
            this.columnsForComputations = this.getInitialColumnsForComputation(datasetMap);
            this.columnsForEntityComputation = new ArrayList<ColumnForComputation>(this.columnsForComputations);
            this.isRoot = isRoot;
            this.id = nodeId;
            this.timeWindows = new ArrayList<TimeWindow>(inputDataset.timeWindows);
            this.schema = new Schema();
            this.columnsForComputations.forEach(column -> this.schema.addColumn(column.name, column.schemaType));
        }

        private SchemaColumn findSchemaColumnInDatasetMap(String columnName, Map<String, Dataset> datasetMap) {
            if (!datasetMap.containsKey(this.dataset.name)) {
                throw ErrorContext.iae((String)("Dataset " + this.dataset.name + " is not part of the input datasets"));
            }
            SchemaColumn schemaColumn = datasetMap.get(this.dataset.name).getSchema().getColumn(columnName);
            if (schemaColumn == null) {
                throw ErrorContext.iae((String)("The column '" + columnName + "' does not belong to the dataset '" + this.dataset.name + "'"));
            }
            return schemaColumn;
        }

        private List<ColumnForComputation> getInitialColumnsForComputation(Map<String, Dataset> datasetMap) {
            ArrayList<ColumnForComputation> columnsForComputations = new ArrayList<ColumnForComputation>();
            for (AutoFeatureGenerationRecipePayloadParams.Column selectedColumn : this.dataset.getSelectedColumns()) {
                SchemaColumn schemaColumn = this.findSchemaColumnInDatasetMap(selectedColumn.name, datasetMap);
                columnsForComputations.add(new ColumnForComputation(selectedColumn.name, selectedColumn.variableType, schemaColumn.getType(), null));
            }
            return columnsForComputations;
        }

        public int getId() {
            return this.id;
        }

        public AutoFeatureGenerationRecipePayloadParams.InputDesc getDataset() {
            return this.dataset;
        }

        public boolean isRoot() {
            return this.isRoot;
        }

        public String getDatasetName() {
            return this.dataset.name;
        }

        public String getDatasetLabel() {
            return this.dataset.originLabel;
        }

        public String getDatasetSanitizedLabel() {
            return this.dataset.originLabel.replace(".", "_");
        }

        public List<ColumnForComputation> getColumnsForComputation() {
            return this.columnsForComputations;
        }

        public void addComputedColumns(List<ColumnForComputation> newColumns) {
            this.columnsForComputations.addAll(newColumns);
            newColumns.forEach(columnForComputation -> {
                SchemaColumn newSchemaColumn = new SchemaColumn(columnForComputation.name, columnForComputation.schemaType);
                if (StringUtils.isNotBlank((String)columnForComputation.description)) {
                    newSchemaColumn.withComment(columnForComputation.description);
                }
                this.schema.addColumn(newSchemaColumn);
            });
        }

        public List<ColumnForComputation> getColumnsForEntityComputation() {
            return this.columnsForEntityComputation;
        }

        public Schema getSchema() {
            return this.schema;
        }

        public List<String> getSchemaColumnNames() {
            return this.schema.getColumns().stream().map(SchemaColumn::getName).collect(Collectors.toList());
        }

        public SchemaColumn getSchemaColumn(String columnName) {
            return this.schema.getColumn(columnName);
        }

        public void addJoinColumnToSchema(String name, Map<String, Dataset> datasetsMap) {
            if (!this.schema.hasColumn(name)) {
                this.schema.addColumn(this.findSchemaColumnInDatasetMap(name, datasetsMap));
            }
        }

        public void setCutoffTimeColumn(SchemaColumn cutoffTimeColumn) {
            this.cutoffTimeColumn = cutoffTimeColumn.getName();
            if (!this.schema.hasColumn(cutoffTimeColumn.getName())) {
                this.schema.addColumn(cutoffTimeColumn.getName(), cutoffTimeColumn.getType());
            }
        }

        public String getCutoffTimeColumn() {
            return this.cutoffTimeColumn;
        }

        public boolean hasTimeIndexColumn() {
            return StringUtils.isNotBlank((String)this.dataset.timeIndexColumn);
        }

        public String getTimeIndexColumn() {
            return this.dataset.timeIndexColumn;
        }

        public List<TimeWindow> getTimeWindows() {
            return this.timeWindows;
        }

        public String getSingleTimeWindowSuffix() {
            if (this.timeWindows == null || this.timeWindows.isEmpty()) {
                return "";
            }
            return this.timeWindows.get(0).getSuffix();
        }

        @Override
        public void accept(Visitor visitor) {
            visitor.visit(this);
        }
    }

    public abstract class Edge
    implements Visitable {
        public final int currentNode;
        public final int relatedNode;
        public final List<JoinRecipePayloadParams.MatchingCondition> matchingConditions;
        protected final List<VariableType> TYPES_TO_AGGREGATE = Arrays.asList(VariableType.NUMERIC, VariableType.CATEGORY);

        private Edge(int currentNode, int relatedNode, List<JoinRecipePayloadParams.MatchingCondition> matchingConditions) {
            this.currentNode = currentNode;
            this.relatedNode = relatedNode;
            this.matchingConditions = matchingConditions;
        }

        public String getCurrentNodeName() {
            return this.getNode(this.currentNode).getDatasetName();
        }

        public String getCurrentNodeSanitizedLabel() {
            return this.getNode(this.currentNode).getDatasetSanitizedLabel();
        }

        public String getRelatedNodeName() {
            return this.getNode(this.relatedNode).getDatasetName();
        }

        public String getRelatedNodeLabel() {
            return this.getNode(this.relatedNode).getDatasetLabel();
        }

        public String getRelatedNodeSanitizedLabel() {
            return this.getNode(this.relatedNode).getDatasetSanitizedLabel();
        }

        public List<ColumnForComputation> getColumnsForComputationFromRelatedNode() {
            return this.getColumnsForComputation(this.relatedNode);
        }

        public Schema getSchemaFromCurrentNode() {
            return this.getCurrentNode().getSchema();
        }

        public Node getCurrentNode() {
            return this.getNode(this.currentNode);
        }

        public Node getRelatedNode() {
            return this.getNode(this.relatedNode);
        }

        public Node getNode(int nodeId) {
            if (!RelationshipGraph.this.nodes.containsKey(nodeId)) {
                throw ErrorContext.iae((String)"The requested node is not part of the graph");
            }
            return RelationshipGraph.this.nodes.get(nodeId);
        }

        public void addColumnsToCurrentNode(List<ColumnForComputation> computedColumns) {
            this.getNode(this.currentNode).addComputedColumns(computedColumns);
        }

        public Type getColumnSchemaTypeFromRelatedNode(String columnName) {
            Node node = this.getNode(this.relatedNode);
            SchemaColumn schemaColumn = node.getSchemaColumn(columnName);
            if (schemaColumn == null) {
                throw ErrorContext.iae((String)("The column '" + columnName + "' does not belong to the columns of the dataset '" + node.getDatasetName() + "'"));
            }
            return schemaColumn.getType();
        }

        protected List<ColumnForComputation> getColumnsForComputation(int nodeId) {
            return this.getNode(nodeId).getColumnsForComputation();
        }

        protected List<ColumnForComputation> filterColumnsForComputationFromNode(List<VariableType> typesToKeep, int node) {
            ArrayList<ColumnForComputation> columns = new ArrayList<ColumnForComputation>();
            for (ColumnForComputation column : this.getColumnsForComputation(node)) {
                if (!typesToKeep.contains((Object)column.variableType)) continue;
                columns.add(column);
            }
            return columns;
        }

        protected List<ColumnForComputation> filterColumnsForComputationFromRelatedNode(List<VariableType> typesToKeep) {
            return this.filterColumnsForComputationFromNode(typesToKeep, this.relatedNode);
        }

        public String getSingleTimeWindowSuffix() {
            return this.getRelatedNode().getSingleTimeWindowSuffix();
        }
    }

    public class ForwardEdge
    extends Edge {
        public ForwardEdge(int source, int destination, List<JoinRecipePayloadParams.MatchingCondition> matchingConditions) {
            super(source, destination, matchingConditions);
        }

        @Override
        public void accept(Visitor visitor) {
            visitor.visit(this);
        }
    }

    public class BackwardEdge
    extends Edge {
        public BackwardEdge(int node1, int node2, List<JoinRecipePayloadParams.MatchingCondition> matchingConditions) {
            super(node1, node2, matchingConditions);
        }

        @Override
        public void accept(Visitor visitor) {
            visitor.visit(this);
        }

        public List<ColumnForComputation> getColumnsToAggregate() {
            return this.filterColumnsForComputationFromRelatedNode(this.TYPES_TO_AGGREGATE);
        }

        public ArrayList<GroupingRecipePayloadParams.GroupingKey> getGroupingKeys() {
            ArrayList<GroupingRecipePayloadParams.GroupingKey> keys = new ArrayList<GroupingRecipePayloadParams.GroupingKey>();
            for (JoinRecipePayloadParams.MatchingCondition condition : this.matchingConditions) {
                String relatedGroupColumn = condition.column2.name;
                if (keys.stream().anyMatch(key -> key.column.equals(relatedGroupColumn))) continue;
                GroupingRecipePayloadParams.GroupingKey groupKey = new GroupingRecipePayloadParams.GroupingKey(relatedGroupColumn);
                groupKey.type = this.getColumnSchemaTypeFromRelatedNode(relatedGroupColumn);
                keys.add(groupKey);
            }
            String cutoffTimeColumn = this.getRelatedNode().getCutoffTimeColumn();
            if (StringUtils.isNotBlank((String)cutoffTimeColumn) && keys.stream().noneMatch(key -> key.column.equals(cutoffTimeColumn))) {
                GroupingRecipePayloadParams.GroupingKey groupKey = new GroupingRecipePayloadParams.GroupingKey(cutoffTimeColumn);
                groupKey.type = RelationshipGraph.this.cutoffColumn.getType();
                keys.add(groupKey);
            }
            return keys;
        }
    }
}

