@@ -17,6 +17,34 @@ library Math {
1717 Expand // Away from zero
1818 }
1919
20+ /**
21+ * @dev Return the 512-bit addition of two uint256.
22+ *
23+ * The result is stored in two 256 variables such that sum = high * 2²⁵⁶ + low.
24+ */
25+ function add512 (uint256 a , uint256 b ) internal pure returns (uint256 high , uint256 low ) {
26+ assembly ("memory-safe" ) {
27+ low := add (a, b)
28+ high := lt (low, a)
29+ }
30+ }
31+
32+ /**
33+ * @dev Return the 512-bit multiplication of two uint256.
34+ *
35+ * The result is stored in two 256 variables such that product = high * 2²⁵⁶ + low.
36+ */
37+ function mul512 (uint256 a , uint256 b ) internal pure returns (uint256 high , uint256 low ) {
38+ // 512-bit multiply [high low] = x * y. Compute the product mod 2²⁵⁶ and mod 2²⁵⁶ - 1, then use
39+ // the Chinese Remainder Theorem to reconstruct the 512 bit result. The result is stored in two 256
40+ // variables such that product = high * 2²⁵⁶ + low.
41+ assembly ("memory-safe" ) {
42+ let mm := mulmod (a, b, not (0 ))
43+ low := mul (a, b)
44+ high := sub (sub (mm, low), lt (mm, low))
45+ }
46+ }
47+
2048 /**
2149 * @dev Returns the addition of two unsigned integers, with an success flag (no overflow).
2250 */
@@ -143,42 +171,34 @@ library Math {
143171 */
144172 function mulDiv (uint256 x , uint256 y , uint256 denominator ) internal pure returns (uint256 result ) {
145173 unchecked {
146- // 512-bit multiply [prod1 prod0] = x * y. Compute the product mod 2²⁵⁶ and mod 2²⁵⁶ - 1, then use
147- // the Chinese Remainder Theorem to reconstruct the 512 bit result. The result is stored in two 256
148- // variables such that product = prod1 * 2²⁵⁶ + prod0.
149- uint256 prod0 = x * y; // Least significant 256 bits of the product
150- uint256 prod1; // Most significant 256 bits of the product
151- assembly {
152- let mm := mulmod (x, y, not (0 ))
153- prod1 := sub (sub (mm, prod0), lt (mm, prod0))
154- }
174+ (uint256 high , uint256 low ) = mul512 (x, y);
155175
156176 // Handle non-overflow cases, 256 by 256 division.
157- if (prod1 == 0 ) {
177+ if (high == 0 ) {
158178 // Solidity will revert if denominator == 0, unlike the div opcode on its own.
159179 // The surrounding unchecked block does not change this fact.
160180 // See https://docs.soliditylang.org/en/latest/control-structures.html#checked-or-unchecked-arithmetic.
161- return prod0 / denominator;
181+ return low / denominator;
162182 }
163183
164184 // Make sure the result is less than 2²⁵⁶. Also prevents denominator == 0.
165- if (denominator <= prod1 ) {
185+ if (denominator <= high ) {
166186 Panic.panic (ternary (denominator == 0 , Panic.DIVISION_BY_ZERO, Panic.UNDER_OVERFLOW));
167187 }
168188
169189 ///////////////////////////////////////////////
170190 // 512 by 256 division.
171191 ///////////////////////////////////////////////
172192
173- // Make division exact by subtracting the remainder from [prod1 prod0 ].
193+ // Make division exact by subtracting the remainder from [high low ].
174194 uint256 remainder;
175195 assembly {
176196 // Compute remainder using mulmod.
177197 remainder := mulmod (x, y, denominator)
178198
179199 // Subtract 256 bit number from 512 bit number.
180- prod1 := sub (prod1 , gt (remainder, prod0 ))
181- prod0 := sub (prod0 , remainder)
200+ high := sub (high , gt (remainder, low ))
201+ low := sub (low , remainder)
182202 }
183203
184204 // Factor powers of two out of denominator and compute largest power of two divisor of denominator.
@@ -189,15 +209,15 @@ library Math {
189209 // Divide denominator by twos.
190210 denominator := div (denominator, twos)
191211
192- // Divide [prod1 prod0 ] by twos.
193- prod0 := div (prod0 , twos)
212+ // Divide [high low ] by twos.
213+ low := div (low , twos)
194214
195215 // Flip twos such that it is 2²⁵⁶ / twos. If twos is zero, then it becomes one.
196216 twos := add (div (sub (0 , twos), twos), 1 )
197217 }
198218
199- // Shift in bits from prod1 into prod0 .
200- prod0 |= prod1 * twos;
219+ // Shift in bits from high into low .
220+ low |= high * twos;
201221
202222 // Invert denominator mod 2²⁵⁶. Now that denominator is an odd number, it has an inverse modulo 2²⁵⁶ such
203223 // that denominator * inv ≡ 1 mod 2²⁵⁶. Compute the inverse by starting with a seed that is correct for
@@ -215,9 +235,9 @@ library Math {
215235
216236 // Because the division is now exact we can divide by multiplying with the modular inverse of denominator.
217237 // This will give us the correct result modulo 2²⁵⁶. Since the preconditions guarantee that the outcome is
218- // less than 2²⁵⁶, this is the final result. We don't need to compute the high bits of the result and prod1
238+ // less than 2²⁵⁶, this is the final result. We don't need to compute the high bits of the result and high
219239 // is no longer required.
220- result = prod0 * inverse;
240+ result = low * inverse;
221241 return result;
222242 }
223243 }
@@ -229,6 +249,26 @@ library Math {
229249 return mulDiv (x, y, denominator) + SafeCast.toUint (unsignedRoundsUp (rounding) && mulmod (x, y, denominator) > 0 );
230250 }
231251
252+ /**
253+ * @dev Calculates floor(x * y >> n) with full precision. Throws if result overflows a uint256.
254+ */
255+ function mulShr (uint256 x , uint256 y , uint8 n ) internal pure returns (uint256 result ) {
256+ unchecked {
257+ (uint256 high , uint256 low ) = mul512 (x, y);
258+ if (high >= 1 << n) {
259+ Panic.panic (Panic.UNDER_OVERFLOW);
260+ }
261+ return (high << (256 - n)) | (low >> n);
262+ }
263+ }
264+
265+ /**
266+ * @dev Calculates x * y >> n with full precision, following the selected rounding direction.
267+ */
268+ function mulShr (uint256 x , uint256 y , uint8 n , Rounding rounding ) internal pure returns (uint256 ) {
269+ return mulShr (x, y, n) + SafeCast.toUint (unsignedRoundsUp (rounding) && mulmod (x, y, 1 << n) > 0 );
270+ }
271+
232272 /**
233273 * @dev Calculate the modular multiplicative inverse of a number in Z/nZ.
234274 *
0 commit comments