/*
 * Decompiled with CFR 0.152.
 */
package org.apache.hadoop.hive.ql.exec.vector.aggregation;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Random;
import java.util.Set;
import org.apache.hadoop.hive.common.type.DataTypePhysicalVariation;
import org.apache.hadoop.hive.ql.exec.vector.VectorRandomBatchSource;
import org.apache.hadoop.hive.ql.exec.vector.VectorRandomRowSource;
import org.apache.hadoop.hive.ql.exec.vector.aggregation.AggregationBase;
import org.apache.hadoop.hive.ql.plan.ExprNodeColumnDesc;
import org.apache.hadoop.hive.ql.plan.ExprNodeDesc;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFCount;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import org.apache.hadoop.hive.serde2.io.HiveDecimalWritable;
import org.apache.hadoop.hive.serde2.io.ShortWritable;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableShortObjectInspector;
import org.apache.hadoop.hive.serde2.typeinfo.CharTypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.DecimalTypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils;
import org.apache.hadoop.hive.serde2.typeinfo.VarcharTypeInfo;
import org.junit.Assert;
import org.junit.Ignore;
import org.junit.Test;

public class TestVectorAggregation
extends AggregationBase {
    private static final Set<String> varianceNames = new HashSet<String>();
    private static TypeInfo[] integerTypeInfos;
    private static TypeInfo[] floatingTypeInfos;
    private static TypeInfo[] decimalTypeInfos;
    private static TypeInfo[] stringFamilyTypeInfos;
    private static final int TEST_ROW_COUNT = 100000;

    @Test
    public void testAvgIntegers() throws Exception {
        Random random = new Random(7743L);
        this.doIntegerTests("avg", random);
    }

    @Test
    public void testAvgFloating() throws Exception {
        Random random = new Random(7743L);
        this.doFloatingTests("avg", random);
    }

    @Test
    public void testAvgDecimal() throws Exception {
        Random random = new Random(7743L);
        this.doDecimalTests("avg", random, false);
    }

    @Test
    public void testAvgDecimal64() throws Exception {
        Random random = new Random(7743L);
        this.doDecimalTests("avg", random, true);
    }

    @Test
    public void testAvgTimestamp() throws Exception {
        Random random = new Random(7743L);
        this.doTests(random, "avg", (TypeInfo)TypeInfoFactory.timestampTypeInfo);
    }

    @Test
    public void testCount() throws Exception {
        Random random = new Random(7743L);
        this.doTests(random, "count", (TypeInfo)TypeInfoFactory.shortTypeInfo);
        this.doTests(random, "count", (TypeInfo)TypeInfoFactory.longTypeInfo);
        this.doTests(random, "count", (TypeInfo)TypeInfoFactory.doubleTypeInfo);
        this.doTests(random, "count", (TypeInfo)new DecimalTypeInfo(18, 10));
        this.doTests(random, "count", (TypeInfo)TypeInfoFactory.stringTypeInfo);
    }

    @Test
    public void testCountStar() throws Exception {
        Random random = new Random(7743L);
        this.doTests(random, "count", (TypeInfo)TypeInfoFactory.shortTypeInfo, true, false);
        this.doTests(random, "count", (TypeInfo)TypeInfoFactory.longTypeInfo, true, false);
        this.doTests(random, "count", (TypeInfo)TypeInfoFactory.doubleTypeInfo, true, false);
        this.doTests(random, "count", (TypeInfo)new DecimalTypeInfo(18, 10), true, false);
        this.doTests(random, "count", (TypeInfo)TypeInfoFactory.stringTypeInfo, true, false);
    }

    @Test
    public void testMax() throws Exception {
        Random random = new Random(7743L);
        this.doIntegerTests("max", random);
        this.doFloatingTests("max", random);
        this.doDecimalTests("max", random, false);
        this.doDecimalTests("max", random, true);
        this.doTests(random, "max", (TypeInfo)TypeInfoFactory.timestampTypeInfo);
        this.doTests(random, "max", (TypeInfo)TypeInfoFactory.intervalDayTimeTypeInfo);
        this.doStringFamilyTests("max", random);
    }

    @Test
    public void testMin() throws Exception {
        Random random = new Random(7743L);
        this.doIntegerTests("min", random);
        this.doFloatingTests("min", random);
        this.doDecimalTests("min", random, false);
        this.doDecimalTests("min", random, true);
        this.doTests(random, "min", (TypeInfo)TypeInfoFactory.timestampTypeInfo);
        this.doTests(random, "min", (TypeInfo)TypeInfoFactory.intervalDayTimeTypeInfo);
        this.doStringFamilyTests("min", random);
    }

    @Test
    public void testSum() throws Exception {
        Random random = new Random(7743L);
        this.doTests(random, "sum", (TypeInfo)TypeInfoFactory.shortTypeInfo);
        this.doTests(random, "sum", (TypeInfo)TypeInfoFactory.longTypeInfo);
        this.doTests(random, "sum", (TypeInfo)TypeInfoFactory.doubleTypeInfo);
        this.doDecimalTests("sum", random, false);
        this.doDecimalTests("sum", random, true);
        this.doTests(random, "sum", (TypeInfo)TypeInfoFactory.timestampTypeInfo);
    }

    @Ignore
    @Test
    public void testBloomFilter() throws Exception {
        Random random = new Random(7743L);
        this.doIntegerTests("bloom_filter", random);
        this.doFloatingTests("bloom_filter", random);
        this.doDecimalTests("bloom_filter", random, false);
        this.doTests(random, "bloom_filter", (TypeInfo)TypeInfoFactory.timestampTypeInfo);
        this.doStringFamilyTests("bloom_filter", random);
    }

    @Test
    public void testVarianceIntegers() throws Exception {
        Random random = new Random(7743L);
        for (String aggregationName : varianceNames) {
            this.doIntegerTests(aggregationName, random);
        }
    }

    @Test
    public void testVarianceFloating() throws Exception {
        Random random = new Random(7743L);
        for (String aggregationName : varianceNames) {
            this.doFloatingTests(aggregationName, random);
        }
    }

    @Test
    public void testVarianceDecimal() throws Exception {
        Random random = new Random(7743L);
        for (String aggregationName : varianceNames) {
            this.doDecimalTests(aggregationName, random, false);
        }
    }

    @Test
    public void testVarianceTimestamp() throws Exception {
        Random random = new Random(7743L);
        for (String aggregationName : varianceNames) {
            this.doTests(random, aggregationName, (TypeInfo)TypeInfoFactory.timestampTypeInfo);
        }
    }

    private void doIntegerTests(String aggregationName, Random random) throws Exception {
        for (TypeInfo typeInfo : integerTypeInfos) {
            this.doTests(random, aggregationName, typeInfo);
        }
    }

    private void doFloatingTests(String aggregationName, Random random) throws Exception {
        for (TypeInfo typeInfo : floatingTypeInfos) {
            this.doTests(random, aggregationName, typeInfo);
        }
    }

    private void doDecimalTests(String aggregationName, Random random, boolean tryDecimal64) throws Exception {
        for (TypeInfo typeInfo : decimalTypeInfos) {
            this.doTests(random, aggregationName, typeInfo, false, tryDecimal64);
        }
    }

    private void doStringFamilyTests(String aggregationName, Random random) throws Exception {
        for (TypeInfo typeInfo : stringFamilyTypeInfos) {
            this.doTests(random, aggregationName, typeInfo);
        }
    }

    private boolean checkDecimal64(boolean tryDecimal64, TypeInfo typeInfo) {
        if (!tryDecimal64 || !(typeInfo instanceof DecimalTypeInfo)) {
            return false;
        }
        DecimalTypeInfo decimalTypeInfo = (DecimalTypeInfo)typeInfo;
        boolean result = HiveDecimalWritable.isPrecisionDecimal64((int)decimalTypeInfo.getPrecision());
        return result;
    }

    public static int getLinearRandomNumber(Random random, int maxSize) {
        int randomMultiplier = maxSize * (maxSize + 1) / 2;
        int randomInt = random.nextInt(randomMultiplier);
        int linearRandomNumber = 0;
        int i = maxSize;
        while (randomInt >= 0) {
            randomInt -= i;
            ++linearRandomNumber;
            --i;
        }
        return linearRandomNumber;
    }

    private void doMerge(GenericUDAFEvaluator.Mode mergeUdafEvaluatorMode, Random random, String aggregationName, TypeInfo typeInfo, VectorRandomRowSource.GenerationSpec keyGenerationSpec, List<String> columns, String[] columnNames, int dataAggrMaxKeyCount, int reductionFactor, TypeInfo partial1OutputTypeInfo, Object[] partial1ResultsArray) throws Exception {
        ArrayList<VectorRandomRowSource.GenerationSpec> mergeAggrGenerationSpecList = new ArrayList<VectorRandomRowSource.GenerationSpec>();
        ArrayList<DataTypePhysicalVariation> mergeDataTypePhysicalVariationList = new ArrayList<DataTypePhysicalVariation>();
        mergeAggrGenerationSpecList.add(keyGenerationSpec);
        mergeDataTypePhysicalVariationList.add(DataTypePhysicalVariation.NONE);
        VectorRandomRowSource.GenerationSpec mergeGenerationSpec = VectorRandomRowSource.GenerationSpec.createOmitGeneration(partial1OutputTypeInfo);
        mergeAggrGenerationSpecList.add(mergeGenerationSpec);
        mergeDataTypePhysicalVariationList.add(DataTypePhysicalVariation.NONE);
        ExprNodeColumnDesc mergeCol1Expr = new ExprNodeColumnDesc(partial1OutputTypeInfo, "col1", "table", false);
        ArrayList<ExprNodeDesc> mergeParameters = new ArrayList<ExprNodeDesc>();
        mergeParameters.add((ExprNodeDesc)mergeCol1Expr);
        int mergeParameterCount = mergeParameters.size();
        ObjectInspector[] mergeParameterObjectInspectors = new ObjectInspector[mergeParameterCount];
        for (int i = 0; i < mergeParameterCount; ++i) {
            TypeInfo paramTypeInfo = ((ExprNodeDesc)mergeParameters.get(i)).getTypeInfo();
            mergeParameterObjectInspectors[i] = TypeInfoUtils.getStandardWritableObjectInspectorFromTypeInfo((TypeInfo)paramTypeInfo);
        }
        VectorRandomRowSource mergeRowSource = new VectorRandomRowSource();
        mergeRowSource.initGenerationSpecSchema(random, mergeAggrGenerationSpecList, 0, false, true, mergeDataTypePhysicalVariationList);
        Object[][] mergeRandomRows = mergeRowSource.randomRows(100000);
        int mergeMaxKeyCount = dataAggrMaxKeyCount / reductionFactor;
        Object[] partial1Results = (Object[])partial1ResultsArray[0];
        int partial1Key = 0;
        for (int i = 0; i < mergeRandomRows.length; ++i) {
            while (true) {
                if (partial1Key >= dataAggrMaxKeyCount) {
                    partial1Key = 0;
                }
                if (partial1Results[partial1Key] != null) break;
                partial1Key = (short)(partial1Key + 1);
            }
            short mergeKey = (short)(partial1Key % mergeMaxKeyCount);
            mergeRandomRows[i][0] = new ShortWritable(mergeKey);
            mergeRandomRows[i][1] = partial1Results[partial1Key];
            partial1Key = (short)(partial1Key + 1);
        }
        VectorRandomBatchSource mergeBatchSource = VectorRandomBatchSource.createInterestingBatches(random, mergeRowSource, mergeRandomRows, null);
        GenericUDAFEvaluator mergeEvaluator = TestVectorAggregation.getEvaluator(aggregationName, typeInfo);
        ObjectInspector mergeReturnOI = mergeEvaluator.init(mergeUdafEvaluatorMode, mergeParameterObjectInspectors);
        TypeInfo mergeOutputTypeInfo = TypeInfoUtils.getTypeInfoFromObjectInspector((ObjectInspector)mergeReturnOI);
        Object[] mergeResultsArray = new Object[AggregationBase.AggregationTestMode.count];
        this.executeAggregationTests(aggregationName, partial1OutputTypeInfo, mergeEvaluator, mergeOutputTypeInfo, mergeUdafEvaluatorMode, mergeMaxKeyCount, columns, columnNames, mergeParameters, mergeRandomRows, mergeRowSource, mergeBatchSource, false, mergeResultsArray);
        this.verifyAggregationResults(partial1OutputTypeInfo, mergeOutputTypeInfo, mergeMaxKeyCount, mergeUdafEvaluatorMode, mergeResultsArray);
    }

    private void doTests(Random random, String aggregationName, TypeInfo typeInfo) throws Exception {
        this.doTests(random, aggregationName, typeInfo, false, false);
    }

    private void doTests(Random random, String aggregationName, TypeInfo typeInfo, boolean isCountStar, boolean tryDecimal64) throws Exception {
        boolean hasDifferentFinalExpr;
        boolean hasDifferentPartial2Expr;
        boolean hasDifferentCompleteExpr;
        ArrayList<VectorRandomRowSource.GenerationSpec> dataAggrGenerationSpecList = new ArrayList<VectorRandomRowSource.GenerationSpec>();
        ArrayList<DataTypePhysicalVariation> explicitDataTypePhysicalVariationList = new ArrayList<DataTypePhysicalVariation>();
        PrimitiveTypeInfo keyTypeInfo = TypeInfoFactory.shortTypeInfo;
        VectorRandomRowSource.GenerationSpec keyGenerationSpec = VectorRandomRowSource.GenerationSpec.createOmitGeneration((TypeInfo)keyTypeInfo);
        dataAggrGenerationSpecList.add(keyGenerationSpec);
        explicitDataTypePhysicalVariationList.add(DataTypePhysicalVariation.NONE);
        boolean decimal64Enable = this.checkDecimal64(tryDecimal64, typeInfo);
        VectorRandomRowSource.GenerationSpec generationSpec = VectorRandomRowSource.GenerationSpec.createSameType(typeInfo);
        dataAggrGenerationSpecList.add(generationSpec);
        explicitDataTypePhysicalVariationList.add(decimal64Enable ? DataTypePhysicalVariation.DECIMAL_64 : DataTypePhysicalVariation.NONE);
        ArrayList<String> columns = new ArrayList<String>();
        columns.add("col0");
        columns.add("col1");
        ExprNodeColumnDesc dataAggrCol1Expr = new ExprNodeColumnDesc(typeInfo, "col1", "table", false);
        ArrayList<ExprNodeDesc> dataAggrParameters = new ArrayList<ExprNodeDesc>();
        if (!isCountStar) {
            dataAggrParameters.add((ExprNodeDesc)dataAggrCol1Expr);
        }
        int dataAggrParameterCount = dataAggrParameters.size();
        ObjectInspector[] dataAggrParameterObjectInspectors = new ObjectInspector[dataAggrParameterCount];
        for (int i = 0; i < dataAggrParameterCount; ++i) {
            TypeInfo paramTypeInfo = ((ExprNodeDesc)dataAggrParameters.get(i)).getTypeInfo();
            dataAggrParameterObjectInspectors[i] = TypeInfoUtils.getStandardWritableObjectInspectorFromTypeInfo((TypeInfo)paramTypeInfo);
        }
        String[] columnNames = columns.toArray(new String[0]);
        int dataAggrMaxKeyCount = 20000;
        int reductionFactor = 16;
        ObjectInspector keyObjectInspector = VectorRandomRowSource.getObjectInspector((TypeInfo)keyTypeInfo);
        VectorRandomRowSource partial1RowSource = new VectorRandomRowSource();
        boolean allowNull = !aggregationName.equals("bloom_filter");
        partial1RowSource.initGenerationSpecSchema(random, dataAggrGenerationSpecList, 0, allowNull, true, explicitDataTypePhysicalVariationList);
        Object[][] partial1RandomRows = partial1RowSource.randomRows(100000);
        int partial1RowCount = partial1RandomRows.length;
        for (int i = 0; i < partial1RowCount; ++i) {
            short shortKey = (short)TestVectorAggregation.getLinearRandomNumber(random, 20000);
            partial1RandomRows[i][0] = ((WritableShortObjectInspector)keyObjectInspector).create(shortKey);
        }
        VectorRandomBatchSource partial1BatchSource = VectorRandomBatchSource.createInterestingBatches(random, partial1RowSource, partial1RandomRows, null);
        GenericUDAFEvaluator partial1Evaluator = TestVectorAggregation.getEvaluator(aggregationName, typeInfo);
        if (isCountStar) {
            Assert.assertTrue((boolean)(partial1Evaluator instanceof GenericUDAFCount.GenericUDAFCountEvaluator));
            GenericUDAFCount.GenericUDAFCountEvaluator countEvaluator = (GenericUDAFCount.GenericUDAFCountEvaluator)partial1Evaluator;
            countEvaluator.setCountAllColumns(true);
        }
        GenericUDAFEvaluator.Mode partial1UdafEvaluatorMode = GenericUDAFEvaluator.Mode.PARTIAL1;
        ObjectInspector partial1ReturnOI = partial1Evaluator.init(partial1UdafEvaluatorMode, dataAggrParameterObjectInspectors);
        TypeInfo partial1OutputTypeInfo = TypeInfoUtils.getTypeInfoFromObjectInspector((ObjectInspector)partial1ReturnOI);
        Object[] partial1ResultsArray = new Object[AggregationBase.AggregationTestMode.count];
        this.executeAggregationTests(aggregationName, typeInfo, partial1Evaluator, partial1OutputTypeInfo, partial1UdafEvaluatorMode, 20000, columns, columnNames, dataAggrParameters, partial1RandomRows, partial1RowSource, partial1BatchSource, tryDecimal64, partial1ResultsArray);
        this.verifyAggregationResults(typeInfo, partial1OutputTypeInfo, 20000, partial1UdafEvaluatorMode, partial1ResultsArray);
        if (varianceNames.contains(aggregationName)) {
            hasDifferentCompleteExpr = true;
        } else {
            switch (aggregationName) {
                case "avg": {
                    hasDifferentCompleteExpr = true;
                    break;
                }
                case "bloom_filter": 
                case "count": 
                case "max": 
                case "min": 
                case "sum": {
                    hasDifferentCompleteExpr = false;
                    break;
                }
                default: {
                    throw new RuntimeException("Unexpected aggregation name " + aggregationName);
                }
            }
        }
        if (hasDifferentCompleteExpr) {
            VectorRandomRowSource completeRowSource = new VectorRandomRowSource();
            completeRowSource.initGenerationSpecSchema(random, dataAggrGenerationSpecList, 0, true, true, explicitDataTypePhysicalVariationList);
            Object[][] completeRandomRows = completeRowSource.randomRows(100000);
            int completeRowCount = completeRandomRows.length;
            for (int i = 0; i < completeRowCount; ++i) {
                short shortKey = (short)TestVectorAggregation.getLinearRandomNumber(random, 20000);
                completeRandomRows[i][0] = ((WritableShortObjectInspector)keyObjectInspector).create(shortKey);
            }
            VectorRandomBatchSource completeBatchSource = VectorRandomBatchSource.createInterestingBatches(random, completeRowSource, completeRandomRows, null);
            GenericUDAFEvaluator completeEvaluator = TestVectorAggregation.getEvaluator(aggregationName, typeInfo);
            GenericUDAFEvaluator.Mode completeUdafEvaluatorMode = GenericUDAFEvaluator.Mode.COMPLETE;
            ObjectInspector completeReturnOI = completeEvaluator.init(completeUdafEvaluatorMode, dataAggrParameterObjectInspectors);
            TypeInfo completeOutputTypeInfo = TypeInfoUtils.getTypeInfoFromObjectInspector((ObjectInspector)completeReturnOI);
            Object[] completeResultsArray = new Object[AggregationBase.AggregationTestMode.count];
            this.executeAggregationTests(aggregationName, typeInfo, completeEvaluator, completeOutputTypeInfo, completeUdafEvaluatorMode, 20000, columns, columnNames, dataAggrParameters, completeRandomRows, completeRowSource, completeBatchSource, tryDecimal64, completeResultsArray);
            this.verifyAggregationResults(typeInfo, completeOutputTypeInfo, 20000, completeUdafEvaluatorMode, completeResultsArray);
        }
        if (varianceNames.contains(aggregationName)) {
            hasDifferentPartial2Expr = true;
        } else {
            switch (aggregationName) {
                case "avg": {
                    hasDifferentPartial2Expr = true;
                    break;
                }
                case "bloom_filter": 
                case "count": 
                case "max": 
                case "min": 
                case "sum": {
                    hasDifferentPartial2Expr = false;
                    break;
                }
                default: {
                    throw new RuntimeException("Unexpected aggregation name " + aggregationName);
                }
            }
        }
        if (hasDifferentPartial2Expr) {
            GenericUDAFEvaluator.Mode mergeUdafEvaluatorMode = GenericUDAFEvaluator.Mode.PARTIAL2;
            this.doMerge(mergeUdafEvaluatorMode, random, aggregationName, typeInfo, keyGenerationSpec, columns, columnNames, 20000, 16, partial1OutputTypeInfo, partial1ResultsArray);
        }
        if (varianceNames.contains(aggregationName)) {
            hasDifferentFinalExpr = true;
        } else {
            switch (aggregationName) {
                case "avg": {
                    hasDifferentFinalExpr = true;
                    break;
                }
                case "bloom_filter": 
                case "count": {
                    hasDifferentFinalExpr = true;
                    break;
                }
                case "max": 
                case "min": 
                case "sum": {
                    hasDifferentFinalExpr = false;
                    break;
                }
                default: {
                    throw new RuntimeException("Unexpected aggregation name " + aggregationName);
                }
            }
        }
        if (hasDifferentFinalExpr) {
            GenericUDAFEvaluator.Mode mergeUdafEvaluatorMode = GenericUDAFEvaluator.Mode.FINAL;
            this.doMerge(mergeUdafEvaluatorMode, random, aggregationName, typeInfo, keyGenerationSpec, columns, columnNames, 20000, 16, partial1OutputTypeInfo, partial1ResultsArray);
        }
    }

    static {
        varianceNames.add("variance");
        varianceNames.add("var_samp");
        varianceNames.add("std");
        varianceNames.add("stddev_samp");
        integerTypeInfos = new TypeInfo[]{TypeInfoFactory.byteTypeInfo, TypeInfoFactory.shortTypeInfo, TypeInfoFactory.intTypeInfo, TypeInfoFactory.longTypeInfo};
        floatingTypeInfos = new TypeInfo[]{TypeInfoFactory.doubleTypeInfo};
        decimalTypeInfos = new TypeInfo[]{new DecimalTypeInfo(38, 18), new DecimalTypeInfo(25, 2), new DecimalTypeInfo(19, 4), new DecimalTypeInfo(18, 10), new DecimalTypeInfo(17, 3), new DecimalTypeInfo(12, 2), new DecimalTypeInfo(7, 1)};
        stringFamilyTypeInfos = new TypeInfo[]{TypeInfoFactory.stringTypeInfo, new CharTypeInfo(25), new CharTypeInfo(10), new VarcharTypeInfo(20), new VarcharTypeInfo(15)};
    }
}

