1+ # FIXME Add copyright header
2+
3+
14import islpy as isl
5+ from islpy import dim_type
26import pymbolic .primitives as p
37
48from dataclasses import dataclass
711
812from loopy import LoopKernel
913from loopy .symbolic import WalkMapper
14+ from loopy .translation_unit import for_each_kernel
15+ from loopy .typing import ExpressionT
1016
1117@dataclass (frozen = True )
1218class HappensAfter :
@@ -32,7 +38,7 @@ def __init__(self, kernel: LoopKernel, var_names: set):
3238
3339 super .__init__ ()
3440
35- def map_subscript (self , expr : p . expression , inames : frozenset , insn_id : str ):
41+ def map_subscript (self , expr : ExpressionT , inames : frozenset , insn_id : str ):
3642
3743 domain = self .kernel .get_inames_domain (inames )
3844
@@ -110,7 +116,7 @@ def compute_happens_after(knl: LoopKernel) -> LoopKernel:
110116 # return the kernel with the new instructions
111117 return knl .copy (instructions = new_insns )
112118
113- def add_lexicographic_happens_after (knl : LoopKernel ) -> None :
119+ def add_lexicographic_happens_after_orig (knl : LoopKernel ) -> None :
114120 """
115121 TODO properly format this documentation.
116122
@@ -122,7 +128,7 @@ def add_lexicographic_happens_after(knl: LoopKernel) -> None:
122128 """
123129
124130 # we want to modify the output dimension and OUT = 3
125- dim_type = isl .dim_type ( 3 )
131+ dim_type = isl .dim_type . out
126132
127133 # generate an unordered mapping from statement instances to points in the
128134 # loop domain
@@ -148,3 +154,74 @@ def add_lexicographic_happens_after(knl: LoopKernel) -> None:
148154
149155 # determine a lexicographic order on the space the schedules belong to
150156
157+
158+ @for_each_kernel
159+ def add_lexicographic_happens_after (knl : LoopKernel ) -> LoopKernel :
160+
161+ new_insns = []
162+
163+ for iafter , insn_after in enumerate (knl .instructions ):
164+ if iafter == 0 :
165+ new_insns .append (insn_after )
166+ else :
167+ insn_before = knl .instructions [iafter - 1 ]
168+ shared_inames = insn_after .within_inames & insn_before .within_inames
169+ unshared_before = insn_before .within_inames
170+
171+ domain_before = knl .get_inames_domain (insn_before .within_inames )
172+ domain_after = knl .get_inames_domain (insn_after .within_inames )
173+
174+ happens_before = isl .Map .from_domain_and_range (
175+ domain_before , domain_after )
176+ for idim in range (happens_before .dim (dim_type .out )):
177+ happens_before = happens_before .set_dim_name (
178+ dim_type .out , idim ,
179+ happens_before .get_dim_name (dim_type .out , idim ) + "'" )
180+ n_inames_before = happens_before .dim (dim_type .in_ )
181+ happens_before_set = happens_before .move_dims (
182+ dim_type .out , 0 ,
183+ dim_type .in_ , 0 ,
184+ n_inames_before ).range ()
185+
186+ shared_inames_order_before = [
187+ domain_before .get_dim_name (dim_type .out , idim )
188+ for idim in range (domain_before .dim (dim_type .out ))
189+ if domain_before .get_dim_name (dim_type .out , idim )
190+ in shared_inames ]
191+ shared_inames_order_after = [
192+ domain_after .get_dim_name (dim_type .out , idim )
193+ for idim in range (domain_after .dim (dim_type .out ))
194+ if domain_after .get_dim_name (dim_type .out , idim )
195+ in shared_inames ]
196+
197+ assert shared_inames_order_after == shared_inames_order_before
198+ shared_inames_order = shared_inames_order_after
199+
200+ affs = isl .affs_from_space (happens_before_set .space )
201+
202+ lex_set = isl .Set .empty (happens_before_set .space )
203+ for iinnermost , innermost_iname in enumerate (shared_inames_order ):
204+ innermost_set = affs [innermost_iname ].lt_set (
205+ affs [innermost_iname + "'" ])
206+
207+ for outer_iname in shared_inames_order [:iinnermost ]:
208+ innermost_set = innermost_set & (
209+ affs [outer_iname ].eq_set (affs [outer_iname + "'" ]))
210+
211+ lex_set = lex_set | innermost_set
212+
213+ lex_map = isl .Map .from_range (lex_set ).move_dims (
214+ dim_type .in_ , 0 ,
215+ dim_type .out , 0 ,
216+ n_inames_before )
217+
218+ happens_before = happens_before & lex_map
219+
220+ pu .db
221+
222+ new_insns .append (insn_after )
223+
224+ return knl .copy (instructions = new_insns )
225+
226+
227+
0 commit comments