Skip to content

Commit b32ae5b

Browse files
author
Atticus Kuhn
committed
add optimisation performance comparison
1 parent c25062a commit b32ae5b

File tree

9 files changed

+419
-148
lines changed

9 files changed

+419
-148
lines changed
Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
import PartIiProject.Bench.Common
2+
import PartIiProject.Bench.Table
3+
import PartIiProject.Optimisations
4+
import PartIiProject.SyntaxSDQLProg
5+
import Lean
6+
import Std
7+
8+
open PartIiProject
9+
open PartIiProject.Optimisations
10+
open System
11+
12+
namespace OptimisationPerformanceComparison
13+
14+
def outDir : FilePath := FilePath.mk ".sdql-opt-perf-out"
15+
def runtimeSrc : FilePath := FilePath.mk "sdql_runtime.rs"
16+
17+
def iters : Nat := 3
18+
def dictN : Int := 200000
19+
def memoN : Int := 100000
20+
def memoM : Int := 1000
21+
22+
unsafe structure BenchCase where
23+
name : String
24+
prog : SProg2
25+
opts : List Optimisation
26+
env : List (String × String) := []
27+
28+
structure Reading where
29+
name : String
30+
unoptMs : Nat
31+
optMs : Nat
32+
33+
def meanNat (xs : List Nat) : Nat :=
34+
match xs with
35+
| [] => 0
36+
| _ =>
37+
let total := xs.foldl (fun s x => s + x) 0
38+
total / xs.length
39+
40+
def timeBinaryAvgMs (binPath : FilePath) (iters : Nat) (env : List (String × String)) :
41+
IO (Except String Nat) := do
42+
if iters == 0 then
43+
return .error "iters must be > 0"
44+
let mut acc : List Nat := []
45+
for _ in [0:iters] do
46+
match ← PartIiProject.Bench.timeBinaryMs binPath env with
47+
| .ok ms => acc := ms :: acc
48+
| .error e => return .error e
49+
return .ok (meanNat acc)
50+
51+
unsafe def runCase (b : BenchCase) : IO (Except String Reading) := do
52+
let baseCp := ToCore2.trProg2 b.prog
53+
let optTerm := applyOptimisationsLoc b.opts baseCp.term
54+
let optCp : Prog2 := { baseCp with term := optTerm }
55+
56+
let unoptBin ←
57+
match ← PartIiProject.Bench.compileProg2ToBin outDir runtimeSrc s!"{b.name}_unopt" baseCp with
58+
| .ok p => pure p
59+
| .error e => return .error s!"Unoptimised compile failed:\n{e}"
60+
61+
let optBin ←
62+
match ← PartIiProject.Bench.compileProg2ToBin outDir runtimeSrc s!"{b.name}_opt" optCp with
63+
| .ok p => pure p
64+
| .error e => return .error s!"Optimised compile failed:\n{e}"
65+
66+
let unoptMs ←
67+
match ← timeBinaryAvgMs unoptBin iters b.env with
68+
| .ok ms => pure ms
69+
| .error e => return .error s!"Unoptimised run failed:\n{e}"
70+
71+
let optMs ←
72+
match ← timeBinaryAvgMs optBin iters b.env with
73+
| .ok ms => pure ms
74+
| .error e => return .error s!"Optimised run failed:\n{e}"
75+
76+
return .ok { name := b.name, unoptMs := unoptMs, optMs := optMs }
77+
78+
unsafe def p_vertical_loop_fusion_key_map : SProg2 :=
79+
[SDQLProg2 { { int -> int } }|
80+
let dictN = 200000 in
81+
let d = sum( <i, _> <- range(dictN) ) { i -> i } in
82+
let y = sum( <x, x_v> <- d ) { x + 1 -> x_v } in
83+
sum( <x, x_v> <- y ) { x + 2 -> x_v }
84+
]
85+
86+
unsafe def p_vertical_loop_fusion_value_map : SProg2 :=
87+
[SDQLProg2 { { int -> int } }|
88+
let dictN = 200000 in
89+
let d = sum( <i, _> <- range(dictN) ) { i -> i } in
90+
let y = sum( <x, x_v> <- d ) { x -> x_v + 1 } in
91+
sum( <x, x_v> <- y ) { x -> x_v + 2 }
92+
]
93+
94+
unsafe def p_horizontal_loop_fusion : SProg2 :=
95+
[SDQLProg2 { int }|
96+
let dictN = 200000 in
97+
let d = sum( <i, _> <- range(dictN) ) { i -> i } in
98+
let y1 = sum( <_, v> <- d ) v in
99+
let y2 = sum( <_, v> <- d ) (v + 1) in
100+
y1 + y2
101+
]
102+
103+
unsafe def p_loop_factorization_left : SProg2 :=
104+
[SDQLProg2 { int }|
105+
let dictN = 200000 in
106+
let d = sum( <i, _> <- range(dictN) ) { i -> i } in
107+
sum( <_, v> <- d ) (2 * v)
108+
]
109+
110+
unsafe def p_loop_factorization_right : SProg2 :=
111+
[SDQLProg2 { int }|
112+
let dictN = 200000 in
113+
let d = sum( <i, _> <- range(dictN) ) { i -> i } in
114+
sum( <_, v> <- d ) (v * 2)
115+
]
116+
117+
unsafe def p_loop_invariant_code_motion : SProg2 :=
118+
[SDQLProg2 { int }|
119+
let dictN = 200000 in
120+
let d = sum( <i, _> <- range(dictN) ) { i -> i } in
121+
sum( <_, v> <- d ) (let y = 5 in v + y)
122+
]
123+
124+
-- Memoization lookup + hoisting: without hoisting this is not a win; include code motion.
125+
unsafe def p_loop_memoization_lookup : SProg2 :=
126+
[SDQLProg2 { int }|
127+
let memoN = 100000 in
128+
let memoM = 1000 in
129+
let d = sum( <i, _> <- range(memoN) ) { i -> i } in
130+
sum( <i, _> <- range(memoM) ) (sum( <k, v> <- d ) (if k == i then v))
131+
]
132+
133+
-- Partition + hoisting: then-branch depends on `i`, so lookup-memoization cannot hoist `f`.
134+
unsafe def p_loop_memoization_partition : SProg2 :=
135+
[SDQLProg2 { int }|
136+
let memoN = 100000 in
137+
let memoM = 1000 in
138+
let d = sum( <i, _> <- range(memoN) ) { i -> i } in
139+
sum( <i, _> <- range(memoM) ) (sum( <k, v> <- d ) (if k == i then v + i))
140+
]
141+
142+
unsafe def mkCases : List BenchCase :=
143+
[ { name := "vertical_loop_fusion_key_map"
144+
prog := p_vertical_loop_fusion_key_map
145+
opts := [verticalLoopFusionKeyMap2, verticalLoopFusionValueMap2]
146+
}
147+
, { name := "vertical_loop_fusion_value_map"
148+
prog := p_vertical_loop_fusion_value_map
149+
opts := [verticalLoopFusionKeyMap2, verticalLoopFusionValueMap2]
150+
}
151+
, { name := "horizontal_loop_fusion"
152+
prog := p_horizontal_loop_fusion
153+
opts := [horizontalLoopFusion2]
154+
}
155+
, { name := "loop_factorization_left"
156+
prog := p_loop_factorization_left
157+
opts := [loopFactorizationLeft2]
158+
}
159+
, { name := "loop_factorization_right"
160+
prog := p_loop_factorization_right
161+
opts := [loopFactorizationRight2]
162+
}
163+
, { name := "loop_invariant_code_motion"
164+
prog := p_loop_invariant_code_motion
165+
opts := [loopInvariantCodeMotion2]
166+
}
167+
, { name := "loop_memoization_lookup"
168+
prog := p_loop_memoization_lookup
169+
opts := [loopMemoizationLookup2, loopInvariantCodeMotion2]
170+
}
171+
, { name := "loop_memoization_partition"
172+
prog := p_loop_memoization_partition
173+
opts := [loopMemoizationPartition2, loopInvariantCodeMotion2]
174+
}
175+
]
176+
177+
unsafe def main (_args : List String) : IO UInt32 := do
178+
if (← outDir.pathExists) then
179+
IO.FS.removeDirAll outDir
180+
IO.FS.createDirAll outDir
181+
182+
let mut readings : List Reading := []
183+
let mut failures : List (String × String) := []
184+
185+
for b in mkCases do
186+
match ← runCase b with
187+
| .ok r => readings := readings.concat r
188+
| .error e => failures := failures.concat (b.name, e)
189+
190+
PartIiProject.Bench.printMsComparisonTable
191+
s!"SDQL optimisation performance comparison (mean of {iters} run(s); wall-clock ms)"
192+
"unopt" "opt" "opt/unopt"
193+
(readings.map fun r =>
194+
{ name := r.name, leftMs := r.unoptMs, rightMs := r.optMs })
195+
(preamble := [s!"Params: dictN={dictN}, memoN={memoN}, memoM={memoM}"])
196+
197+
if !failures.isEmpty then
198+
IO.eprintln ""
199+
IO.eprintln "Failures:"
200+
for (nm, err) in failures do
201+
IO.eprintln s!"- {nm}: {err}"
202+
return 1
203+
return 0
204+
205+
end OptimisationPerformanceComparison
206+
207+
unsafe def main (args : List String) : IO UInt32 :=
208+
OptimisationPerformanceComparison.main args

