package org.apache.druid.benchmark.query;

import com.fasterxml.jackson.databind.ObjectMapper;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.TimeUnit;
import org.apache.druid.collections.StupidPool;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.data.input.InputRow;
import org.apache.druid.jackson.DefaultObjectMapper;
import org.apache.druid.java.util.common.FileUtils;
import org.apache.druid.java.util.common.concurrent.Execs;
import org.apache.druid.java.util.common.granularity.Granularities;
import org.apache.druid.java.util.common.logger.Logger;
import org.apache.druid.offheap.OffheapBufferGenerator;
import org.apache.druid.query.FinalizeResultsQueryRunner;
import org.apache.druid.query.Query;
import org.apache.druid.query.QueryPlus;
import org.apache.druid.query.QueryRunner;
import org.apache.druid.query.QueryRunnerFactory;
import org.apache.druid.query.QueryToolChest;
import org.apache.druid.query.aggregation.DoubleMinAggregatorFactory;
import org.apache.druid.query.aggregation.DoubleSumAggregatorFactory;
import org.apache.druid.query.aggregation.LongMaxAggregatorFactory;
import org.apache.druid.query.aggregation.LongSumAggregatorFactory;
import org.apache.druid.query.aggregation.hyperloglog.HyperUniquesAggregatorFactory;
import org.apache.druid.query.aggregation.hyperloglog.HyperUniquesSerde;
import org.apache.druid.query.context.ResponseContext;
import org.apache.druid.query.ordering.StringComparators;
import org.apache.druid.query.spec.MultipleIntervalSegmentSpec;
import org.apache.druid.query.topn.DimensionTopNMetricSpec;
import org.apache.druid.query.topn.TopNQuery;
import org.apache.druid.query.topn.TopNQueryBuilder;
import org.apache.druid.query.topn.TopNQueryConfig;
import org.apache.druid.query.topn.TopNQueryQueryToolChest;
import org.apache.druid.query.topn.TopNQueryRunnerFactory;
import org.apache.druid.segment.IncrementalIndexSegment;
import org.apache.druid.segment.IndexIO;
import org.apache.druid.segment.IndexMergerV9;
import org.apache.druid.segment.IndexSpec;
import org.apache.druid.segment.QueryableIndex;
import org.apache.druid.segment.QueryableIndexSegment;
import org.apache.druid.segment.column.ColumnConfig;
import org.apache.druid.segment.generator.DataGenerator;
import org.apache.druid.segment.generator.GeneratorBasicSchemas;
import org.apache.druid.segment.generator.GeneratorSchemaInfo;
import org.apache.druid.segment.incremental.IncrementalIndex;
import org.apache.druid.segment.serde.ComplexMetrics;
import org.apache.druid.segment.writeout.OffHeapMemorySegmentWriteOutMediumFactory;
import org.apache.druid.timeline.SegmentId;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;
import org.openjdk.jmh.annotations.Fork;
import org.openjdk.jmh.annotations.Measurement;
import org.openjdk.jmh.annotations.Mode;
import org.openjdk.jmh.annotations.OutputTimeUnit;
import org.openjdk.jmh.annotations.Param;
import org.openjdk.jmh.annotations.Scope;
import org.openjdk.jmh.annotations.Setup;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.annotations.TearDown;
import org.openjdk.jmh.annotations.Warmup;
import org.openjdk.jmh.infra.Blackhole;

@Warmup(iterations = 10)
@State(Scope.Benchmark)
@Measurement(iterations = 25)
@Fork(1)
/* loaded from: input_file:org/apache/druid/benchmark/query/TopNBenchmark.class */
public class TopNBenchmark {

    @Param({"1"})
    private int numSegments;

    @Param({"750000"})
    private int rowsPerSegment;

    @Param({"basic.A", "basic.numericSort", "basic.alphanumericSort"})
    private String schemaAndQuery;

    @Param({"10"})
    private int threshold;
    private static final Logger log = new Logger(TopNBenchmark.class);
    private static final int RNG_SEED = 9999;
    private static final IndexMergerV9 INDEX_MERGER_V9;
    private static final IndexIO INDEX_IO;
    public static final ObjectMapper JSON_MAPPER;
    private List<IncrementalIndex> incIndexes;
    private List<QueryableIndex> qIndexes;
    private QueryRunnerFactory factory;
    private GeneratorSchemaInfo schemaInfo;
    private TopNQueryBuilder queryBuilder;
    private TopNQuery query;
    private File tmpDir;
    private ExecutorService executorService;
    private static final Map<String, Map<String, TopNQueryBuilder>> SCHEMA_QUERY_MAP;

