/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.pivot.backend.dss;

import com.dataiku.dip.io.ColumnBlock;
import com.dataiku.dip.io.LinoReader;
import com.dataiku.dip.pivot.backend.common.PivotPostprocessor;
import com.dataiku.dip.pivot.backend.common.ResponseValidator;
import com.dataiku.dip.pivot.backend.common.highcardinality.BinsAndTensorsSafetyChecks;
import com.dataiku.dip.pivot.backend.common.highcardinality.PostPruneSafetyChecks;
import com.dataiku.dip.pivot.backend.dss.AxisHandler;
import com.dataiku.dip.pivot.backend.dss.DataTensor;
import com.dataiku.dip.pivot.backend.dss.LongDataTensor;
import com.dataiku.dip.pivot.backend.dss.MultipassPivotTableBuilder;
import com.dataiku.dip.pivot.backend.dss.PivotTableAggrBuilder;
import com.dataiku.dip.pivot.backend.dss.PivotUtils;
import com.dataiku.dip.pivot.backend.dss.aggregators.AbstractAggregator;
import com.dataiku.dip.pivot.backend.dss.aggregators.TensorCustomAggregator;
import com.dataiku.dip.pivot.backend.model.Aggregation;
import com.dataiku.dip.pivot.backend.model.AxisDef;
import com.dataiku.dip.pivot.backend.model.AxisElt;
import com.dataiku.dip.pivot.backend.model.AxisSortPrune;
import com.dataiku.dip.pivot.backend.model.DateAxisElt;
import com.dataiku.dip.pivot.backend.model.NumericalAxisElt;
import com.dataiku.dip.pivot.backend.model.NumericalAxisParams;
import com.dataiku.dip.pivot.backend.model.PivotTableAggregatedRequest;
import com.dataiku.dip.pivot.backend.model.PivotTableResponse;
import com.dataiku.dip.pivot.backend.model.PivotTableTensorRequest;
import com.dataiku.dip.pivot.backend.model.PivotTableTensorResponse;
import com.dataiku.dip.shaker.filter.FilteringExecutor;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.variables.VariablesContext;
import com.google.common.collect.Lists;
import com.google.refine.udaf.UdafUtils;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import org.apache.commons.lang.NotImplementedException;

