This commit is contained in:
tovjemam 2025-11-23 19:00:32 +01:00
parent 08ea929084
commit 7e0b3cb9fa
4 changed files with 299 additions and 36 deletions

View File

@ -81,4 +81,39 @@ int main()
std::cout << "error: " << e.what() << std::endl;
}
}
//{
// mp::Int<32> val{1};
// for (mp::Int<4> a{1}; a < 100U; a += 1U)
// {
// for (size_t i = 0; i < 20; ++i)
// {
// val *= a;
// //PrintInt("val", val);
// }
// for (size_t i = 0; i < 20; ++i)
// {
// val /= a;
// }
// //if (val != 0U)
// //{
// PrintInt("val", val);
// //}
// }
//}
{
mp::Int<32> a{0xA0000000, 0x6D7217CA, 0x431E0FAE, 0x1};
mp::Int<16> b{0x6FC10000, 0x2386F2};
PrintInt("a", a);
PrintInt("b", b);
auto c = a / b;
PrintInt("c", c);
}
}

View File

@ -83,6 +83,17 @@ class BasicInt
return m_data[idx];
}
void fix_leading_zeros()
{
size_t new_size = m_data.size();
while (new_size > 0 && m_data[new_size - 1] == TElem{0})
{
new_size--;
}
m_data.resize(new_size);
}
//std::span<TElem> data() { return m_data; }
size_t size_elems() const { return m_data.size(); }

View File