    private void setupQueries() {
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        GeneratorSchemaInfo generatorSchemaInfo = GeneratorBasicSchemas.SCHEMA_MAP.get("basic");
        MultipleIntervalSegmentSpec multipleIntervalSegmentSpec = new MultipleIntervalSegmentSpec(Collections.singletonList(generatorSchemaInfo.getDataInterval()));
        ArrayList arrayList = new ArrayList();
        arrayList.add(new LongSumAggregatorFactory("sumLongSequential", "sumLongSequential"));
        arrayList.add(new LongMaxAggregatorFactory("maxLongUniform", "maxLongUniform"));
        arrayList.add(new DoubleSumAggregatorFactory("sumFloatNormal", "sumFloatNormal"));
        arrayList.add(new DoubleMinAggregatorFactory("minFloatZipf", "minFloatZipf"));
        arrayList.add(new HyperUniquesAggregatorFactory("hyperUniquesMet", "hyper"));
        linkedHashMap.put("A", new TopNQueryBuilder().dataSource("blah").granularity(Granularities.ALL).dimension("dimSequential").metric("sumFloatNormal").intervals(multipleIntervalSegmentSpec).aggregators(arrayList));
        MultipleIntervalSegmentSpec multipleIntervalSegmentSpec2 = new MultipleIntervalSegmentSpec(Collections.singletonList(generatorSchemaInfo.getDataInterval()));
        ArrayList arrayList2 = new ArrayList();
        arrayList2.add(new LongSumAggregatorFactory("sumLongSequential", "sumLongSequential"));
        linkedHashMap.put("numericSort", new TopNQueryBuilder().dataSource("blah").granularity(Granularities.ALL).dimension("dimUniform").metric(new DimensionTopNMetricSpec(null, StringComparators.NUMERIC)).intervals(multipleIntervalSegmentSpec2).aggregators(arrayList2));
        MultipleIntervalSegmentSpec multipleIntervalSegmentSpec3 = new MultipleIntervalSegmentSpec(Collections.singletonList(generatorSchemaInfo.getDataInterval()));
        ArrayList arrayList3 = new ArrayList();
        arrayList3.add(new LongSumAggregatorFactory("sumLongSequential", "sumLongSequential"));
        linkedHashMap.put("alphanumericSort", new TopNQueryBuilder().dataSource("blah").granularity(Granularities.ALL).dimension("dimUniform").metric(new DimensionTopNMetricSpec(null, StringComparators.ALPHANUMERIC)).intervals(multipleIntervalSegmentSpec3).aggregators(arrayList3));
        SCHEMA_QUERY_MAP.put("basic", linkedHashMap);
    }

    @Setup
    public void setup() throws IOException {
        log.info("SETUP CALLED AT " + System.currentTimeMillis(), new Object[0]);
        ComplexMetrics.registerSerde("hyperUnique", new HyperUniquesSerde());
        this.executorService = Execs.multiThreaded(this.numSegments, "TopNThreadPool");
        setupQueries();
        String[] split = this.schemaAndQuery.split("\\.");
        String str = split[0];
        String str2 = split[1];
        this.schemaInfo = GeneratorBasicSchemas.SCHEMA_MAP.get(str);
        this.queryBuilder = SCHEMA_QUERY_MAP.get(str).get(str2);
        this.queryBuilder.threshold(this.threshold);
        this.query = this.queryBuilder.build();
        this.incIndexes = new ArrayList();
        for (int i = 0; i < this.numSegments; i++) {
            log.info("Generating rows for segment " + i, new Object[0]);
            DataGenerator dataGenerator = new DataGenerator(this.schemaInfo.getColumnSchemas(), RNG_SEED + i, this.schemaInfo.getDataInterval(), this.rowsPerSegment);
            IncrementalIndex makeIncIndex = makeIncIndex();
            for (int i2 = 0; i2 < this.rowsPerSegment; i2++) {
                InputRow nextRow = dataGenerator.nextRow();
                if (i2 % 10000 == 0) {
                    log.info(i2 + " rows generated.", new Object[0]);
                }
                makeIncIndex.add(nextRow);
            }
            this.incIndexes.add(makeIncIndex);
        }
        this.tmpDir = FileUtils.createTempDir();
        log.info("Using temp dir: " + this.tmpDir.getAbsolutePath(), new Object[0]);
        this.qIndexes = new ArrayList();
        for (int i3 = 0; i3 < this.numSegments; i3++) {
            this.qIndexes.add(INDEX_IO.loadIndex(INDEX_MERGER_V9.persist(this.incIndexes.get(i3), this.tmpDir, new IndexSpec(), null)));
        }
        this.factory = new TopNQueryRunnerFactory(new StupidPool("TopNBenchmark-compute-bufferPool", new OffheapBufferGenerator("compute", 250000000), 0, Integer.MAX_VALUE), new TopNQueryQueryToolChest(new TopNQueryConfig()), QueryBenchmarkUtil.NOOP_QUERYWATCHER);
    }

