Skip to content

Commit 0bb6614

Browse files
author
Atticus Kuhn
committed
Refactor Term2 mul/proj to avoid dependent elimination
Use typeclass witnesses (has_tensor/has_proj) instead of computed indices (tensor/getD) so optimisation passes can pattern-match on Term2 without solver failures. Update codegen and docs, and add initial Optimisations modules.
1 parent d7a6c70 commit 0bb6614

File tree

10 files changed

+429
-12
lines changed

10 files changed

+429
-12
lines changed

PartIiProject/CodegenRust.lean

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ mutual
139139
.letIn (compileLoc2 bound) (compileLoc2 body)
140140
| .add a t1 t2 =>
141141
Compile.compileAdd a (compileLoc2 t1) (compileLoc2 t2)
142-
| @Term2.mul _ _ _ _ s1 s2 e1 e2 =>
142+
| @Term2.mul _ _ _ _ _ s1 s2 _ e1 e2 =>
143143
let lhs := compileLoc2 e1
144144
let rhs := compileLoc2 e2
145145
match s1, s2 with
@@ -206,7 +206,7 @@ mutual
206206
let accResult : Rust.ExprLoc (n+1) :=
207207
Rust.ExprLoc.withUnknownLoc (.var ⟨0, Nat.succ_pos n⟩)
208208
.block stmts accResult
209-
| .proj _ r i => .tupleProj (compileLoc2 r) i
209+
| @Term2.proj _ _ _ r i _ => .tupleProj (compileLoc2 r) i
210210
| .builtin b a =>
211211
Compile.compileBuiltin b (compileLoc2 a)
212212