PartIiProject/Bench/Common.lean

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
import PartIiProject.CodegenRust
2+
import PartIiProject.SyntaxSDQLProg
3+
import Lean
4+
import Std
5+
6+
open PartIiProject
7+
open System
8+
9+
namespace PartIiProject.Bench
10+
11+
def writeFile (p : FilePath) (s : String) : IO Unit :=
12+
IO.FS.writeFile p s
13+
14+
def copyFile (src dst : FilePath) : IO Unit := do
15+
let contents ← IO.FS.readFile src
16+
IO.FS.writeFile dst contents
17+
18+
def absPath (p : FilePath) : IO FilePath := do
19+
let cwd ← IO.currentDir
20+
pure (cwd / p)
21+
22+
def runProc (cmd : String) (args : Array String) (cwd? : Option FilePath := none)
23+
(envVars : List (String × String) := []) : IO (Nat × String × String) := do
24+
let env : Array (String × Option String) := envVars.toArray.map (fun (k, v) => (k, some v))
25+
let out ← IO.Process.output { cmd := cmd, args := args, cwd := cwd?, env := env, inheritEnv := Bool.true }
26+
return (out.exitCode.toNat, out.stdout, out.stderr)
27+
28+
def runProcDiscardStdout (cmd : String) (args : Array String) (cwd? : Option FilePath := none)
29+
(envVars : List (String × String) := []) : IO (Nat × String) := do
30+
let env : Array (String × Option String) := envVars.toArray.map (fun (k, v) => (k, some v))
31+
let child ← IO.Process.spawn {
32+
cmd := cmd,
33+
args := args,
34+
cwd := cwd?,
35+
env := env,
36+
inheritEnv := Bool.true,
37+
stdout := .null,
38+
stderr := .piped,
39+
}
40+
let code := (← child.wait).toNat
41+
let err ← child.stderr.readToEnd
42+
return (code, err)
43+
44+
def runTimedMs (cmd : String) (args : Array String) (cwd? : Option FilePath := none)
45+
(envVars : List (String × String) := []) : IO (Nat × Nat) := do
46+
let env : Array (String × Option String) := envVars.toArray.map (fun (k, v) => (k, some v))
47+
let start ← IO.monoMsNow
48+
let child ← IO.Process.spawn {
49+
cmd := cmd,
50+
args := args,
51+
cwd := cwd?,
52+
env := env,
53+
stdout := .null,
54+
stderr := .null,
55+
}
56+
let code := (← child.wait).toNat
57+
let stop ← IO.monoMsNow
58+
return (code, stop - start)
59+
60+
def compileRust (rsPath binPath : FilePath) : IO (Except String Unit) := do
61+
let (code, _out, err) ← runProc "rustc"
62+
#["-O", "-C", "target-cpu=native", "-o", binPath.toString, rsPath.toString]
63+
if code != 0 then
64+
return .error err
65+
return .ok ()
66+
67+
def copyRuntime (runtimeSrc outDir : FilePath) : IO Unit := do
68+
IO.FS.createDirAll outDir
69+
copyFile runtimeSrc (outDir / "sdql_runtime.rs")
70+
71+
unsafe def compileProg2ToBin (outDir runtimeSrc : FilePath) (name : String) (cp : Prog2) :
72+
IO (Except String FilePath) := do
73+
copyRuntime runtimeSrc outDir
74+
let rs := PartIiProject.renderRustProg2Shown cp
75+
let rsPath := outDir / s!"{name}.rs"
76+
let binPath := outDir / s!"{name}.bin"
77+
writeFile rsPath rs
78+
match ← compileRust rsPath binPath with
79+
| .ok () => return .ok binPath
80+
| .error err => return .error err
81+
82+
unsafe def compileSProg2ToBin (outDir runtimeSrc : FilePath) (name : String) (sp : SProg2) :
83+
IO (Except String FilePath) := do
84+
compileProg2ToBin outDir runtimeSrc name (ToCore2.trProg2 sp)
85+
86+
def timeBinaryMs (binPath : FilePath) (envVars : List (String × String) := []) : IO (Except String Nat) := do
87+
let (code, ms) ← runTimedMs binPath.toString #[] (cwd? := none) (envVars := envVars)
88+
if code != 0 then
89+
return .error s!"Non-zero exit code {code} for {binPath.toString}"
90+
return .ok ms
91+
92+
def padLeft (width : Nat) (s : String) : String :=
93+
if s.length >= width then s else String.mk (List.replicate (width - s.length) ' ') ++ s
94+
95+
def padRight (width : Nat) (s : String) : String :=
96+
if s.length >= width then s else s ++ String.mk (List.replicate (width - s.length) ' ')
97+
98+
def pad3 (n : Nat) : String :=
99+
let s := toString n
100+
if s.length >= 3 then s else String.mk (List.replicate (3 - s.length) '0') ++ s
101+
102+
def ratioMilli (numer denom : Nat) : Nat :=
103+
if denom == 0 then 0 else (numer * 1000) / denom
104+
105+
def ratioString (numer denom : Nat) : String :=
106+
if denom == 0 then "n/a"
107+
else
108+
let rm := ratioMilli numer denom
109+
let whole := rm / 1000
110+
let frac := rm % 1000
111+
s!"{whole}.{pad3 frac}×"
112+
113+
end PartIiProject.Bench

