#pragma once #include "int.hpp" #include "storage.hpp" #include "utils.hpp" namespace mp { template constexpr size_t ResultMaxSize = std::max(SizeA, SizeB); template struct OpResult; template requires ElementTypesMatch struct OpResult { using type = BasicInt>; }; template struct OpResult { //using type = BasicInt>; using type = TLhs; // keep the same size as TLhs when operating with regular ints }; template using OpResultType = typename OpResult::type; template class OverflowError : public std::runtime_error { public: OverflowError(T&& value) : std::runtime_error("Overflow"), m_value(std::move(value)) {} const T& value() const { return m_value; } private: T m_value; }; template using OverflowErrorOf = OverflowError>; // Regular int to mp int template BasicInt to_mp_int(TR value) { BasicInt res; if constexpr (sizeof(TR) <= sizeof(TElem)) { res.set(0, static_cast(value)); } else { constexpr TR ELEM_MASK = static_cast(std::numeric_limits::max()); size_t idx = 0; while (value != 0) { res.set(idx, static_cast(value & ELEM_MASK)); value >>= (sizeof(TElem) * 8); idx++; } } return res; } // no-op template requires std::is_same_v const TR& to_mp_int(const TR& value) { return value; } template requires ElementTypesMatch auto operator<=>(const TLhs& lhs, const TRhs& rhs) { //auto size_comp = lhs.size_elems() <=> rhs.size_elems(); //if (size_comp != std::strong_ordering::equal) //{ // return size_comp; //} //for (size_t i = lhs.size_elems(); i-- > 0;) //{ // auto elem_comp = lhs.get(i) <=> rhs.get(i); // if (elem_comp != std::strong_ordering::equal) // { // return elem_comp; // } //} size_t max_size = std::max(lhs.size_elems(), rhs.size_elems()); for (size_t i = max_size; i-- > 0;) { auto elem_comp = lhs.get(i) <=> rhs.get(i); if (elem_comp != std::strong_ordering::equal) { return elem_comp; } } return std::strong_ordering::equal; } template auto operator<=>(const TLhs& lhs, TRhs rhs) { return lhs <=> to_mp_int(rhs); } template bool operator==(const TLhs& lhs, const TRhs& rhs) { return (lhs <=> rhs) == std::strong_ordering::equal; } template bool operator!=(const TLhs& lhs, const TRhs& rhs) { return !(lhs == rhs); } struct Addition { template requires ElementTypesMatch && ElementTypesMatch static void invoke(const TLhs& lhs, const TRhs& rhs, TRes& res) { using ElementType = typename TRes::ElementType; res.zero(); ElementType carry = 0; size_t end = std::max(lhs.size_elems(), rhs.size_elems()); for (size_t i = 0; i < end; ++i) { ElementType a = lhs.get(i); ElementType b = rhs.get(i); ElementType c = carry + a; carry = (c < a) ? 1 : 0; c += b; if (c < b) carry = 1; res.set(i, c); } res.set(end, carry); } }; struct Subtraction { template requires ElementTypesMatch && ElementTypesMatch static void invoke(const TLhs& lhs, const TRhs& rhs, TRes& res) { using ElementType = typename TRes::ElementType; res.zero(); ElementType borrow = 0; size_t end = std::max(lhs.size_elems(), rhs.size_elems()); for (size_t i = 0; i < end; ++i) { ElementType a = lhs.get(i); ElementType b = rhs.get(i); ElementType c = a - b - borrow; borrow = (a < b) || (borrow && a == b); res.set(i, c); } if (borrow != 0) { throw std::overflow_error("Subtraction underflow"); } } }; struct Multiplication { template requires ElementTypesMatch && ElementTypesMatch static void invoke(const TLhs& lhs, const TRhs& rhs, TRes& res) { using ElementType = typename TRes::ElementType; using LongerType = DoubleWidthType; res.zero(); bool overflow = false; const size_t n = lhs.size_elems(); const size_t m = rhs.size_elems(); for (size_t i = 0; i < n; i++) { LongerType carry = 0; ElementType a = lhs[i]; for (size_t j = 0; j < m; j++) { ElementType b = rhs[j]; LongerType t = static_cast(a) * b + res.get(i + j) + carry; overflow |= !res.try_set(i + j, static_cast(t)); carry = t >> (sizeof(ElementType) * 8); } overflow |= !res.try_set(i + m, static_cast(carry)); } if (overflow) { throw std::overflow_error("Multiplication overflow"); } } }; struct Division { template static void shift_left(T& value) { using ElementType = typename T::ElementType; for (size_t i = value.size_elems(); i-- > 0;) { ElementType new_val = value.get(i - 1); value.set(i, new_val); } value.set(0, ElementType{0}); } template requires ElementTypesMatch && ElementTypesMatch static void invoke(const TLhs& lhs, const TRhs& rhs, TRes& res) { using ElementType = typename TRes::ElementType; using LongerType = DoubleWidthType; if (rhs == 0U) { throw std::runtime_error("Division by zero"); } BasicInt remainder = lhs; res.zero(); BasicInt temp; while (remainder >= rhs) { int shift = remainder.size_elems() - rhs.size_elems(); temp = rhs; for (int i = 0; i < shift; ++i) { shift_left(temp); } ElementType q_digit = 0; if (remainder >= temp) { remainder -= temp; q_digit++; } res.set(shift, q_digit); } }; }; // Binary operation template OpResultType binary_op(const TLhs& lhs, const TRhs& rhs) { OpResultType res{}; try { TOp::invoke(lhs, to_mp_int(rhs), res); } catch (const std::overflow_error&) { throw OverflowErrorOf(std::move(res)); } return res; } // Compound assignment binary operation template TLhs& ca_binary_op(TLhs& lhs, const TRhs& rhs) { TLhs res{}; try { TOp::invoke(lhs, to_mp_int(rhs), res); } catch (const std::overflow_error&) { throw OverflowErrorOf(std::move(res)); } lhs = std::move(res); return lhs; } // Addition operators template OpResultType operator+(const TLhs& lhs, const TRhs& rhs) { return binary_op(lhs, rhs); } template TLhs& operator+=(TLhs& lhs, const TRhs& rhs) { return ca_binary_op(lhs, rhs); } // Subtraction operators template OpResultType operator-(const TLhs& lhs, const TRhs& rhs) { return binary_op(lhs, rhs); } template TLhs& operator-=(TLhs& lhs, const TRhs& rhs) { return ca_binary_op(lhs, rhs); } // Multiplication operators template OpResultType operator*(const TLhs& lhs, const TRhs& rhs) { return binary_op(lhs, rhs); } template TLhs& operator*=(TLhs& lhs, const TRhs& rhs) { return ca_binary_op(lhs, rhs); } // Division operators template OpResultType operator/(const TLhs& lhs, const TRhs& rhs) { return binary_op(lhs, rhs); } template TLhs& operator/=(TLhs& lhs, const TRhs& rhs) { return ca_binary_op(lhs, rhs); } } // namespace mp