Skip to content

Commit cf1619f

Browse files
Shuffle scalable vector in CodeGen_ARM
By design, LLVM shufflevector doesn't accept scalable vectors. So, we try to use llvm.vector.xx intrinsic where possible. However, those are not enough to cover wide usage of shuffles in Halide. To handle arbitrary index pattern, we decompose a shuffle operation to a sequence of multiple native shuffles, which are lowered to Arm SVE2 intrinsic TBL or TBL2. Another approach could be to perform shuffle in fixed sized vector by adding conversion between scalable vector and fixed vector. However, it seems to be only possible via load/store memory, which would presumably be poor performance. This change also includes: - Peep-hole the particular predicate pattern to emit WHILELT instruction - Shuffle 1bit type scalable vectors as 8bit with type casts - Peep-hole concat_vectors for padding to align up vector - Fix redundant broadcast in CodeGen_LLVM
1 parent bfd9535 commit cf1619f

File tree

6 files changed

+636
-36
lines changed

6 files changed

+636
-36
lines changed

src/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,7 @@ target_sources(
276276
Debug.cpp
277277
DebugArguments.cpp
278278
DebugToFile.cpp
279+
DecomposeVectorShuffle.cpp
279280
Definition.cpp
280281
Deinterleave.cpp
281282
Derivative.cpp

src/CodeGen_ARM.cpp

Lines changed: 226 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include "CodeGen_Posix.h"
77
#include "ConciseCasts.h"
88
#include "Debug.h"
9+
#include "DecomposeVectorShuffle.h"
910
#include "DistributeShifts.h"
1011
#include "IREquality.h"
1112
#include "IRMatch.h"
@@ -20,6 +21,7 @@
2021
namespace Halide {
2122
namespace Internal {
2223

24+
using std::optional;
2325
using std::ostringstream;
2426
using std::pair;
2527
using std::string;
@@ -217,6 +219,9 @@ class CodeGen_ARM : public CodeGen_Posix {
217219

218220
Value *interleave_vectors(const std::vector<Value *> &) override;
219221
Value *shuffle_vectors(Value *a, Value *b, const std::vector<int> &indices) override;
222+
Value *shuffle_scalable_vectors_general(Value *a, Value *b, const std::vector<int> &indices);
223+
Value *codegen_shuffle_indices(int bits, const std::vector<int> &indices);
224+
Value *codegen_whilelt(int total_lanes, int start, int end);
220225
void codegen_vector_reduce(const VectorReduce *, const Expr &) override;
221226
bool codegen_dot_product_vector_reduce(const VectorReduce *, const Expr &);
222227
bool codegen_pairwise_vector_reduce(const VectorReduce *, const Expr &);
@@ -237,6 +242,7 @@ class CodeGen_ARM : public CodeGen_Posix {
237242
};
238243
vector<Pattern> casts, calls, negations;
239244

245+
int natural_vector_size(const Halide::Type &t) const;
240246
string mcpu_target() const override;
241247
string mcpu_tune() const override;
242248
string mattrs() const override;
@@ -267,6 +273,37 @@ class CodeGen_ARM : public CodeGen_Posix {
267273
return Shuffle::make_concat({const_true(true_lanes), const_false(false_lanes)});
268274
}
269275
}
276+
277+
/** Handle general shuffle of vectors. See DecomposeVectorShuffle.h about how it works */
278+
struct VectorShuffler : public DecomposeVectorShuffle<VectorShuffler, Value *> {
279+
VectorShuffler(Value *src_a, Value *src_b, const vector<int> &indices, int vl, CodeGen_ARM &codegen)
280+
: DecomposeVectorShuffle(src_a, src_b, indices, vl), codegen(codegen) {
281+
}
282+
283+
int get_vec_length(Value *v) {
284+
return codegen.get_vector_num_elements(v->getType());
285+
}
286+
287+
Value *align_up_vector(Value *v, int align) {
288+
size_t org_len = get_vec_length(v);
289+
return codegen.slice_vector(v, 0, align_up(org_len, align));
290+
}
291+
292+
Value *slice_vec(Value *v, int start, size_t lanes) {
293+
return codegen.slice_vector(v, start, lanes);
294+
}
295+
296+
Value *concat_vecs(const vector<Value *> &vecs) {
297+
return codegen.concat_vectors(vecs);
298+
}
299+
300+
Value *shuffle_vl_aligned(Value *a, optional<Value *> &b, const vector<int> &indices, int vl) {
301+
return codegen.shuffle_scalable_vectors_general(a, b.value_or(nullptr), indices);
302+
}
303+
304+
private:
305+
CodeGen_ARM &codegen;
306+
};
270307
};
271308

272309
CodeGen_ARM::CodeGen_ARM(const Target &target)
@@ -1981,9 +2018,71 @@ void CodeGen_ARM::visit(const Shuffle *op) {
19812018

19822019
value = codegen_dense_vector_load(load, nullptr, /* slice_to_native */ false);
19832020
value = CodeGen_Posix::shuffle_vectors(value, op->indices);
1984-
} else {
2021+
return;
2022+
}
2023+
2024+
if (target_vscale() == 0) {
19852025
CodeGen_Posix::visit(op);
2026+
return;
19862027
}
2028+
2029+
const int total_lanes = op->type.lanes();
2030+
if (op->type.bits() == 1) {
2031+
// Peep-hole pattern that matches SVE "whilelt" which represents particular pattern of
2032+
// vector predicate. e.g. 11100000 (active_lanes=3, all_lanes=8)
2033+
if (op->is_concat() && op->vectors.size() == 2 &&
2034+
op->type.is_int_or_uint() &&
2035+
is_power_of_two(total_lanes) &&
2036+
total_lanes >= 2 * target_vscale() && total_lanes <= 16 * target_vscale() &&
2037+
is_const_one(op->vectors[0]) && is_const_zero(op->vectors[1])) {
2038+
2039+
int active_lanes = op->vectors[0].type().lanes();
2040+
value = codegen_whilelt(op->type.lanes(), 0, active_lanes);
2041+
return;
2042+
} else {
2043+
// Rewrite to process 1bit type vector as 8 bit vector, and then cast back
2044+
std::vector<Expr> vecs_i8;
2045+
vecs_i8.reserve(op->vectors.size());
2046+
for (const auto &vec_i1 : op->vectors) {
2047+
Type upgraded_type = vec_i1.type().with_bits(8);
2048+
vecs_i8.emplace_back(Cast::make(upgraded_type, vec_i1));
2049+
}
2050+
Expr equiv = Shuffle::make(vecs_i8, op->indices);
2051+
equiv = Cast::make(op->type, equiv);
2052+
equiv = common_subexpression_elimination(equiv);
2053+
value = codegen(equiv);
2054+
return;
2055+
}
2056+
} else if (op->is_concat() && op->vectors.size() == 2) {
2057+
// Here, we deal with some specific patterns of concat(a, b).
2058+
// Others are decomposed by CodeGen_LLVM at first,
2059+
// which in turn calles CodeGen_ARM::concat_vectors().
2060+
2061+
if (const Broadcast *bc_1 = op->vectors[1].as<Broadcast>()) {
2062+
// Common pattern where padding is appended to align lanes.
2063+
// Create broadcast of padding with dst lanes, then insert vec[0] at lane 0.
2064+
Value *val_0 = codegen(op->vectors[0]);
2065+
Value *val_1_scalar = codegen(bc_1->value);
2066+
Value *padding = builder->CreateVectorSplat(llvm::ElementCount::getScalable(total_lanes / target_vscale()), val_1_scalar);
2067+
value = insert_scalable_vector(padding, val_0, 0);
2068+
return;
2069+
}
2070+
} else if (op->is_broadcast()) {
2071+
// Undo simplification to avoid arbitrary-indexed shuffle
2072+
Expr equiv;
2073+
for (int f = 0; f < op->broadcast_factor(); ++f) {
2074+
if (equiv.defined()) {
2075+
equiv = Shuffle::make_concat({equiv, op->vectors[0]});
2076+
} else {
2077+
equiv = op->vectors[0];
2078+
}
2079+
}
2080+
equiv = common_subexpression_elimination(equiv);
2081+
value = codegen(equiv);
2082+
return;
2083+
}
2084+
2085+
CodeGen_Posix::visit(op);
19872086
}
19882087

19892088
llvm::Type *CodeGen_ARM::get_vector_type_from_value(Value *vec_or_scalar, int n) {
@@ -2186,52 +2285,139 @@ Value *CodeGen_ARM::shuffle_vectors(Value *a, Value *b, const std::vector<int> &
21862285
}
21872286

21882287
internal_assert(a->getType() == b->getType());
2288+
llvm::Type *src_type = a->getType();
2289+
llvm::Type *elt = get_vector_element_type(src_type);
2290+
const int bits = elt->getScalarSizeInBits();
2291+
// note: lanes are multiplied by vscale
2292+
const int natural_lanes = natural_vector_size(Int(bits));
2293+
const int src_lanes = get_vector_num_elements(src_type);
2294+
const int dst_lanes = indices.size();
2295+
2296+
if (src_type->isVectorTy()) {
2297+
// i1 -> shuffle with i8 -> i1
2298+
if (src_type->getScalarSizeInBits() == 1) {
2299+
internal_assert(src_type->isIntegerTy()) << "1 bit floating point type is unexpected\n";
2300+
a = builder->CreateIntCast(a, VectorType::get(i8_t, dyn_cast<llvm::VectorType>(src_type)), false);
2301+
b = builder->CreateIntCast(b, VectorType::get(i8_t, dyn_cast<llvm::VectorType>(src_type)), false);
2302+
Value *v = shuffle_vectors(a, b, indices);
2303+
return builder->CreateIntCast(v, VectorType::get(i1_t, dyn_cast<llvm::VectorType>(v->getType())), false);
2304+
}
2305+
2306+
// Check if deinterleaved slice
2307+
{
2308+
// Get the stride of slice
2309+
int slice_stride = 0;
2310+
const int start_index = indices[0];
2311+
if (dst_lanes > 1) {
2312+
const int stride = indices[1] - start_index;
2313+
bool stride_equal = true;
2314+
for (int i = 2; i < dst_lanes; ++i) {
2315+
stride_equal &= (indices[i] == start_index + i * stride);
2316+
}
2317+
slice_stride = stride_equal ? stride : 0;
2318+
}
21892319

2320+
// Lower slice with stride into llvm.vector.deinterleave intrinsic
2321+
const std::set<int> supported_strides{2, 3, 4, 8};
2322+
if (supported_strides.find(slice_stride) != supported_strides.end() &&
2323+
dst_lanes * slice_stride == src_lanes &&
2324+
indices.front() < slice_stride && // Start position cannot be larger than stride
2325+
is_power_of_two(dst_lanes) &&
2326+
dst_lanes % target_vscale() == 0 &&
2327+
dst_lanes / target_vscale() > 1) {
2328+
2329+
std::string instr = concat_strings("llvm.vector.deinterleave", slice_stride, mangle_llvm_type(a->getType()));
2330+
2331+
// We cannot mix FixedVector and ScalableVector, so dst_type must be scalable
2332+
llvm::Type *dst_type = get_vector_type(elt, dst_lanes / target_vscale(), VectorTypeConstraint::VScale);
2333+
StructType *sret_type = StructType::get(*context, std::vector(slice_stride, dst_type));
2334+
std::vector<llvm::Type *> arg_types{a->getType()};
2335+
llvm::FunctionType *fn_type = FunctionType::get(sret_type, arg_types, false);
2336+
FunctionCallee fn = module->getOrInsertFunction(instr, fn_type);
2337+
2338+
CallInst *deinterleave = builder->CreateCall(fn, {a});
2339+
// extract one element out of the returned struct
2340+
Value *extracted = builder->CreateExtractValue(deinterleave, indices.front());
2341+
2342+
return extracted;
2343+
}
2344+
}
2345+
}
2346+
2347+
// Perform vector shuffle by decomposing the operation to multiple native shuffle steps
2348+
// which calls shuffle_scalable_vectors_general() which emits TBL/TBL2 instruction
2349+
VectorShuffler shuffler(a, b, indices, natural_lanes, *this);
2350+
Value *v = shuffler.shuffle();
2351+
return v;
2352+
}
2353+
2354+
Value *CodeGen_ARM::shuffle_scalable_vectors_general(Value *a, Value *b, const std::vector<int> &indices) {
21902355
llvm::Type *elt = get_vector_element_type(a->getType());
2356+
const int bits = elt->getScalarSizeInBits();
2357+
const int natural_lanes = natural_vector_size(Int(bits));
21912358
const int src_lanes = get_vector_num_elements(a->getType());
21922359
const int dst_lanes = indices.size();
2360+
llvm::Type *dst_type = get_vector_type(elt, dst_lanes);
21932361

2194-
// Check if deinterleaved slice
2195-
{
2196-
// Get the stride of slice
2197-
int slice_stride = 0;
2198-
const int start_index = indices[0];
2199-
if (dst_lanes > 1) {
2200-
const int stride = indices[1] - start_index;
2201-
bool stride_equal = true;
2202-
for (int i = 2; i < dst_lanes; ++i) {
2203-
stride_equal &= (indices[i] == start_index + i * stride);
2204-
}
2205-
slice_stride = stride_equal ? stride : 0;
2206-
}
2362+
internal_assert(target_vscale() > 0 && is_scalable_vector(a)) << "Only deal with scalable vectors\n";
2363+
internal_assert(src_lanes == natural_lanes && dst_lanes == natural_lanes)
2364+
<< "Only deal with vector with natural_lanes\n";
22072365

2208-
// Lower slice with stride into llvm.vector.deinterleave intrinsic
2209-
const std::set<int> supported_strides{2, 3, 4, 8};
2210-
if (supported_strides.find(slice_stride) != supported_strides.end() &&
2211-
dst_lanes * slice_stride == src_lanes &&
2212-
indices.front() < slice_stride && // Start position cannot be larger than stride
2213-
is_power_of_two(dst_lanes) &&
2214-
dst_lanes % target_vscale() == 0 &&
2215-
dst_lanes / target_vscale() > 1) {
2366+
// We select TBL or TBL2 intrinsic depending on indices range
2367+
bool use_tbl = *std::max_element(indices.begin(), indices.end()) < src_lanes;
2368+
internal_assert(use_tbl || b) << "'b' must be valid in case of tbl2\n";
22162369

2217-
std::string instr = concat_strings("llvm.vector.deinterleave", slice_stride, mangle_llvm_type(a->getType()));
2370+
auto instr = concat_strings("llvm.aarch64.sve.", use_tbl ? "tbl" : "tbl2", mangle_llvm_type(dst_type));
22182371

2219-
// We cannot mix FixedVector and ScalableVector, so dst_type must be scalable
2220-
llvm::Type *dst_type = get_vector_type(elt, dst_lanes / target_vscale(), VectorTypeConstraint::VScale);
2221-
StructType *sret_type = StructType::get(*context, std::vector(slice_stride, dst_type));
2222-
std::vector<llvm::Type *> arg_types{a->getType()};
2223-
llvm::FunctionType *fn_type = FunctionType::get(sret_type, arg_types, false);
2224-
FunctionCallee fn = module->getOrInsertFunction(instr, fn_type);
2372+
Value *val_indices = codegen_shuffle_indices(bits, indices);
2373+
llvm::Type *vt_natural = get_vector_type(elt, natural_lanes);
2374+
std::vector<llvm::Type *> llvm_arg_types;
2375+
std::vector<llvm::Value *> llvm_arg_vals;
2376+
if (use_tbl) {
2377+
llvm_arg_types = {vt_natural, val_indices->getType()};
2378+
llvm_arg_vals = {a, val_indices};
2379+
} else {
2380+
llvm_arg_types = {vt_natural, vt_natural, val_indices->getType()};
2381+
llvm_arg_vals = {a, b, val_indices};
2382+
}
2383+
llvm::FunctionType *fn_type = FunctionType::get(vt_natural, llvm_arg_types, false);
2384+
FunctionCallee fn = module->getOrInsertFunction(instr, fn_type);
22252385

2226-
CallInst *deinterleave = builder->CreateCall(fn, {a});
2227-
// extract one element out of the returned struct
2228-
Value *extracted = builder->CreateExtractValue(deinterleave, indices.front());
2386+
Value *v = builder->CreateCall(fn, llvm_arg_vals);
2387+
return v;
2388+
}
22292389

2230-
return extracted;
2231-
}
2390+
Value *CodeGen_ARM::codegen_shuffle_indices(int bits, const std::vector<int> &indices) {
2391+
const int lanes = indices.size();
2392+
llvm::Type *index_type = IntegerType::get(module->getContext(), bits);
2393+
llvm::Type *index_vec_type = get_vector_type(index_type, lanes);
2394+
2395+
std::vector<Constant *> llvm_indices(lanes);
2396+
for (int i = 0; i < lanes; i++) {
2397+
int idx = indices[i];
2398+
llvm_indices[i] = idx >= 0 ? ConstantInt::get(index_type, idx) : UndefValue::get(index_type);
22322399
}
22332400

2234-
return CodeGen_Posix::shuffle_vectors(a, b, indices);
2401+
Value *v = ConstantVector::get(llvm_indices);
2402+
v = builder->CreateInsertVector(index_vec_type, UndefValue::get(index_vec_type),
2403+
v, ConstantInt::get(i64_t, 0));
2404+
return v;
2405+
}
2406+
2407+
Value *CodeGen_ARM::codegen_whilelt(int total_lanes, int start, int end) {
2408+
// Generates SVE "whilelt" instruction which represents vector predicate pattern of
2409+
// e.g. 11100000 (total_lanes = 8 , start = 0, end = 3)
2410+
// -> @llvm.aarch64.sve.whilelt.nxv8i1.i32(i32 0, i32 3)
2411+
internal_assert(target_vscale() > 0);
2412+
internal_assert(total_lanes % target_vscale() == 0);
2413+
std::string instr = concat_strings("llvm.aarch64.sve.whilelt.nxv", total_lanes / target_vscale(), "i1.i32");
2414+
2415+
llvm::Type *pred_type = get_vector_type(llvm_type_of(Int(1)), total_lanes);
2416+
llvm::FunctionType *fn_type = FunctionType::get(pred_type, {i32_t, i32_t}, false);
2417+
FunctionCallee fn = module->getOrInsertFunction(instr, fn_type);
2418+
2419+
value = builder->CreateCall(fn, {ConstantInt::get(i32_t, start), ConstantInt::get(i32_t, end)});
2420+
return value;
22352421
}
22362422

22372423
void CodeGen_ARM::visit(const Ramp *op) {
@@ -2659,6 +2845,11 @@ Type CodeGen_ARM::upgrade_type_for_storage(const Type &t) const {
26592845
return CodeGen_Posix::upgrade_type_for_storage(t);
26602846
}
26612847

2848+
int CodeGen_ARM::natural_vector_size(const Halide::Type &t) const {
2849+
internal_assert(t.bits() > 1) << "natural_vector_size requested with 1 bits\n";
2850+
return native_vector_bits() / t.bits();
2851+
}
2852+
26622853
string CodeGen_ARM::mcpu_target() const {
26632854
if (target.bits == 32) {
26642855
if (target.has_feature(Target::ARMv7s)) {

src/CodeGen_LLVM.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4115,7 +4115,9 @@ void CodeGen_LLVM::visit(const Shuffle *op) {
41154115
} else {
41164116
internal_assert(op->indices[0] == 0);
41174117
}
4118-
value = create_broadcast(value, op->indices.size());
4118+
if (op->indices.size() > 1) {
4119+
value = create_broadcast(value, op->indices.size());
4120+
}
41194121
return;
41204122
}
41214123
}

0 commit comments

Comments
 (0)