diff --git a/main.cpp b/main.cpp index 0600d8c..42d154a 100644 --- a/main.cpp +++ b/main.cpp @@ -1,11 +1,7 @@ -#include #include #include -//#include "mp.hpp" -#include "mp/int.hpp" -#include "mp/math.hpp" -#include "mp/lib.hpp" +#include "mp/mp.hpp" template static void PrintInt(const char* name, const T& val) @@ -25,9 +21,18 @@ static void display_op(const char* desc, const MyInt& a, const MyInt& b, std::fu std::cout << "op: " << mp::to_string(a) << ' ' << desc << ' ' << mp::to_string(b) << " = " << mp::to_string(op(a, b)) << std::endl; } +template +constexpr static auto operator""_mpi() +{ + constexpr char str[] = {Cs..., '\0'}; + return mp::from_string<128>(str); +} + int main() { + 1000000_mpi; + // mp::Int a{0xDEADBEEFDEADF154, 0x0123456789ABCDEF, 0x1111222233334444}; mp::Int<32> a{0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF}; mp::Int<1> b{0x55}; @@ -73,7 +78,7 @@ int main() std::cout << sizeof(mp::Int<150>) << std::endl; std::cout << sizeof(mp::Int<160>) << std::endl; - + // auto x = 12345678901234567890123456789012345678901234567890123456789012345678901234567890_mpi; // { // mp::Int<1024> acc{1}; @@ -133,12 +138,10 @@ int main() // } { - mp::Int<32> a; - mp::Int<16> b; + mp::Int<32> a = "690000100000000000000000000000000000"; + mp::Int<16> b{"10690010000000000000000"}; mp::Int<32> c; - mp::parse_string("690000100000000000000000000000000000", a); - mp::parse_string("10690010000000000000000", b); c = a / b; PrintDec("a", a); PrintDec("b", b); @@ -146,75 +149,20 @@ int main() } { - display_op("+", - mp::from_string<32>("123456789012345678901234567890"), - mp::from_string<32>("987654321098765432109876543210"), - std::plus<>{}); - - display_op("+", - mp::from_string<32>("987654321098765432109876543210"), - mp::from_string<32>("-123456789012345678901234567890"), - std::plus<>{}); - - display_op("+", - mp::from_string<32>("-123456789012345678901234567890"), - mp::from_string<32>("-987654321098765432109876543210"), - std::plus<>{}); - - display_op("+", - mp::from_string<32>("123456789012345678901234567890"), - mp::from_string<32>("-987654321098765432109876543210"), - std::plus<>{}); - - display_op("-", - mp::from_string<32>("987654321098765432109876543210"), - mp::from_string<32>("123456789012345678901234567890"), - std::minus<>{}); - - display_op("-", - mp::from_string<32>("123456789012345678901234567890"), - mp::from_string<32>("987654321098765432109876543210"), - std::minus<>{}); - - display_op("-", - mp::from_string<32>("-123456789012345678901234567890"), - mp::from_string<32>("-987654321098765432109876543210"), - std::minus<>{}); - - display_op("-", - mp::from_string<32>("987654321098765432109876543210"), - mp::from_string<32>("-123456789012345678901234567890"), - std::minus<>{}); - - display_op("*", - mp::from_string<32>("12345678901234567890"), - mp::from_string<32>("98765432109876543210"), - std::multiplies<>{}); - - display_op("*", - mp::from_string<32>("-12345678901234567890"), - mp::from_string<32>("98765432109876543210"), - std::multiplies<>{}); - - display_op("*", - mp::from_string<32>("-12345678901234567890"), - mp::from_string<32>("-98765432109876543210"), - std::multiplies<>{}); - - display_op("/", - mp::from_string<32>("1219326311370217952237463801111263506900"), - mp::from_string<32>("12345678901234567890"), - std::divides<>{}); - - display_op("/", - mp::from_string<32>("-1219326311370217952237463801111263506900"), - mp::from_string<32>("12345678901234567890"), - std::divides<>{}); - - display_op("/", - mp::from_string<32>("-1219326311370217952237463801111263506900"), - mp::from_string<32>("-12345678901234567890"), - std::divides<>{}); + display_op("+", 123456789012345678901234567890_mpi, 987654321098765432109876543210_mpi, std::plus<>{}); + display_op("+", 987654321098765432109876543210_mpi, -123456789012345678901234567890_mpi, std::plus<>{}); + display_op("+", "-123456789012345678901234567890", "-987654321098765432109876543210", std::plus<>{}); + display_op("+", "123456789012345678901234567890", "-987654321098765432109876543210", std::plus<>{}); + display_op("-", "987654321098765432109876543210", "123456789012345678901234567890", std::minus<>{}); + display_op("-", "123456789012345678901234567890", "987654321098765432109876543210", std::minus<>{}); + display_op("-", "-123456789012345678901234567890", "-987654321098765432109876543210", std::minus<>{}); + display_op("-", "987654321098765432109876543210", "-123456789012345678901234567890", std::minus<>{}); + display_op("*", "12345678901234567890", "98765432109876543210", std::multiplies<>{}); + display_op("*", "-12345678901234567890", "98765432109876543210", std::multiplies<>{}); + display_op("*", "-12345678901234567890", "-98765432109876543210", std::multiplies<>{}); + display_op("/", "1219326311370217952237463801111263526900", "12345678901234567890", std::divides<>{}); + display_op("/", "-1219326311370217952237463801111263526900", "12345678901234567890", std::divides<>{}); + display_op("/", "-1219326311370217952237463801111263526900", "-12345678901234567890", std::divides<>{}); auto a = mp::from_string<32>("12345678901234567890"); auto b = mp::from_string<32>("98765432109876543210"); @@ -225,8 +173,12 @@ int main() b, std::multiplies<>{}); - - + for (mp::UnlimitedInt i = "10000000000000000000000000000000000000000000"; + i > mp::UnlimitedInt("-20000000000000000000000000000000000000000000"); + i -= mp::UnlimitedInt("1000000000000000000000000000000000000000000")) + { + PrintDec("i", i); + } } } \ No newline at end of file diff --git a/mp/int.hpp b/mp/int.hpp index a00e6b0..ecdecd2 100644 --- a/mp/int.hpp +++ b/mp/int.hpp @@ -1,11 +1,39 @@ #pragma once +#include + #include "utils.hpp" #include "storage.hpp" namespace mp { +template +inline void parse_string(const char* str, T& number) +{ + number.zero(); + + bool negative = false; + + if (*str == '-') + { + negative = true; + ++str; + } + + for (; *str; ++str) + { + if (*str < '0' || *str > '9') + { + throw std::invalid_argument("Invalid character in input string"); + } + + number = number * 10U + static_cast(*str - '0'); + } + + number.set_negative(negative); +} + template class BasicInt { @@ -13,7 +41,7 @@ class BasicInt using ElementType = TElem; constexpr static size_t MAX_BYTES = MaxBytes; constexpr static size_t ELEMENT_BYTES = sizeof(TElem); - constexpr static size_t MAX_ELEMS = (MAX_BYTES + ELEMENT_BYTES - 1) / ELEMENT_BYTES; + constexpr static size_t MAX_ELEMS = calculate_max_elems(MAX_BYTES, ELEMENT_BYTES); constexpr static TElem LAST_ELEM_MASK = calculate_last_elem_mask(); BasicInt() = default; @@ -31,6 +59,92 @@ class BasicInt BasicInt& operator=(const BasicInt& other) = default; BasicInt& operator=(BasicInt&& other) noexcept = default; + template + BasicInt(const T& other) + { + *this = other; + } + + template + BasicInt(T value) + { + *this = value; + } + + BasicInt(const char* str) + { + *this = str; + } + + template + requires ElementTypesMatch + BasicInt& operator=(const T& other) + { + size_t other_size = other.size_elems(); + for (size_t i = 0; i < other_size; ++i) + { + set(i, other.get(i)); + } + + if (size_elems() > other_size) + resize(other_size); + + set_negative(other.negative()); + return *this; + } + + template + BasicInt& operator=(T value) + { + using UnsignedType = std::make_unsigned_t; + UnsignedType uvalue; + + if (std::is_signed_v && value < 0) + { + uvalue = static_cast(-value); + set_negative(true); + } + else + { + uvalue = static_cast(value); + set_negative(false); + } + + zero(); + + if constexpr (sizeof(UnsignedType) <= sizeof(TElem)) + { + set(0, static_cast(uvalue)); + } + else + { + constexpr UnsignedType ELEM_MASK = static_cast(std::numeric_limits::max()); + size_t idx = 0; + + while (uvalue != 0) + { + set(idx, static_cast(uvalue & ELEM_MASK)); + uvalue >>= (sizeof(TElem) * 8); + idx++; + } + } + + return *this; + } + + BasicInt& operator=(const char* str) + { + parse_string(str, *this); + return *this; + } + + BasicInt operator-() const + { + BasicInt res = *this; + res.set_negative(!res.negative()); + return res; + } + TElem& operator[](size_t index) { return m_data[index]; } const TElem& operator[](size_t index) const { return m_data[index]; } @@ -112,4 +226,7 @@ class BasicInt template using Int = BasicInt; +constexpr size_t UNLIMITED = std::numeric_limits::max(); +using UnlimitedInt = BasicInt; + } // namespace mp \ No newline at end of file diff --git a/mp/lib.hpp b/mp/lib.hpp index 37ae889..f34b419 100644 --- a/mp/lib.hpp +++ b/mp/lib.hpp @@ -15,7 +15,7 @@ inline char hex_digit(uint8_t bits) return bits > 9 ? 'a' + (bits - 10) : '0' + bits; } -template +template inline std::string to_hex_string(const T& number) { constexpr size_t ELEMENT_DIGITS = T::ELEMENT_BYTES * 2; @@ -59,7 +59,7 @@ inline typename T::ElementType div_mod(T& value, const typename T::ElementType d return remainder; } -template +template inline std::string to_string(const T& number) { std::string str; @@ -76,33 +76,7 @@ inline std::string to_string(const T& number) return str; } -template -inline void parse_string(const char* str, T& number) -{ - number.zero(); - - bool negative = false; - - if (*str == '-') - { - negative = true; - ++str; - } - - for (; *str; ++str) - { - if (*str < '0' || *str > '9') - { - throw std::invalid_argument("Invalid character in input string"); - } - - number = number * 10U + static_cast(*str - '0'); - } - - number.set_negative(negative); -} - -template +template inline T from_string(const char* str) { T number; @@ -118,8 +92,25 @@ inline Int from_string(const char* str) return number; } +template +inline T factorial(const T& n) +{ + if (n < 0) + { + throw std::invalid_argument("Factorial is not defined for negative numbers"); + } -// template + T result = 1; + for (T i = 2; i <= n; ++i) + { + result *= i; + } + + return result; +} + + +// template // inline T parse_string(const char* str) // { // T number; diff --git a/mp/math.hpp b/mp/math.hpp index c9764d9..8bdce22 100644 --- a/mp/math.hpp +++ b/mp/math.hpp @@ -9,10 +9,10 @@ namespace mp template constexpr size_t ResultMaxSize = std::max(SizeA, SizeB); -template +template struct OpResult; -template +template requires ElementTypesMatch struct OpResult @@ -20,17 +20,17 @@ struct OpResult using type = BasicInt>; }; -template +template struct OpResult { //using type = BasicInt>; using type = TLhs; // keep the same size as TLhs when operating with regular ints }; -template +template using OpResultType = typename OpResult::type; -template +template class OverflowError : public std::runtime_error { public: @@ -42,7 +42,7 @@ class OverflowError : public std::runtime_error T m_value; }; -template +template using OverflowErrorOf = OverflowError>; // Regular int to mp int @@ -50,84 +50,107 @@ template BasicInt to_mp_int(TR value) { BasicInt res; - if constexpr (sizeof(TR) <= sizeof(TElem)) - { - res.set(0, static_cast(value)); - } - else - { - constexpr TR ELEM_MASK = static_cast(std::numeric_limits::max()); - size_t idx = 0; - while (value != 0) - { - res.set(idx, static_cast(value & ELEM_MASK)); - value >>= (sizeof(TElem) * 8); - idx++; - } - } - + res = value; return res; } // no-op -template +template requires std::is_same_v const TR& to_mp_int(const TR& value) { return value; } -template +template requires ElementTypesMatch auto operator<=>(const TLhs& lhs, const TRhs& rhs) { - //auto size_comp = lhs.size_elems() <=> rhs.size_elems(); - //if (size_comp != std::strong_ordering::equal) - //{ - // return size_comp; - //} + const bool lhs_neg = lhs.negative(); + const bool rhs_neg = rhs.negative(); - //for (size_t i = lhs.size_elems(); i-- > 0;) - //{ - // auto elem_comp = lhs.get(i) <=> rhs.get(i); - // if (elem_comp != std::strong_ordering::equal) - // { - // return elem_comp; - // } - //} + std::strong_ordering res = std::strong_ordering::equal; + bool nonzero = false; size_t max_size = std::max(lhs.size_elems(), rhs.size_elems()); for (size_t i = max_size; i-- > 0;) { - auto elem_comp = lhs.get(i) <=> rhs.get(i); + auto l = lhs.get(i); + auto r = rhs.get(i); + + auto elem_comp = l <=> r; if (elem_comp != std::strong_ordering::equal) { - return elem_comp; + res = elem_comp; + break; + } + else + { + if (l != 0) + nonzero = true; } } - return std::strong_ordering::equal; + // both zero => ignore sign + if (res == std::strong_ordering::equal && !nonzero) + { + return std::strong_ordering::equal; + } + + // different sign, non-zero + if (lhs_neg != rhs_neg) + { + return lhs_neg ? std::strong_ordering::less : std::strong_ordering::greater; + } + + // flip for negative + if (lhs_neg) + { + res = (res == std::strong_ordering::less) ? std::strong_ordering::greater : std::strong_ordering::less; + } + + return res; } -template +template auto operator<=>(const TLhs& lhs, TRhs rhs) { return lhs <=> to_mp_int(rhs); } -template +template bool operator==(const TLhs& lhs, const TRhs& rhs) { return (lhs <=> rhs) == std::strong_ordering::equal; } -template +template bool operator!=(const TLhs& lhs, const TRhs& rhs) { return !(lhs == rhs); } -template +template + requires ElementTypesMatch +inline bool less_ignore_sign(const TLhs& lhs, const TRhs& rhs) +{ + size_t max_size = std::max(lhs.size_elems(), rhs.size_elems()); + + for (size_t i = max_size; i-- > 0; ) + { + auto l = lhs.get(i); + auto r = rhs.get(i); + + if (l < r) + return true; + else if (l > r) + return false; + } + + return false; +} + +template requires ElementTypesMatch && ElementTypesMatch inline void sub_ignore_underflow(const TLhs& lhs, const TRhs& rhs, TRes& res) { @@ -154,11 +177,11 @@ inline void sub_ignore_underflow(const TLhs& lhs, const TRhs& rhs, TRes& res) } } -template +template requires ElementTypesMatch && ElementTypesMatch inline void sub_ignore_sign(const TLhs& lhs, const TRhs& rhs, TRes& res) { - if (lhs < rhs) + if (less_ignore_sign(lhs, rhs)) { sub_ignore_underflow(rhs, lhs, res); res.set_negative(true); @@ -170,7 +193,7 @@ inline void sub_ignore_sign(const TLhs& lhs, const TRhs& rhs, TRes& res) } } -template +template requires ElementTypesMatch && ElementTypesMatch inline void add_ignore_sign(const TLhs& lhs, const TRhs& rhs, TRes& res) { @@ -200,7 +223,7 @@ inline void add_ignore_sign(const TLhs& lhs, const TRhs& rhs, TRes& res) struct Addition { - template + template requires ElementTypesMatch && ElementTypesMatch static void invoke(const TLhs& lhs, const TRhs& rhs, TRes& res) { @@ -227,7 +250,7 @@ struct Addition struct Subtraction { - template + template requires ElementTypesMatch && ElementTypesMatch static void invoke(const TLhs& lhs, const TRhs& rhs, TRes& res) { @@ -250,7 +273,7 @@ struct Subtraction add_ignore_sign(lhs, rhs, res); res.set_negative(true); } - else // if (!lhs_neg && rhs_neg) + else { // a - (-b) == a + b add_ignore_sign(lhs, rhs, res); @@ -263,7 +286,7 @@ struct Subtraction struct Multiplication { - template + template requires ElementTypesMatch && ElementTypesMatch static void invoke(const TLhs& lhs, const TRhs& rhs, TRes& res) { @@ -302,7 +325,7 @@ struct Multiplication struct Division { - template + template requires ElementTypesMatch && ElementTypesMatch static void invoke(const TLhs& lhs, const TRhs& rhs, TRes& res) { @@ -456,7 +479,7 @@ struct Division }; // Binary operation -template +template OpResultType binary_op(const TLhs& lhs, const TRhs& rhs) { OpResultType res{}; @@ -474,7 +497,7 @@ OpResultType binary_op(const TLhs& lhs, const TRhs& rhs) } // Compound assignment binary operation -template +template TLhs& ca_binary_op(TLhs& lhs, const TRhs& rhs) { TLhs res{}; @@ -494,13 +517,13 @@ TLhs& ca_binary_op(TLhs& lhs, const TRhs& rhs) // Addition operators -template +template OpResultType operator+(const TLhs& lhs, const TRhs& rhs) { return binary_op(lhs, rhs); } -template +template TLhs& operator+=(TLhs& lhs, const TRhs& rhs) { return ca_binary_op(lhs, rhs); @@ -508,13 +531,13 @@ TLhs& operator+=(TLhs& lhs, const TRhs& rhs) // Subtraction operators -template +template OpResultType operator-(const TLhs& lhs, const TRhs& rhs) { return binary_op(lhs, rhs); } -template +template TLhs& operator-=(TLhs& lhs, const TRhs& rhs) { return ca_binary_op(lhs, rhs); @@ -522,13 +545,13 @@ TLhs& operator-=(TLhs& lhs, const TRhs& rhs) // Multiplication operators -template +template OpResultType operator*(const TLhs& lhs, const TRhs& rhs) { return binary_op(lhs, rhs); } -template +template TLhs& operator*=(TLhs& lhs, const TRhs& rhs) { return ca_binary_op(lhs, rhs); @@ -536,13 +559,13 @@ TLhs& operator*=(TLhs& lhs, const TRhs& rhs) // Division operators -template +template OpResultType operator/(const TLhs& lhs, const TRhs& rhs) { return binary_op(lhs, rhs); } -template +template TLhs& operator/=(TLhs& lhs, const TRhs& rhs) { return ca_binary_op(lhs, rhs); diff --git a/mp/mp.hpp b/mp/mp.hpp new file mode 100644 index 0000000..5092ebe --- /dev/null +++ b/mp/mp.hpp @@ -0,0 +1,5 @@ +#pragma once + +#include "int.hpp" +#include "math.hpp" +#include "lib.hpp" diff --git a/mp/storage.hpp b/mp/storage.hpp index d452969..760c117 100644 --- a/mp/storage.hpp +++ b/mp/storage.hpp @@ -65,6 +65,6 @@ constexpr size_t MAX_ARRAY_BYTES = 128; template using Container = - std::conditional_t<(sizeof(TElem) * MaxSize > MAX_ARRAY_BYTES), VectorContainer, ArrayContainer>; + std::conditional_t<(MaxSize > MAX_ARRAY_BYTES / sizeof(TElem)), VectorContainer, ArrayContainer>; } // namespace mp \ No newline at end of file diff --git a/mp/utils.hpp b/mp/utils.hpp index 4fa694b..ce5ab33 100644 --- a/mp/utils.hpp +++ b/mp/utils.hpp @@ -78,10 +78,14 @@ template concept ElementSuitable = std::unsigned_integral; template -concept AnyConstMpInt = requires(T t, TElem a, size_t i) { +concept AnyMpInt = requires(T t, TElem a, size_t i, bool b) { { t[i] } -> std::convertible_to; + { t.zero() }; + { t.set(i, a) }; + { t.try_set(i, a) } -> std::convertible_to; { t.get(i) } -> std::convertible_to; { t.size_elems() } -> std::convertible_to; + { t.set_negative(b) }; { t.negative() } -> std::convertible_to; { T::MAX_BYTES } -> std::convertible_to; { T::MAX_ELEMS } -> std::convertible_to; @@ -89,19 +93,11 @@ concept AnyConstMpInt = requires(T t, TElem a, size_t i) { { T::LAST_ELEM_MASK } -> std::convertible_to; }; -template -concept AnyMpInt = AnyConstMpInt && requires(T t, TElem a, size_t i, bool b) { - { t.set(i, a) }; - { t.try_set(i, a) } -> std::convertible_to; - { t.set_negative(b) }; - { t.zero() }; -}; +template +concept AnyRegularInt = std::integral; template -concept AnyRegularInt = std::unsigned_integral; - -template -concept AnyConstInt = AnyRegularInt || AnyConstMpInt; +concept AnyInt = AnyRegularInt || AnyMpInt; //======================== Utils =======================// @@ -111,9 +107,17 @@ using LongestElementSuitableType = uint64_t; using LongestElementSuitableType = HalfWidthType; #endif -template +template constexpr bool ElementTypesMatch = std::is_same_v; +constexpr size_t calculate_max_elems(size_t max_bytes, size_t elem_bytes) +{ + if (max_bytes % elem_bytes == 0) + return max_bytes / elem_bytes; + else + return (max_bytes / elem_bytes) + 1; +} + template constexpr TElem calculate_last_elem_mask() {