Skip to content

Commit f0bed3a

Browse files
committed
support count pushdown. tpcds 14/ tpch13
ds14 增加了agg push,执行时间 4.7 -> 4.8 h13 增加了 agg push,应该让p6 恢复到 p4 的成绩,从10sec 恢复到 7 sec
1 parent 43a1f70 commit f0bed3a

File tree

10 files changed

+612
-123
lines changed

10 files changed

+612
-123
lines changed

fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/eageraggregation/PushDownAggregation.java

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,12 @@
4141
import org.apache.doris.nereids.trees.expressions.NamedExpression;
4242
import org.apache.doris.nereids.trees.expressions.Slot;
4343
import org.apache.doris.nereids.trees.expressions.SlotReference;
44+
import org.apache.doris.nereids.trees.expressions.functions.Function;
4445
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
46+
import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
4547
import org.apache.doris.nereids.trees.expressions.functions.agg.Max;
4648
import org.apache.doris.nereids.trees.expressions.functions.agg.Min;
49+
import org.apache.doris.nereids.trees.expressions.functions.agg.RollUpTrait;
4750
import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
4851
import org.apache.doris.nereids.trees.expressions.functions.scalar.If;
4952
import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
@@ -82,6 +85,7 @@ public class PushDownAggregation extends DefaultPlanRewriter<JobContext> impleme
8285
public final EagerAggRewriter writer = new EagerAggRewriter();
8386

