/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dss.shadelib.org.apache.iceberg.parquet;

import com.dataiku.dss.shadelib.org.apache.iceberg.Schema;
import com.dataiku.dss.shadelib.org.apache.iceberg.expressions.Binder;
import com.dataiku.dss.shadelib.org.apache.iceberg.expressions.Bound;
import com.dataiku.dss.shadelib.org.apache.iceberg.expressions.BoundReference;
import com.dataiku.dss.shadelib.org.apache.iceberg.expressions.Expression;
import com.dataiku.dss.shadelib.org.apache.iceberg.expressions.ExpressionVisitors;
import com.dataiku.dss.shadelib.org.apache.iceberg.expressions.Expressions;
import com.dataiku.dss.shadelib.org.apache.iceberg.expressions.Literal;
import com.dataiku.dss.shadelib.org.apache.iceberg.parquet.ParquetUtil;
import com.dataiku.dss.shadelib.org.apache.iceberg.relocated.com.google.common.collect.ImmutableList;
import com.dataiku.dss.shadelib.org.apache.iceberg.relocated.com.google.common.collect.Maps;
import com.dataiku.dss.shadelib.org.apache.iceberg.relocated.com.google.common.collect.Sets;
import com.dataiku.dss.shadelib.org.apache.iceberg.types.Type;
import com.dataiku.dss.shadelib.org.apache.iceberg.types.TypeUtil;
import com.dataiku.dss.shadelib.org.apache.iceberg.types.Types;
import com.dataiku.dss.shadelib.org.apache.iceberg.util.DecimalUtil;
import com.dataiku.dss.shadelib.org.apache.iceberg.util.UUIDUtil;
import com.dataiku.dss.shadelib.org.apache.parquet.column.values.bloomfilter.BloomFilter;
import com.dataiku.dss.shadelib.org.apache.parquet.hadoop.BloomFilterReader;
import com.dataiku.dss.shadelib.org.apache.parquet.hadoop.metadata.BlockMetaData;
import com.dataiku.dss.shadelib.org.apache.parquet.hadoop.metadata.ColumnChunkMetaData;
import com.dataiku.dss.shadelib.org.apache.parquet.io.api.Binary;
import com.dataiku.dss.shadelib.org.apache.parquet.schema.LogicalTypeAnnotation;
import com.dataiku.dss.shadelib.org.apache.parquet.schema.MessageType;
import com.dataiku.dss.shadelib.org.apache.parquet.schema.PrimitiveType;
import java.math.BigDecimal;
import java.nio.ByteBuffer;
import java.util.Map;
import java.util.Set;
import java.util.UUID;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ParquetBloomRowGroupFilter {
    private static final Logger LOG = LoggerFactory.getLogger(ParquetBloomRowGroupFilter.class);
    private final Schema schema;
    private final Expression expr;
    private final boolean caseSensitive;
    private static final boolean ROWS_MIGHT_MATCH = true;
    private static final boolean ROWS_CANNOT_MATCH = false;

    public ParquetBloomRowGroupFilter(Schema schema, Expression unbound) {
        this(schema, unbound, true);
    }

    public ParquetBloomRowGroupFilter(Schema schema, Expression unbound, boolean caseSensitive) {
        this.schema = schema;
        Types.StructType struct = schema.asStruct();
        this.expr = Binder.bind(struct, Expressions.rewriteNot(unbound), caseSensitive);
        this.caseSensitive = caseSensitive;
    }

    public boolean shouldRead(MessageType fileSchema, BlockMetaData rowGroup, BloomFilterReader bloomReader) {
        return new BloomEvalVisitor().eval(fileSchema, rowGroup, bloomReader);
    }

    private class BloomEvalVisitor
    extends ExpressionVisitors.BoundExpressionVisitor<Boolean> {
        private BloomFilterReader bloomReader;
        private Set<Integer> fieldsWithBloomFilter = null;
        private Map<Integer, ColumnChunkMetaData> columnMetaMap = null;
        private Map<Integer, BloomFilter> bloomCache = null;
        private Map<Integer, PrimitiveType> parquetPrimitiveTypes = null;
        private Map<Integer, Type> types = null;

        private BloomEvalVisitor() {
        }

        private boolean eval(MessageType fileSchema, BlockMetaData rowGroup, BloomFilterReader bloomFilterReader) {
            this.bloomReader = bloomFilterReader;
            this.fieldsWithBloomFilter = Sets.newHashSet();
            this.columnMetaMap = Maps.newHashMap();
            this.bloomCache = Maps.newHashMap();
            this.parquetPrimitiveTypes = Maps.newHashMap();
            this.types = Maps.newHashMap();
            for (ColumnChunkMetaData meta : rowGroup.getColumns()) {
                PrimitiveType colType = fileSchema.getType(meta.getPath().toArray()).asPrimitiveType();
                if (colType.getId() == null) continue;
                int id = colType.getId().intValue();
                Type icebergType = ParquetBloomRowGroupFilter.this.schema.findType(id);
                if (!ParquetUtil.hasNoBloomFilterPages(meta)) {
                    this.fieldsWithBloomFilter.add(id);
                }
                this.columnMetaMap.put(id, meta);
                this.parquetPrimitiveTypes.put(id, colType);
                this.types.put(id, icebergType);
            }
            Set<Integer> filterRefs = Binder.boundReferences(ParquetBloomRowGroupFilter.this.schema.asStruct(), ImmutableList.of(ParquetBloomRowGroupFilter.this.expr), ParquetBloomRowGroupFilter.this.caseSensitive);
            if (!filterRefs.isEmpty()) {
                Sets.SetView<Integer> overlappedBloomFilters = Sets.intersection(this.fieldsWithBloomFilter, filterRefs);
                if (overlappedBloomFilters.isEmpty()) {
                    return true;
                }
                LOG.debug("Using Bloom filters for columns with IDs: {}", overlappedBloomFilters);
            }
            return ExpressionVisitors.visitEvaluator(ParquetBloomRowGroupFilter.this.expr, this);
        }

        @Override
        public Boolean alwaysTrue() {
            return true;
        }

        @Override
        public Boolean alwaysFalse() {
            return false;
        }

        @Override
        public Boolean not(Boolean result) {
            throw new UnsupportedOperationException("This path shouldn't be reached.");
        }

        @Override
        public Boolean and(Boolean leftResult, Boolean rightResult) {
            return leftResult != false && rightResult != false;
        }

        @Override
        public Boolean or(Boolean leftResult, Boolean rightResult) {
            return leftResult != false || rightResult != false;
        }

        @Override
        public <T> Boolean isNull(BoundReference<T> ref) {
            return true;
        }

        @Override
        public <T> Boolean notNull(BoundReference<T> ref) {
            return true;
        }

        @Override
        public <T> Boolean isNaN(BoundReference<T> ref) {
            return true;
        }

        @Override
        public <T> Boolean notNaN(BoundReference<T> ref) {
            return true;
        }

        @Override
        public <T> Boolean lt(BoundReference<T> ref, Literal<T> lit) {
            return true;
        }

        @Override
        public <T> Boolean ltEq(BoundReference<T> ref, Literal<T> lit) {
            return true;
        }

        @Override
        public <T> Boolean gt(BoundReference<T> ref, Literal<T> lit) {
            return true;
        }

        @Override
        public <T> Boolean gtEq(BoundReference<T> ref, Literal<T> lit) {
            return true;
        }

        @Override
        public <T> Boolean eq(BoundReference<T> ref, Literal<T> lit) {
            int id = ref.fieldId();
            if (!this.fieldsWithBloomFilter.contains(id)) {
                return true;
            }
            BloomFilter bloom = this.loadBloomFilter(id);
            Type type = this.types.get(id);
            T value = lit.value();
            return this.shouldRead(this.parquetPrimitiveTypes.get(id), value, bloom, type);
        }

        @Override
        public <T> Boolean notEq(BoundReference<T> ref, Literal<T> lit) {
            return true;
        }

        @Override
        public <T> Boolean in(BoundReference<T> ref, Set<T> literalSet) {
            int id = ref.fieldId();
            if (!this.fieldsWithBloomFilter.contains(id)) {
                return true;
            }
            BloomFilter bloom = this.loadBloomFilter(id);
            Type type = this.types.get(id);
            for (T e : literalSet) {
                if (!this.shouldRead(this.parquetPrimitiveTypes.get(id), e, bloom, type)) continue;
                return true;
            }
            return false;
        }

        @Override
        public <T> Boolean notIn(BoundReference<T> ref, Set<T> literalSet) {
            return true;
        }

        @Override
        public <T> Boolean startsWith(BoundReference<T> ref, Literal<T> lit) {
            return true;
        }

        @Override
        public <T> Boolean notStartsWith(BoundReference<T> ref, Literal<T> lit) {
            return true;
        }

        private BloomFilter loadBloomFilter(int id) {
            if (this.bloomCache.containsKey(id)) {
                return this.bloomCache.get(id);
            }
            ColumnChunkMetaData columnChunkMetaData = this.columnMetaMap.get(id);
            BloomFilter bloomFilter = this.bloomReader.readBloomFilter(columnChunkMetaData);
            if (bloomFilter == null) {
                throw new IllegalStateException("Failed to read required bloom filter for id: " + id);
            }
            this.bloomCache.put(id, bloomFilter);
            return bloomFilter;
        }

        private <T> boolean shouldRead(PrimitiveType primitiveType, T value, BloomFilter bloom, Type type) {
            switch (primitiveType.getPrimitiveTypeName()) {
                case INT32: {
                    switch (type.typeId()) {
                        case DECIMAL: {
                            BigDecimal decimalValue = (BigDecimal)value;
                            long hashValue = bloom.hash(decimalValue.unscaledValue().intValue());
                            return bloom.findHash(hashValue);
                        }
                        case INTEGER: 
                        case DATE: {
                            long hashValue = bloom.hash(((Number)value).intValue());
                            return bloom.findHash(hashValue);
                        }
                    }
                    return true;
                }
                case INT64: {
                    switch (type.typeId()) {
                        case DECIMAL: {
                            BigDecimal decimalValue = (BigDecimal)value;
                            long hashValue = bloom.hash(decimalValue.unscaledValue().longValue());
                            return bloom.findHash(hashValue);
                        }
                        case LONG: 
                        case TIME: 
                        case TIMESTAMP: {
                            long hashValue = bloom.hash(((Number)value).longValue());
                            return bloom.findHash(hashValue);
                        }
                    }
                    return true;
                }
                case FLOAT: {
                    long hashValue = bloom.hash(((Number)value).floatValue());
                    return bloom.findHash(hashValue);
                }
                case DOUBLE: {
                    long hashValue = bloom.hash(((Number)value).doubleValue());
                    return bloom.findHash(hashValue);
                }
                case FIXED_LEN_BYTE_ARRAY: 
                case BINARY: {
                    switch (type.typeId()) {
                        case STRING: {
                            long hashValue = bloom.hash(Binary.fromCharSequence((CharSequence)value));
                            return bloom.findHash(hashValue);
                        }
                        case BINARY: 
                        case FIXED: {
                            long hashValue = bloom.hash(Binary.fromConstantByteBuffer((ByteBuffer)value));
                            return bloom.findHash(hashValue);
                        }
                        case DECIMAL: {
                            LogicalTypeAnnotation.DecimalLogicalTypeAnnotation metadata = (LogicalTypeAnnotation.DecimalLogicalTypeAnnotation)primitiveType.getLogicalTypeAnnotation();
                            int scale = metadata.getScale();
                            int precision = metadata.getPrecision();
                            byte[] requiredBytes = new byte[TypeUtil.decimalRequiredBytes(precision)];
                            byte[] binary = DecimalUtil.toReusedFixLengthBytes(precision, scale, (BigDecimal)value, requiredBytes);
                            long hashValue = bloom.hash(Binary.fromConstantByteArray(binary));
                            return bloom.findHash(hashValue);
                        }
                        case UUID: {
                            long hashValue = bloom.hash(Binary.fromConstantByteArray(UUIDUtil.convert((UUID)value)));
                            return bloom.findHash(hashValue);
                        }
                    }
                    return true;
                }
            }
            return true;
        }

        @Override
        public <T> Boolean handleNonReference(Bound<T> term) {
            return true;
        }
    }
}

