573 lines
16 KiB
C++
573 lines
16 KiB
C++
#pragma once
|
|
|
|
#include "int.hpp"
|
|
#include "utils.hpp"
|
|
|
|
namespace mp
|
|
{
|
|
|
|
template <size_t SizeA, size_t SizeB>
|
|
constexpr size_t ResultMaxSize = std::max(SizeA, SizeB);
|
|
|
|
template <AnyMpInt TL, AnyInt TR>
|
|
struct OpResult;
|
|
|
|
template <AnyMpInt TLhs, AnyMpInt TRhs>
|
|
requires ElementTypesMatch<TLhs, TRhs>
|
|
struct OpResult<TLhs, TRhs>
|
|
|
|
{
|
|
using type = BasicInt<typename TLhs::ElementType, ResultMaxSize<TLhs::MAX_BYTES, TRhs::MAX_BYTES>>;
|
|
};
|
|
|
|
template <AnyMpInt TLhs, AnyRegularInt TRhs>
|
|
struct OpResult<TLhs, TRhs>
|
|
{
|
|
// using type = BasicInt<typename TLhs::ElementType, ResultMaxSize<TLhs::MAX_BYTES, sizeof(TRhs)>>;
|
|
using type = TLhs; // keep the same size as TLhs when operating with regular ints
|
|
};
|
|
|
|
template <AnyMpInt TLhs, AnyInt TRhs>
|
|
using OpResultType = typename OpResult<TLhs, TRhs>::type;
|
|
|
|
template <AnyMpInt T>
|
|
class OverflowError : public std::runtime_error
|
|
{
|
|
public:
|
|
OverflowError(T&& value) : std::runtime_error("Overflow"), m_value(std::move(value)) {}
|
|
|
|
const T& value() const { return m_value; }
|
|
|
|
private:
|
|
T m_value;
|
|
};
|
|
|
|
template <AnyMpInt TLhs, AnyInt TRhs>
|
|
using OverflowErrorOf = OverflowError<OpResultType<TLhs, TRhs>>;
|
|
|
|
// Regular int to mp int
|
|
template <ElementSuitable TElem, AnyRegularInt TR>
|
|
BasicInt<TElem, sizeof(TR)> to_mp_int(TR value)
|
|
{
|
|
BasicInt<TElem, sizeof(TR)> res;
|
|
res = value;
|
|
return res;
|
|
}
|
|
|
|
// no-op
|
|
template <ElementSuitable TElem, AnyMpInt TR>
|
|
requires std::is_same_v<typename TR::ElementType, TElem>
|
|
const TR& to_mp_int(const TR& value)
|
|
{
|
|
return value;
|
|
}
|
|
|
|
template <AnyMpInt TLhs, AnyMpInt TRhs>
|
|
requires ElementTypesMatch<TLhs, TRhs>
|
|
auto operator<=>(const TLhs& lhs, const TRhs& rhs)
|
|
{
|
|
const bool lhs_neg = lhs.negative();
|
|
const bool rhs_neg = rhs.negative();
|
|
|
|
std::strong_ordering res = std::strong_ordering::equal;
|
|
bool nonzero = false;
|
|
|
|
size_t max_size = std::max(lhs.size_elems(), rhs.size_elems());
|
|
for (size_t i = max_size; i-- > 0;)
|
|
{
|
|
auto l = lhs.get(i);
|
|
auto r = rhs.get(i);
|
|
|
|
auto elem_comp = l <=> r;
|
|
if (elem_comp != std::strong_ordering::equal)
|
|
{
|
|
res = elem_comp;
|
|
break;
|
|
}
|
|
else
|
|
{
|
|
if (l != 0)
|
|
nonzero = true;
|
|
}
|
|
}
|
|
|
|
// both zero => ignore sign
|
|
if (res == std::strong_ordering::equal && !nonzero)
|
|
{
|
|
return std::strong_ordering::equal;
|
|
}
|
|
|
|
// different sign, non-zero
|
|
if (lhs_neg != rhs_neg)
|
|
{
|
|
return lhs_neg ? std::strong_ordering::less : std::strong_ordering::greater;
|
|
}
|
|
|
|
// flip for negative
|
|
if (lhs_neg)
|
|
{
|
|
res = (res == std::strong_ordering::less) ? std::strong_ordering::greater : std::strong_ordering::less;
|
|
}
|
|
|
|
return res;
|
|
}
|
|
|
|
template <AnyMpInt TLhs, AnyRegularInt TRhs>
|
|
auto operator<=>(const TLhs& lhs, TRhs rhs)
|
|
{
|
|
return lhs <=> to_mp_int<typename TLhs::ElementType>(rhs);
|
|
}
|
|
|
|
template <AnyMpInt TLhs, AnyInt TRhs>
|
|
bool operator==(const TLhs& lhs, const TRhs& rhs)
|
|
{
|
|
return (lhs <=> rhs) == std::strong_ordering::equal;
|
|
}
|
|
|
|
template <AnyMpInt TLhs, AnyInt TRhs>
|
|
bool operator!=(const TLhs& lhs, const TRhs& rhs)
|
|
{
|
|
return !(lhs == rhs);
|
|
}
|
|
|
|
template <AnyMpInt TLhs, AnyMpInt TRhs>
|
|
requires ElementTypesMatch<TLhs, TRhs>
|
|
inline bool less_ignore_sign(const TLhs& lhs, const TRhs& rhs)
|
|
{
|
|
size_t max_size = std::max(lhs.size_elems(), rhs.size_elems());
|
|
|
|
for (size_t i = max_size; i-- > 0;)
|
|
{
|
|
auto l = lhs.get(i);
|
|
auto r = rhs.get(i);
|
|
|
|
if (l < r)
|
|
return true;
|
|
else if (l > r)
|
|
return false;
|
|
}
|
|
|
|
return false;
|
|
}
|
|
|
|
template <AnyMpInt TLhs, AnyMpInt TRhs, AnyMpInt TRes>
|
|
requires ElementTypesMatch<TLhs, TRes> && ElementTypesMatch<TRhs, TRes>
|
|
inline void sub_ignore_underflow(const TLhs& lhs, const TRhs& rhs, TRes& res)
|
|
{
|
|
using ElementType = typename TRes::ElementType;
|
|
|
|
res.zero();
|
|
|
|
ElementType borrow = 0;
|
|
size_t end = std::max(lhs.size_elems(), rhs.size_elems());
|
|
|
|
for (size_t i = 0; i < end; ++i)
|
|
{
|
|
ElementType a = lhs.get(i);
|
|
ElementType b = rhs.get(i);
|
|
|
|
ElementType c = a - b - borrow;
|
|
borrow = (a < b) || (borrow && a == b);
|
|
res.set(i, c);
|
|
}
|
|
|
|
if (borrow != 0)
|
|
{
|
|
throw std::overflow_error("Subtraction underflow");
|
|
}
|
|
}
|
|
|
|
template <AnyMpInt TLhs, AnyMpInt TRhs, AnyMpInt TRes>
|
|
requires ElementTypesMatch<TLhs, TRes> && ElementTypesMatch<TRhs, TRes>
|
|
inline void sub_ignore_sign(const TLhs& lhs, const TRhs& rhs, TRes& res)
|
|
{
|
|
if (less_ignore_sign(lhs, rhs))
|
|
{
|
|
sub_ignore_underflow(rhs, lhs, res);
|
|
res.set_negative(true);
|
|
}
|
|
else
|
|
{
|
|
sub_ignore_underflow(lhs, rhs, res);
|
|
res.set_negative(false);
|
|
}
|
|
}
|
|
|
|
template <AnyMpInt TLhs, AnyMpInt TRhs, AnyMpInt TRes>
|
|
requires ElementTypesMatch<TLhs, TRes> && ElementTypesMatch<TRhs, TRes>
|
|
inline void add_ignore_sign(const TLhs& lhs, const TRhs& rhs, TRes& res)
|
|
{
|
|
using ElementType = typename TRes::ElementType;
|
|
|
|
res.zero();
|
|
|
|
ElementType carry = 0;
|
|
size_t end = std::max(lhs.size_elems(), rhs.size_elems());
|
|
|
|
for (size_t i = 0; i < end; ++i)
|
|
{
|
|
ElementType a = lhs.get(i);
|
|
ElementType b = rhs.get(i);
|
|
|
|
ElementType c = carry + a;
|
|
carry = (c < a) ? 1 : 0;
|
|
c += b;
|
|
if (c < b)
|
|
carry = 1;
|
|
|
|
res.set(i, c);
|
|
}
|
|
|
|
res.set(end, carry);
|
|
}
|
|
|
|
struct Addition
|
|
{
|
|
template <AnyMpInt TLhs, AnyMpInt TRhs, AnyMpInt TRes>
|
|
requires ElementTypesMatch<TLhs, TRes> && ElementTypesMatch<TRhs, TRes>
|
|
static void invoke(const TLhs& lhs, const TRhs& rhs, TRes& res)
|
|
{
|
|
const bool lhs_neg = lhs.negative();
|
|
const bool rhs_neg = rhs.negative();
|
|
|
|
if (lhs_neg == rhs_neg)
|
|
{
|
|
add_ignore_sign(lhs, rhs, res);
|
|
res.set_negative(lhs_neg);
|
|
}
|
|
else if (lhs_neg)
|
|
{
|
|
// (-a) + b == b - a
|
|
sub_ignore_sign(rhs, lhs, res);
|
|
}
|
|
else
|
|
{
|
|
// a + (-b) == a - b
|
|
sub_ignore_sign(lhs, rhs, res);
|
|
}
|
|
}
|
|
};
|
|
|
|
struct Subtraction
|
|
{
|
|
template <AnyMpInt TLhs, AnyMpInt TRhs, AnyMpInt TRes>
|
|
requires ElementTypesMatch<TLhs, TRes> && ElementTypesMatch<TRhs, TRes>
|
|
static void invoke(const TLhs& lhs, const TRhs& rhs, TRes& res)
|
|
{
|
|
const bool lhs_neg = lhs.negative();
|
|
const bool rhs_neg = rhs.negative();
|
|
|
|
if (!lhs_neg && !rhs_neg)
|
|
{
|
|
// a - b
|
|
sub_ignore_sign(lhs, rhs, res);
|
|
}
|
|
else if (lhs_neg && rhs_neg)
|
|
{
|
|
// (-a) - (-b) == b - a
|
|
sub_ignore_sign(rhs, lhs, res);
|
|
}
|
|
else if (lhs_neg && !rhs_neg)
|
|
{
|
|
// (-a) - b == -(a + b)
|
|
add_ignore_sign(lhs, rhs, res);
|
|
res.set_negative(true);
|
|
}
|
|
else
|
|
{
|
|
// a - (-b) == a + b
|
|
add_ignore_sign(lhs, rhs, res);
|
|
res.set_negative(false);
|
|
}
|
|
}
|
|
};
|
|
|
|
struct Multiplication
|
|
{
|
|
template <AnyMpInt TLhs, AnyMpInt TRhs, AnyMpInt TRes>
|
|
requires ElementTypesMatch<TLhs, TRes> && ElementTypesMatch<TRhs, TRes>
|
|
static void invoke(const TLhs& lhs, const TRhs& rhs, TRes& res)
|
|
{
|
|
using ElementType = typename TRes::ElementType;
|
|
using LongerType = DoubleWidthType<ElementType>;
|
|
|
|
res.zero();
|
|
res.set_negative(lhs.negative() != rhs.negative());
|
|
|
|
bool overflow = false;
|
|
|
|
const size_t n = lhs.size_elems();
|
|
const size_t m = rhs.size_elems();
|
|
|
|
for (size_t i = 0; i < n; i++)
|
|
{
|
|
LongerType carry = 0;
|
|
ElementType a = lhs[i];
|
|
for (size_t j = 0; j < m; j++)
|
|
{
|
|
ElementType b = rhs[j];
|
|
LongerType t = static_cast<LongerType>(a) * b + res.get(i + j) + carry;
|
|
|
|
overflow |= !res.try_set(i + j, static_cast<ElementType>(t));
|
|
carry = t >> (sizeof(ElementType) * 8);
|
|
}
|
|
overflow |= !res.try_set(i + m, static_cast<ElementType>(carry));
|
|
}
|
|
|
|
if (overflow)
|
|
{
|
|
throw std::overflow_error("Multiplication overflow");
|
|
}
|
|
}
|
|
};
|
|
|
|
struct Division
|
|
{
|
|
template <AnyMpInt TLhs, AnyMpInt TRhs, AnyMpInt TRes>
|
|
requires ElementTypesMatch<TLhs, TRes> && ElementTypesMatch<TRhs, TRes>
|
|
static void invoke(const TLhs& lhs, const TRhs& rhs, TRes& res)
|
|
{
|
|
using ElementType = typename TRes::ElementType;
|
|
using LongerType = DoubleWidthType<ElementType>;
|
|
|
|
if (rhs == 0U)
|
|
{
|
|
throw std::runtime_error("Division by zero");
|
|
}
|
|
|
|
const size_t n = rhs.size_elems();
|
|
const size_t m = (lhs.size_elems() >= n) ? (lhs.size_elems() - n) : 0;
|
|
|
|
// If divisor larger than dividend => quotient = 0
|
|
if (lhs.size_elems() < n)
|
|
{
|
|
res.zero();
|
|
return;
|
|
}
|
|
|
|
// Base and mask
|
|
constexpr unsigned int W = sizeof(ElementType) * 8;
|
|
const LongerType BASE = (static_cast<LongerType>(1) << W);
|
|
const LongerType MASK = BASE - 1;
|
|
|
|
// Prepare u (normalized dividend) length m + n + 1
|
|
BasicInt<ElementType, calculate_sum_bytes(calculate_sum_bytes(TLhs::MAX_BYTES, TRhs::MAX_BYTES), sizeof(ElementType))> u;
|
|
u.zero();
|
|
for (size_t i = 0; i < lhs.size_elems(); ++i)
|
|
u.set(i, lhs.get(i));
|
|
|
|
// Prepare v (normalized divisor) length n
|
|
BasicInt<ElementType, calculate_sum_bytes(TRhs::MAX_BYTES, sizeof(ElementType))> v;
|
|
v.zero();
|
|
for (size_t i = 0; i < n; ++i)
|
|
v.set(i, rhs.get(i));
|
|
|
|
// Normalization factor d = BASE / (v[n-1] + 1)
|
|
ElementType d = 1;
|
|
if (v.get(n - 1) + 1 != 0) // defensive
|
|
{
|
|
d = static_cast<ElementType>(BASE / (static_cast<LongerType>(v.get(n - 1)) + 1));
|
|
}
|
|
|
|
if (d != 1)
|
|
{
|
|
// u = u * d
|
|
LongerType carry = 0;
|
|
for (size_t i = 0; i < (m + n + 1); ++i)
|
|
{
|
|
LongerType t = static_cast<LongerType>(u.get(i)) * d + carry;
|
|
u.set(i, static_cast<ElementType>(t & MASK));
|
|
carry = t >> W;
|
|
}
|
|
// v = v * d
|
|
carry = 0;
|
|
for (size_t i = 0; i < n; ++i)
|
|
{
|
|
LongerType t = static_cast<LongerType>(v.get(i)) * d + carry;
|
|
v.set(i, static_cast<ElementType>(t & MASK));
|
|
carry = t >> W;
|
|
}
|
|
// v[n] implicitly 0
|
|
}
|
|
|
|
// Prepare quotient
|
|
res.zero();
|
|
res.set_negative(lhs.negative() != rhs.negative());
|
|
|
|
// Main loop j = m .. 0
|
|
for (int j = static_cast<int>(m); j >= 0; --j)
|
|
{
|
|
// u[j + n] might be zero or >0
|
|
const LongerType uj_n = static_cast<LongerType>(u.get(j + n));
|
|
const LongerType uj_n1 = static_cast<LongerType>(u.get(j + n - 1));
|
|
const LongerType vn_1 = static_cast<LongerType>(v.get(n - 1));
|
|
|
|
// Estimate q_hat = (u[j+n]*BASE + u[j+n-1]) / v[n-1]
|
|
LongerType numerator = (uj_n * BASE) + uj_n1;
|
|
LongerType qhat = numerator / vn_1;
|
|
LongerType rhat = numerator % vn_1;
|
|
|
|
if (qhat >= BASE)
|
|
qhat = BASE - 1;
|
|
|
|
// Correction loop (only if n >= 2)
|
|
if (n >= 2)
|
|
{
|
|
const LongerType vn_2 = static_cast<LongerType>(v.get(n - 2));
|
|
while (qhat * vn_2 > (rhat * BASE + static_cast<LongerType>(u.get(j + n - 2))))
|
|
{
|
|
qhat -= 1;
|
|
rhat += vn_1;
|
|
if (rhat >= BASE)
|
|
break;
|
|
}
|
|
}
|
|
|
|
// Multiply v by qhat and subtract from u segment starting at j
|
|
LongerType carry_mul = 0;
|
|
unsigned int borrow = 0;
|
|
for (size_t i = 0; i < n; ++i)
|
|
{
|
|
// p = qhat * v[i] + carry_mul
|
|
LongerType p = qhat * static_cast<LongerType>(v.get(i)) + carry_mul;
|
|
ElementType p_low = static_cast<ElementType>(p & MASK);
|
|
carry_mul = p >> W;
|
|
|
|
ElementType uval = u.get(j + i);
|
|
LongerType sub = static_cast<LongerType>(uval);
|
|
LongerType needed = static_cast<LongerType>(p_low) + borrow;
|
|
bool under = (sub < needed);
|
|
ElementType new_u = static_cast<ElementType>((sub + BASE - needed) & MASK);
|
|
u.set(j + i, new_u);
|
|
borrow = under ? 1u : 0u;
|
|
}
|
|
|
|
// Subtract carry_mul and borrow from u[j+n]
|
|
{
|
|
ElementType uval = u.get(j + n);
|
|
LongerType sub = static_cast<LongerType>(uval);
|
|
LongerType needed = carry_mul + borrow;
|
|
bool under = (sub < needed);
|
|
ElementType new_u = static_cast<ElementType>((sub + BASE - needed) & MASK);
|
|
u.set(j + n, new_u);
|
|
borrow = under ? 1u : 0u;
|
|
}
|
|
|
|
if (borrow != 0)
|
|
{
|
|
// qhat was too large, decrement and add v back
|
|
qhat -= 1;
|
|
LongerType carry_add = 0;
|
|
for (size_t i = 0; i < n; ++i)
|
|
{
|
|
LongerType sum =
|
|
static_cast<LongerType>(u.get(j + i)) + static_cast<LongerType>(v.get(i)) + carry_add;
|
|
u.set(j + i, static_cast<ElementType>(sum & MASK));
|
|
carry_add = sum >> W;
|
|
}
|
|
// add carry_add to u[j+n]
|
|
u.set(j + n, static_cast<ElementType>((static_cast<LongerType>(u.get(j + n)) + carry_add) & MASK));
|
|
}
|
|
|
|
// store quotient digit
|
|
res.set(static_cast<size_t>(j), static_cast<ElementType>(qhat & MASK));
|
|
}
|
|
|
|
// Note: remainder unnormalization is not required for quotient; we ignore remainder.
|
|
}
|
|
};
|
|
|
|
// Binary operation
|
|
template <typename TOp, AnyMpInt TLhs, AnyInt TRhs>
|
|
OpResultType<TLhs, TRhs> binary_op(const TLhs& lhs, const TRhs& rhs)
|
|
{
|
|
OpResultType<TLhs, TRhs> res{};
|
|
|
|
try
|
|
{
|
|
TOp::invoke(lhs, to_mp_int<typename TLhs::ElementType>(rhs), res);
|
|
}
|
|
catch (const std::overflow_error&)
|
|
{
|
|
throw OverflowErrorOf<TLhs, TRhs>(std::move(res));
|
|
}
|
|
|
|
return res;
|
|
}
|
|
|
|
// Compound assignment binary operation
|
|
template <typename TOp, AnyMpInt TLhs, AnyInt TRhs>
|
|
TLhs& ca_binary_op(TLhs& lhs, const TRhs& rhs)
|
|
{
|
|
TLhs res{};
|
|
|
|
try
|
|
{
|
|
TOp::invoke(lhs, to_mp_int<typename TLhs::ElementType>(rhs), res);
|
|
}
|
|
catch (const std::overflow_error&)
|
|
{
|
|
throw OverflowErrorOf<TLhs, TRhs>(std::move(res));
|
|
}
|
|
|
|
lhs = std::move(res);
|
|
return lhs;
|
|
}
|
|
|
|
// Addition operators
|
|
|
|
template <AnyMpInt TLhs, AnyInt TRhs>
|
|
OpResultType<TLhs, TRhs> operator+(const TLhs& lhs, const TRhs& rhs)
|
|
{
|
|
return binary_op<Addition>(lhs, rhs);
|
|
}
|
|
|
|
template <AnyMpInt TLhs, AnyInt TRhs>
|
|
TLhs& operator+=(TLhs& lhs, const TRhs& rhs)
|
|
{
|
|
return ca_binary_op<Addition>(lhs, rhs);
|
|
}
|
|
|
|
// Subtraction operators
|
|
|
|
template <AnyMpInt TLhs, AnyInt TRhs>
|
|
OpResultType<TLhs, TRhs> operator-(const TLhs& lhs, const TRhs& rhs)
|
|
{
|
|
return binary_op<Subtraction>(lhs, rhs);
|
|
}
|
|
|
|
template <AnyMpInt TLhs, AnyInt TRhs>
|
|
TLhs& operator-=(TLhs& lhs, const TRhs& rhs)
|
|
{
|
|
return ca_binary_op<Subtraction>(lhs, rhs);
|
|
}
|
|
|
|
// Multiplication operators
|
|
|
|
template <AnyMpInt TLhs, AnyInt TRhs>
|
|
OpResultType<TLhs, TRhs> operator*(const TLhs& lhs, const TRhs& rhs)
|
|
{
|
|
return binary_op<Multiplication>(lhs, rhs);
|
|
}
|
|
|
|
template <AnyMpInt TLhs, AnyInt TRhs>
|
|
TLhs& operator*=(TLhs& lhs, const TRhs& rhs)
|
|
{
|
|
return ca_binary_op<Multiplication>(lhs, rhs);
|
|
}
|
|
|
|
// Division operators
|
|
|
|
template <AnyMpInt TLhs, AnyInt TRhs>
|
|
OpResultType<TLhs, TRhs> operator/(const TLhs& lhs, const TRhs& rhs)
|
|
{
|
|
return binary_op<Division>(lhs, rhs);
|
|
}
|
|
|
|
template <AnyMpInt TLhs, AnyInt TRhs>
|
|
TLhs& operator/=(TLhs& lhs, const TRhs& rhs)
|
|
{
|
|
return ca_binary_op<Division>(lhs, rhs);
|
|
}
|
|
|
|
} // namespace mp
|