Skip to content

Commit 9e6306a

Browse files
DiamonDinoiaserge-sans-paille
authored andcommitted
fix: make is_cross_lane check 128-bit lane boundaries
Public API now checks 128-bit (16-byte) lanes, the standard for SSE/AVX/AVX512. Internal helper available for explicit lane sizes. Uses C++14 constexpr procedural style.
1 parent 41eb7fc commit 9e6306a

File tree

2 files changed

+101
-6
lines changed

2 files changed

+101
-6
lines changed

include/xsimd/arch/common/xsimd_common_swizzle.hpp

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,41 @@ namespace xsimd
167167
return cross_impl<0, sizeof...(Vs), sizeof...(Vs) / 2, Vs...>::value;
168168
}
169169

170+
/**
171+
* @brief Internal: Check if a swizzle pattern crosses lane boundaries
172+
*
173+
* @tparam LaneSizeBytes Size of a lane in bytes (must be > 0)
174+
* @tparam ElemT Element type to determine element size
175+
* @tparam U Type of the index values
176+
* @tparam Vs... Index values for the swizzle pattern
177+
*
178+
* @return true if any element accesses data from a different lane
179+
*
180+
* This is an internal helper. Architecture-specific code can call this directly
181+
* with explicit lane sizes (e.g., detail::is_cross_lane_with_lane_size<16, float, ...>()
182+
* for 128-bit lanes).
183+
*/
184+
template <std::size_t LaneSizeBytes, typename ElemT, typename U, U... Vs>
185+
XSIMD_INLINE constexpr bool is_cross_lane_with_lane_size() noexcept
186+
{
187+
static_assert(std::is_integral<U>::value, "swizzle mask values must be integral");
188+
static_assert(sizeof...(Vs) >= 1, "need at least one value");
189+
static_assert(LaneSizeBytes > 0, "lane size must be positive");
190+
191+
constexpr std::size_t lane_elems = LaneSizeBytes / sizeof(ElemT);
192+
constexpr U values[] = { Vs... };
193+
constexpr std::size_t N = sizeof...(Vs);
194+
195+
for (std::size_t i = 0; i < N; ++i)
196+
{
197+
std::size_t elem_lane = i / lane_elems;
198+
std::size_t target_lane = static_cast<std::size_t>(values[i]) / lane_elems;
199+
if (elem_lane != target_lane)
200+
return true;
201+
}
202+
return false;
203+
}
204+
170205
template <typename T, T... Vs>
171206
XSIMD_INLINE constexpr bool is_identity() noexcept { return detail::identity_impl<0, T, Vs...>(); }
172207
template <typename T, T... Vs>
@@ -184,7 +219,39 @@ namespace xsimd
184219
template <typename T, class A, T... Vs>
185220
XSIMD_INLINE constexpr bool is_only_from_hi(batch_constant<T, A, Vs...>) noexcept { return detail::is_only_from_hi<T, Vs...>(); }
186221
template <typename T, class A, T... Vs>
187-
XSIMD_INLINE constexpr bool is_cross_lane(batch_constant<T, A, Vs...>) noexcept { return detail::is_cross_lane<Vs...>(); }
222+
XSIMD_INLINE constexpr bool is_cross_lane(batch_constant<T, A, Vs...>) noexcept
223+
{
224+
return detail::is_cross_lane_with_lane_size<16, T, T, Vs...>();
225+
}
226+
227+
/**
228+
* @brief Public: Check if a swizzle pattern crosses 128-bit lane boundaries
229+
*
230+
* Checks if indices cross 128-bit (16-byte) lane boundaries, which is the
231+
* standard lane size for SSE/AVX/AVX512 shuffle operations.
232+
*
233+
* @tparam ElemT Element type to determine element size
234+
* @tparam U Type of the index values
235+
* @tparam Vs... Index values for the swizzle pattern
236+
*
237+
* @return true if any element accesses data from a different 128-bit lane
238+
*
239+
* Examples:
240+
* - is_cross_lane<float, 0, 1, 2, 3, 4, 5, 6, 7>() // no crossing (within 128-bit)
241+
* - is_cross_lane<float, 4, 5, 6, 7, 0, 1, 2, 3, 8, 9, 10, 11, 12, 13, 14, 15>() // crosses
242+
*/
243+
template <typename ElemT, typename U, U... Vs>
244+
XSIMD_INLINE constexpr bool is_cross_lane() noexcept
245+
{
246+
return is_cross_lane_with_lane_size<16, ElemT, U, Vs...>();
247+
}
248+
249+
// Overload with std::size_t indices
250+
template <typename ElemT, std::size_t... Vs>
251+
XSIMD_INLINE constexpr bool is_cross_lane() noexcept
252+
{
253+
return is_cross_lane<ElemT, std::size_t, Vs...>();
254+
}
188255

