From c9cde27c5163a3054a37735484c988e565093593 Mon Sep 17 00:00:00 2001 From: Nikola Grcevski Date: Mon, 16 Dec 2019 16:59:41 -0500 Subject: [PATCH] Refactor results trimming to avoid rebuilding the maps. --- .../core/operator/CombineGroupByOperator.java | 48 +++++++------- .../AggregationGroupByTrimmingService.java | 36 +++++++++++ .../queries/BaseMultiValueQueriesTest.java | 2 +- ...gmentAggregationMultiValueQueriesTest.java | 28 ++++++++- ...AggregationGroupByTrimmingServiceTest.java | 63 +++++++++++++++---- 5 files changed, 140 insertions(+), 37 deletions(-) diff --git a/pinot-core/src/main/java/org/apache/pinot/core/operator/CombineGroupByOperator.java b/pinot-core/src/main/java/org/apache/pinot/core/operator/CombineGroupByOperator.java index 794dc197fc7d..5c6b3e481bba 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/operator/CombineGroupByOperator.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/operator/CombineGroupByOperator.java @@ -103,8 +103,7 @@ public CombineGroupByOperator(List operators, BrokerRequest brokerRequ */ @Override protected IntermediateResultsBlock getNextBlock() { - ConcurrentHashMap resultsMap = new ConcurrentHashMap<>(); - AtomicInteger numGroups = new AtomicInteger(); + List> results = new ArrayList<>(); ConcurrentLinkedQueue mergedProcessingExceptions = new ConcurrentLinkedQueue<>(); AggregationFunctionContext[] aggregationFunctionContexts = @@ -113,6 +112,7 @@ protected IntermediateResultsBlock getNextBlock() { AggregationFunction[] aggregationFunctions = new AggregationFunction[numAggregationFunctions]; for (int i = 0; i < numAggregationFunctions; i++) { aggregationFunctions[i] = aggregationFunctionContexts[i].getAggregationFunction(); + results.add(new ConcurrentHashMap<>(1000, 0.2f, 1000)); } // We use a CountDownLatch to track if all Futures are finished by the query timeout, and cancel the unfinished @@ -154,26 +154,27 @@ public void runJob() { // Merge aggregation group-by result. AggregationGroupByResult aggregationGroupByResult = intermediateResultsBlock.getAggregationGroupByResult(); if (aggregationGroupByResult != null) { - // Iterate over the group-by keys, for each key, update the group-by result in the resultsMap. - Iterator groupKeyIterator = aggregationGroupByResult.getGroupKeyIterator(); - while (groupKeyIterator.hasNext()) { - GroupKeyGenerator.GroupKey groupKey = groupKeyIterator.next(); - resultsMap.compute(groupKey._stringKey, (key, value) -> { - if (value == null) { - if (numGroups.getAndIncrement() < _interSegmentNumGroupsLimit) { - value = new Object[numAggregationFunctions]; - for (int i = 0; i < numAggregationFunctions; i++) { - value[i] = aggregationGroupByResult.getResultForKey(groupKey, i); + int index = 0; + for (Map resultsMap : results) { + // Iterate over the group-by keys, for each key, update the group-by result in the resultsMap. + Iterator groupKeyIterator = aggregationGroupByResult.getGroupKeyIterator(); + final int i = index; + AtomicInteger numGroups = new AtomicInteger(); + while (groupKeyIterator.hasNext()) { + GroupKeyGenerator.GroupKey groupKey = groupKeyIterator.next(); + resultsMap.compute(groupKey._stringKey, (key, value) -> { + if (value == null) { + if (numGroups.getAndIncrement() < _interSegmentNumGroupsLimit) { + return aggregationGroupByResult.getResultForKey(groupKey, i); } + } else { + return aggregationFunctions[i].merge(value, aggregationGroupByResult.getResultForKey(groupKey, i)); } - } else { - for (int i = 0; i < numAggregationFunctions; i++) { - value[i] = aggregationFunctions[i] - .merge(value[i], aggregationGroupByResult.getResultForKey(groupKey, i)); - } - } - return value; - }); + return value; + }); + } + + index ++; } } } catch (Exception e) { @@ -200,8 +201,11 @@ public void runJob() { // Trim the results map. AggregationGroupByTrimmingService aggregationGroupByTrimmingService = new AggregationGroupByTrimmingService(aggregationFunctions, (int) _brokerRequest.getGroupBy().getTopN()); + + int resultSize = (numAggregationFunctions == 0) ? 0 : results.get(0).size(); + List> trimmedResults = - aggregationGroupByTrimmingService.trimIntermediateResultsMap(resultsMap); + aggregationGroupByTrimmingService.trimIntermediateResults(results); IntermediateResultsBlock mergedBlock = new IntermediateResultsBlock(aggregationFunctionContexts, trimmedResults, true); @@ -227,7 +231,7 @@ public void runJob() { // TODO: this value should be set in the inner-segment operators. Setting it here might cause false positive as we // are comparing number of groups across segments with the groups limit for each segment. - if (resultsMap.size() >= _innerSegmentNumGroupsLimit) { + if (resultSize >= _innerSegmentNumGroupsLimit) { mergedBlock.setNumGroupsLimitReached(true); } diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/groupby/AggregationGroupByTrimmingService.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/groupby/AggregationGroupByTrimmingService.java index a16e344cbdc2..c7248dc0924c 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/groupby/AggregationGroupByTrimmingService.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/groupby/AggregationGroupByTrimmingService.java @@ -115,6 +115,42 @@ public List> trimIntermediateResultsMap(@Nonnull Map> trimIntermediateResults(@Nonnull List> intermediateResults) { + int numAggregationFunctions = _aggregationFunctions.length; + + if (intermediateResults.size() == 0) { + return intermediateResults; + } + + int numGroups = intermediateResults.get(0).size(); + if (numGroups > _trimThreshold) { + List> trimmedResultMaps = new ArrayList<>(numAggregationFunctions); + + // Trim the result only if number of groups is larger than the threshold + Sorter[] sorters = new Sorter[numAggregationFunctions]; + for (int i = 0; i < numAggregationFunctions; i++) { + AggregationFunction aggregationFunction = _aggregationFunctions[i]; + Sorter sorter = getSorter(_trimSize, aggregationFunction, aggregationFunction.isIntermediateResultComparable()); + for (Map.Entry entry : intermediateResults.get(i).entrySet()) { + sorter.add(entry.getKey(), entry.getValue()); + } + sorters[i] = sorter; + } + + // Dump trimmed results into maps + for (int i = 0; i < numAggregationFunctions; i++) { + Map trimmedResultMap = new HashMap<>(_trimSize); + sorters[i].dumpToMap(trimmedResultMap); + trimmedResultMaps.add(trimmedResultMap); + } + + return trimmedResultMaps; + } + + return intermediateResults; + } + /** * Given an array of maps from group key to final result for each aggregation function, trim the results to topN size. */ diff --git a/pinot-core/src/test/java/org/apache/pinot/queries/BaseMultiValueQueriesTest.java b/pinot-core/src/test/java/org/apache/pinot/queries/BaseMultiValueQueriesTest.java index 9fc8d6c4cf07..02ec9b4cdfc7 100644 --- a/pinot-core/src/test/java/org/apache/pinot/queries/BaseMultiValueQueriesTest.java +++ b/pinot-core/src/test/java/org/apache/pinot/queries/BaseMultiValueQueriesTest.java @@ -61,7 +61,7 @@ * */ public abstract class BaseMultiValueQueriesTest extends BaseQueriesTest { - private static final String AVRO_DATA = "data" + File.separator + "test_data-mv.avro"; + private static final String AVRO_DATA = "data/test_data-mv.avro"; private static final String SEGMENT_NAME = "testTable_1756015683_1756015683"; private static final File INDEX_DIR = new File(FileUtils.getTempDirectory(), "MultiValueQueriesTest"); diff --git a/pinot-core/src/test/java/org/apache/pinot/queries/InterSegmentAggregationMultiValueQueriesTest.java b/pinot-core/src/test/java/org/apache/pinot/queries/InterSegmentAggregationMultiValueQueriesTest.java index 98855b1335a7..f5aa63f8af85 100644 --- a/pinot-core/src/test/java/org/apache/pinot/queries/InterSegmentAggregationMultiValueQueriesTest.java +++ b/pinot-core/src/test/java/org/apache/pinot/queries/InterSegmentAggregationMultiValueQueriesTest.java @@ -19,15 +19,19 @@ package org.apache.pinot.queries; import java.io.Serializable; +import java.util.Iterator; +import java.util.List; import java.util.function.Function; + +import org.apache.pinot.common.response.broker.AggregationResult; import org.apache.pinot.common.response.broker.BrokerResponseNative; +import org.apache.pinot.common.response.broker.GroupByResult; import org.apache.pinot.spi.utils.BytesUtils; import org.apache.pinot.core.plan.maker.InstancePlanMakerImplV2; import org.apache.pinot.core.startree.hll.HllUtil; import org.testng.annotations.Test; -import static org.testng.Assert.assertFalse; -import static org.testng.Assert.assertTrue; +import static org.testng.Assert.*; public class InterSegmentAggregationMultiValueQueriesTest extends BaseMultiValueQueriesTest { @@ -407,4 +411,24 @@ public void testNumGroupsLimit() { brokerResponse = getBrokerResponseForQuery(query, new InstancePlanMakerImplV2(1000, 1000)); assertTrue(brokerResponse.isNumGroupsLimitReached()); } + + @Test + public void testNumGroupsMultiLimit() { + String query = "SELECT COUNT(*), SUM(column1) FROM testTable GROUP BY column7"; + + BrokerResponseNative brokerResponse = getBrokerResponseForQuery(query); + assertFalse(brokerResponse.isNumGroupsLimitReached()); + + List results = brokerResponse.getAggregationResults(); + Iterator resultsIter = results.get(0).getGroupByResult().iterator(); + assertEquals(resultsIter.next().getValue(), "199756"); + assertEquals(resultsIter.next().getValue(), "29944"); + + resultsIter = results.get(1).getGroupByResult().iterator(); + assertEquals(resultsIter.next().getValue(), "190754303720564.00000"); + assertEquals(resultsIter.next().getValue(), "31917445702108.00000"); + + brokerResponse = getBrokerResponseForQuery(query, new InstancePlanMakerImplV2(5, 5)); + assertTrue(brokerResponse.isNumGroupsLimitReached()); + } } diff --git a/pinot-core/src/test/java/org/apache/pinot/query/aggregation/groupby/AggregationGroupByTrimmingServiceTest.java b/pinot-core/src/test/java/org/apache/pinot/query/aggregation/groupby/AggregationGroupByTrimmingServiceTest.java index ce4a1b7d478d..15a5e1f749b1 100644 --- a/pinot-core/src/test/java/org/apache/pinot/query/aggregation/groupby/AggregationGroupByTrimmingServiceTest.java +++ b/pinot-core/src/test/java/org/apache/pinot/query/aggregation/groupby/AggregationGroupByTrimmingServiceTest.java @@ -81,6 +81,21 @@ public void setUp() { _trimmingService = new AggregationGroupByTrimmingService(AGGREGATION_FUNCTIONS, GROUP_BY_TOP_N); } + private int checkTrimmedResults(List> trimmedIntermediateResultMaps) { + Map trimmedSumResultMap = trimmedIntermediateResultMaps.get(0); + Map trimmedDistinctCountResultMap = trimmedIntermediateResultMaps.get(1); + int trimSize = trimmedSumResultMap.size(); + Assert.assertEquals(trimmedDistinctCountResultMap.size(), trimSize, ERROR_MESSAGE); + for (int i = NUM_GROUPS - trimSize; i < NUM_GROUPS; i++) { + String group = _groups.get(i); + Assert.assertEquals(((Double) trimmedSumResultMap.get(group)).intValue(), i, ERROR_MESSAGE); + Assert.assertEquals(((IntOpenHashSet) trimmedDistinctCountResultMap.get(group)).size(), + i / (NUM_GROUPS / MAX_SIZE_OF_SET) + 1, ERROR_MESSAGE); + } + + return trimSize; + } + @SuppressWarnings("unchecked") @Test public void testTrimming() { @@ -95,24 +110,15 @@ public void testTrimming() { } List> trimmedIntermediateResultMaps = _trimmingService.trimIntermediateResultsMap(intermediateResultsMap); - Map trimmedSumResultMap = trimmedIntermediateResultMaps.get(0); - Map trimmedDistinctCountResultMap = trimmedIntermediateResultMaps.get(1); - int trimSize = trimmedSumResultMap.size(); - Assert.assertEquals(trimmedDistinctCountResultMap.size(), trimSize, ERROR_MESSAGE); - for (int i = NUM_GROUPS - trimSize; i < NUM_GROUPS; i++) { - String group = _groups.get(i); - Assert.assertEquals(((Double) trimmedSumResultMap.get(group)).intValue(), i, ERROR_MESSAGE); - Assert.assertEquals(((IntOpenHashSet) trimmedDistinctCountResultMap.get(group)).size(), - i / (NUM_GROUPS / MAX_SIZE_OF_SET) + 1, ERROR_MESSAGE); - } + int trimSize = checkTrimmedResults(trimmedIntermediateResultMaps); // Test Broker side trimming Map finalDistinctCountResultMap = new HashMap<>(trimSize); - for (Map.Entry entry : trimmedDistinctCountResultMap.entrySet()) { + for (Map.Entry entry : trimmedIntermediateResultMaps.get(1).entrySet()) { finalDistinctCountResultMap.put(entry.getKey(), ((IntOpenHashSet) entry.getValue()).size()); } List[] groupByResultLists = - _trimmingService.trimFinalResults(new Map[]{trimmedSumResultMap, finalDistinctCountResultMap}); + _trimmingService.trimFinalResults(new Map[]{trimmedIntermediateResultMaps.get(0), finalDistinctCountResultMap}); List sumGroupByResultList = groupByResultLists[0]; List distinctCountGroupByResultList = groupByResultLists[1]; for (int i = 0; i < GROUP_BY_TOP_N; i++) { @@ -132,6 +138,39 @@ public void testTrimming() { } } + @Test + public void testTrimmingResults() { + // Test Server side trimming + List> intermediateResults = new ArrayList<>(); + for (int i = 0; i < 2; i++) { + intermediateResults.add(new HashMap<>(NUM_GROUPS)); + } + + for (int i = 0; i < NUM_GROUPS; i++) { + IntOpenHashSet set = new IntOpenHashSet(); + for (int j = 0; j <= i; j += NUM_GROUPS / MAX_SIZE_OF_SET) { + set.add(j); + } + intermediateResults.get(0).put(_groups.get(i), (double) i); + intermediateResults.get(1).put(_groups.get(i), set); + } + + List> trimmedIntermediateResultMaps = + _trimmingService.trimIntermediateResults(intermediateResults); + Map trimmedSumResultMap = trimmedIntermediateResultMaps.get(0); + Map trimmedDistinctCountResultMap = trimmedIntermediateResultMaps.get(1); + int trimSize = trimmedSumResultMap.size(); + Assert.assertEquals(trimmedDistinctCountResultMap.size(), trimSize, ERROR_MESSAGE); + for (int i = NUM_GROUPS - trimSize; i < NUM_GROUPS; i++) { + String group = _groups.get(i); + Assert.assertEquals(((Double) trimmedSumResultMap.get(group)).intValue(), i, ERROR_MESSAGE); + Assert.assertEquals(((IntOpenHashSet) trimmedDistinctCountResultMap.get(group)).size(), + i / (NUM_GROUPS / MAX_SIZE_OF_SET) + 1, ERROR_MESSAGE); + } + + checkTrimmedResults(trimmedIntermediateResultMaps); + } + private static String buildGroupString(List group) { StringBuilder groupStringBuilder = new StringBuilder(); for (int i = 0; i < NUM_GROUP_KEYS; i++) {