public class TensorPivotBuilder
extends PivotTableAggrBuilder
implements MultipassPivotTableBuilder {
    private final boolean computeSubTotals;
    protected PivotTableTensorRequest request;
    protected int numAxes;
    protected AxisHandler[] handlers;
    private final VariablesContext variablesContext;
    private final LinoReader linoReader;
    private long beforeFilterRecords = 0L;
    private long afterFilterRecords = 0L;
    private LongDataTensor countTensor;
    public final List<AbstractAggregator<?>> aggregators = new ArrayList();
    private static final DKULogger logger = DKULogger.getLogger((String)"dku.shaker.pivot");

    public TensorPivotBuilder(PivotTableTensorRequest request, AxisHandler[] handlers) {
        this(request, handlers, null, null);
    }

    public TensorPivotBuilder(PivotTableTensorRequest request, AxisHandler[] handlers, VariablesContext variablesContext, LinoReader linoReader) {
        this.request = request;
        this.handlers = handlers;
        this.variablesContext = variablesContext;
        this.linoReader = linoReader;
        this.filtered = request.isFiltered();
        this.computeSubTotals = request.computeSubTotals;
        this.numAxes = handlers.length;
    }

    @Override
    public void linoInit(LinoReader linoReader) {
        logger.info((Object)"Init PivotTableTensorBuilder");
        this.buildLinoFilters(this.request, linoReader);
    }

    @Override
    public void addPass1(ColumnBlock[] axisBlocks, LinoReader linoReader, int blockIdx) throws IOException {
        boolean[] filters = null;
        if (this.filtered) {
            filters = new boolean[axisBlocks[0].nbRecords()];
            this.filterLinoBlock(this.request, linoReader, blockIdx, filters);
        }
        for (int i = 0; i < this.handlers.length; ++i) {
            this.handlers[i].observe(axisBlocks[i], filters);
        }
    }

    @Override
    public void endPass1() throws BinsAndTensorsSafetyChecks.MemoryLimitExceededException {
        int[] numBins = new int[this.numAxes];
        for (int i = 0; i < numBins.length; ++i) {
            numBins[i] = this.handlers[i].getNbBins() + (this.computeSubTotals ? 1 : 0);
        }
        this.endPass1(numBins);
    }

    protected void endPass1(int[] numBins) throws BinsAndTensorsSafetyChecks.MemoryLimitExceededException {
        BinsAndTensorsSafetyChecks.failIfTensorWouldBeTooLargeOrTooManyBins((PivotTableAggregatedRequest)this.request, numBins, this.variablesContext);
        logger.info((Object)"End of binning pass");
        this.countTensor = new LongDataTensor(numBins);
        for (Aggregation aggregation : this.request.aggregations) {
            this.aggregators.add(PivotUtils.buildAggregator(this.variablesContext, this.linoReader, aggregation, new Object[]{numBins}));
        }
    }

    protected int[] getAxisBins(ColumnBlock[] axisBlocks, boolean[] filters, int[][] axisBins) {
        int[] allBins = new int[this.numAxes];
        for (int i = 0; i < this.numAxes; ++i) {
            this.handlers[i].getBins(axisBlocks[i], axisBins[i], filters);
            allBins[i] = this.countTensor.axisLengths[i] - 1;
        }
        return (int[])(this.computeSubTotals ? allBins : null);
    }

    @Override
    public void addPass2(ColumnBlock[] axisBlocks, List<ColumnBlock> aggrColumns, LinoReader linoReader, int blockIdx) throws IOException {
        this.updateLinoFilterFacets(this.request, linoReader, blockIdx);
        this.setCustomAggregatorsContext(linoReader, blockIdx);
        int numRows = axisBlocks[0].nbRecords();
        boolean[] filters = null;
        this.beforeFilterRecords += (long)numRows;
        if (this.filtered) {
            filters = new boolean[numRows];
            this.filterLinoBlock(this.request, linoReader, blockIdx, filters);
            this.afterFilterRecords += (long)FilteringExecutor.countTrue(filters);
        } else {
            this.afterFilterRecords += (long)numRows;
        }
        int[][] axisBins = new int[this.numAxes][];
        for (int i = 0; i < this.numAxes; ++i) {
            assert (axisBlocks[i].nbRecords() == numRows);
            axisBins[i] = new int[axisBlocks[i].nbRecords()];
        }
        int[] totalBins = this.getAxisBins(axisBlocks, filters, axisBins);
        for (int r = 0; r < numRows; ++r) {
            if (filters != null && !filters[r]) continue;
            int[] coords = new int[this.numAxes];
            boolean hasUnsetBin = false;
            for (int i = 0; i < this.numAxes; ++i) {
                coords[i] = axisBins[i][r];
                if (coords[i] != -1) continue;
                hasUnsetBin = true;
            }
            if (hasUnsetBin) continue;
            int row = r;
            this.cartesianProduct(coords, totalBins, newCoords -> {
                this.addRowCountData((int[])newCoords);
                for (int a = 0; a < this.request.aggregations.size(); ++a) {
                    ColumnBlock columnBlock = (ColumnBlock)aggrColumns.get(a);
                    this.aggregators.get(a).handleBlock(columnBlock, row, (int[])newCoords);
                }
            });
        }
    }

    private void setCustomAggregatorsContext(LinoReader linoReader, int blockIdx) {
        for (int a = 0; a < this.request.aggregations.size(); ++a) {
            AbstractAggregator<?> aggregator = this.aggregators.get(a);
            if (!(aggregator instanceof TensorCustomAggregator)) continue;
            UdafUtils.setBlocksIfCustomAggregation(aggregator, linoReader, blockIdx);
            ((TensorCustomAggregator)aggregator).countTensor = this.countTensor;
        }
    }

    private void cartesianProduct(int[] coords, int[] totalBins, ThrowingConsumer<int[], IOException> function) throws IOException {
        if (totalBins == null) {
            function.accept(coords);
            return;
        }
        ArrayList lists = new ArrayList();
        for (int i = 0; i < coords.length; ++i) {
            ArrayList<Integer> list = new ArrayList<Integer>();
            list.add(coords[i]);
            list.add(totalBins[i]);
            lists.add(list);
        }
        List cartesianProduct = Lists.cartesianProduct(lists);
        for (List integers : cartesianProduct) {
            int[] ints = new int[coords.length];
            for (int j = 0; j < ints.length; ++j) {
                ints[j] = (Integer)integers.get(j);
            }
            function.accept(ints);
        }
    }

    private void addRowCountData(int[] coords) {
        for (int i = 0; i < this.numAxes; ++i) {
            long[] lArray = this.countTensor.axes[i];
            int n = coords[i];
            lArray[n] = lArray[n] + 1L;
        }
        this.countTensor.increment(coords);
    }

    @Override
    public PivotTableTensorResponse end() throws IOException {
        logger.info((Object)"End of accumulation phase");
        for (AbstractAggregator<?> aggregator : this.aggregators) {
            aggregator.end();
        }
        AxisHandler.Axis[] axes = new AxisHandler.Axis[this.numAxes];
        for (int i = 0; i < this.numAxes; ++i) {
            axes[i] = this.sortAndPruneAxis(this.request.axes[i], this.handlers[i], this.request.axes[i].sortPrune, i, this.computeSubTotals);
        }
        logger.debug((Object)"End of prune");
        PivotTableTensorResponse resp = this.buildEmptyResponse(axes);
        this.copyCountData(this.countTensor, resp.counts, axes);
        for (int aggrId = 0; aggrId < this.request.aggregations.size(); ++aggrId) {
            resp.aggregations.add(this.copyAggrData(this.aggregators.get(aggrId), resp.counts.axisLengths, resp.counts, axes));
        }
        for (int i = 0; i < this.numAxes; ++i) {
            if (!this.request.axes[i].sortPrune.generateOthersCategory || axes[i].nbNotCutoff == axes[i].elts.size() || resp.counts.axes[i][resp.axisLabels[i].size()] == 0L) continue;
            resp.axisLabels[i].add(this.createOtherBin());
        }
        PivotPostprocessor.computeMeasureModesTensor(this.request, resp);
        PivotPostprocessor.computeIntermeasuresTensor(this.request, resp);
        logger.info((Object)"Building facets");
        this.computeFilterFacets(this.request, resp);
        logger.info((Object)"End of building output data");
        resp.setRecordCounts(this.beforeFilterRecords, this.afterFilterRecords);
        ResponseValidator.validateNonEmptyAxes(resp.counts);
        return resp;
    }

    @Override
    public void cleanup() {
        PivotUtils.cleanupAggregators(this.aggregators);
    }

    private PivotTableTensorResponse buildEmptyResponse(AxisHandler.Axis[] axes) {
        PivotTableTensorResponse resp = new PivotTableTensorResponse(this.numAxes, this.request);
        resp.engine = PivotTableResponse.PivotEngine.LINO;
        boolean[] willNeedAnOthersBin = new boolean[this.numAxes];
        for (int i = 0; i < this.numAxes; ++i) {
            if (this.request.axes[i].sortPrune.maxValues > 0L) {
                axes[i].nbNotCutoff = (int)Math.min(this.request.axes[i].sortPrune.maxValues, (long)axes[i].nbNotCutoff) + (this.computeSubTotals ? 1 : 0);
                int notCutoff = 0;
                for (int x = 0; x < axes[i].elts.size(); ++x) {
                    if (axes[i].elts.get((int)x).cutoffed || Objects.equals(axes[i].elts.get((int)x).label, "___dku_total_value___") || ++notCutoff <= (int)this.request.axes[i].sortPrune.maxValues) continue;
                    axes[i].elts.get((int)x).cutoffed = true;
                    willNeedAnOthersBin[i] = true;
                }
            }
            for (int x = 0; x < axes[i].elts.size(); ++x) {
                if (axes[i].elts.get((int)x).cutoffed) continue;
                resp.axisLabels[i].add(axes[i].elts.get(x));
            }
            if (willNeedAnOthersBin[i]) continue;
            this.request.axes[i].sortPrune.generateOthersCategory = false;
        }
        PostPruneSafetyChecks.checkTensorResponse(this.request, resp);
        int[] axisLengths = new int[this.numAxes];
        for (int i = 0; i < this.numAxes; ++i) {
            axisLengths[i] = this.request.axes[i].sortPrune.generateOthersCategory ? axes[i].nbNotCutoff + 1 : axes[i].nbNotCutoff;
        }
        resp.counts = new LongDataTensor(axisLengths);
        return resp;
    }

    private void copyCountData(final LongDataTensor origTensor, final LongDataTensor destTensor, final AxisHandler.Axis[] axes) throws IOException {
        block0: for (int i = 0; i < this.numAxes; ++i) {
            int targetBin = 0;
            for (int x = 0; x < axes[i].elts.size(); ++x) {
                if (axes[i].elts.get((int)x).cutoffed) continue;
                int origBin = axes[i].elts.get((int)x).binIndex;
                destTensor.axes[i][targetBin] = origTensor.axes[i][origBin];
                if (++targetBin >= axes[i].nbNotCutoff) continue block0;
            }
        }
        final LongDataTensor.LongMerger countMerger = Long::sum;
        TensorPivotBuilder.forAllCells(axes, new CoordinateConsumer(){

            @Override
            public void consume(int[] origCoordinates, int[] targetCoordinates) {
                long newValue = origTensor.get(origCoordinates);
                for (int i = 0; i < TensorPivotBuilder.this.numAxes; ++i) {
                    if (targetCoordinates[i] != axes[i].nbNotCutoff) continue;
                    if (!TensorPivotBuilder.this.request.axes[i].sortPrune.generateOthersCategory) {
                        return;
                    }
                    destTensor.axes[i][targetCoordinates[i]] = countMerger.merge(destTensor.axes[i][targetCoordinates[i]], newValue);
                }
                destTensor.merge(targetCoordinates, countMerger, newValue);
            }
        }, true);
    }

    private DataTensor<?> copyAggrData(final AbstractAggregator<?> aggr, int[] axisLengths, LongDataTensor countTensor, final AxisHandler.Axis[] axes) throws IOException {
        aggr.initMerge(axisLengths);
        TensorPivotBuilder.forAllCells(axes, new CoordinateConsumer(){

            @Override
            public void consume(int[] origCoordinates, int[] targetCoordinates, int[] numCutoffs) throws IOException {
                aggr.mergeTensorAndAxes(TensorPivotBuilder.this.request, origCoordinates, targetCoordinates, axes);
            }
        }, true);
        aggr.mergeEnd(axes, countTensor);
        aggr.postProcessEndPhase2();
        return aggr.getMergeDT();
    }

    public static void forAllCells(AxisHandler.Axis[] axes, CoordinateConsumer consumer, boolean includeCutOffs) throws IOException {
        TensorPivotBuilder.forAllCells(axes, includeCutOffs, consumer, 0, new int[axes.length], new int[axes.length], new int[axes.length]);
    }

    public static void forAllCells(AxisHandler.Axis[] axes, boolean includeCutOffs, CoordinateConsumer consumer, int currentAxisIdx, int[] origCoordinates, int[] targetCoordinates, int[] numCutOffs) throws IOException {
        numCutOffs[currentAxisIdx] = 0;
        for (int x = 0; x < axes[currentAxisIdx].elts.size(); ++x) {
            if (axes[currentAxisIdx].elts.get((int)x).cutoffed) {
                int n = currentAxisIdx;
                numCutOffs[n] = numCutOffs[n] + 1;
                if (!includeCutOffs) continue;
                targetCoordinates[currentAxisIdx] = axes[currentAxisIdx].nbNotCutoff;
            } else {
                targetCoordinates[currentAxisIdx] = x - numCutOffs[currentAxisIdx];
            }
            origCoordinates[currentAxisIdx] = axes[currentAxisIdx].elts.get((int)x).binIndex;
            if (currentAxisIdx < axes.length - 1) {
                TensorPivotBuilder.forAllCells(axes, includeCutOffs, consumer, currentAxisIdx + 1, origCoordinates, targetCoordinates, numCutOffs);
                continue;
            }
            consumer.consume(origCoordinates, targetCoordinates, numCutOffs);
        }
    }

    protected AxisHandler.Axis sortAndPruneAxis(AxisDef def, AxisHandler handler, AxisSortPrune sortPrune, int axisIdx, boolean computeAxisTotal) {
        AxisHandler.Axis ret = new AxisHandler.Axis();
        ret.elts = new ArrayList<AxisElt>();
        ArrayList<AxisElt> axisElts = new ArrayList<AxisElt>();
        int idx = 0;
        for (AxisElt elt : handler.getAxisElts()) {
            elt.binIndex = idx++;
            axisElts.add(elt);
        }
        if (computeAxisTotal) {
            this.createAllAxisElt(axisElts, idx);
        }
        ret.elts = axisElts;
        if (sortPrune == null) {
            ret.nbNotCutoff = ret.elts.size();
            return ret;
        }
        logger.info((Object)("Do SortPrune on axis " + axisIdx + " have " + ret.elts.size() + " elts"));
        if (sortPrune.sortType != null) {
            switch (sortPrune.sortType) {
                case ORIGINAL: {
                    break;
                }
                case NATURAL: {
                    this.sortAxisOnNatural(def, ret, handler);
                    break;
                }
                case AGGREGATION: {
                    this.sortAxisOnAggregation(sortPrune, ret, axisIdx);
                    break;
                }
                case CUSTOM: {
                    this.applyCustomSorting(ret, sortPrune.customSortingValues);
                    break;
                }
                default: {
                    throw new NotImplementedException();
                }
            }
        }
        for (AxisSortPrune.MeasureFilter filter : sortPrune.filters) {
            AbstractAggregator<?> aggr = this.aggregators.get(filter.measureFilterId);
            logger.info((Object)"Measure cutoff");
            for (int i = 0; i < ret.elts.size(); ++i) {
                double measureValue = aggr.getOutDT().getAxisAsDouble(axisIdx, ret.elts.get((int)i).binIndex);
                logger.info((Object)("For elt " + i + " measureValue= " + measureValue));
                if (!(measureValue > filter.maxValue) && !(measureValue < filter.minValue)) continue;
                ret.elts.get((int)i).cutoffed = true;
                logger.info((Object)" -->Cut it off");
            }
        }
        if (def.type == AxisDef.Type.ALPHANUM || def.type == AxisDef.Type.NUMERICAL && def.numParams.mode == NumericalAxisParams.BinningMode.NONE) {
            for (int i = 0; i < ret.elts.size(); ++i) {
                long count = this.countTensor.axes[axisIdx][ret.elts.get((int)i).binIndex];
                if (count != 0L) continue;
                ret.elts.get((int)i).cutoffed = true;
            }
        }
        for (int i = 0; i < ret.elts.size(); ++i) {
            if (ret.elts.get((int)i).cutoffed) continue;
            ++ret.nbNotCutoff;
        }
        return ret;
    }

    private void createAllAxisElt(List<AxisElt> axisElts, int idx) {
        AxisElt allAxisElt;
        AxisElt existingAxisElt;
        AxisElt axisElt = existingAxisElt = axisElts.isEmpty() ? null : axisElts.get(axisElts.size() - 1);
        if (existingAxisElt instanceof DateAxisElt) {
            DateAxisElt dateAxisElt = new DateAxisElt();
            dateAxisElt.tsValue = Long.MAX_VALUE;
            dateAxisElt.sortValue = Double.MAX_VALUE;
            allAxisElt = dateAxisElt;
        } else if (existingAxisElt instanceof NumericalAxisElt) {
            NumericalAxisElt numericalAxisElt = new NumericalAxisElt();
            numericalAxisElt.sortValue = Double.MAX_VALUE;
            allAxisElt = numericalAxisElt;
        } else {
            allAxisElt = new AxisElt();
        }
        allAxisElt.label = "___dku_total_value___";
        allAxisElt.binIndex = idx;
        allAxisElt.cutoffed = false;
        axisElts.add(allAxisElt);
    }

    private void sortAxisOnAggregation(AxisSortPrune sortPrune, AxisHandler.Axis ret, int axisIdx) {
        int asc = sortPrune.sortAscending ? -1 : 1;
        logger.info((Object)"Aggregation sort");
        AbstractAggregator<?> aggr = this.aggregators.get(sortPrune.aggregationSortId);
        aggr.sortAxis(ret, asc, axisIdx);
    }

    public static interface ThrowingConsumer<T, E extends Exception> {
        public void accept(T var1) throws E;
    }

    public static abstract class CoordinateConsumer {
        public void consume(int[] origCoordinates, int[] targetCoordinates) throws IOException {
        }

        public void consume(int[] origCoorrdinates, int[] targetCoordinates, int[] numCutoffs) throws IOException {
            this.consume(origCoorrdinates, targetCoordinates);
        }
    }
}