8487
private final Set<Class> pushDownAggFunctionSet = Sets.newHashSet(
88+
Count.class,
8589
Sum.class,
8690
Max.class,
8791
Min.class);
@@ -148,7 +152,7 @@ public Plan visitLogicalAggregate(LogicalAggregate<? extends Plan> agg, JobConte
148152
AggregateFunction aggFunction = (AggregateFunction) obj;
149153
if (pushDownAggFunctionSet.contains(aggFunction.getClass())
150154
&& !aggFunction.isDistinct()) {
151-
if (aggFunction.child(0) instanceof If) {
155+
if (aggFunction.arity() > 0 && aggFunction.child(0) instanceof If) {
152156
If body = (If) (aggFunction).child(0);
153157
Set<Slot> valueSlots = Sets.newHashSet(body.getTrueValue().getInputSlots());
154158
valueSlots.addAll(body.getFalseValue().getInputSlots());
@@ -226,10 +230,20 @@ public Plan visitLogicalAggregate(LogicalAggregate<? extends Plan> agg, JobConte
226230
// -> T2 [...]
227231
// for min(A), replaceMap: A->minA
228232
// for sum(A), replaceMap: A->sumA
229-
Map<Expression, Slot> replaceMap = new HashMap<>();
233+
// for count(A), replaceMap: count(A)->sum(countA), because count needs rollup to sum
234+
Map<Expression, Expression> replaceMap = new HashMap<>();
230235
List<AggregateFunction> relatedAggFunc = aggFunctionsForOutputExpressions.get(ne);
231236
for (AggregateFunction func : relatedAggFunc) {
232-
replaceMap.put(func.child(0), pushDownContext.getAliasMap().get(func).toSlot());
237+
Slot pushedDownSlot = pushDownContext.getAliasMap().get(func).toSlot();
238+
if (func instanceof Count) {
239+
// For count(A), after pushdown we have count(A) as x,
240+
// and the top agg should use sum(x) instead of count(x)
241+
Function rollUpFunc = ((RollUpTrait) func).constructRollUp(pushedDownSlot);
242+
replaceMap.put(func, rollUpFunc);
243+
} else if (func.arity() > 0) {
244+
// For sum/max/min, replace the child expression with the pushed down slot
245+
replaceMap.put(func.child(0), pushedDownSlot);
246+
}
233247
}
234248
NamedExpression replaceAliasExpr = (NamedExpression) ExpressionUtils.replace(ne, replaceMap);
235249
replaceAliasExpr = (NamedExpression) ExpressionUtils.rebuildSignature(replaceAliasExpr);

fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2233,6 +2233,10 @@ public static int getEagerAggregationMode() {
22332233
}
22342234
}
22352235

2236+
public void setEagerAggregationMode(int mode) {
2237+
this.eagerAggregationMode = mode;
2238+
}
2239+
22362240
@VariableMgr.VarAttr(name = "eager_aggregation_on_join", needForward = true)
22372241
public boolean eagerAggregationOnJoin = false;
22382242

fe/fe-core/src/main/java/org/apache/doris/qe/StmtExecutor.java

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -771,9 +771,7 @@ private void executeByNereids(TUniqueId queryId) throws Exception {
771771
new AnalysisException(e.getMessage(), e));
772772
} catch (Exception | Error e) {
773773
// Maybe our bug
774-
if (LOG.isDebugEnabled()) {
775-
LOG.debug("Command({}) process failed.", originStmt.originStmt, e);
776-
}
774+
LOG.info("Command({}) process failed.", originStmt.originStmt, e);
777775
context.getState().setError(ErrorCode.ERR_UNKNOWN_ERROR, e.getMessage());
778776
throw new NereidsException("Command (" + originStmt.originStmt + ") process failed.",
779777
new AnalysisException(e.getMessage() == null ? e.toString() : e.getMessage(), e));
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
package org.apache.doris.nereids.rules.rewrite.eageraggregation;
19+
20+
import org.apache.doris.nereids.util.MemoPatternMatchSupported;
21+
import org.apache.doris.nereids.util.PlanChecker;
22+
import org.apache.doris.utframe.TestWithFeService;
23+
24+
import org.junit.jupiter.api.Test;
25+
26+
class EagerAggRewriterTest extends TestWithFeService implements MemoPatternMatchSupported {
27+
@Override
28+
protected void runBeforeAll() throws Exception {
29+
createDatabase("test");
30+
connectContext.setDatabase("default_cluster:test");
31+
createTables(
32+
"CREATE TABLE IF NOT EXISTS t1 (\n"
33+
+ " id1 int not null,\n"
34+
+ " name varchar(20)\n"
35+
+ ")\n"
36+
+ "DUPLICATE KEY(id1)\n"
37+
+ "DISTRIBUTED BY HASH(id1) BUCKETS 10\n"
38+
+ "PROPERTIES (\"replication_num\" = \"1\")\n",
39+
"CREATE TABLE IF NOT EXISTS t2 (\n"
40+
+ " id2 int not null,\n"
41+
+ " name varchar(20)\n"
42+
+ ")\n"
43+
+ "DUPLICATE KEY(id2)\n"
44+
+ "DISTRIBUTED BY HASH(id2) BUCKETS 10\n"
45+
+ "PROPERTIES (\"replication_num\" = \"1\")\n"
46+
);
47+
connectContext.getSessionVariable().setDisableNereidsRules("PRUNE_EMPTY_PARTITION");
48+
}
49+
50+
@Test
51+
void testNotPushAggCaseWhenToNullableSideOfOuterJoin() {
52+
connectContext.getSessionVariable().setEagerAggregationMode(1);
53+
connectContext.getSessionVariable().setDisableJoinReorder(true);
54+
try {
55+
// RIGHT JOIN: agg function (case-when) references left side columns,
56+
// left side is nullable, should NOT be pushed below the join
57+
String sql = "select max(case when t1.name is not null then 'aaa' end) from t1 right join t2 on t1.id1 = t2.id2"
58+
+ " group by t1.id1";
59+
PlanChecker.from(connectContext)
60+
.analyze(sql)
61+
.rewrite()
62+
.nonMatch(logicalJoin(logicalAggregate(), any()))
63+
.printlnTree();
64+
65+
// LEFT JOIN: agg function(case-when) references right side columns,
66+
// right side is nullable, should NOT be pushed below the join
67+
sql = "select max(case when t2.name is null then 'xxx' end) from t1 left join t2"
68+
+ " on t1.id1 = t2.id2 group by t1.id1";
69+
PlanChecker.from(connectContext)
70+
.analyze(sql)
71+
.rewrite()
72+
.nonMatch(logicalJoin(any(), logicalAggregate()))
73+
.printlnTree();
74+
// RIGHT JOIN: agg function (not-case-when) references left side columns,
75+
// left side is nullable, can be pushed below the join
76+
sql = "select max(t2.name) from t1 left join t2"
77+
+ " on t1.id1 = t2.id2 group by t1.id1";
78+
PlanChecker.from(connectContext)
79+
.analyze(sql)
80+
.rewrite()
81+
.matches(logicalJoin(any(), logicalAggregate()))
82+
.printlnTree();
83+
} finally {
84+
connectContext.getSessionVariable().setEagerAggregationMode(0);
85+
}
86+
}
87+
88+
@Test
89+
void testPushDownCount() {
90+
// Test count pushdown: count(a) should be pushed down and
91+
// the top aggregation should use sum to aggregate the count results
92+
// Before: agg(count(name), groupby(id2))
93+
// -> join(t1.id1=t2.id2)
94+
// -> t1(id1, name)
95+
// -> t2(id2)
96+
// After: agg(sum(x), groupby(id2))
97+
// -> join(t1.id1=t2.id2)
98+
// -> agg(count(name) as x, groupby(id1))
99+
// -> t1(id1, name)
100+
// -> t2(id2)
101+
connectContext.getSessionVariable().setEagerAggregationMode(1);
102+
try {
103+
String sql = "select count(t1.name), t2.id2 from t1 join t2 on t1.id1 = t2.id2 group by t2.id2";
104+
PlanChecker.from(connectContext)
105+
.analyze(sql)
106+
.rewrite()
107+
.matches(logicalAggregate(logicalProject(logicalJoin(logicalAggregate(), any()))))
108+
.printlnTree();
109+
} finally {
110+
connectContext.getSessionVariable().setEagerAggregationMode(0);
111+
}
112+
}
113+
}

regression-test/data/nereids_p0/eager_agg/eager_agg.out

Lines changed: 158 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ UnUsed:
2424
SyntaxError:
2525

2626
-- !a_exe --
27+
2024 66.00 54.00
28+
2025 50.00 42.00
2729

2830
-- !a2 --
2931
PhysicalResultSink
@@ -50,6 +52,8 @@ UnUsed:
5052
SyntaxError:
5153

5254
-- !a2_exe --
55+
2024 120.00
56+
2025 92.00
5357

5458
-- !sum_min_max --
5559
PhysicalResultSink
@@ -76,8 +80,10 @@ UnUsed:
7680
SyntaxError:
7781

7882
-- !sum_min_max_exe --
83+
2024 66.00 11.00 16.00
84+
2025 50.00 20.00 22.00
7985

80-
-- !avg_count --
86+
-- !avg --
8187
PhysicalResultSink
8288
--hashAgg[GLOBAL]
8389
----PhysicalDistribute[DistributionSpecHash]
@@ -98,7 +104,146 @@ Used: leading({ ss ws } dt )
98104
UnUsed:
99105
SyntaxError:
100106

101-
-- !avg_count_exe --
107+
-- !avg_exe --
108+
2025 25.0000
109+
2024 16.5000
110+
111+
-- !count_column --
112+
PhysicalResultSink
113+
--hashAgg[GLOBAL]
114+
----PhysicalDistribute[DistributionSpecHash]
115+
------hashAgg[LOCAL]
116+
--------PhysicalProject
117+
----------hashJoin[INNER_JOIN shuffle] hashCondition=((dt.d_date_sk = ss.ss_sold_date_sk)) otherCondition=()
118+
------------PhysicalProject
119+
--------------hashJoin[INNER_JOIN shuffleBucket] hashCondition=((ss.ss_item_sk = ws.ws_item_sk)) otherCondition=()
120+
----------------PhysicalProject
121+
------------------PhysicalOlapScan[store_sales(ss)]
122+
----------------hashAgg[GLOBAL]
123+
------------------PhysicalDistribute[DistributionSpecHash]
124+
--------------------hashAgg[LOCAL]
125+
----------------------PhysicalProject
126+
------------------------PhysicalOlapScan[web_sales(ws)]
127+
------------PhysicalProject
128+
--------------PhysicalOlapScan[date_dim(dt)]
129+
130+
Hint log:
131+
Used: leading({ ss ws } dt )
132+
UnUsed:
133+
SyntaxError:
134+
135+
-- !count_column_exe --
136+
2024 4
137+
2025 2
138+
139+
-- !count_star --
140+
PhysicalResultSink
141+
--hashAgg[GLOBAL]
142+
----PhysicalDistribute[DistributionSpecHash]
143+
------hashAgg[LOCAL]
144+
--------PhysicalProject
145+
----------hashJoin[INNER_JOIN shuffle] hashCondition=((dt.d_date_sk = ss.ss_sold_date_sk)) otherCondition=()
146+
------------PhysicalProject
147+
--------------hashJoin[INNER_JOIN bucketShuffle] hashCondition=((ss.ss_item_sk = ws.ws_item_sk)) otherCondition=()
148+
----------------hashAgg[GLOBAL]
149+
------------------PhysicalDistribute[DistributionSpecHash]
150+
--------------------hashAgg[LOCAL]
151+
----------------------PhysicalProject
152+
------------------------PhysicalOlapScan[store_sales(ss)]
153+
----------------PhysicalProject
154+
------------------PhysicalOlapScan[web_sales(ws)]
155+
------------PhysicalProject
156+
--------------PhysicalOlapScan[date_dim(dt)]
157+
158+
Hint log:
159+
Used: leading({ ss ws } dt )
160+
UnUsed:
161+
SyntaxError:
162+
163+
-- !count_star_exe --
164+
2025 2
165+
2024 4
166+
167+
-- !count_distinct --
168+
PhysicalResultSink
169+
--hashAgg[GLOBAL]
170+
----PhysicalDistribute[DistributionSpecHash]
171+
------hashAgg[LOCAL]
172+
--------PhysicalProject
173+
----------hashJoin[INNER_JOIN shuffle] hashCondition=((dt.d_date_sk = ss.ss_sold_date_sk)) otherCondition=()
174+
------------PhysicalProject
175+
--------------hashJoin[INNER_JOIN shuffle] hashCondition=((ss.ss_item_sk = ws.ws_item_sk)) otherCondition=()
176+
----------------PhysicalProject
177+
------------------PhysicalOlapScan[store_sales(ss)]
178+
----------------PhysicalProject
179+
------------------PhysicalOlapScan[web_sales(ws)]
180+
------------PhysicalProject
181+
--------------PhysicalOlapScan[date_dim(dt)]
182+
183+
Hint log:
184+
Used: leading({ ss ws } dt )
185+
UnUsed:
186+
SyntaxError:
187+
188+
-- !count_distinct_exe --
189+
2024 2
190+
2025 1
191+
192+
-- !count_sum_mixed --
193+
PhysicalResultSink
194+
--hashAgg[GLOBAL]
195+
----PhysicalDistribute[DistributionSpecHash]
196+
------hashAgg[LOCAL]
197+
--------PhysicalProject
198+
----------hashJoin[INNER_JOIN bucketShuffle] hashCondition=((dt.d_date_sk = ss.ss_sold_date_sk)) otherCondition=()
199+
------------hashAgg[GLOBAL]
200+
--------------PhysicalDistribute[DistributionSpecHash]
201+
----------------hashAgg[LOCAL]
202+
------------------PhysicalProject
203+
--------------------hashJoin[INNER_JOIN shuffle] hashCondition=((ss.ss_item_sk = ws.ws_item_sk)) otherCondition=()
204+
----------------------PhysicalProject
205+
------------------------PhysicalOlapScan[store_sales(ss)]
206+
----------------------PhysicalProject
207+
------------------------PhysicalOlapScan[web_sales(ws)]
208+
------------PhysicalProject
209+
--------------PhysicalOlapScan[date_dim(dt)]
210+
211+
Hint log:
212+
Used: leading({ ss ws } dt )
213+
UnUsed:
214+
SyntaxError:
215+
216+
-- !count_sum_mixed_exe --
217+
2025 2 42.00
218+
2024 4 54.00
219+
220+
-- !count_star_sum_mixed --
221+
PhysicalResultSink
222+
--hashAgg[GLOBAL]
223+
----PhysicalDistribute[DistributionSpecHash]
224+
------hashAgg[LOCAL]
225+
--------PhysicalProject
226+
----------hashJoin[INNER_JOIN shuffle] hashCondition=((dt.d_date_sk = ss.ss_sold_date_sk)) otherCondition=()
227+
------------PhysicalProject
228+
--------------hashJoin[INNER_JOIN bucketShuffle] hashCondition=((ss.ss_item_sk = ws.ws_item_sk)) otherCondition=()
229+
----------------hashAgg[GLOBAL]
230+
------------------PhysicalDistribute[DistributionSpecHash]
231+
--------------------hashAgg[LOCAL]
232+
----------------------PhysicalProject
233+
------------------------PhysicalOlapScan[store_sales(ss)]
234+
----------------PhysicalProject
235+
------------------PhysicalOlapScan[web_sales(ws)]
236+
------------PhysicalProject
237+
--------------PhysicalOlapScan[date_dim(dt)]
238+
239+
Hint log:
240+
Used: leading({ ss ws } dt )
241+
UnUsed:
242+
SyntaxError:
243+
244+
-- !count_star_sum_mixed_exe --
245+
2024 4 54.00
246+
2025 2 42.00
102247

103248
-- !groupkey_push_SS_JOIN_D --
104249
PhysicalResultSink
@@ -126,6 +271,11 @@ UnUsed:
126271
SyntaxError:
127272

128273
-- !groupkey_push_SS_JOIN_D_exe --
274+
2024 10.00 12.00
275+
2024 15.00 17.00
276+
2024 25.00 29.00
277+
2025 18.00 21.00
278+
2025 20.00 23.00
129279

130280
-- !groupkey_push --
131281
PhysicalResultSink
@@ -153,6 +303,10 @@ UnUsed:
153303
SyntaxError:
154304

155305
-- !groupkey_push_exe --
306+
2024 20.00 22.00
307+
2024 30.00 32.00
308+
2025 18.00 20.00
309+
2025 20.00 22.00
156310

157311
-- !sum_if_push --
158312
PhysicalResultSink
@@ -178,7 +332,8 @@ UnUsed:
178332
SyntaxError:
179333

180334
-- !sum_if_push_exe --
181-
1 \N \N \N \N \N \N
335+
1 30.50 \N \N \N \N \N
336+
2 \N 22.00 \N \N \N \N
182337

183338
-- !check_nullable --
184339
PhysicalResultSink

0 commit comments

Comments
 (0)