189256
} // namespace detail
190257
} // namespace kernel

test/test_batch_manip.cpp

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,11 +52,39 @@ namespace xsimd
5252
static_assert(is_dup_hi<std::uint32_t, 2, 3, 2, 3>(), "4-lane dup_hi failed");
5353
static_assert(!is_dup_lo<std::uint32_t, 2, 3, 2, 3>(), "4-lane dup_lo on dup_hi");
5454

55-
static_assert(is_cross_lane<0, 1, 0, 1>(), "dup-lo only → crossing");
56-
static_assert(is_cross_lane<2, 3, 2, 3>(), "dup-hi only → crossing");
57-
static_assert(is_cross_lane<0, 3, 3, 3>(), "one low + rest high → crossing");
58-
static_assert(!is_cross_lane<1, 0, 2, 3>(), "mixed low/high → no crossing");
59-
static_assert(!is_cross_lane<0, 1, 2, 3>(), "mixed low/high → no crossing");
55+
static_assert(is_cross_lane<double, 0, 1, 0, 1>(), "dup-lo only → crossing");
56+
static_assert(is_cross_lane<double, 2, 3, 2, 3>(), "dup-hi only → crossing");
57+
static_assert(is_cross_lane<double, 0, 3, 3, 3>(), "one low + rest high → crossing");
58+
static_assert(!is_cross_lane<double, 1, 0, 2, 3>(), "mixed low/high → no crossing");
59+
static_assert(!is_cross_lane<double, 0, 1, 2, 3>(), "mixed low/high → no crossing");
60+
// 8-element 128-bit lane crossing checks
61+
// For 8 doubles (64 bytes): lanes are [0-1], [2-3], [4-5], [6-7]
62+
static_assert(!is_cross_lane<double, 1, 0, 3, 2, 5, 4, 7, 6>(), "8-lane reverse within 128-bit lanes → no crossing");
63+
static_assert(!is_cross_lane<double, 0, 1, 2, 3, 4, 5, 6, 7>(), "identity 8-lane → no crossing");
64+
static_assert(is_cross_lane<double, 2, 3, 0, 1, 4, 5, 6, 7>(), "8-lane double swap first two 128-bit lanes → crossing");
65+
// For 8 int32 (32 bytes): lanes are [0-3], [4-7]
66+
static_assert(is_cross_lane<std::int32_t, 4, 5, 6, 7, 0, 1, 2, 3>(), "8-lane int32_t swap 128-bit lanes → crossing");
67+
68+
// Additional compile-time checks for 16-element batches (e.g. float/int32)
69+
static_assert(is_cross_lane<float, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7>(),
70+
"16-lane 128-bit swap → crossing");
71+
static_assert(!is_cross_lane<float, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15>(),
72+
"identity 16-lane → no crossing");
73+
static_assert(is_cross_lane<std::uint32_t, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7>(),
74+
"16-lane uint32_t swap → crossing");
75+
76+
// Explicit 128-bit lane boundary checks (LaneSizeBytes = 16)
77+
// For float (4 bytes): 16 bytes = 4 elements per 128-bit lane
78+
static_assert(detail::is_cross_lane_with_lane_size<16, float, std::size_t, 4, 5, 6, 7, 0, 1, 2, 3, 8, 9, 10, 11, 12, 13, 14, 15>(),
79+
"float: swap first two 128-bit lanes → crossing");
80+
static_assert(!detail::is_cross_lane_with_lane_size<16, float, std::size_t, 3, 2, 1, 0, 7, 6, 5, 4, 11, 10, 9, 8, 15, 14, 13, 12>(),
81+
"float: reverse within each 128-bit lane → no crossing");
82+
83+
// For double (8 bytes): 16 bytes = 2 elements per 128-bit lane
84+
static_assert(detail::is_cross_lane_with_lane_size<16, double, std::size_t, 2, 3, 0, 1, 4, 5, 6, 7>(),
85+
"double: swap first two 128-bit lanes → crossing");
86+
static_assert(!detail::is_cross_lane_with_lane_size<16, double, std::size_t, 1, 0, 3, 2, 5, 4, 7, 6>(),
87+
"double: reverse within each 128-bit lane → no crossing");
6088
}
6189
}
6290
}

0 commit comments

Comments
 (0)