Skip to content

Commit 1420d2d

Browse files
Make Magnitude sum N-ary and more general (#652)
When we start collecting like terms as part of adding `Constant` instances to each other, we will naturally form N-ary sums. To prepare the way, we introduce a `MagSum<Ms...>` utility to perform this operation. Now, we _could_ just implement this in terms of the existing `operator+` for `Magnitude`. However, that gets a little weird: it makes the result depend on the order. It's possible for a subset of inputs to overflow (thus making the binary operation return an error), when subsequent inputs might "rescue" the result (say, by subtracting to bring it back into range). Thus, I redid the implementation from scratch, using modular arithmetic for the result, and keeping track of the (signed) number of overflows. This gives us exact information about the true sum (over a very very wide range of values), although we can't always _use_ this information. Specifically, we can only use the result when the overflow is 0 (simple), or when it is -1 _and_ the sum is not 0 (because if sum _were_ 0 in this case it would actually represent -2^64). This new, more symmetrical implementation immediately takes the place of all three `Magnitude`-`Magnitude` implementations of `operator+`. Helps #607. --------- Co-authored-by: Michael Hordijk <hordijk@aurora.tech>
1 parent 289cb47 commit 1420d2d

File tree

2 files changed

+145
-39
lines changed

2 files changed

+145
-39
lines changed

au/magnitude.hh

Lines changed: 91 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,19 @@ using CommonMagnitude = typename CommonMagnitudeImpl<Ms...>::type;
275275
template <typename... Ms>
276276
using CommonMagnitudeT = CommonMagnitude<Ms...>;
277277

278+
// The sum of arbitrarily many `Magnitude` and/or `Zero` types.
279+
//
280+
// We only support this when it is "easy" to compute, where "easy" is defined as:
281+
// 1) all inputs being expressible as integer multiples of some common factor;
282+
// 2) each such integer's absolute value fitting in a `uint64_t`;
283+
// 3) *and*, the absolute value of the sum also fitting in a `uint64_t`.
284+
//
285+
// For all other cases, we currently produce a compile time error.
286+
template <typename... Ms>
287+
struct MagSumImpl;
288+
template <typename... Ms>
289+
using MagSum = typename MagSumImpl<Ms...>::type;
290+
278291
////////////////////////////////////////////////////////////////////////////////////////////////////
279292
// Value based interface for Magnitude (and Zero).
280293

@@ -513,45 +526,8 @@ constexpr Zero mag_round(Zero) { return {}; }
513526

514527
// Addition:
515528
template <typename... BP1s, typename... BP2s>
516-
constexpr auto operator+(Magnitude<BP1s...> m1, Magnitude<BP2s...> m2) {
517-
constexpr auto sgn1 = sign(m1);
518-
constexpr auto sgn2 = sign(m2);
519-
520-
constexpr auto abs_common = Abs<CommonMagnitude<Magnitude<BP1s...>, Magnitude<BP2s...>>>{};
521-
constexpr auto abs_num1 = abs(m1) / abs_common;
522-
constexpr auto abs_num2 = abs(m2) / abs_common;
523-
524-
// These `get_value` calls automatically check that individual _inputs_ fit in `uint64_t`.
525-
constexpr auto abs_num1_u64 = get_value<std::uint64_t>(abs_num1);
526-
constexpr auto abs_num2_u64 = get_value<std::uint64_t>(abs_num2);
527-
528-
// Biggest absolute input determines overall sign.
529-
//
530-
// Note that when the magnitudes are equal, either the choice doesn't matter (when the inputs
531-
// have the same sign), or the outcome should just be `Zero`. In the latter case, we rely on
532-
// the explicit `Negative` overloads below being a better match.
533-
constexpr auto sgn =
534-
std::conditional_t<(abs_num1_u64 > abs_num2_u64), decltype(sgn1), decltype(sgn2)>{};
535-
536-
// Here, we are taking advantage of modular arithmetic on unsigned integers. This actually does
537-
// handle all the signs correctly, although it may not be obvious at first glance.
538-
constexpr auto num1_u64 = is_positive(sgn1) ? abs_num1_u64 : -abs_num1_u64;
539-
constexpr auto num2_u64 = is_positive(sgn2) ? abs_num2_u64 : -abs_num2_u64;
540-
constexpr auto abs_sum_u64 = is_positive(sgn) ? (num1_u64 + num2_u64) : -(num1_u64 + num2_u64);
541-
542-
// Here is where we guard against overflow in the _output_.
543-
static_assert((sgn1 != sgn2) || abs_sum_u64 >= abs_num1_u64,
544-
"Magnitude addition overflowed uint64_t");
545-
546-
return sgn * mag<abs_sum_u64>() * abs_common;
547-
}
548-
template <typename... BPs>
549-
constexpr Zero operator+(Magnitude<Negative, BPs...>, Magnitude<BPs...>) {
550-
return {};
551-
}
552-
template <typename... BPs>
553-
constexpr Zero operator+(Magnitude<BPs...>, Magnitude<Negative, BPs...>) {
554-
return {};
529+
constexpr auto operator+(Magnitude<BP1s...>, Magnitude<BP2s...>) {
530+
return MagSum<Magnitude<BP1s...>, Magnitude<BP2s...>>{};
555531
}
556532
template <typename... BPs>
557533
constexpr auto operator+(Zero, Magnitude<BPs...> m) {
@@ -1276,4 +1252,80 @@ struct CommonMagnitudeImpl<Zero, M> : stdx::type_identity<M> {};
12761252
template <>
12771253
struct CommonMagnitudeImpl<Zero, Zero> : stdx::type_identity<Zero> {};
12781254

1255+
////////////////////////////////////////////////////////////////////////////////////////////////////
1256+
// `MagSum` implementation.
1257+
1258+
namespace detail {
1259+
1260+
// `U64MagSum<Ms...>` is the `Magnitude` (or `Zero`) equal to the sum of the Magnitudes `Ms...`, as
1261+
// long as these preconditions are met:
1262+
//
1263+
// 1. The absolute value of each member of `Ms...` fits in a `std::uint64_t`.
1264+
// 2. The absolute value of the sum of all members of `Ms...` fits in a `std::uint64_t`.
1265+
template <typename... Ms>
1266+
struct U64MagSumImpl {
1267+
struct U64SumResult {
1268+
std::uint64_t sum = 0u;
1269+
int overflow = 0;
1270+
};
1271+
1272+
static constexpr U64SumResult compute() {
1273+
const std::uint64_t abs_values[] = {get_value<std::uint64_t>(Abs<Ms>{})...};
1274+
const int overflows[] = {(IsPositive<Ms>::value ? 0 : -1)...};
1275+
1276+
U64SumResult result = {0u, 0};
1277+
for (std::size_t i = 0u; i < sizeof...(Ms); ++i) {
1278+
std::uint64_t old_sum = result.sum;
1279+
result.sum += (overflows[i] >= 0) ? abs_values[i] : -abs_values[i];
1280+
result.overflow += overflows[i] + (result.sum < old_sum);
1281+
}
1282+
1283+
return result;
1284+
}
1285+
static constexpr std::uint64_t sum = compute().sum;
1286+
static constexpr int overflow = compute().overflow;
1287+
1288+
static_assert((overflow == 0) || (overflow == -1 && sum > 0u),
1289+
"Magnitude sum overflowed uint64_t");
1290+
1291+
using Sign = std::conditional_t<(overflow == -1), Magnitude<Negative>, Magnitude<>>;
1292+
1293+
using AbsMag = std::conditional_t<(overflow == 0) && (sum == 0u),
1294+
Zero,
1295+
// The surprising `sum == 0u` avoids asking for `mag<0>()`.
1296+
// It's fine, because it can never actually be used.
1297+
decltype(mag<(overflow == -1 ? -sum : sum) + (sum == 0u)>())>;
1298+
1299+
using type = decltype(Sign{} * AbsMag{});
1300+
};
1301+
template <typename... Ms>
1302+
using U64MagSum = typename U64MagSumImpl<Ms...>::type;
1303+
1304+
template <typename... Ms>
1305+
constexpr std::uint64_t U64MagSumImpl<Ms...>::sum;
1306+
1307+
template <typename... Ms>
1308+
constexpr int U64MagSumImpl<Ms...>::overflow;
1309+
1310+
template <typename... Ms>
1311+
struct MagSumImplHelper {
1312+
using Common = CommonMagnitude<Ms...>;
1313+
using type = decltype(Common{} * U64MagSum<decltype(Ms{} / Common{})...>{});
1314+
};
1315+
1316+
// The sum of no things is nothing.
1317+
template <>
1318+
struct MagSumImplHelper<> : stdx::type_identity<Zero> {};
1319+
1320+
// Keep stripping off zeros until we find at least one nonzero element, so that the common magnitude
1321+
// machinery can find something meaningful and avoid dividing by zero. (Zeros in the middle will be
1322+
// automatically handled correctly as long as there are nonzero elements.)
1323+
template <typename... Ms>
1324+
struct MagSumImplHelper<Zero, Ms...> : MagSumImplHelper<Ms...> {};
1325+
1326+
} // namespace detail
1327+
1328+
template <typename... Ms>
1329+
struct MagSumImpl : detail::MagSumImplHelper<Ms...> {};
1330+
12791331
} // namespace au

au/magnitude_test.cc

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1216,6 +1216,60 @@ TEST(IsMagnitudeU64RationalCompatible, FalseForIrrationals) {
12161216
EXPECT_THAT(is_magnitude_u64_rational_compatible(sqrt(mag<2>())), IsFalse());
12171217
}
12181218

1219+
TEST(MagSum, SumsArbitraryNumberOfMagnitudes) {
1220+
StaticAssertTypeEq<MagSum<decltype(mag<1>()), decltype(mag<2>()), decltype(mag<3>())>,
1221+
decltype(mag<6>())>();
1222+
1223+
StaticAssertTypeEq<
1224+
MagSum<decltype(mag<10>()), decltype(mag<20>()), decltype(mag<30>()), decltype(mag<40>())>,
1225+
decltype(mag<100>())>();
1226+
}
1227+
1228+
TEST(MagSum, HandlesNegativeMagnitudes) {
1229+
StaticAssertTypeEq<MagSum<decltype(mag<5>()), decltype(-mag<3>())>, decltype(mag<2>())>();
1230+
1231+
StaticAssertTypeEq<MagSum<decltype(-mag<5>()), decltype(mag<3>())>, decltype(-mag<2>())>();
1232+
1233+
StaticAssertTypeEq<MagSum<decltype(mag<10>()), decltype(-mag<3>()), decltype(-mag<2>())>,
1234+
decltype(mag<5>())>();
1235+
}
1236+
1237+
TEST(MagSum, ProducesZeroWhenInputsCancelOut) {
1238+
StaticAssertTypeEq<MagSum<decltype(mag<5>()), decltype(-mag<5>())>, Zero>();
1239+
1240+
StaticAssertTypeEq<MagSum<decltype(mag<10>()), decltype(-mag<3>()), decltype(-mag<7>())>,
1241+
Zero>();
1242+
}
1243+
1244+
TEST(MagSum, HandlesFractions) {
1245+
StaticAssertTypeEq<MagSum<decltype(mag<1>() / mag<6>()),
1246+
decltype(mag<1>() / mag<3>()),
1247+
decltype(mag<1>() / mag<2>())>,
1248+
decltype(mag<1>())>();
1249+
}
1250+
1251+
TEST(MagSum, IgnoresZeroInputs) {
1252+
StaticAssertTypeEq<MagSum<Zero, decltype(mag<5>())>, decltype(mag<5>())>();
1253+
StaticAssertTypeEq<MagSum<decltype(mag<5>()), Zero>, decltype(mag<5>())>();
1254+
StaticAssertTypeEq<MagSum<Zero, decltype(mag<5>()), Zero>, decltype(mag<5>())>();
1255+
StaticAssertTypeEq<MagSum<Zero, Zero, decltype(mag<5>())>, decltype(mag<5>())>();
1256+
StaticAssertTypeEq<MagSum<decltype(mag<2>()), Zero, decltype(mag<3>())>, decltype(mag<5>())>();
1257+
StaticAssertTypeEq<MagSum<Zero>, Zero>();
1258+
StaticAssertTypeEq<MagSum<Zero, Zero>, Zero>();
1259+
StaticAssertTypeEq<MagSum<Zero, Zero, Zero>, Zero>();
1260+
}
1261+
1262+
TEST(MagSum, ResultIsIndependentOfInputOrder) {
1263+
// A binary fold could only handle this in certain orders; N-ary handles all orders.
1264+
using One = decltype(mag<1>());
1265+
using TwoTo63 = decltype(pow<63>(mag<2>()));
1266+
using NegTwo = decltype(-mag<2>());
1267+
1268+
// 1 + 2^63 + 2^63 - 2 = 2^64 - 1
1269+
using Expected = decltype(mag<-uint64_t{1}>());
1270+
StaticAssertTypeEq<MagSum<One, TwoTo63, TwoTo63, NegTwo>, Expected>();
1271+
}
1272+
12191273
TEST(AssertMagnitudeU64RationalCompatible, NoCompilerErrorForKnownValidInput) {
12201274
(void)AssertMagnitudeU64RationalCompatible<decltype(mag<1>())>{};
12211275
(void)AssertMagnitudeU64RationalCompatible<decltype(mag<1000>())>{};

0 commit comments

Comments
 (0)