@@ -55,9 +55,9 @@ class GEGLU:
5555 @util .autotune (geglu_configs (), ["m_size" , "k_size" , "x_size" ])
5656 @triton .heuristics (
5757 {
58- "require_m_boundary_check" : lambda args : args ["m_size" ] % args ["m_block_size" ] == 0 ,
59- "require_k_boundary_check" : lambda args : args ["k_size" ] % args ["k_block_size" ] == 0 ,
60- "require_x_boundary_check" : lambda args : args ["x_size" ] % args ["x_block_size" ] == 0 ,
58+ "require_m_boundary_check" : lambda args : args ["m_size" ] % args ["m_block_size" ],
59+ "require_k_boundary_check" : lambda args : args ["k_size" ] % args ["k_block_size" ],
60+ "require_x_boundary_check" : lambda args : args ["x_size" ] % args ["x_block_size" ],
6161 }
6262 )
6363 @triton .jit
@@ -71,7 +71,6 @@ def forward(
7171 n_size : tl .int32 ,
7272 k_size : tl .int32 ,
7373 x_size : tl .int32 ,
74- input_batch_stride : tl .int32 ,
7574 input_m_stride : tl .int32 ,
7675 input_k_stride : tl .int32 ,
7776 weight_n_stride : tl .int32 ,
@@ -89,31 +88,30 @@ def forward(
8988 num_m_blocks = tl .cdiv (m_size , m_block_size )
9089 num_x_blocks = tl .cdiv (x_size , x_block_size )
9190 num_blocks = num_m_blocks * num_x_blocks
92- batch = pid // num_blocks
9391 block = pid % num_blocks
9492 m_block = block // num_x_blocks
9593 x_block = block % num_x_blocks
9694 m_offset = m_block * m_block_size
9795 x_offset = x_block * x_block_size
9896
9997 output_block_ptr = tl .make_block_ptr (
100- output_ptr + batch * m_size * x_size ,
98+ output_ptr ,
10199 shape = (m_size , x_size ),
102100 strides = (x_size , 1 ),
103101 offsets = (m_offset , x_offset ),
104102 block_shape = (m_block_size , x_block_size ),
105103 order = (1 , 0 ),
106104 )
107105 state_block_ptr = tl .make_block_ptr (
108- state_gate_ptr + batch * m_size * n_size ,
106+ state_gate_ptr ,
109107 shape = (m_size , n_size ),
110108 strides = (n_size , 1 ),
111109 offsets = (m_offset , x_offset ),
112110 block_shape = (m_block_size , x_block_size ),
113111 order = (1 , 0 ),
114112 )
115113 gate_block_ptr = tl .make_block_ptr (
116- state_gate_ptr + batch * m_size * n_size ,
114+ state_gate_ptr ,
117115 shape = (m_size , n_size ),
118116 strides = (n_size , 1 ),
119117 offsets = (m_offset , x_offset + x_size ),
@@ -122,7 +120,7 @@ def forward(
122120 )
123121
124122 state = language .Linear .forward (
125- input_ptr + batch * input_batch_stride ,
123+ input_ptr ,
126124 weight_ptr ,
127125 bias_ptr ,
128126 m_size ,
@@ -144,7 +142,7 @@ def forward(
144142 dtype ,
145143 )
146144 gate = language .Linear .forward (
147- input_ptr + batch * input_batch_stride ,
145+ input_ptr ,
148146 weight_ptr ,
149147 bias_ptr ,
150148 m_size ,
@@ -167,16 +165,21 @@ def forward(
167165 )
168166 output = state * language .math .GELU .forward (gate )
169167
170- if require_m_boundary_check & require_x_boundary_check :
171- tl .store (output_block_ptr , output .to (dtype ))
172- tl .store (state_block_ptr , state .to (dtype ))
173- tl .store (gate_block_ptr , gate .to (dtype ))
174- else :
168+ if require_m_boundary_check | require_x_boundary_check :
175169 tl .store (output_block_ptr , output .to (dtype ), boundary_check = (0 , 1 ))
176170 tl .store (state_block_ptr , state .to (dtype ), boundary_check = (0 , 1 ))
177171 tl .store (gate_block_ptr , gate .to (dtype ), boundary_check = (0 , 1 ))
172+ else :
173+ tl .store (output_block_ptr , output .to (dtype ))
174+ tl .store (state_block_ptr , state .to (dtype ))
175+ tl .store (gate_block_ptr , gate .to (dtype ))
178176
179177 @staticmethod
178+ @triton .heuristics (
179+ {
180+ "require_x_boundary_check" : lambda args : args ["x_size" ] % args ["x_block_size" ],
181+ }
182+ )
180183 @triton .jit
181184 def backward (
182185 grad_state_gate_ptr : tl .tensor ,
@@ -187,56 +190,67 @@ def backward(
187190 x_size : tl .int32 ,
188191 dtype : tl .constexpr ,
189192 x_block_size : tl .constexpr ,
193+ require_x_boundary_check : tl .constexpr ,
190194 ):
191195 pid = tl .program_id (0 )
192- batch = pid // m_size
193196 m_offset = pid % m_size
194197
195198 grad_state_block_ptr = tl .make_block_ptr (
196- grad_state_gate_ptr + batch * m_size * n_size ,
199+ grad_state_gate_ptr ,
197200 shape = (m_size , n_size ),
198201 strides = (n_size , 1 ),
199202 offsets = (m_offset , 0 ),
200203 block_shape = (1 , x_block_size ),
201204 order = (1 , 0 ),
202205 )
203206 grad_gate_block_ptr = tl .make_block_ptr (
204- grad_state_gate_ptr + batch * m_size * n_size ,
207+ grad_state_gate_ptr ,
205208 shape = (m_size , n_size ),
206209 strides = (n_size , 1 ),
207210 offsets = (m_offset , x_size ),
208211 block_shape = (1 , x_block_size ),
209212 order = (1 , 0 ),
210213 )
211214 grad_output_block_ptr = tl .make_block_ptr (
212- grad_output_ptr + batch * m_size * x_size ,
215+ grad_output_ptr ,
213216 shape = (m_size , x_size ),
214217 strides = (x_size , 1 ),
215218 offsets = (m_offset , 0 ),
216219 block_shape = (1 , x_block_size ),
217220 order = (1 , 0 ),
218221 )
219222 state_block_ptr = tl .make_block_ptr (
220- state_gate_ptr + batch * m_size * n_size ,
223+ state_gate_ptr ,
221224 shape = (m_size , n_size ),
222225 strides = (n_size , 1 ),
223226 offsets = (m_offset , 0 ),
224227 block_shape = (1 , x_block_size ),
225228 order = (1 , 0 ),
226229 )
227230 gate_block_ptr = tl .make_block_ptr (
228- state_gate_ptr + batch * m_size * n_size ,
231+ state_gate_ptr ,
229232 shape = (m_size , n_size ),
230233 strides = (n_size , 1 ),
231234 offsets = (m_offset , x_size ),
232235 block_shape = (1 , x_block_size ),
233236 order = (1 , 0 ),
234237 )
235238
236- grad_output = tl .load (grad_output_block_ptr , boundary_check = (1 ,))
237- state = tl .load (state_block_ptr , boundary_check = (1 ,))
238- gate = tl .load (gate_block_ptr , boundary_check = (1 ,))
239+ if require_x_boundary_check :
240+ grad_output = tl .load (grad_output_block_ptr , boundary_check = (1 ,))
241+ state = tl .load (state_block_ptr , boundary_check = (1 ,))
242+ gate = tl .load (gate_block_ptr , boundary_check = (1 ,))
243+ else :
244+ grad_output = tl .load (grad_output_block_ptr )
245+ state = tl .load (state_block_ptr )
246+ gate = tl .load (gate_block_ptr )
247+
239248 grad_state = grad_output * language .math .GELU .forward (gate )
240249 grad_gate = language .math .GELU .backward (grad_output * state , gate )
241- tl .store (grad_state_block_ptr , grad_state .to (dtype ), boundary_check = (1 ,))
242- tl .store (grad_gate_block_ptr , grad_gate .to (dtype ), boundary_check = (1 ,))
250+
251+ if require_x_boundary_check :
252+ tl .store (grad_state_block_ptr , grad_state .to (dtype ), boundary_check = (1 ,))
253+ tl .store (grad_gate_block_ptr , grad_gate .to (dtype ), boundary_check = (1 ,))
254+ else :
255+ tl .store (grad_state_block_ptr , grad_state .to (dtype ))
256+ tl .store (grad_gate_block_ptr , grad_gate .to (dtype ))
0 commit comments