#pragma once #include "int.hpp" #include "utils.hpp" namespace mp { template constexpr size_t ResultMaxSize = std::max(SizeA, SizeB); template struct OpResult; template requires ElementTypesMatch struct OpResult { using type = BasicInt>; }; template struct OpResult { //using type = BasicInt>; using type = TLhs; // keep the same size as TLhs when operating with regular ints }; template using OpResultType = typename OpResult::type; template class OverflowError : public std::runtime_error { public: OverflowError(T&& value) : std::runtime_error("Overflow"), m_value(std::move(value)) {} const T& value() const { return m_value; } private: T m_value; }; template using OverflowErrorOf = OverflowError>; // Regular int to mp int template BasicInt to_mp_int(TR value) { BasicInt res; if constexpr (sizeof(TR) <= sizeof(TElem)) { res.set(0, static_cast(value)); } else { constexpr TR ELEM_MASK = static_cast(std::numeric_limits::max()); size_t idx = 0; while (value != 0) { res.set(idx, static_cast(value & ELEM_MASK)); value >>= (sizeof(TElem) * 8); idx++; } } return res; } // no-op template requires std::is_same_v const TR& to_mp_int(const TR& value) { return value; } template requires ElementTypesMatch auto operator<=>(const TLhs& lhs, const TRhs& rhs) { //auto size_comp = lhs.size_elems() <=> rhs.size_elems(); //if (size_comp != std::strong_ordering::equal) //{ // return size_comp; //} //for (size_t i = lhs.size_elems(); i-- > 0;) //{ // auto elem_comp = lhs.get(i) <=> rhs.get(i); // if (elem_comp != std::strong_ordering::equal) // { // return elem_comp; // } //} size_t max_size = std::max(lhs.size_elems(), rhs.size_elems()); for (size_t i = max_size; i-- > 0;) { auto elem_comp = lhs.get(i) <=> rhs.get(i); if (elem_comp != std::strong_ordering::equal) { return elem_comp; } } return std::strong_ordering::equal; } template auto operator<=>(const TLhs& lhs, TRhs rhs) { return lhs <=> to_mp_int(rhs); } template bool operator==(const TLhs& lhs, const TRhs& rhs) { return (lhs <=> rhs) == std::strong_ordering::equal; } template bool operator!=(const TLhs& lhs, const TRhs& rhs) { return !(lhs == rhs); } template requires ElementTypesMatch && ElementTypesMatch inline void sub_ignore_underflow(const TLhs& lhs, const TRhs& rhs, TRes& res) { using ElementType = typename TRes::ElementType; res.zero(); ElementType borrow = 0; size_t end = std::max(lhs.size_elems(), rhs.size_elems()); for (size_t i = 0; i < end; ++i) { ElementType a = lhs.get(i); ElementType b = rhs.get(i); ElementType c = a - b - borrow; borrow = (a < b) || (borrow && a == b); res.set(i, c); } if (borrow != 0) { throw std::overflow_error("Subtraction underflow"); } } template requires ElementTypesMatch && ElementTypesMatch inline void sub_ignore_sign(const TLhs& lhs, const TRhs& rhs, TRes& res) { if (lhs < rhs) { sub_ignore_underflow(rhs, lhs, res); res.set_negative(true); } else { sub_ignore_underflow(lhs, rhs, res); res.set_negative(false); } } template requires ElementTypesMatch && ElementTypesMatch inline void add_ignore_sign(const TLhs& lhs, const TRhs& rhs, TRes& res) { using ElementType = typename TRes::ElementType; res.zero(); ElementType carry = 0; size_t end = std::max(lhs.size_elems(), rhs.size_elems()); for (size_t i = 0; i < end; ++i) { ElementType a = lhs.get(i); ElementType b = rhs.get(i); ElementType c = carry + a; carry = (c < a) ? 1 : 0; c += b; if (c < b) carry = 1; res.set(i, c); } res.set(end, carry); } struct Addition { template requires ElementTypesMatch && ElementTypesMatch static void invoke(const TLhs& lhs, const TRhs& rhs, TRes& res) { const bool lhs_neg = lhs.negative(); const bool rhs_neg = rhs.negative(); if (lhs_neg == rhs_neg) { add_ignore_sign(lhs, rhs, res); res.set_negative(lhs_neg); } else if (lhs_neg) { // (-a) + b == b - a sub_ignore_sign(rhs, lhs, res); } else { // a + (-b) == a - b sub_ignore_sign(lhs, rhs, res); } } }; struct Subtraction { template requires ElementTypesMatch && ElementTypesMatch static void invoke(const TLhs& lhs, const TRhs& rhs, TRes& res) { const bool lhs_neg = lhs.negative(); const bool rhs_neg = rhs.negative(); if (!lhs_neg && !rhs_neg) { // a - b sub_ignore_sign(lhs, rhs, res); } else if (lhs_neg && rhs_neg) { // (-a) - (-b) == b - a sub_ignore_sign(rhs, lhs, res); } else if (lhs_neg && !rhs_neg) { // (-a) - b == -(a + b) add_ignore_sign(lhs, rhs, res); res.set_negative(true); } else // if (!lhs_neg && rhs_neg) { // a - (-b) == a + b add_ignore_sign(lhs, rhs, res); res.set_negative(false); } } }; struct Multiplication { template requires ElementTypesMatch && ElementTypesMatch static void invoke(const TLhs& lhs, const TRhs& rhs, TRes& res) { using ElementType = typename TRes::ElementType; using LongerType = DoubleWidthType; res.zero(); res.set_negative(lhs.negative() != rhs.negative()); bool overflow = false; const size_t n = lhs.size_elems(); const size_t m = rhs.size_elems(); for (size_t i = 0; i < n; i++) { LongerType carry = 0; ElementType a = lhs[i]; for (size_t j = 0; j < m; j++) { ElementType b = rhs[j]; LongerType t = static_cast(a) * b + res.get(i + j) + carry; overflow |= !res.try_set(i + j, static_cast(t)); carry = t >> (sizeof(ElementType) * 8); } overflow |= !res.try_set(i + m, static_cast(carry)); } if (overflow) { throw std::overflow_error("Multiplication overflow"); } } }; struct Division { template requires ElementTypesMatch && ElementTypesMatch static void invoke(const TLhs& lhs, const TRhs& rhs, TRes& res) { using ElementType = typename TRes::ElementType; using LongerType = DoubleWidthType; if (rhs == 0U) { throw std::runtime_error("Division by zero"); } const size_t n = rhs.size_elems(); const size_t m = (lhs.size_elems() >= n) ? (lhs.size_elems() - n) : 0; // If divisor larger than dividend => quotient = 0 if (lhs.size_elems() < n) { res.zero(); return; } // Base and mask constexpr unsigned int W = sizeof(ElementType) * 8; const LongerType BASE = (static_cast(1) << W); const LongerType MASK = BASE - 1; // Prepare u (normalized dividend) length m + n + 1 BasicInt u; u.zero(); for (size_t i = 0; i < lhs.size_elems(); ++i) u.set(i, lhs.get(i)); // Prepare v (normalized divisor) length n BasicInt v; v.zero(); for (size_t i = 0; i < n; ++i) v.set(i, rhs.get(i)); // Normalization factor d = BASE / (v[n-1] + 1) ElementType d = 1; if (v.get(n - 1) + 1 != 0) // defensive { d = static_cast(BASE / (static_cast(v.get(n - 1)) + 1)); } if (d != 1) { // u = u * d LongerType carry = 0; for (size_t i = 0; i < (m + n + 1); ++i) { LongerType t = static_cast(u.get(i)) * d + carry; u.set(i, static_cast(t & MASK)); carry = t >> W; } // v = v * d carry = 0; for (size_t i = 0; i < n; ++i) { LongerType t = static_cast(v.get(i)) * d + carry; v.set(i, static_cast(t & MASK)); carry = t >> W; } // v[n] implicitly 0 } // Prepare quotient res.zero(); res.set_negative(lhs.negative() != rhs.negative()); // Main loop j = m .. 0 for (int j = static_cast(m); j >= 0; --j) { // u[j + n] might be zero or >0 const LongerType uj_n = static_cast(u.get(j + n)); const LongerType uj_n1 = static_cast(u.get(j + n - 1)); const LongerType vn_1 = static_cast(v.get(n - 1)); // Estimate q_hat = (u[j+n]*BASE + u[j+n-1]) / v[n-1] LongerType numerator = (uj_n * BASE) + uj_n1; LongerType qhat = numerator / vn_1; LongerType rhat = numerator % vn_1; if (qhat >= BASE) qhat = BASE - 1; // Correction loop (only if n >= 2) if (n >= 2) { const LongerType vn_2 = static_cast(v.get(n - 2)); while (qhat * vn_2 > (rhat * BASE + static_cast(u.get(j + n - 2)))) { qhat -= 1; rhat += vn_1; if (rhat >= BASE) break; } } // Multiply v by qhat and subtract from u segment starting at j LongerType carry_mul = 0; unsigned int borrow = 0; for (size_t i = 0; i < n; ++i) { // p = qhat * v[i] + carry_mul LongerType p = qhat * static_cast(v.get(i)) + carry_mul; ElementType p_low = static_cast(p & MASK); carry_mul = p >> W; ElementType uval = u.get(j + i); LongerType sub = static_cast(uval); LongerType needed = static_cast(p_low) + borrow; bool under = (sub < needed); ElementType new_u = static_cast((sub + BASE - needed) & MASK); u.set(j + i, new_u); borrow = under ? 1u : 0u; } // Subtract carry_mul and borrow from u[j+n] { ElementType uval = u.get(j + n); LongerType sub = static_cast(uval); LongerType needed = carry_mul + borrow; bool under = (sub < needed); ElementType new_u = static_cast((sub + BASE - needed) & MASK); u.set(j + n, new_u); borrow = under ? 1u : 0u; } if (borrow != 0) { // qhat was too large, decrement and add v back qhat -= 1; LongerType carry_add = 0; for (size_t i = 0; i < n; ++i) { LongerType sum = static_cast(u.get(j + i)) + static_cast(v.get(i)) + carry_add; u.set(j + i, static_cast(sum & MASK)); carry_add = sum >> W; } // add carry_add to u[j+n] u.set(j + n, static_cast((static_cast(u.get(j + n)) + carry_add) & MASK)); } // store quotient digit res.set(static_cast(j), static_cast(qhat & MASK)); } // Note: remainder unnormalization is not required for quotient; we ignore remainder. } }; // Binary operation template OpResultType binary_op(const TLhs& lhs, const TRhs& rhs) { OpResultType res{}; try { TOp::invoke(lhs, to_mp_int(rhs), res); } catch (const std::overflow_error&) { throw OverflowErrorOf(std::move(res)); } return res; } // Compound assignment binary operation template TLhs& ca_binary_op(TLhs& lhs, const TRhs& rhs) { TLhs res{}; try { TOp::invoke(lhs, to_mp_int(rhs), res); } catch (const std::overflow_error&) { throw OverflowErrorOf(std::move(res)); } lhs = std::move(res); return lhs; } // Addition operators template OpResultType operator+(const TLhs& lhs, const TRhs& rhs) { return binary_op(lhs, rhs); } template TLhs& operator+=(TLhs& lhs, const TRhs& rhs) { return ca_binary_op(lhs, rhs); } // Subtraction operators template OpResultType operator-(const TLhs& lhs, const TRhs& rhs) { return binary_op(lhs, rhs); } template TLhs& operator-=(TLhs& lhs, const TRhs& rhs) { return ca_binary_op(lhs, rhs); } // Multiplication operators template OpResultType operator*(const TLhs& lhs, const TRhs& rhs) { return binary_op(lhs, rhs); } template TLhs& operator*=(TLhs& lhs, const TRhs& rhs) { return ca_binary_op(lhs, rhs); } // Division operators template OpResultType operator/(const TLhs& lhs, const TRhs& rhs) { return binary_op(lhs, rhs); } template TLhs& operator/=(TLhs& lhs, const TRhs& rhs) { return ca_binary_op(lhs, rhs); } } // namespace mp