Skip to content
This repository was archived by the owner on Oct 16, 2023. It is now read-only.

Commit c6908ec

Browse files
committed
Optimize Linear and GEGLU
1 parent cf768d2 commit c6908ec

File tree

6 files changed

+101
-112
lines changed

6 files changed

+101
-112
lines changed

tests/test_linear.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ def test_forward(num_batches, m_size, n_size, k_size, device):
3131
assert util.equal(torch.nn.functional.linear(input, weight, bias), trident.function.linear(input, weight, bias))
3232

3333
input = input.permute(0, 2, 1)
34-
weight = weight.permute(1, 0)
3534

3635
assert util.equal(torch.nn.functional.linear(input, weight), trident.function.linear(input, weight))
3736

@@ -56,7 +55,6 @@ def train(func):
5655
assert util.equal(y, b)
5756

5857
input = input.permute(0, 2, 1).reshape(num_batches, m_size, k_size)
59-
weight = weight.permute(1, 0).reshape(n_size, k_size)
6058

6159
(x, y) = train(torch.nn.functional.linear)
6260
(a, b) = train(trident.function.linear)

trident/function/function.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -74,10 +74,12 @@ def geglu(input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor = None,
7474
See GEGLU for details.
7575
"""
7676
if input.dim() == 2:
77-
output = operation.GEGLU.apply(input.view(1, *input.shape), weight, bias, use_accelerator)
78-
return output.view(output.shape[1:3])
79-
else:
8077
return operation.GEGLU.apply(input, weight, bias, use_accelerator)
78+
else:
79+
input = input if input.is_contiguous() else input.contiguous()
80+
output_shape = (*input.shape[0:2], weight.shape[0] // 2)
81+
output = operation.GEGLU.apply(input.view(-1, input.shape[2]), weight, bias, use_accelerator)
82+
return output.view(*output_shape)
8183

8284

8385
def gelu(input: torch.Tensor):
@@ -156,11 +158,7 @@ def linear(
156158
157159
See Linear for more details.
158160
"""
159-
if input.dim() == 2:
160-
output = operation.Linear.apply(input.view(1, *input.shape), weight, bias, use_accelerator)
161-
return output.view(output.shape[1:3])
162-
else:
163-
return operation.Linear.apply(input, weight, bias, use_accelerator)
161+
return operation.Linear.apply(input, weight, bias, use_accelerator)
164162

165163

166164
def max(input: torch.Tensor, dim: int):

trident/kernel/geglu.py

Lines changed: 40 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,9 @@ class GEGLU:
5555
@util.autotune(geglu_configs(), ["m_size", "k_size", "x_size"])
5656
@triton.heuristics(
5757
{
58-
"require_m_boundary_check": lambda args: args["m_size"] % args["m_block_size"] == 0,
59-
"require_k_boundary_check": lambda args: args["k_size"] % args["k_block_size"] == 0,
60-
"require_x_boundary_check": lambda args: args["x_size"] % args["x_block_size"] == 0,
58+
"require_m_boundary_check": lambda args: args["m_size"] % args["m_block_size"],
59+
"require_k_boundary_check": lambda args: args["k_size"] % args["k_block_size"],
60+
"require_x_boundary_check": lambda args: args["x_size"] % args["x_block_size"],
6161
}
6262
)
6363
@triton.jit
@@ -71,7 +71,6 @@ def forward(
7171
n_size: tl.int32,
7272
k_size: tl.int32,
7373
x_size: tl.int32,
74-
input_batch_stride: tl.int32,
7574
input_m_stride: tl.int32,
7675
input_k_stride: tl.int32,
7776
weight_n_stride: tl.int32,
@@ -89,31 +88,30 @@ def forward(
8988
num_m_blocks = tl.cdiv(m_size, m_block_size)
9089
num_x_blocks = tl.cdiv(x_size, x_block_size)
9190
num_blocks = num_m_blocks * num_x_blocks
92-
batch = pid // num_blocks
9391
block = pid % num_blocks
9492
m_block = block // num_x_blocks
9593
x_block = block % num_x_blocks
9694
m_offset = m_block * m_block_size
9795
x_offset = x_block * x_block_size
9896

9997
output_block_ptr = tl.make_block_ptr(
100-
output_ptr + batch * m_size * x_size,
98+
output_ptr,
10199
shape=(m_size, x_size),
102100
strides=(x_size, 1),
103101
offsets=(m_offset, x_offset),
104102
block_shape=(m_block_size, x_block_size),
105103
order=(1, 0),
106104
)
107105
state_block_ptr = tl.make_block_ptr(
108-
state_gate_ptr + batch * m_size * n_size,
106+
state_gate_ptr,
109107
shape=(m_size, n_size),
110108
strides=(n_size, 1),
111109
offsets=(m_offset, x_offset),
112110
block_shape=(m_block_size, x_block_size),
113111
order=(1, 0),
114112
)
115113
gate_block_ptr = tl.make_block_ptr(
116-
state_gate_ptr + batch * m_size * n_size,
114+
state_gate_ptr,
117115
shape=(m_size, n_size),
118116
strides=(n_size, 1),
119117
offsets=(m_offset, x_offset + x_size),
@@ -122,7 +120,7 @@ def forward(
122120
)
123121

124122
state = language.Linear.forward(
125-
input_ptr + batch * input_batch_stride,
123+
input_ptr,
126124
weight_ptr,
127125
bias_ptr,
128126
m_size,
@@ -144,7 +142,7 @@ def forward(
144142
dtype,
145143
)
146144
gate = language.Linear.forward(
147-
input_ptr + batch * input_batch_stride,
145+
input_ptr,
148146
weight_ptr,
149147
bias_ptr,
150148
m_size,
@@ -167,16 +165,21 @@ def forward(
167165
)
168166
output = state * language.math.GELU.forward(gate)
169167

170-
if require_m_boundary_check & require_x_boundary_check:
171-
tl.store(output_block_ptr, output.to(dtype))
172-
tl.store(state_block_ptr, state.to(dtype))
173-
tl.store(gate_block_ptr, gate.to(dtype))
174-
else:
168+
if require_m_boundary_check | require_x_boundary_check:
175169
tl.store(output_block_ptr, output.to(dtype), boundary_check=(0, 1))
176170
tl.store(state_block_ptr, state.to(dtype), boundary_check=(0, 1))
177171
tl.store(gate_block_ptr, gate.to(dtype), boundary_check=(0, 1))
172+
else:
173+
tl.store(output_block_ptr, output.to(dtype))
174+
tl.store(state_block_ptr, state.to(dtype))
175+
tl.store(gate_block_ptr, gate.to(dtype))
178176

179177
@staticmethod
178+
@triton.heuristics(
179+
{
180+
"require_x_boundary_check": lambda args: args["x_size"] % args["x_block_size"],
181+
}
182+
)
180183
@triton.jit
181184
def backward(
182185
grad_state_gate_ptr: tl.tensor,
@@ -187,56 +190,67 @@ def backward(
187190
x_size: tl.int32,
188191
dtype: tl.constexpr,
189192
x_block_size: tl.constexpr,
193+
require_x_boundary_check: tl.constexpr,
190194
):
191195
pid = tl.program_id(0)
192-
batch = pid // m_size
193196
m_offset = pid % m_size
194197

195198
grad_state_block_ptr = tl.make_block_ptr(
196-
grad_state_gate_ptr + batch * m_size * n_size,
199+
grad_state_gate_ptr,
197200
shape=(m_size, n_size),
198201
strides=(n_size, 1),
199202
offsets=(m_offset, 0),
200203
block_shape=(1, x_block_size),
201204
order=(1, 0),
202205
)
203206
grad_gate_block_ptr = tl.make_block_ptr(
204-
grad_state_gate_ptr + batch * m_size * n_size,
207+
grad_state_gate_ptr,
205208
shape=(m_size, n_size),
206209
strides=(n_size, 1),
207210
offsets=(m_offset, x_size),
208211
block_shape=(1, x_block_size),
209212
order=(1, 0),
210213
)
211214
grad_output_block_ptr = tl.make_block_ptr(
212-
grad_output_ptr + batch * m_size * x_size,
215+
grad_output_ptr,
213216
shape=(m_size, x_size),
214217
strides=(x_size, 1),
215218
offsets=(m_offset, 0),
216219
block_shape=(1, x_block_size),
217220
order=(1, 0),
218221
)
219222
state_block_ptr = tl.make_block_ptr(
220-
state_gate_ptr + batch * m_size * n_size,
223+
state_gate_ptr,
221224
shape=(m_size, n_size),
222225
strides=(n_size, 1),
223226
offsets=(m_offset, 0),
224227
block_shape=(1, x_block_size),
225228
order=(1, 0),
226229
)
227230
gate_block_ptr = tl.make_block_ptr(
228-
state_gate_ptr + batch * m_size * n_size,
231+
state_gate_ptr,
229232
shape=(m_size, n_size),
230233
strides=(n_size, 1),
231234
offsets=(m_offset, x_size),
232235
block_shape=(1, x_block_size),
233236
order=(1, 0),
234237
)
235238

236-
grad_output = tl.load(grad_output_block_ptr, boundary_check=(1,))
237-
state = tl.load(state_block_ptr, boundary_check=(1,))
238-
gate = tl.load(gate_block_ptr, boundary_check=(1,))
239+
if require_x_boundary_check:
240+
grad_output = tl.load(grad_output_block_ptr, boundary_check=(1,))
241+
state = tl.load(state_block_ptr, boundary_check=(1,))
242+
gate = tl.load(gate_block_ptr, boundary_check=(1,))
243+
else:
244+
grad_output = tl.load(grad_output_block_ptr)
245+
state = tl.load(state_block_ptr)
246+
gate = tl.load(gate_block_ptr)
247+
239248
grad_state = grad_output * language.math.GELU.forward(gate)
240249
grad_gate = language.math.GELU.backward(grad_output * state, gate)
241-
tl.store(grad_state_block_ptr, grad_state.to(dtype), boundary_check=(1,))
242-
tl.store(grad_gate_block_ptr, grad_gate.to(dtype), boundary_check=(1,))
250+
251+
if require_x_boundary_check:
252+
tl.store(grad_state_block_ptr, grad_state.to(dtype), boundary_check=(1,))
253+
tl.store(grad_gate_block_ptr, grad_gate.to(dtype), boundary_check=(1,))
254+
else:
255+
tl.store(grad_state_block_ptr, grad_state.to(dtype))
256+
tl.store(grad_gate_block_ptr, grad_gate.to(dtype))

trident/kernel/linear.py

Lines changed: 16 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def linear_configs_for_backward_bias():
111111

112112
class Linear:
113113
@staticmethod
114-
@util.autotune(linear_configs([16, 64, 128], [32, 64, 128], [32, 64]), ["m_size", "n_size", "k_size"])
114+
@util.autotune(linear_configs([16, 64, 128, 256], [32, 64, 128], [32, 64]), ["m_size", "n_size", "k_size"])
115115
@triton.heuristics(
116116
{
117117
"require_m_boundary_check": lambda args: args["m_size"] % args["m_block_size"],
@@ -128,7 +128,6 @@ def forward(
128128
m_size: tl.int32,
129129
n_size: tl.int32,
130130
k_size: tl.int32,
131-
input_batch_stride: tl.int32,
132131
input_m_stride: tl.int32,
133132
input_k_stride: tl.int32,
134133
weight_n_stride: tl.int32,
@@ -143,18 +142,14 @@ def forward(
143142
require_k_boundary_check: tl.constexpr,
144143
):
145144
pid = tl.program_id(0)
146-
num_m_blocks = tl.cdiv(m_size, m_block_size)
147145
num_n_blocks = tl.cdiv(n_size, n_block_size)
148-
num_blocks = num_m_blocks * num_n_blocks
149-
batch = pid // num_blocks
150-
block = pid % num_blocks
151-
m_block = block // num_n_blocks
152-
n_block = block % num_n_blocks
146+
m_block = pid // num_n_blocks
147+
n_block = pid % num_n_blocks
153148
m_offset = m_block * m_block_size
154149
n_offset = n_block * n_block_size
155150

156151
output = language.Linear.forward(
157-
input_ptr + batch * input_batch_stride,
152+
input_ptr,
158153
weight_ptr,
159154
bias_ptr,
160155
m_size,
@@ -177,7 +172,7 @@ def forward(
177172
)
178173

179174
output_block_ptr = tl.make_block_ptr(
180-
output_ptr + batch * m_size * n_size,
175+
output_ptr,
181176
shape=(m_size, n_size),
182177
strides=(n_size, 1),
183178
offsets=(m_offset, n_offset),
@@ -223,15 +218,14 @@ def backward(
223218
num_m_blocks = tl.cdiv(m_size, m_block_size)
224219
num_k_blocks = tl.cdiv(k_size, k_block_size)
225220
num_blocks = num_m_blocks * num_k_blocks
226-
batch = pid // num_blocks
227221
block = pid % num_blocks
228222
m_block = block // num_k_blocks
229223
k_block = block % num_k_blocks
230224
m_offset = m_block * m_block_size
231225
k_offset = k_block * k_block_size
232226

233227
grad_input = language.Linear.backward(
234-
grad_output_ptr + batch * m_size * n_size,
228+
grad_output_ptr,
235229
weight_ptr,
236230
m_size,
237231
n_size,
@@ -251,7 +245,7 @@ def backward(
251245
)
252246

253247
grad_input_block_ptr = tl.make_block_ptr(
254-
grad_input_ptr + batch * m_size * k_size,
248+
grad_input_ptr,
255249
shape=(m_size, k_size),
256250
strides=(input_m_stride, input_k_stride),
257251
offsets=(m_offset, k_offset),
@@ -278,13 +272,12 @@ def backward(
278272
)
279273
@triton.jit
280274
def backward_weight(
281-
grad_weight_staging_ptr: tl.tensor,
275+
grad_weight_ptr: tl.tensor,
282276
grad_output_ptr: tl.tensor,
283277
input_ptr: tl.tensor,
284278
m_size: tl.int32,
285279
n_size: tl.int32,
286280
k_size: tl.int32,
287-
input_batch_stride: tl.int32,
288281
input_m_stride: tl.int32,
289282
input_k_stride: tl.int32,
290283
use_accelerator: tl.constexpr,
@@ -300,16 +293,15 @@ def backward_weight(
300293
num_n_blocks = tl.cdiv(n_size, n_block_size)
301294
num_k_blocks = tl.cdiv(k_size, k_block_size)
302295
num_blocks = num_n_blocks * num_k_blocks
303-
batch = pid // num_blocks
304296
block = pid % num_blocks
305297
n_block = block // num_k_blocks
306298
k_block = block % num_k_blocks
307299
n_offset = n_block * n_block_size
308300
k_offset = k_block * k_block_size
309301

310302
grad_weight = language.Linear.backward_weight(
311-
grad_output_ptr + batch * m_size * n_size,
312-
input_ptr + batch * input_batch_stride,
303+
grad_output_ptr,
304+
input_ptr,
313305
m_size,
314306
n_size,
315307
k_size,
@@ -328,7 +320,7 @@ def backward_weight(
328320
)
329321

330322
grad_weight_staging_block_ptr = tl.make_block_ptr(
331-
grad_weight_staging_ptr + batch * n_size * k_size,
323+
grad_weight_ptr,
332324
shape=(n_size, k_size),
333325
strides=(k_size, 1),
334326
offsets=(n_offset, k_offset),
@@ -346,7 +338,7 @@ def backward_weight(
346338
@triton.heuristics({"require_m_boundary_check": lambda args: args["m_size"] % args["m_block_size"]})
347339
@triton.jit
348340
def backward_bias(
349-
grad_bias_staging_ptr: tl.tensor,
341+
grad_bias_ptr: tl.tensor,
350342
grad_output_ptr: tl.tensor,
351343
m_size: tl.int32,
352344
n_size: tl.int32,
@@ -355,10 +347,9 @@ def backward_bias(
355347
require_m_boundary_check: tl.constexpr,
356348
):
357349
pid = tl.program_id(0)
358-
batch = pid // n_size
359350
n_offset = pid % n_size
360351
grad_bias = language.Linear.backward_bias(
361-
grad_output_ptr + batch * m_size * n_size,
352+
grad_output_ptr,
362353
m_size,
363354
n_size,
364355
n_offset,
@@ -367,12 +358,12 @@ def backward_bias(
367358
dtype,
368359
)
369360

370-
grad_bias_staging_block_ptr = tl.make_block_ptr(
371-
grad_bias_staging_ptr + batch * n_size,
361+
grad_bias_block_ptr = tl.make_block_ptr(
362+
grad_bias_ptr,
372363
shape=(n_size,),
373364
strides=(1,),
374365
offsets=(n_offset,),
375366
block_shape=(1,),
376367
order=(0,),
377368
)
378-
tl.store(grad_bias_staging_block_ptr, grad_bias)
369+
tl.store(grad_bias_block_ptr, grad_bias)

0 commit comments

Comments
 (0)