diff --git a/main.cpp b/main.cpp index fc93991..cb559eb 100644 --- a/main.cpp +++ b/main.cpp @@ -81,4 +81,39 @@ int main() std::cout << "error: " << e.what() << std::endl; } } + + //{ + // mp::Int<32> val{1}; + + // for (mp::Int<4> a{1}; a < 100U; a += 1U) + // { + // for (size_t i = 0; i < 20; ++i) + // { + // val *= a; + // //PrintInt("val", val); + // } + + // for (size_t i = 0; i < 20; ++i) + // { + // val /= a; + // } + + // //if (val != 0U) + // //{ + // PrintInt("val", val); + // //} + // } + //} + + { + mp::Int<32> a{0xA0000000, 0x6D7217CA, 0x431E0FAE, 0x1}; + mp::Int<16> b{0x6FC10000, 0x2386F2}; + + PrintInt("a", a); + PrintInt("b", b); + + auto c = a / b; + PrintInt("c", c); + + } } \ No newline at end of file diff --git a/mp/int.hpp b/mp/int.hpp index 2f5eca8..3ae9193 100644 --- a/mp/int.hpp +++ b/mp/int.hpp @@ -83,6 +83,17 @@ class BasicInt return m_data[idx]; } + void fix_leading_zeros() + { + size_t new_size = m_data.size(); + while (new_size > 0 && m_data[new_size - 1] == TElem{0}) + { + new_size--; + } + + m_data.resize(new_size); + } + //std::span data() { return m_data; } size_t size_elems() const { return m_data.size(); } diff --git a/mp/math.hpp b/mp/math.hpp index 99f5d9a..4102313 100644 --- a/mp/math.hpp +++ b/mp/math.hpp @@ -10,11 +10,28 @@ namespace mp template constexpr size_t ResultMaxSize = std::max(SizeA, SizeB); -template - requires ElementTypesMatch -using OpResult = BasicInt>; +template +struct OpResult; -template +template + requires ElementTypesMatch +struct OpResult + +{ + using type = BasicInt>; +}; + +template +struct OpResult +{ + //using type = BasicInt>; + using type = TLhs; // keep the same size as TLhs when operating with regular ints +}; + +template +using OpResultType = typename OpResult::type; + +template class OverflowError : public std::runtime_error { public: @@ -26,12 +43,94 @@ class OverflowError : public std::runtime_error T m_value; }; -template -using OverflowErrorOf = OverflowError>; +template +using OverflowErrorOf = OverflowError>; + +// Regular int to mp int +template +BasicInt to_mp_int(TR value) +{ + BasicInt res; + if constexpr (sizeof(TR) <= sizeof(TElem)) + { + res.set(0, static_cast(value)); + } + else + { + constexpr TR ELEM_MASK = static_cast(std::numeric_limits::max()); + size_t idx = 0; + while (value != 0) + { + res.set(idx, static_cast(value & ELEM_MASK)); + value >>= (sizeof(TElem) * 8); + idx++; + } + } + + return res; +} + +// no-op +template + requires std::is_same_v +const TR& to_mp_int(const TR& value) +{ + return value; +} + +template + requires ElementTypesMatch +auto operator<=>(const TLhs& lhs, const TRhs& rhs) +{ + //auto size_comp = lhs.size_elems() <=> rhs.size_elems(); + //if (size_comp != std::strong_ordering::equal) + //{ + // return size_comp; + //} + + //for (size_t i = lhs.size_elems(); i-- > 0;) + //{ + // auto elem_comp = lhs.get(i) <=> rhs.get(i); + // if (elem_comp != std::strong_ordering::equal) + // { + // return elem_comp; + // } + //} + + size_t max_size = std::max(lhs.size_elems(), rhs.size_elems()); + for (size_t i = max_size; i-- > 0;) + { + auto elem_comp = lhs.get(i) <=> rhs.get(i); + if (elem_comp != std::strong_ordering::equal) + { + return elem_comp; + } + } + + return std::strong_ordering::equal; +} + +template +auto operator<=>(const TLhs& lhs, TRhs rhs) +{ + return lhs <=> to_mp_int(rhs); +} + +template +bool operator==(const TLhs& lhs, const TRhs& rhs) +{ + return (lhs <=> rhs) == std::strong_ordering::equal; +} + +template +bool operator!=(const TLhs& lhs, const TRhs& rhs) +{ + return !(lhs == rhs); +} struct Addition { - template + template requires ElementTypesMatch && ElementTypesMatch static void invoke(const TLhs& lhs, const TRhs& rhs, TRes& res) { @@ -60,14 +159,44 @@ struct Addition } }; -struct Multiplication +struct Subtraction { - template + template requires ElementTypesMatch && ElementTypesMatch static void invoke(const TLhs& lhs, const TRhs& rhs, TRes& res) { using ElementType = typename TRes::ElementType; - using DoubleType = DoubleWidthType; + + res.zero(); + + ElementType borrow = 0; + size_t end = std::max(lhs.size_elems(), rhs.size_elems()); + + for (size_t i = 0; i < end; ++i) + { + ElementType a = lhs.get(i); + ElementType b = rhs.get(i); + + ElementType c = a - b - borrow; + borrow = (a < b) || (borrow && a == b); + res.set(i, c); + } + + if (borrow != 0) + { + throw std::overflow_error("Subtraction underflow"); + } + } +}; + +struct Multiplication +{ + template + requires ElementTypesMatch && ElementTypesMatch + static void invoke(const TLhs& lhs, const TRhs& rhs, TRes& res) + { + using ElementType = typename TRes::ElementType; + using LongerType = DoubleWidthType; res.zero(); @@ -78,12 +207,12 @@ struct Multiplication for (size_t i = 0; i < n; i++) { - DoubleType carry = 0; + LongerType carry = 0; ElementType a = lhs[i]; for (size_t j = 0; j < m; j++) { ElementType b = rhs[j]; - DoubleType t = static_cast(a) * b + res.get(i + j) + carry; + LongerType t = static_cast(a) * b + res.get(i + j) + carry; overflow |= !res.try_set(i + j, static_cast(t)); carry = t >> (sizeof(ElementType) * 8); @@ -95,19 +224,73 @@ struct Multiplication { throw std::overflow_error("Multiplication overflow"); } + } +}; + +struct Division +{ + template + static void shift_left(T& value) + { + using ElementType = typename T::ElementType; + + for (size_t i = value.size_elems(); i-- > 0;) + { + ElementType new_val = value.get(i - 1); + value.set(i, new_val); + } + + value.set(0, ElementType{0}); + } + + template + requires ElementTypesMatch && ElementTypesMatch + static void invoke(const TLhs& lhs, const TRhs& rhs, TRes& res) + { + using ElementType = typename TRes::ElementType; + using LongerType = DoubleWidthType; + + if (rhs == 0U) + { + throw std::runtime_error("Division by zero"); + } + + BasicInt remainder = lhs; + res.zero(); + BasicInt temp; + + while (remainder >= rhs) + { + int shift = remainder.size_elems() - rhs.size_elems(); + + temp = rhs; + for (int i = 0; i < shift; ++i) + { + shift_left(temp); + } + + ElementType q_digit = 0; + + if (remainder >= temp) + { + remainder -= temp; + q_digit++; + } + + res.set(shift, q_digit); + } }; }; // Binary operation -template - requires ElementTypesMatch -OpResult binary_op(const TLhs& lhs, const TRhs& rhs) +template +OpResultType binary_op(const TLhs& lhs, const TRhs& rhs) { - OpResult res{}; + OpResultType res{}; try { - TOp::invoke(lhs, rhs, res); + TOp::invoke(lhs, to_mp_int(rhs), res); } catch (const std::overflow_error&) { @@ -118,15 +301,14 @@ OpResult binary_op(const TLhs& lhs, const TRhs& rhs) } // Compound assignment binary operation -template - requires ElementTypesMatch +template TLhs& ca_binary_op(TLhs& lhs, const TRhs& rhs) { TLhs res{}; try { - TOp::invoke(lhs, rhs, res); + TOp::invoke(lhs, to_mp_int(rhs), res); } catch (const std::overflow_error&) { @@ -139,34 +321,59 @@ TLhs& ca_binary_op(TLhs& lhs, const TRhs& rhs) // Addition operators -template - requires ElementTypesMatch -OpResult operator+(const TLhs& lhs, const TRhs& rhs) +template +OpResultType operator+(const TLhs& lhs, const TRhs& rhs) { return binary_op(lhs, rhs); } -template - requires ElementTypesMatch +template TLhs& operator+=(TLhs& lhs, const TRhs& rhs) { return ca_binary_op(lhs, rhs); } +// Subtraction operators + +template +OpResultType operator-(const TLhs& lhs, const TRhs& rhs) +{ + return binary_op(lhs, rhs); +} + +template +TLhs& operator-=(TLhs& lhs, const TRhs& rhs) +{ + return ca_binary_op(lhs, rhs); +} + // Multiplication operators -template - requires ElementTypesMatch -OpResult operator*(const TLhs& lhs, const TRhs& rhs) +template +OpResultType operator*(const TLhs& lhs, const TRhs& rhs) { return binary_op(lhs, rhs); } -template - requires ElementTypesMatch +template TLhs& operator*=(TLhs& lhs, const TRhs& rhs) { return ca_binary_op(lhs, rhs); } +// Division operators + +template +OpResultType operator/(const TLhs& lhs, const TRhs& rhs) +{ + return binary_op(lhs, rhs); +} + +template +TLhs& operator/=(TLhs& lhs, const TRhs& rhs) +{ + return ca_binary_op(lhs, rhs); +} + + } // namespace mp \ No newline at end of file diff --git a/mp/utils.hpp b/mp/utils.hpp index 868be88..3751247 100644 --- a/mp/utils.hpp +++ b/mp/utils.hpp @@ -73,12 +73,9 @@ template concept ElementSuitable = std::unsigned_integral; template -concept AnyInt = requires(T t, TElem a, size_t i) { +concept AnyConstMpInt = requires(T t, TElem a, size_t i) { { t[i] } -> std::convertible_to; - { t.set(i, a) }; - { t.try_set(i, a) } -> std::convertible_to; { t.get(i) } -> std::convertible_to; - { t.zero() }; { t.size_elems() } -> std::convertible_to; { T::MAX_BYTES } -> std::convertible_to; { T::MAX_ELEMS } -> std::convertible_to; @@ -86,6 +83,19 @@ concept AnyInt = requires(T t, TElem a, size_t i) { { T::LAST_ELEM_MASK } -> std::convertible_to; }; +template +concept AnyMpInt = AnyConstMpInt && requires(T t, TElem a, size_t i) { + { t.set(i, a) }; + { t.try_set(i, a) } -> std::convertible_to; + { t.zero() }; +}; + +template +concept AnyRegularInt = std::unsigned_integral; + +template +concept AnyConstInt = AnyRegularInt || AnyConstMpInt; + //======================== Utils =======================// #ifdef __SIZEOF_INT128__ @@ -94,7 +104,7 @@ using LongestElementSuitableType = uint64_t; using LongestElementSuitableType = HalfWidthType; #endif -template +template constexpr bool ElementTypesMatch = std::is_same_v; template @@ -121,7 +131,7 @@ inline char hex_digit(uint8_t bits) return bits > 9 ? 'a' + (bits - 10) : '0' + bits; } -template +template std::string to_hex_string(const T& number) { constexpr size_t ELEMENT_DIGITS = T::ELEMENT_BYTES * 2;