@ -10,11 +10,28 @@ namespace mp
template <size_t SizeA, size_t SizeB>
constexpr size_t ResultMaxSize = std::max(SizeA, SizeB);
template <AnyInt TLhs, AnyInt TRhs>
requires ElementTypesMatch<TLhs, TRhs>
using OpResult = BasicInt<typename TLhs::ElementType, ResultMaxSize<TLhs::MAX_BYTES, TRhs::MAX_BYTES>>;
template <AnyConstMpInt TL, AnyConstInt TR>
struct OpResult;
template <AnyInt T>
template <AnyConstMpInt TLhs, AnyConstMpInt TRhs>
requires ElementTypesMatch<TLhs, TRhs>
struct OpResult<TLhs, TRhs>
{
using type = BasicInt<typename TLhs::ElementType, ResultMaxSize<TLhs::MAX_BYTES, TRhs::MAX_BYTES>>;
};
template <AnyConstMpInt TLhs, AnyRegularInt TRhs>
struct OpResult<TLhs, TRhs>
{
//using type = BasicInt<typename TLhs::ElementType, ResultMaxSize<TLhs::MAX_BYTES, sizeof(TRhs)>>;
using type = TLhs; // keep the same size as TLhs when operating with regular ints
};
template <AnyConstMpInt TLhs, AnyConstInt TRhs>
using OpResultType = typename OpResult<TLhs, TRhs>::type;
template <AnyConstMpInt T>
class OverflowError : public std::runtime_error
{
public:
@ -26,12 +43,94 @@ class OverflowError : public std::runtime_error
T m_value;
};
template <AnyInt TLhs, AnyInt TRhs>
using OverflowErrorOf = OverflowError<OpResult<TLhs, TRhs>>;
template <AnyConstMpInt TLhs, AnyConstInt TRhs>
using OverflowErrorOf = OverflowError<OpResultType<TLhs, TRhs>>;
// Regular int to mp int
template <ElementSuitable TElem, AnyRegularInt TR>
BasicInt<TElem, sizeof(TR)> to_mp_int(TR value)
{
BasicInt<TElem, sizeof(TR)> res;
if constexpr (sizeof(TR) <= sizeof(TElem))
{
res.set(0, static_cast<TElem>(value));
}
else
{
constexpr TR ELEM_MASK = static_cast<TR>(std::numeric_limits<TElem>::max());
size_t idx = 0;
while (value != 0)
{
res.set(idx, static_cast<TElem>(value & ELEM_MASK));
value >>= (sizeof(TElem) * 8);
idx++;
}
}
return res;
}
// no-op
template <ElementSuitable TElem, AnyConstMpInt TR>
requires std::is_same_v<typename TR::ElementType, TElem>
const TR& to_mp_int(const TR& value)
{
return value;
}
template <AnyConstMpInt TLhs, AnyConstMpInt TRhs>
requires ElementTypesMatch<TLhs, TRhs>
auto operator<=>(const TLhs& lhs, const TRhs& rhs)
{
//auto size_comp = lhs.size_elems() <=> rhs.size_elems();
//if (size_comp != std::strong_ordering::equal)
//{
// return size_comp;
//}
//for (size_t i = lhs.size_elems(); i-- > 0;)
//{
// auto elem_comp = lhs.get(i) <=> rhs.get(i);
// if (elem_comp != std::strong_ordering::equal)
// {
// return elem_comp;
// }
//}
size_t max_size = std::max(lhs.size_elems(), rhs.size_elems());
for (size_t i = max_size; i-- > 0;)
{
auto elem_comp = lhs.get(i) <=> rhs.get(i);
if (elem_comp != std::strong_ordering::equal)
{
return elem_comp;
}
}
return std::strong_ordering::equal;
}
template <AnyConstMpInt TLhs, AnyRegularInt TRhs>
auto operator<=>(const TLhs& lhs, TRhs rhs)
{
return lhs <=> to_mp_int<typename TLhs::ElementType>(rhs);
}
template <AnyConstMpInt TLhs, AnyConstInt TRhs>
bool operator==(const TLhs& lhs, const TRhs& rhs)
{
return (lhs <=> rhs) == std::strong_ordering::equal;
}
template <AnyConstMpInt TLhs, AnyConstInt TRhs>
bool operator!=(const TLhs& lhs, const TRhs& rhs)
{
return !(lhs == rhs);
}
struct Addition
{
template <AnyInt TLhs, AnyInt TRhs, AnyInt TRes>
template <AnyConstMpInt TLhs, AnyConstMpInt TRhs, AnyMpInt TRes>
requires ElementTypesMatch<TLhs, TRes> && ElementTypesMatch<TRhs, TRes>
static void invoke(const TLhs& lhs, const TRhs& rhs, TRes& res)
{
@ -60,14 +159,44 @@ struct Addition
}
};
struct Multiplication
struct Subtraction
{
template <AnyInt TLhs, AnyInt TRhs, AnyInt TRes>
template <AnyConstMpInt TLhs, AnyConstMpInt TRhs, AnyMpInt TRes>
requires ElementTypesMatch<TLhs, TRes> && ElementTypesMatch<TRhs, TRes>
static void invoke(const TLhs& lhs, const TRhs& rhs, TRes& res)
{
using ElementType = typename TRes::ElementType;
using DoubleType = DoubleWidthType<ElementType>;
res.zero();
ElementType borrow = 0;
size_t end = std::max(lhs.size_elems(), rhs.size_elems());
for (size_t i = 0; i < end; ++i)
{
ElementType a = lhs.get(i);
ElementType b = rhs.get(i);
ElementType c = a - b - borrow;
borrow = (a < b) || (borrow && a == b);
res.set(i, c);
}
if (borrow != 0)
{
throw std::overflow_error("Subtraction underflow");
}
}
};
struct Multiplication
{
template <AnyConstMpInt TLhs, AnyConstMpInt TRhs, AnyMpInt TRes>
requires ElementTypesMatch<TLhs, TRes> && ElementTypesMatch<TRhs, TRes>
static void invoke(const TLhs& lhs, const TRhs& rhs, TRes& res)
{
using ElementType = typename TRes::ElementType;
using LongerType = DoubleWidthType<ElementType>;
res.zero();
@ -78,12 +207,12 @@ struct Multiplication
for (size_t i = 0; i < n; i++)
{
DoubleType carry = 0;
LongerType carry = 0;
ElementType a = lhs[i];
for (size_t j = 0; j < m; j++)
{
ElementType b = rhs[j];
DoubleType t = static_cast<DoubleType>(a) * b + res.get(i + j) + carry;
LongerType t = static_cast<LongerType>(a) * b + res.get(i + j) + carry;
overflow |= !res.try_set(i + j, static_cast<ElementType>(t));
carry = t >> (sizeof(ElementType) * 8);
@ -95,19 +224,73 @@ struct Multiplication
{
throw std::overflow_error("Multiplication overflow");
}
}
};
struct Division
{
template <AnyMpInt T>
static void shift_left(T& value)
{
using ElementType = typename T::ElementType;
for (size_t i = value.size_elems(); i-- > 0;)
{
ElementType new_val = value.get(i - 1);
value.set(i, new_val);
}
value.set(0, ElementType{0});
}
template <AnyConstMpInt TLhs, AnyConstMpInt TRhs, AnyMpInt TRes>
requires ElementTypesMatch<TLhs, TRes> && ElementTypesMatch<TRhs, TRes>
static void invoke(const TLhs& lhs, const TRhs& rhs, TRes& res)
{
using ElementType = typename TRes::ElementType;
using LongerType = DoubleWidthType<ElementType>;
if (rhs == 0U)
{
throw std::runtime_error("Division by zero");
}
BasicInt<ElementType, TRes::MAX_BYTES> remainder = lhs;
res.zero();
BasicInt<ElementType, TRhs::MAX_BYTES> temp;
while (remainder >= rhs)
{
int shift = remainder.size_elems() - rhs.size_elems();
temp = rhs;
for (int i = 0; i < shift; ++i)
{
shift_left(temp);
}
ElementType q_digit = 0;
if (remainder >= temp)
{
remainder -= temp;
q_digit++;
}
res.set(shift, q_digit);
}
};
};
// Binary operation
template <typename TOp, AnyInt TLhs, AnyInt TRhs>
requires ElementTypesMatch<TLhs, TRhs>
OpResult<TLhs, TRhs> binary_op(const TLhs& lhs, const TRhs& rhs)
template <typename TOp, AnyConstMpInt TLhs, AnyConstInt TRhs>
OpResultType<TLhs, TRhs> binary_op(const TLhs& lhs, const TRhs& rhs)
{
OpResult<TLhs, TRhs> res{};
OpResultType<TLhs, TRhs> res{};
try
{
TOp::invoke(lhs, rhs, res);
TOp::invoke(lhs, to_mp_int<typename TLhs::ElementType>(rhs), res);
}
catch (const std::overflow_error&)
{
@ -118,15 +301,14 @@ OpResult<TLhs, TRhs> binary_op(const TLhs& lhs, const TRhs& rhs)
}
// Compound assignment binary operation
template <typename TOp, AnyInt TLhs, AnyInt TRhs>
requires ElementTypesMatch<TLhs, TRhs>
template <typename TOp, AnyMpInt TLhs, AnyConstInt TRhs>
TLhs& ca_binary_op(TLhs& lhs, const TRhs& rhs)
{
TLhs res{};
try
{
TOp::invoke(lhs, rhs, res);
TOp::invoke(lhs, to_mp_int<typename TLhs::ElementType>(rhs), res);
}
catch (const std::overflow_error&)
{
@ -139,34 +321,59 @@ TLhs& ca_binary_op(TLhs& lhs, const TRhs& rhs)
// Addition operators
template <AnyInt TLhs, AnyInt TRhs>
requires ElementTypesMatch<TLhs, TRhs>
OpResult<TLhs, TRhs> operator+(const TLhs& lhs, const TRhs& rhs)
template <AnyConstMpInt TLhs, AnyConstInt TRhs>
OpResultType<TLhs, TRhs> operator+(const TLhs& lhs, const TRhs& rhs)
{
return binary_op<Addition>(lhs, rhs);
}
template <AnyInt TLhs, AnyInt TRhs>
requires ElementTypesMatch<TLhs, TRhs>
template <AnyMpInt TLhs, AnyConstInt TRhs>
TLhs& operator+=(TLhs& lhs, const TRhs& rhs)
{
return ca_binary_op<Addition>(lhs, rhs);
}
// Subtraction operators
template <AnyConstMpInt TLhs, AnyConstInt TRhs>
OpResultType<TLhs, TRhs> operator-(const TLhs& lhs, const TRhs& rhs)
{
return binary_op<Subtraction>(lhs, rhs);
}
template <AnyMpInt TLhs, AnyConstInt TRhs>
TLhs& operator-=(TLhs& lhs, const TRhs& rhs)
{
return ca_binary_op<Subtraction>(lhs, rhs);
}
// Multiplication operators
template <AnyInt TLhs, AnyInt TRhs>
requires ElementTypesMatch<TLhs, TRhs>
OpResult<TLhs, TRhs> operator*(const TLhs& lhs, const TRhs& rhs)
template <AnyConstMpInt TLhs, AnyConstInt TRhs>
OpResultType<TLhs, TRhs> operator*(const TLhs& lhs, const TRhs& rhs)
{
return binary_op<Multiplication>(lhs, rhs);
}
template <AnyInt TLhs, AnyInt TRhs>
requires ElementTypesMatch<TLhs, TRhs>
template <AnyMpInt TLhs, AnyConstInt TRhs>
TLhs& operator*=(TLhs& lhs, const TRhs& rhs)
{
return ca_binary_op<Multiplication>(lhs, rhs);
}
// Division operators
template <AnyConstMpInt TLhs, AnyConstInt TRhs>
OpResultType<TLhs, TRhs> operator/(const TLhs& lhs, const TRhs& rhs)
{
return binary_op<Division>(lhs, rhs);
}
template <AnyMpInt TLhs, AnyConstInt TRhs>
TLhs& operator/=(TLhs& lhs, const TRhs& rhs)
{
return ca_binary_op<Division>(lhs, rhs);
}
} // namespace mp

View File

@ -73,12 +73,9 @@ template <typename T>
concept ElementSuitable = std::unsigned_integral<T>;
template <typename T, typename TElem = T::ElementType>
concept AnyInt = requires(T t, TElem a, size_t i) {
concept AnyConstMpInt = requires(T t, TElem a, size_t i) {
{ t[i] } -> std::convertible_to<TElem>;
{ t.set(i, a) };
{ t.try_set(i, a) } -> std::convertible_to<bool>;
{ t.get(i) } -> std::convertible_to<TElem>;
{ t.zero() };
{ t.size_elems() } -> std::convertible_to<size_t>;
{ T::MAX_BYTES } -> std::convertible_to<size_t>;
{ T::MAX_ELEMS } -> std::convertible_to<size_t>;
@ -86,6 +83,19 @@ concept AnyInt = requires(T t, TElem a, size_t i) {
{ T::LAST_ELEM_MASK } -> std::convertible_to<TElem>;
};
template <typename T, typename TElem = T::ElementType>
concept AnyMpInt = AnyConstMpInt<T> && requires(T t, TElem a, size_t i) {
{ t.set(i, a) };
{ t.try_set(i, a) } -> std::convertible_to<bool>;
{ t.zero() };
};
template <typename T>
concept AnyRegularInt = std::unsigned_integral<T>;
template <typename T>
concept AnyConstInt = AnyRegularInt<T> || AnyConstMpInt<T>;
//======================== Utils =======================//
#ifdef __SIZEOF_INT128__
@ -94,7 +104,7 @@ using LongestElementSuitableType = uint64_t;
using LongestElementSuitableType = HalfWidthType<size_t>;
#endif
template <AnyInt TA, AnyInt TB>
template <AnyConstMpInt TA, AnyConstMpInt TB>
constexpr bool ElementTypesMatch = std::is_same_v<typename TA::ElementType, typename TB::ElementType>;
template <ElementSuitable TElem, size_t MaxBytes>
@ -121,7 +131,7 @@ inline char hex_digit(uint8_t bits)
return bits > 9 ? 'a' + (bits - 10) : '0' + bits;
}
template <AnyInt T>
template <AnyConstMpInt T>
std::string to_hex_string(const T& number)
{
constexpr size_t ELEMENT_DIGITS = T::ELEMENT_BYTES * 2;