    @TearDown
    public void tearDown() throws IOException {
        FileUtils.deleteDirectory(this.tmpDir);
    }

    private IncrementalIndex makeIncIndex() {
        return new IncrementalIndex.Builder().setSimpleTestingIndexSchema(this.schemaInfo.getAggsArray()).setMaxRowCount(this.rowsPerSegment).buildOnheap();
    }

    private static <T> List<T> runQuery(QueryRunnerFactory queryRunnerFactory, QueryRunner queryRunner, Query<T> query) {
        QueryToolChest toolchest = queryRunnerFactory.getToolchest();
        return new FinalizeResultsQueryRunner(toolchest.mergeResults(toolchest.preMergeQueryDecoration(queryRunner)), toolchest).run(QueryPlus.wrap(query), ResponseContext.createEmpty()).toList();
    }

    @Benchmark
    @OutputTimeUnit(TimeUnit.MICROSECONDS)
    @BenchmarkMode({Mode.AverageTime})
    public void querySingleIncrementalIndex(Blackhole blackhole) {
        blackhole.consume(runQuery(this.factory, QueryBenchmarkUtil.makeQueryRunner(this.factory, SegmentId.dummy("incIndex"), new IncrementalIndexSegment(this.incIndexes.get(0), SegmentId.dummy("incIndex"))), this.query));
    }

    @Benchmark
    @OutputTimeUnit(TimeUnit.MICROSECONDS)
    @BenchmarkMode({Mode.AverageTime})
    public void querySingleQueryableIndex(Blackhole blackhole) {
        blackhole.consume(runQuery(this.factory, QueryBenchmarkUtil.makeQueryRunner(this.factory, SegmentId.dummy("qIndex"), new QueryableIndexSegment(this.qIndexes.get(0), SegmentId.dummy("qIndex"))), this.query));
    }

    @Benchmark
    @OutputTimeUnit(TimeUnit.MICROSECONDS)
    @BenchmarkMode({Mode.AverageTime})
    public void queryMultiQueryableIndex(Blackhole blackhole) {
        ArrayList arrayList = new ArrayList();
        QueryToolChest toolchest = this.factory.getToolchest();
        for (int i = 0; i < this.numSegments; i++) {
            SegmentId dummy = SegmentId.dummy("qIndex " + i);
            arrayList.add(toolchest.preMergeQueryDecoration(QueryBenchmarkUtil.makeQueryRunner(this.factory, dummy, new QueryableIndexSegment(this.qIndexes.get(i), dummy))));
        }
        blackhole.consume(toolchest.postMergeQueryDecoration(new FinalizeResultsQueryRunner(toolchest.mergeResults(this.factory.mergeRunners(this.executorService, arrayList)), toolchest)).run(QueryPlus.wrap(this.query), ResponseContext.createEmpty()).toList());
    }

    static {
        NullHandling.initializeForTests();
        JSON_MAPPER = new DefaultObjectMapper();
        INDEX_IO = new IndexIO(JSON_MAPPER, new ColumnConfig() { // from class: org.apache.druid.benchmark.query.TopNBenchmark.1
            @Override // org.apache.druid.segment.column.ColumnConfig
            public int columnCacheSizeBytes() {
                return 0;
            }
        });
        INDEX_MERGER_V9 = new IndexMergerV9(JSON_MAPPER, INDEX_IO, OffHeapMemorySegmentWriteOutMediumFactory.instance());
        SCHEMA_QUERY_MAP = new LinkedHashMap();
    }
}
