66#include " CodeGen_Posix.h"
77#include " ConciseCasts.h"
88#include " Debug.h"
9+ #include " DecomposeVectorShuffle.h"
910#include " DistributeShifts.h"
1011#include " IREquality.h"
1112#include " IRMatch.h"
2021namespace Halide {
2122namespace Internal {
2223
24+ using std::optional;
2325using std::ostringstream;
2426using std::pair;
2527using std::string;
@@ -217,6 +219,9 @@ class CodeGen_ARM : public CodeGen_Posix {
217219
218220 Value *interleave_vectors (const std::vector<Value *> &) override ;
219221 Value *shuffle_vectors (Value *a, Value *b, const std::vector<int > &indices) override ;
222+ Value *shuffle_scalable_vectors_general (Value *a, Value *b, const std::vector<int > &indices);
223+ Value *codegen_shuffle_indices (int bits, const std::vector<int > &indices);
224+ Value *codegen_whilelt (int total_lanes, int start, int end);
220225 void codegen_vector_reduce (const VectorReduce *, const Expr &) override ;
221226 bool codegen_dot_product_vector_reduce (const VectorReduce *, const Expr &);
222227 bool codegen_pairwise_vector_reduce (const VectorReduce *, const Expr &);
@@ -237,6 +242,7 @@ class CodeGen_ARM : public CodeGen_Posix {
237242 };
238243 vector<Pattern> casts, calls, negations;
239244
245+ int natural_vector_size (const Halide::Type &t) const ;
240246 string mcpu_target () const override ;
241247 string mcpu_tune () const override ;
242248 string mattrs () const override ;
@@ -267,6 +273,37 @@ class CodeGen_ARM : public CodeGen_Posix {
267273 return Shuffle::make_concat ({const_true (true_lanes), const_false (false_lanes)});
268274 }
269275 }
276+
277+ /* * Handle general shuffle of vectors. See DecomposeVectorShuffle.h about how it works */
278+ struct VectorShuffler : public DecomposeVectorShuffle <VectorShuffler, Value *> {
279+ VectorShuffler (Value *src_a, Value *src_b, const vector<int > &indices, int vl, CodeGen_ARM &codegen)
280+ : DecomposeVectorShuffle(src_a, src_b, indices, vl), codegen(codegen) {
281+ }
282+
283+ int get_vec_length (Value *v) {
284+ return codegen.get_vector_num_elements (v->getType ());
285+ }
286+
287+ Value *align_up_vector (Value *v, int align) {
288+ size_t org_len = get_vec_length (v);
289+ return codegen.slice_vector (v, 0 , align_up (org_len, align));
290+ }
291+
292+ Value *slice_vec (Value *v, int start, size_t lanes) {
293+ return codegen.slice_vector (v, start, lanes);
294+ }
295+
296+ Value *concat_vecs (const vector<Value *> &vecs) {
297+ return codegen.concat_vectors (vecs);
298+ }
299+
300+ Value *shuffle_vl_aligned (Value *a, optional<Value *> &b, const vector<int > &indices, int vl) {
301+ return codegen.shuffle_scalable_vectors_general (a, b.value_or (nullptr ), indices);
302+ }
303+
304+ private:
305+ CodeGen_ARM &codegen;
306+ };
270307};
271308
272309CodeGen_ARM::CodeGen_ARM (const Target &target)
@@ -1981,9 +2018,71 @@ void CodeGen_ARM::visit(const Shuffle *op) {
19812018
19822019 value = codegen_dense_vector_load (load, nullptr , /* slice_to_native */ false );
19832020 value = CodeGen_Posix::shuffle_vectors (value, op->indices );
1984- } else {
2021+ return ;
2022+ }
2023+
2024+ if (target_vscale () == 0 ) {
19852025 CodeGen_Posix::visit (op);
2026+ return ;
19862027 }
2028+
2029+ const int total_lanes = op->type .lanes ();
2030+ if (op->type .bits () == 1 ) {
2031+ // Peep-hole pattern that matches SVE "whilelt" which represents particular pattern of
2032+ // vector predicate. e.g. 11100000 (active_lanes=3, all_lanes=8)
2033+ if (op->is_concat () && op->vectors .size () == 2 &&
2034+ op->type .is_int_or_uint () &&
2035+ is_power_of_two (total_lanes) &&
2036+ total_lanes >= 2 * target_vscale () && total_lanes <= 16 * target_vscale () &&
2037+ is_const_one (op->vectors [0 ]) && is_const_zero (op->vectors [1 ])) {
2038+
2039+ int active_lanes = op->vectors [0 ].type ().lanes ();
2040+ value = codegen_whilelt (op->type .lanes (), 0 , active_lanes);
2041+ return ;
2042+ } else {
2043+ // Rewrite to process 1bit type vector as 8 bit vector, and then cast back
2044+ std::vector<Expr> vecs_i8;
2045+ vecs_i8.reserve (op->vectors .size ());
2046+ for (const auto &vec_i1 : op->vectors ) {
2047+ Type upgraded_type = vec_i1.type ().with_bits (8 );
2048+ vecs_i8.emplace_back (Cast::make (upgraded_type, vec_i1));
2049+ }
2050+ Expr equiv = Shuffle::make (vecs_i8, op->indices );
2051+ equiv = Cast::make (op->type , equiv);
2052+ equiv = common_subexpression_elimination (equiv);
2053+ value = codegen (equiv);
2054+ return ;
2055+ }
2056+ } else if (op->is_concat () && op->vectors .size () == 2 ) {
2057+ // Here, we deal with some specific patterns of concat(a, b).
2058+ // Others are decomposed by CodeGen_LLVM at first,
2059+ // which in turn calles CodeGen_ARM::concat_vectors().
2060+
2061+ if (const Broadcast *bc_1 = op->vectors [1 ].as <Broadcast>()) {
2062+ // Common pattern where padding is appended to align lanes.
2063+ // Create broadcast of padding with dst lanes, then insert vec[0] at lane 0.
2064+ Value *val_0 = codegen (op->vectors [0 ]);
2065+ Value *val_1_scalar = codegen (bc_1->value );
2066+ Value *padding = builder->CreateVectorSplat (llvm::ElementCount::getScalable (total_lanes / target_vscale ()), val_1_scalar);
2067+ value = insert_scalable_vector (padding, val_0, 0 );
2068+ return ;
2069+ }
2070+ } else if (op->is_broadcast ()) {
2071+ // Undo simplification to avoid arbitrary-indexed shuffle
2072+ Expr equiv;
2073+ for (int f = 0 ; f < op->broadcast_factor (); ++f) {
2074+ if (equiv.defined ()) {
2075+ equiv = Shuffle::make_concat ({equiv, op->vectors [0 ]});
2076+ } else {
2077+ equiv = op->vectors [0 ];
2078+ }
2079+ }
2080+ equiv = common_subexpression_elimination (equiv);
2081+ value = codegen (equiv);
2082+ return ;
2083+ }
2084+
2085+ CodeGen_Posix::visit (op);
19872086}
19882087
19892088llvm::Type *CodeGen_ARM::get_vector_type_from_value (Value *vec_or_scalar, int n) {
@@ -2186,52 +2285,139 @@ Value *CodeGen_ARM::shuffle_vectors(Value *a, Value *b, const std::vector<int> &
21862285 }
21872286
21882287 internal_assert (a->getType () == b->getType ());
2288+ llvm::Type *src_type = a->getType ();
2289+ llvm::Type *elt = get_vector_element_type (src_type);
2290+ const int bits = elt->getScalarSizeInBits ();
2291+ // note: lanes are multiplied by vscale
2292+ const int natural_lanes = natural_vector_size (Int (bits));
2293+ const int src_lanes = get_vector_num_elements (src_type);
2294+ const int dst_lanes = indices.size ();
2295+
2296+ if (src_type->isVectorTy ()) {
2297+ // i1 -> shuffle with i8 -> i1
2298+ if (src_type->getScalarSizeInBits () == 1 ) {
2299+ internal_assert (src_type->isIntegerTy ()) << " 1 bit floating point type is unexpected\n " ;
2300+ a = builder->CreateIntCast (a, VectorType::get (i8_t , dyn_cast<llvm::VectorType>(src_type)), false );
2301+ b = builder->CreateIntCast (b, VectorType::get (i8_t , dyn_cast<llvm::VectorType>(src_type)), false );
2302+ Value *v = shuffle_vectors (a, b, indices);
2303+ return builder->CreateIntCast (v, VectorType::get (i1_t , dyn_cast<llvm::VectorType>(v->getType ())), false );
2304+ }
2305+
2306+ // Check if deinterleaved slice
2307+ {
2308+ // Get the stride of slice
2309+ int slice_stride = 0 ;
2310+ const int start_index = indices[0 ];
2311+ if (dst_lanes > 1 ) {
2312+ const int stride = indices[1 ] - start_index;
2313+ bool stride_equal = true ;
2314+ for (int i = 2 ; i < dst_lanes; ++i) {
2315+ stride_equal &= (indices[i] == start_index + i * stride);
2316+ }
2317+ slice_stride = stride_equal ? stride : 0 ;
2318+ }
21892319
2320+ // Lower slice with stride into llvm.vector.deinterleave intrinsic
2321+ const std::set<int > supported_strides{2 , 3 , 4 , 8 };
2322+ if (supported_strides.find (slice_stride) != supported_strides.end () &&
2323+ dst_lanes * slice_stride == src_lanes &&
2324+ indices.front () < slice_stride && // Start position cannot be larger than stride
2325+ is_power_of_two (dst_lanes) &&
2326+ dst_lanes % target_vscale () == 0 &&
2327+ dst_lanes / target_vscale () > 1 ) {
2328+
2329+ std::string instr = concat_strings (" llvm.vector.deinterleave" , slice_stride, mangle_llvm_type (a->getType ()));
2330+
2331+ // We cannot mix FixedVector and ScalableVector, so dst_type must be scalable
2332+ llvm::Type *dst_type = get_vector_type (elt, dst_lanes / target_vscale (), VectorTypeConstraint::VScale);
2333+ StructType *sret_type = StructType::get (*context, std::vector (slice_stride, dst_type));
2334+ std::vector<llvm::Type *> arg_types{a->getType ()};
2335+ llvm::FunctionType *fn_type = FunctionType::get (sret_type, arg_types, false );
2336+ FunctionCallee fn = module ->getOrInsertFunction (instr, fn_type);
2337+
2338+ CallInst *deinterleave = builder->CreateCall (fn, {a});
2339+ // extract one element out of the returned struct
2340+ Value *extracted = builder->CreateExtractValue (deinterleave, indices.front ());
2341+
2342+ return extracted;
2343+ }
2344+ }
2345+ }
2346+
2347+ // Perform vector shuffle by decomposing the operation to multiple native shuffle steps
2348+ // which calls shuffle_scalable_vectors_general() which emits TBL/TBL2 instruction
2349+ VectorShuffler shuffler (a, b, indices, natural_lanes, *this );
2350+ Value *v = shuffler.shuffle ();
2351+ return v;
2352+ }
2353+
2354+ Value *CodeGen_ARM::shuffle_scalable_vectors_general (Value *a, Value *b, const std::vector<int > &indices) {
21902355 llvm::Type *elt = get_vector_element_type (a->getType ());
2356+ const int bits = elt->getScalarSizeInBits ();
2357+ const int natural_lanes = natural_vector_size (Int (bits));
21912358 const int src_lanes = get_vector_num_elements (a->getType ());
21922359 const int dst_lanes = indices.size ();
2360+ llvm::Type *dst_type = get_vector_type (elt, dst_lanes);
21932361
2194- // Check if deinterleaved slice
2195- {
2196- // Get the stride of slice
2197- int slice_stride = 0 ;
2198- const int start_index = indices[0 ];
2199- if (dst_lanes > 1 ) {
2200- const int stride = indices[1 ] - start_index;
2201- bool stride_equal = true ;
2202- for (int i = 2 ; i < dst_lanes; ++i) {
2203- stride_equal &= (indices[i] == start_index + i * stride);
2204- }
2205- slice_stride = stride_equal ? stride : 0 ;
2206- }
2362+ internal_assert (target_vscale () > 0 && is_scalable_vector (a)) << " Only deal with scalable vectors\n " ;
2363+ internal_assert (src_lanes == natural_lanes && dst_lanes == natural_lanes)
2364+ << " Only deal with vector with natural_lanes\n " ;
22072365
2208- // Lower slice with stride into llvm.vector.deinterleave intrinsic
2209- const std::set<int > supported_strides{2 , 3 , 4 , 8 };
2210- if (supported_strides.find (slice_stride) != supported_strides.end () &&
2211- dst_lanes * slice_stride == src_lanes &&
2212- indices.front () < slice_stride && // Start position cannot be larger than stride
2213- is_power_of_two (dst_lanes) &&
2214- dst_lanes % target_vscale () == 0 &&
2215- dst_lanes / target_vscale () > 1 ) {
2366+ // We select TBL or TBL2 intrinsic depending on indices range
2367+ bool use_tbl = *std::max_element (indices.begin (), indices.end ()) < src_lanes;
2368+ internal_assert (use_tbl || b) << " 'b' must be valid in case of tbl2\n " ;
22162369
2217- std::string instr = concat_strings (" llvm.vector.deinterleave " , slice_stride , mangle_llvm_type (a-> getType () ));
2370+ auto instr = concat_strings (" llvm.aarch64.sve. " , use_tbl ? " tbl " : " tbl2 " , mangle_llvm_type (dst_type ));
22182371
2219- // We cannot mix FixedVector and ScalableVector, so dst_type must be scalable
2220- llvm::Type *dst_type = get_vector_type (elt, dst_lanes / target_vscale (), VectorTypeConstraint::VScale);
2221- StructType *sret_type = StructType::get (*context, std::vector (slice_stride, dst_type));
2222- std::vector<llvm::Type *> arg_types{a->getType ()};
2223- llvm::FunctionType *fn_type = FunctionType::get (sret_type, arg_types, false );
2224- FunctionCallee fn = module ->getOrInsertFunction (instr, fn_type);
2372+ Value *val_indices = codegen_shuffle_indices (bits, indices);
2373+ llvm::Type *vt_natural = get_vector_type (elt, natural_lanes);
2374+ std::vector<llvm::Type *> llvm_arg_types;
2375+ std::vector<llvm::Value *> llvm_arg_vals;
2376+ if (use_tbl) {
2377+ llvm_arg_types = {vt_natural, val_indices->getType ()};
2378+ llvm_arg_vals = {a, val_indices};
2379+ } else {
2380+ llvm_arg_types = {vt_natural, vt_natural, val_indices->getType ()};
2381+ llvm_arg_vals = {a, b, val_indices};
2382+ }
2383+ llvm::FunctionType *fn_type = FunctionType::get (vt_natural, llvm_arg_types, false );
2384+ FunctionCallee fn = module ->getOrInsertFunction (instr, fn_type);
22252385
2226- CallInst *deinterleave = builder->CreateCall (fn, {a} );
2227- // extract one element out of the returned struct
2228- Value *extracted = builder-> CreateExtractValue (deinterleave, indices. front ());
2386+ Value *v = builder->CreateCall (fn, llvm_arg_vals );
2387+ return v;
2388+ }
22292389
2230- return extracted;
2231- }
2390+ Value *CodeGen_ARM::codegen_shuffle_indices (int bits, const std::vector<int > &indices) {
2391+ const int lanes = indices.size ();
2392+ llvm::Type *index_type = IntegerType::get (module ->getContext (), bits);
2393+ llvm::Type *index_vec_type = get_vector_type (index_type, lanes);
2394+
2395+ std::vector<Constant *> llvm_indices (lanes);
2396+ for (int i = 0 ; i < lanes; i++) {
2397+ int idx = indices[i];
2398+ llvm_indices[i] = idx >= 0 ? ConstantInt::get (index_type, idx) : UndefValue::get (index_type);
22322399 }
22332400
2234- return CodeGen_Posix::shuffle_vectors (a, b, indices);
2401+ Value *v = ConstantVector::get (llvm_indices);
2402+ v = builder->CreateInsertVector (index_vec_type, UndefValue::get (index_vec_type),
2403+ v, ConstantInt::get (i64_t , 0 ));
2404+ return v;
2405+ }
2406+
2407+ Value *CodeGen_ARM::codegen_whilelt (int total_lanes, int start, int end) {
2408+ // Generates SVE "whilelt" instruction which represents vector predicate pattern of
2409+ // e.g. 11100000 (total_lanes = 8 , start = 0, end = 3)
2410+ // -> @llvm.aarch64.sve.whilelt.nxv8i1.i32(i32 0, i32 3)
2411+ internal_assert (target_vscale () > 0 );
2412+ internal_assert (total_lanes % target_vscale () == 0 );
2413+ std::string instr = concat_strings (" llvm.aarch64.sve.whilelt.nxv" , total_lanes / target_vscale (), " i1.i32" );
2414+
2415+ llvm::Type *pred_type = get_vector_type (llvm_type_of (Int (1 )), total_lanes);
2416+ llvm::FunctionType *fn_type = FunctionType::get (pred_type, {i32_t , i32_t }, false );
2417+ FunctionCallee fn = module ->getOrInsertFunction (instr, fn_type);
2418+
2419+ value = builder->CreateCall (fn, {ConstantInt::get (i32_t , start), ConstantInt::get (i32_t , end)});
2420+ return value;
22352421}
22362422
22372423void CodeGen_ARM::visit (const Ramp *op) {
@@ -2659,6 +2845,11 @@ Type CodeGen_ARM::upgrade_type_for_storage(const Type &t) const {
26592845 return CodeGen_Posix::upgrade_type_for_storage (t);
26602846}
26612847
2848+ int CodeGen_ARM::natural_vector_size (const Halide::Type &t) const {
2849+ internal_assert (t.bits () > 1 ) << " natural_vector_size requested with 1 bits\n " ;
2850+ return native_vector_bits () / t.bits ();
2851+ }
2852+
26622853string CodeGen_ARM::mcpu_target () const {
26632854 if (target.bits == 32 ) {
26642855 if (target.has_feature (Target::ARMv7s)) {
0 commit comments