Skip to content

Commit fecfc24

Browse files
author
Atticus Kuhn
committed
add optimistaions
1 parent 0bb6614 commit fecfc24

File tree

7 files changed

+126
-51
lines changed

7 files changed

+126
-51
lines changed

PartIiProject/Optimisations/Term2Utils.lean

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,26 @@ def liftSubst2 {ctx ctx' : List Ty} {a b : Ty} (σ : Subst ctx ctx') :
7878
Subst (a :: b :: ctx) (a :: b :: ctx') :=
7979
liftSubst (a := a) (liftSubst (a := b) σ)
8080

81+
mutual
82+
/-- A "dummy" term of a given type, intended only for unreachable substitution branches. -/
83+
def defaultTerm2 {ctx : List Ty} : {ty : Ty} → Term2 ctx ty
84+
| .bool => .constBool false
85+
| .int => .constInt 0
86+
| .real => .constReal 0.0
87+
| .maxProduct => .promote (.mk SourceLocation.unknown (.constReal 0.0))
88+
| .date => .builtin (.DateLit 0) (.mk SourceLocation.unknown (.constRecord .nil))
89+
| .string => .constString ""
90+
| .record l => .constRecord (defaultFields2 (ctx := ctx) l)
91+
| .dict _ _ => .emptyDict
92+
93+
def defaultLoc2 {ctx : List Ty} : {ty : Ty} → TermLoc2 ctx ty
94+
| ty => .mk SourceLocation.unknown (defaultTerm2 (ctx := ctx) (ty := ty))
95+
96+
def defaultFields2 {ctx : List Ty} : (l : List Ty) → TermFields2 ctx l
97+
| [] => .nil
98+
| t :: ts => .cons (defaultLoc2 (ctx := ctx) (ty := t)) (defaultFields2 (ctx := ctx) ts)
99+
end
100+
81101
mutual
82102
def substTerm2 {ctx ctx' : List Ty} {ty : Ty}
83103
(σ : Subst ctx ctx') : Term2 ctx ty → Term2 ctx' ty
@@ -126,11 +146,11 @@ mutual
126146
| .lookup _ d k => mentionsIndexLoc d i || mentionsIndexLoc k i
127147
| .not e => mentionsIndexLoc e i
128148
| .ite c t f => mentionsIndexLoc c i || mentionsIndexLoc t i || mentionsIndexLoc f i
129-
| .letin bound body => mentionsIndexLoc bound i || mentionsIndexLoc body i
149+
| .letin bound body => mentionsIndexLoc bound i || mentionsIndexLoc body (i + 1)
130150
| .add _ t1 t2 => mentionsIndexLoc t1 i || mentionsIndexLoc t2 i
131151
| @Term2.mul _ _ _ _ _ _ _ _ t1 t2 => mentionsIndexLoc t1 i || mentionsIndexLoc t2 i
132152
| .promote e => mentionsIndexLoc e i
133-
| .sum _ d body => mentionsIndexLoc d i || mentionsIndexLoc body i
153+
| .sum _ d body => mentionsIndexLoc d i || mentionsIndexLoc body (i + 2)
134154
| @Term2.proj _ _ _ record _ _ => mentionsIndexLoc record i
135155
| .builtin _ arg => mentionsIndexLoc arg i
136156

PartIiProject/Optimisations/VerticalLoopFusion.lean

Lines changed: 66 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -6,50 +6,6 @@ namespace PartIiProject.Optimisations
66

77
open PartIiProject.Optimisations.Term2
88

9-
private def varMem? {ctx : List Ty} {ty : Ty} (tm : Term2 ctx ty) : Option (Mem ty ctx) :=
10-
match tm with
11-
| .var m => some m
12-
| .constInt _ => none
13-
| .constReal _ => none
14-
| .constBool _ => none
15-
| .constString _ => none
16-
| .constRecord _ => none
17-
| .emptyDict => none
18-
| .dictInsert _ _ _ => none
19-
| .lookup _ _ _ => none
20-
| .not _ => none
21-
| .ite _ _ _ => none
22-
| .letin _ _ => none
23-
| .add _ _ _ => none
24-
| @Term2.mul _ _ _ _ _ _ _ _ _ _ => none
25-
| .promote _ => none
26-
| .sum _ _ _ => none
27-
| @Term2.proj _ _ _ _ _ _ => none
28-
| .builtin _ _ => none
29-
30-
private def singletonDictInsert?
31-
{ctx : List Ty} {ty : Ty} (tm : Term2 ctx ty) :
32-
Option (Σ dom : Ty, Σ range : Ty, TermLoc2 ctx dom × TermLoc2 ctx range) :=
33-
match tm with
34-
| @Term2.dictInsert _ dom range k v _d => some ⟨dom, ⟨range, (k, v)⟩⟩
35-
| .var _ => none
36-
| .constInt _ => none
37-
| .constReal _ => none
38-
| .constBool _ => none
39-
| .constString _ => none
40-
| .constRecord _ => none
41-
| .emptyDict => none
42-
| .lookup _ _ _ => none
43-
| .not _ => none
44-
| .ite _ _ _ => none
45-
| .letin _ _ => none
46-
| .add _ _ _ => none
47-
| @Term2.mul _ _ _ _ _ _ _ _ _ _ => none
48-
| .promote _ => none
49-
| .sum _ _ _ => none
50-
| @Term2.proj _ _ _ _ _ _ => none
51-
| .builtin _ _ => none
52-
539
/--
5410
Vertical loop fusion, specialized to the two common "singleton dict" shapes:
5511
@@ -61,10 +17,71 @@ Vertical loop fusion, specialized to the two common "singleton dict" shapes:
6117
`let y = sum(<x,x_v> in e1) { x -> f1(x_v) } in sum(<x,x_v> in y) { x -> f2(x_v) }`
6218
`↦ sum(<x,x_v> in e1) { x -> f2(f1(x_v)) }`
6319
-/
64-
def verticalLoopFusion2 : Optimisation
65-
:= fun {ctx} {ty} t =>
66-
match t with
67-
| t@Term2.letin (⟨_, .sum a dict ⟨ _, .dictInsert x y z⟩ ⟩ ) let_in_body => .some t
68-
| _ => .none
20+
def verticalLoopFusionKeyMap2 : Optimisation :=
21+
fun {ctx} {ty} t =>
22+
match t with
23+
| Term2.letin
24+
(.mk _ (Term2.sum _ e₁ (.mk _ (.dictInsert k₁ v₁ (.mk _ .emptyDict)))))
25+
(.mk _ (Term2.sum a₂
26+
(.mk _ (.var (.head _)))
27+
(.mk bodyLoc (.dictInsert k₂ v₂ (.mk emptyLoc .emptyDict))))) =>
28+
match v₁.term, v₂.term with
29+
| .var (.tail _ (.head _)), .var (.tail _ (.head _)) =>
30+
if Term2.mentionsIndexLoc k₁ 1 || Term2.mentionsIndexLoc k₂ 1 || Term2.mentionsIndexLoc k₂ 2 then
31+
none
32+
else
33+
let σ : Term2.Subst (_ :: _ :: (.dict _ _) :: ctx) (_ :: _ :: ctx) :=
34+
fun {ty} m =>
35+
match m with
36+
| .head _ => k₁.term
37+
| .tail _ m =>
38+
match m with
39+
| .head _ => .var (.tail _ (.head ctx))
40+
| .tail _ m =>
41+
match m with
42+
| .head _ => Term2.defaultTerm2
43+
| .tail _ m => .var (.tail _ (.tail _ m))
44+
let k₂' := Term2.substLoc2 σ k₂
45+
let v₂' := Term2.substLoc2 σ v₂
46+
let emptyFused : TermLoc2 (_ :: _ :: ctx) (.dict _ _) := .mk emptyLoc .emptyDict
47+
let fusedBody : TermLoc2 (_ :: _ :: ctx) (.dict _ _) :=
48+
.mk bodyLoc (.dictInsert k₂' v₂' emptyFused)
49+
some (Term2.sum a₂ e₁ fusedBody)
50+
| _, _ => none
51+
| _ => none
52+
53+
def verticalLoopFusionValueMap2 : Optimisation :=
54+
fun {ctx} {ty} t =>
55+
match t with
56+
| Term2.letin
57+
(.mk _ (Term2.sum _ e₁ (.mk _ (.dictInsert k₁ v₁ (.mk _ .emptyDict)))))
58+
(.mk _ (Term2.sum a₂
59+
(.mk _ (.var (.head _)))
60+
(.mk bodyLoc (.dictInsert k₂ v₂ (.mk emptyLoc .emptyDict))))) =>
61+
match k₁.term, k₂.term with
62+
| .var (.head _), .var (.head _) =>
63+
if Term2.mentionsIndexLoc v₁ 0 || Term2.mentionsIndexLoc v₂ 0 || Term2.mentionsIndexLoc v₂ 2 then
64+
none
65+
else
66+
let σ : Term2.Subst (_ :: _ :: (.dict _ _) :: ctx) (_ :: _ :: ctx) :=
67+
fun {ty} m =>
68+
match m with
69+
| .head _ => k₁.term
70+
| .tail _ m =>
71+
match m with
72+
| .head _ => v₁.term
73+
| .tail _ m =>
74+
match m with
75+
| .head _ => Term2.defaultTerm2
76+
| .tail _ m => .var (.tail _ (.tail _ m))
77+
let k₂' := Term2.substLoc2 σ k₂
78+
let v₂' := Term2.substLoc2 σ v₂
79+
let emptyFused : TermLoc2 (_ :: _ :: ctx) (.dict _ _) := .mk emptyLoc .emptyDict
80+
let fusedBody : TermLoc2 (_ :: _ :: ctx) (.dict _ _) :=
81+
.mk bodyLoc (.dictInsert k₂' v₂' emptyFused)
82+
some (Term2.sum a₂ e₁ fusedBody)
83+
| _, _ => none
84+
| _ => none
85+
6986

7087
end PartIiProject.Optimisations

Tests/GuardMsgs.lean

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import PartIiProject.SyntaxSDQLProg
2+
import Tests.Optimisations.VerticalLoopFusion
23

34
open PartIiProject
45

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import PartIiProject.Optimisations
2+
import PartIiProject.SyntaxSDQLProg
3+
4+
open PartIiProject
5+
open PartIiProject.Optimisations
6+
7+
namespace Tests.Optimisations.VerticalLoopFusion
8+
9+
open ToCore2 in
10+
unsafe def optimiseCoreTerm (p : SProg2) : String :=
11+
let core := trProg2 p
12+
let term' := applyOptimisationsLoc [verticalLoopFusionKeyMap2, verticalLoopFusionValueMap2] core.term
13+
Term2.showTermLoc2 [] term'
14+
15+
/-- info: "sum(x, y in {1 -> 10} ++ {} + {2 -> 20} ++ {}) {x + 1 + 2 -> y} ++ {}" -/
16+
#guard_msgs in
17+
#eval optimiseCoreTerm
18+
([SDQLProg2 { { int -> int } }|
19+
let y = sum( <x, x_v> <- ({ 1 -> 10 } + { 2 -> 20 }) ) { x + 1 -> x_v } in
20+
sum( <x, x_v> <- y ) { x + 2 -> x_v }
21+
] : SProg2)
22+
23+
/-- info: "sum(x, y in {1 -> 10} ++ {} + {2 -> 20} ++ {}) {x -> y + 1 + 2} ++ {}" -/
24+
#guard_msgs in
25+
#eval optimiseCoreTerm
26+
([SDQLProg2 { { int -> int } }|
27+
let y = sum( <x, x_v> <- ({ 1 -> 10 } + { 2 -> 20 }) ) { x -> x_v + 1 } in
28+
sum( <x, x_v> <- y ) { x -> x_v + 2 }
29+
] : SProg2)
30+
31+
end Tests.Optimisations.VerticalLoopFusion
32+

docs/activeContext.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,9 @@ Latest changes:
4646
- Refactored the Rust AST to be DeBruijn-indexed (`Expr : Nat → Type`, vars are `Fin ctx`) and replaced stringly-typed runtime calls with `RuntimeFn`; updated `PartIiProject/CodegenRust.lean` accordingly.
4747
- Added a performance benchmarking runner `Performance.lean` (flake app `performanceComparsion`) that compares runtime (ms) of `sdql-rs` binaries vs Lean-generated Rust binaries, including microbenchmarks and TPCH cases.
4848
- Fixed a dependent-pattern-matching blocker in optimisation passes by refactoring `Term2.mul`/`Term2.proj` to carry typeclass witnesses (`has_tensor`/`has_proj`) instead of computed indices (`tensor` / `List.getD`) directly.
49+
- Added a small `Term2` optimisation framework (`PartIiProject/Optimisations/Apply.lean`) where each rewrite is a non-recursive `Optimisation` and `applyOptimisations{,Loc}` performs the recursive traversal + (fuel-bounded) fixpoint iteration.
50+
- Implemented vertical loop fusion over `Term2` as two separate rewrites in `PartIiProject/Optimisations/VerticalLoopFusion.lean` (`verticalLoopFusionKeyMap2` and `verticalLoopFusionValueMap2`).
51+
- Added/confirmed `#guard_msgs` coverage for vertical loop fusion in `Tests/Optimisations/VerticalLoopFusion.lean` (pulled in via `Tests/GuardMsgs.lean`).
4952

5053
Next steps (proposed):
5154

docs/progress.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ What works:
4545
- Performance comparison: `Performance.lean` executable `performanceComparsion` benchmarks runtime (ms) of `sdql-rs` reference binaries vs Lean-generated Rust binaries.
4646
- Surface/core terms are DeBruijn-indexed: surface terms in `SurfaceCore2.lean`, core terms in `Term2.lean`, with lowering in `ToCore2`.
4747
- Optimisation-friendly `Term2` indices: `mul`/`proj` carry `has_tensor`/`has_proj` witnesses to avoid dependent-elimination failures when pattern-matching in optimisation passes.
48+
- Optimisations over `Term2`: `PartIiProject/Optimisations/Apply.lean` provides a recursive driver for non-recursive `Optimisation` rewrites; `PartIiProject/Optimisations/VerticalLoopFusion.lean` implements key-map and value-map vertical loop fusion with `#guard_msgs` regression tests.
4849

4950
What's left to build:
5051

docs/systemPatterns.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ Notable patterns:
110110
- For optimisation passes that pattern-match on `Term2`, avoid computed indices in inductive families:
111111
- `Term2.mul` carries a `has_tensor t1 t2 t3` witness (typeclass) instead of returning `Term2 ctx (tensor t1 t2)` directly.
112112
- `Term2.proj` carries a `has_proj l i t` witness instead of returning `Term2 ctx (l.getD i Ty.int)` directly.
113+
- Optimisation passes are structured as local, non-recursive rewrites over core terms (`PartIiProject/Optimisations/Apply.lean`): each `Optimisation` is `Term2 ctx ty → Option (Term2 ctx ty)`, and `applyOptimisations{,Loc}` provides the recursive traversal and fuel-bounded fixpoint iteration.
113114
- Addition and scaling are encoded as explicit evidence, guiding typing and compilation.
114115
- Lookups and sums rely on the additive identity of the result to stay total and align with sparse semantics.
115116
- Tests compare Rust program output against expected strings or a reference binary. Rust programs use `SDQLShow::show(&result)`.

0 commit comments

Comments
 (0)