PartIiProject/Optimisations.lean

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
import PartIiProject.Optimisations.Apply
2+
import PartIiProject.Optimisations.VerticalLoopFusion
3+
4+
namespace PartIiProject.Optimisations
5+
-- Re-export the main API from submodules.
6+
end PartIiProject.Optimisations
Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
import PartIiProject.Optimisations.Term2Utils
2+
3+
open PartIiProject
4+
5+
namespace PartIiProject.Optimisations
6+
7+
abbrev Optimisation : Type :=
8+
{ctx : List Ty} → {ty : Ty} → Term2 ctx ty → Option (Term2 ctx ty)
9+
10+
def tryOptimisations {ctx : List Ty} {ty : Ty} (opts : List Optimisation) (t : Term2 ctx ty) :
11+
Option (Term2 ctx ty) :=
12+
match opts with
13+
| [] => none
14+
| o :: os =>
15+
match o t with
16+
| some t' => some t'
17+
| none => tryOptimisations os t
18+
19+
mutual
20+
def applyOptimisationsOnceTerm {ctx : List Ty} {ty : Ty} (opts : List Optimisation) :
21+
Term2 ctx ty → Term2 ctx ty × Bool
22+
| .var m => (.var m, false)
23+
| .constInt n => (.constInt n, false)
24+
| .constReal r => (.constReal r, false)
25+
| .constBool b => (.constBool b, false)
26+
| .constString s => (.constString s, false)
27+
| .constRecord fields =>
28+
let (fields', ch) := applyOptimisationsOnceFields opts fields
29+
let t' := Term2.constRecord fields'
30+
match tryOptimisations opts t' with
31+
| some t'' => (t'', true)
32+
| none => (t', ch)
33+
| .emptyDict => (.emptyDict, false)
34+
| .dictInsert k v d =>
35+
let (k', chK) := applyOptimisationsOnceLoc opts k
36+
let (v', chV) := applyOptimisationsOnceLoc opts v
37+
let (d', chD) := applyOptimisationsOnceLoc opts d
38+
let ch := chK || chV || chD
39+
let t' := Term2.dictInsert k' v' d'
40+
match tryOptimisations opts t' with
41+
| some t'' => (t'', true)
42+
| none => (t', ch)
43+
| .lookup aRange d k =>
44+
let (d', chD) := applyOptimisationsOnceLoc opts d
45+
let (k', chK) := applyOptimisationsOnceLoc opts k
46+
let ch := chD || chK
47+
let t' := Term2.lookup aRange d' k'
48+
match tryOptimisations opts t' with
49+
| some t'' => (t'', true)
50+
| none => (t', ch)
51+
| .not e =>
52+
let (e', chE) := applyOptimisationsOnceLoc opts e
53+
let t' := Term2.not e'
54+
match tryOptimisations opts t' with
55+
| some t'' => (t'', true)
56+
| none => (t', chE)
57+
| .ite c t f =>
58+
let (c', chC) := applyOptimisationsOnceLoc opts c
59+
let (t', chT) := applyOptimisationsOnceLoc opts t
60+
let (f', chF) := applyOptimisationsOnceLoc opts f
61+
let ch := chC || chT || chF
62+
let term' := Term2.ite c' t' f'
63+
match tryOptimisations opts term' with
64+
| some term'' => (term'', true)
65+
| none => (term', ch)
66+
| .letin bound body =>
67+
let (bound', chB) := applyOptimisationsOnceLoc opts bound
68+
let (body', chBody) := applyOptimisationsOnceLoc opts body
69+
let ch := chB || chBody
70+
let t' := Term2.letin bound' body'
71+
match tryOptimisations opts t' with
72+
| some t'' => (t'', true)
73+
| none => (t', ch)
74+
| .add a t1 t2 =>
75+
let (t1', ch1) := applyOptimisationsOnceLoc opts t1
76+
let (t2', ch2) := applyOptimisationsOnceLoc opts t2
77+
let ch := ch1 || ch2
78+
let t' := Term2.add a t1' t2'
79+
match tryOptimisations opts t' with
80+
| some t'' => (t'', true)
81+
| none => (t', ch)
82+
| @Term2.mul _ sc t1Ty t2Ty t3 s1 s2 inst e1 e2 =>
83+
let (t1', ch1) := applyOptimisationsOnceLoc opts e1
84+
let (t2', ch2) := applyOptimisationsOnceLoc opts e2
85+
let ch := ch1 || ch2
86+
let t' := @Term2.mul _ sc t1Ty t2Ty t3 s1 s2 inst t1' t2'
87+
match tryOptimisations opts t' with
88+
| some t'' => (t'', true)
89+
| none => (t', ch)
90+
| .promote e =>
91+
let (e', chE) := applyOptimisationsOnceLoc opts e
92+
let t' := Term2.promote e'
93+
match tryOptimisations opts t' with
94+
| some t'' => (t'', true)
95+
| none => (t', chE)
96+
| .sum a d body =>
97+
let (d', chD) := applyOptimisationsOnceLoc opts d
98+
let (body', chBody) := applyOptimisationsOnceLoc opts body
99+
let ch := chD || chBody
100+
let t' := Term2.sum a d' body'
101+
match tryOptimisations opts t' with
102+
| some t'' => (t'', true)
103+
| none => (t', ch)
104+
| @Term2.proj _ l t record i inst =>
105+
let (record', chR) := applyOptimisationsOnceLoc opts record
106+
let t' := @Term2.proj _ l t record' i inst
107+
match tryOptimisations opts t' with
108+
| some t'' => (t'', true)
109+
| none => (t', chR)
110+
| .builtin f arg =>
111+
let (arg', chA) := applyOptimisationsOnceLoc opts arg
112+
let t' := Term2.builtin f arg'
113+
match tryOptimisations opts t' with
114+
| some t'' => (t'', true)
115+
| none => (t', chA)
116+
117+
def applyOptimisationsOnceLoc {ctx : List Ty} {ty : Ty} (opts : List Optimisation) :
118+
TermLoc2 ctx ty → TermLoc2 ctx ty × Bool
119+
| .mk loc inner =>
120+
let (inner', ch) := applyOptimisationsOnceTerm opts inner
121+
(.mk loc inner', ch)
122+
123+
def applyOptimisationsOnceFields {ctx : List Ty} (opts : List Optimisation) :
124+
{l : List Ty} → TermFields2 ctx l → TermFields2 ctx l × Bool
125+
| [], .nil => (.nil, false)
126+
| _ :: _, .cons h t =>
127+
let (h', chH) := applyOptimisationsOnceLoc opts h
128+
let (t', chT) := applyOptimisationsOnceFields opts t
129+
(.cons h' t', chH || chT)
130+
end
131+
132+
def applyOptimisationOnce {ctx : List Ty} {ty : Ty} (opt : Optimisation) (t : Term2 ctx ty) :
133+
Term2 ctx ty :=
134+
(applyOptimisationsOnceTerm [opt] t).fst
135+
136+
def applyOptimisationOnceLoc {ctx : List Ty} {ty : Ty} (opt : Optimisation) (t : TermLoc2 ctx ty) :
137+
TermLoc2 ctx ty :=
138+
(applyOptimisationsOnceLoc [opt] t).fst
139+
140+
partial def applyOptimisations {ctx : List Ty} {ty : Ty}
141+
(opts : List Optimisation) (t : Term2 ctx ty) (fuel : Nat := 32) : Term2 ctx ty :=
142+
match fuel with
143+
| 0 => t
144+
| fuel + 1 =>
145+
let (t', changed) := applyOptimisationsOnceTerm opts t
146+
if changed then
147+
applyOptimisations opts t' fuel
148+
else
149+
t'
150+
151+
partial def applyOptimisationsLoc {ctx : List Ty} {ty : Ty}
152+
(opts : List Optimisation) (t : TermLoc2 ctx ty) (fuel : Nat := 32) : TermLoc2 ctx ty :=
153+
match fuel with
154+
| 0 => t
155+
| fuel + 1 =>
156+
let (t', changed) := applyOptimisationsOnceLoc opts t
157+
if changed then
158+
applyOptimisationsLoc opts t' fuel
159+
else
160+
t'
161+
162+
def applyOptimisation {ctx : List Ty} {ty : Ty} (opt : Optimisation) (t : Term2 ctx ty) (fuel : Nat := 32) :
163+
Term2 ctx ty :=
164+
applyOptimisations [opt] t fuel
165+
166+
def applyOptimisationLoc {ctx : List Ty} {ty : Ty}
167+
(opt : Optimisation) (t : TermLoc2 ctx ty) (fuel : Nat := 32) : TermLoc2 ctx ty :=
168+
applyOptimisationsLoc [opt] t fuel
169+
170+
end PartIiProject.Optimisations
Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
import PartIiProject.Term2
2+
3+
open PartIiProject
4+
5+
namespace PartIiProject.Optimisations
6+
7+
namespace Mem
8+
9+
def index {α : Type} {a : α} {ctx : List α} : Mem a ctx → Nat
10+
| .head _ => 0
11+
| .tail _ m => index m + 1
12+
13+
end Mem
14+
15+
namespace Term2
16+
17+
abbrev Renaming (ctx ctx' : List Ty) : Type :=
18+
∀ {ty : Ty}, Mem ty ctx → Mem ty ctx'
19+
20+
abbrev Subst (ctx ctx' : List Ty) : Type :=
21+
∀ {ty : Ty}, Mem ty ctx → Term2 ctx' ty
22+
23+
def liftRen {ctx ctx' : List Ty} {a : Ty} (ρ : Renaming ctx ctx') :
24+
Renaming (a :: ctx) (a :: ctx')
25+
| _, .head _ => .head _
26+
| _, .tail _ m => .tail _ (ρ m)
27+
28+
def liftRen2 {ctx ctx' : List Ty} {a b : Ty} (ρ : Renaming ctx ctx') :
29+
Renaming (a :: b :: ctx) (a :: b :: ctx') :=
30+
liftRen (a := a) (liftRen (a := b) ρ)
31+
32+
mutual
33+
def renameTerm2 {ctx ctx' : List Ty} {ty : Ty}
34+
(ρ : Renaming ctx ctx') : Term2 ctx ty → Term2 ctx' ty
35+
| .var m => .var (ρ m)
36+
| .constInt n => .constInt n
37+
| .constReal r => .constReal r
38+
| .constBool b => .constBool b
39+
| .constString s => .constString s
40+
| .constRecord fields => .constRecord (renameFields2 ρ fields)
41+
| .emptyDict => .emptyDict
42+
| .dictInsert k v d => .dictInsert (renameLoc2 ρ k) (renameLoc2 ρ v) (renameLoc2 ρ d)
43+
| .lookup aRange d k => .lookup aRange (renameLoc2 ρ d) (renameLoc2 ρ k)
44+
| .not e => .not (renameLoc2 ρ e)
45+
| .ite c t f => .ite (renameLoc2 ρ c) (renameLoc2 ρ t) (renameLoc2 ρ f)
46+
| .letin bound body => .letin (renameLoc2 ρ bound) (renameLoc2 (liftRen ρ) body)
47+
| .add a t1 t2 => .add a (renameLoc2 ρ t1) (renameLoc2 ρ t2)
48+
| @Term2.mul _ sc t1 t2 t3 s1 s2 inst e1 e2 =>
49+
@Term2.mul _ sc t1 t2 t3 s1 s2 inst (renameLoc2 ρ e1) (renameLoc2 ρ e2)
50+
| .promote e => .promote (renameLoc2 ρ e)
51+
| .sum a d body => .sum a (renameLoc2 ρ d) (renameLoc2 (liftRen2 ρ) body)
52+
| @Term2.proj _ l t record i inst =>
53+
@Term2.proj _ l t (renameLoc2 ρ record) i inst
54+
| .builtin f arg => .builtin f (renameLoc2 ρ arg)
55+
56+
def renameLoc2 {ctx ctx' : List Ty} {ty : Ty}
57+
(ρ : Renaming ctx ctx') : TermLoc2 ctx ty → TermLoc2 ctx' ty
58+
| .mk loc inner => .mk loc (renameTerm2 ρ inner)
59+
60+
def renameFields2 {ctx ctx' : List Ty} (ρ : Renaming ctx ctx') :
61+
{l : List Ty} → TermFields2 ctx l → TermFields2 ctx' l
62+
| [], .nil => .nil
63+
| _ :: _, .cons h t => .cons (renameLoc2 ρ h) (renameFields2 ρ t)
64+
end
65+
66+
def wkRen {ctx : List Ty} {a : Ty} : Renaming ctx (a :: ctx) :=
67+
fun m => .tail _ m
68+
69+
def wk {ctx : List Ty} {ty : Ty} {a : Ty} : Term2 ctx ty → Term2 (a :: ctx) ty :=
70+
renameTerm2 wkRen
71+
72+
def liftSubst {ctx ctx' : List Ty} {a : Ty} (σ : Subst ctx ctx') :
73+
Subst (a :: ctx) (a :: ctx')
74+
| _, .head _ => .var (.head _)
75+
| _, .tail _ m => wk (a := a) (σ m)
76+
77+
def liftSubst2 {ctx ctx' : List Ty} {a b : Ty} (σ : Subst ctx ctx') :
78+
Subst (a :: b :: ctx) (a :: b :: ctx') :=
79+
liftSubst (a := a) (liftSubst (a := b) σ)
80+
81+
mutual
82+
def substTerm2 {ctx ctx' : List Ty} {ty : Ty}
83+
(σ : Subst ctx ctx') : Term2 ctx ty → Term2 ctx' ty
84+
| .var m => σ m
85+
| .constInt n => .constInt n
86+
| .constReal r => .constReal r
87+
| .constBool b => .constBool b
88+
| .constString s => .constString s
89+
| .constRecord fields => .constRecord (substFields2 σ fields)
90+
| .emptyDict => .emptyDict
91+
| .dictInsert k v d => .dictInsert (substLoc2 σ k) (substLoc2 σ v) (substLoc2 σ d)
92+
| .lookup aRange d k => .lookup aRange (substLoc2 σ d) (substLoc2 σ k)
93+
| .not e => .not (substLoc2 σ e)
94+
| .ite c t f => .ite (substLoc2 σ c) (substLoc2 σ t) (substLoc2 σ f)
95+
| .letin bound body => .letin (substLoc2 σ bound) (substLoc2 (liftSubst σ) body)
96+
| .add a t1 t2 => .add a (substLoc2 σ t1) (substLoc2 σ t2)
97+
| @Term2.mul _ sc t1 t2 t3 s1 s2 inst e1 e2 =>
98+
@Term2.mul _ sc t1 t2 t3 s1 s2 inst (substLoc2 σ e1) (substLoc2 σ e2)
99+
| .promote e => .promote (substLoc2 σ e)
100+
| .sum a d body => .sum a (substLoc2 σ d) (substLoc2 (liftSubst2 σ) body)
101+
| @Term2.proj _ l t record i inst =>
102+
@Term2.proj _ l t (substLoc2 σ record) i inst
103+
| .builtin f arg => .builtin f (substLoc2 σ arg)
104+
105+
def substLoc2 {ctx ctx' : List Ty} {ty : Ty}
106+
(σ : Subst ctx ctx') : TermLoc2 ctx ty → TermLoc2 ctx' ty
107+
| .mk loc inner => .mk loc (substTerm2 σ inner)
108+
109+
def substFields2 {ctx ctx' : List Ty} (σ : Subst ctx ctx') :
110+
{l : List Ty} → TermFields2 ctx l → TermFields2 ctx' l
111+
| [], .nil => .nil
112+
| _ :: _, .cons h t => .cons (substLoc2 σ h) (substFields2 σ t)
113+
end
114+
115+
mutual
116+
def mentionsIndex {ctx : List Ty} {ty : Ty} (t : Term2 ctx ty) (i : Nat) : Bool :=
117+
match t with
118+
| .var m => (Mem.index m == i)
119+
| .constInt _ => false
120+
| .constReal _ => false
121+
| .constBool _ => false
122+
| .constString _ => false
123+
| .constRecord fields => mentionsIndexFields fields i
124+
| .emptyDict => false
125+
| .dictInsert k v d => mentionsIndexLoc k i || mentionsIndexLoc v i || mentionsIndexLoc d i
126+
| .lookup _ d k => mentionsIndexLoc d i || mentionsIndexLoc k i
127+
| .not e => mentionsIndexLoc e i
128+
| .ite c t f => mentionsIndexLoc c i || mentionsIndexLoc t i || mentionsIndexLoc f i
129+
| .letin bound body => mentionsIndexLoc bound i || mentionsIndexLoc body i
130+
| .add _ t1 t2 => mentionsIndexLoc t1 i || mentionsIndexLoc t2 i
131+
| @Term2.mul _ _ _ _ _ _ _ _ t1 t2 => mentionsIndexLoc t1 i || mentionsIndexLoc t2 i
132+
| .promote e => mentionsIndexLoc e i
133+
| .sum _ d body => mentionsIndexLoc d i || mentionsIndexLoc body i
134+
| @Term2.proj _ _ _ record _ _ => mentionsIndexLoc record i
135+
| .builtin _ arg => mentionsIndexLoc arg i
136+
137+
def mentionsIndexLoc {ctx : List Ty} {ty : Ty} (t : TermLoc2 ctx ty) (i : Nat) : Bool :=
138+
match t with
139+
| .mk _ inner => mentionsIndex inner i
140+
141+
def mentionsIndexFields {ctx : List Ty} {l : List Ty} (fs : TermFields2 ctx l) (i : Nat) : Bool :=
142+
match fs with
143+
| .nil => false
144+
| .cons h t => mentionsIndexLoc h i || mentionsIndexFields t i
145+
end
146+
147+
end Term2
148+
149+
end PartIiProject.Optimisations

0 commit comments

Comments
 (0)