PartIiProject/Bench/Table.lean

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import PartIiProject.Bench.Common
2+
import Std
3+
4+
namespace PartIiProject.Bench
5+
6+
structure MsComparisonRow where
7+
name : String
8+
leftMs : Nat
9+
rightMs : Nat
10+
11+
def printMsComparisonTable
12+
(title : String)
13+
(leftLabel rightLabel ratioLabel : String)
14+
(rows : List MsComparisonRow)
15+
(preamble : List String := [])
16+
(nameHeader : String := "case")
17+
(totalLabel : String := "TOTAL")
18+
(colW : Nat := 12) : IO Unit := do
19+
if rows.isEmpty then
20+
IO.println "No benchmarks ran."
21+
return
22+
23+
let nameW := rows.foldl (fun w r => max w r.name.length) nameHeader.length
24+
25+
IO.println title
26+
for l in preamble do
27+
IO.println l
28+
29+
IO.println (padRight nameW nameHeader ++ " " ++
30+
padLeft colW leftLabel ++ " " ++
31+
padLeft colW rightLabel ++ " " ++
32+
padLeft colW ratioLabel)
33+
IO.println (String.mk (List.replicate (nameW + 2 + colW*3 + 4) '-'))
34+
35+
for r in rows do
36+
IO.println (padRight nameW r.name ++ " " ++
37+
padLeft colW s!"{r.leftMs}ms" ++ " " ++
38+
padLeft colW s!"{r.rightMs}ms" ++ " " ++
39+
padLeft colW (ratioString r.rightMs r.leftMs))
40+
41+
let totalLeft := rows.foldl (fun s r => s + r.leftMs) 0
42+
let totalRight := rows.foldl (fun s r => s + r.rightMs) 0
43+
IO.println (String.mk (List.replicate (nameW + 2 + colW*3 + 4) '-'))
44+
IO.println (padRight nameW totalLabel ++ " " ++
45+
padLeft colW s!"{totalLeft}ms" ++ " " ++
46+
padLeft colW s!"{totalRight}ms" ++ " " ++
47+
padLeft colW (ratioString totalRight totalLeft))
48+
49+
end PartIiProject.Bench
50+

0 commit comments

Comments
 (0)