cpp_mp/src/mp/math.hpp

573 lines
16 KiB
C++

#pragma once
#include "int.hpp"
#include "utils.hpp"
namespace mp
{
template <size_t SizeA, size_t SizeB>
constexpr size_t ResultMaxSize = std::max(SizeA, SizeB);
template <AnyMpInt TL, AnyInt TR>
struct OpResult;
template <AnyMpInt TLhs, AnyMpInt TRhs>
requires ElementTypesMatch<TLhs, TRhs>
struct OpResult<TLhs, TRhs>
{
using type = BasicInt<typename TLhs::ElementType, ResultMaxSize<TLhs::MAX_BYTES, TRhs::MAX_BYTES>>;
};
template <AnyMpInt TLhs, AnyRegularInt TRhs>
struct OpResult<TLhs, TRhs>
{
// using type = BasicInt<typename TLhs::ElementType, ResultMaxSize<TLhs::MAX_BYTES, sizeof(TRhs)>>;
using type = TLhs; // keep the same size as TLhs when operating with regular ints
};
template <AnyMpInt TLhs, AnyInt TRhs>
using OpResultType = typename OpResult<TLhs, TRhs>::type;
template <AnyMpInt T>
class OverflowError : public std::runtime_error
{
public:
OverflowError(T&& value) : std::runtime_error("Overflow"), m_value(std::move(value)) {}
const T& value() const { return m_value; }
private:
T m_value;
};
template <AnyMpInt TLhs, AnyInt TRhs>
using OverflowErrorOf = OverflowError<OpResultType<TLhs, TRhs>>;
// Regular int to mp int
template <ElementSuitable TElem, AnyRegularInt TR>
BasicInt<TElem, sizeof(TR)> to_mp_int(TR value)
{
BasicInt<TElem, sizeof(TR)> res;
res = value;
return res;
}
// no-op
template <ElementSuitable TElem, AnyMpInt TR>
requires std::is_same_v<typename TR::ElementType, TElem>
const TR& to_mp_int(const TR& value)
{
return value;
}
template <AnyMpInt TLhs, AnyMpInt TRhs>
requires ElementTypesMatch<TLhs, TRhs>
auto operator<=>(const TLhs& lhs, const TRhs& rhs)
{
const bool lhs_neg = lhs.negative();
const bool rhs_neg = rhs.negative();
std::strong_ordering res = std::strong_ordering::equal;
bool nonzero = false;
size_t max_size = std::max(lhs.size_elems(), rhs.size_elems());
for (size_t i = max_size; i-- > 0;)
{
auto l = lhs.get(i);
auto r = rhs.get(i);
auto elem_comp = l <=> r;
if (elem_comp != std::strong_ordering::equal)
{
res = elem_comp;
break;
}
else
{
if (l != 0)
nonzero = true;
}
}
// both zero => ignore sign
if (res == std::strong_ordering::equal && !nonzero)
{
return std::strong_ordering::equal;
}
// different sign, non-zero
if (lhs_neg != rhs_neg)
{
return lhs_neg ? std::strong_ordering::less : std::strong_ordering::greater;
}
// flip for negative
if (lhs_neg)
{
res = (res == std::strong_ordering::less) ? std::strong_ordering::greater : std::strong_ordering::less;
}
return res;
}
template <AnyMpInt TLhs, AnyRegularInt TRhs>
auto operator<=>(const TLhs& lhs, TRhs rhs)
{
return lhs <=> to_mp_int<typename TLhs::ElementType>(rhs);
}
template <AnyMpInt TLhs, AnyInt TRhs>
bool operator==(const TLhs& lhs, const TRhs& rhs)
{
return (lhs <=> rhs) == std::strong_ordering::equal;
}
template <AnyMpInt TLhs, AnyInt TRhs>
bool operator!=(const TLhs& lhs, const TRhs& rhs)
{
return !(lhs == rhs);
}
template <AnyMpInt TLhs, AnyMpInt TRhs>
requires ElementTypesMatch<TLhs, TRhs>
inline bool less_ignore_sign(const TLhs& lhs, const TRhs& rhs)
{
size_t max_size = std::max(lhs.size_elems(), rhs.size_elems());
for (size_t i = max_size; i-- > 0;)
{
auto l = lhs.get(i);
auto r = rhs.get(i);
if (l < r)
return true;
else if (l > r)
return false;
}
return false;
}
template <AnyMpInt TLhs, AnyMpInt TRhs, AnyMpInt TRes>
requires ElementTypesMatch<TLhs, TRes> && ElementTypesMatch<TRhs, TRes>
inline void sub_ignore_underflow(const TLhs& lhs, const TRhs& rhs, TRes& res)
{
using ElementType = typename TRes::ElementType;
res.zero();
ElementType borrow = 0;
size_t end = std::max(lhs.size_elems(), rhs.size_elems());
for (size_t i = 0; i < end; ++i)
{
ElementType a = lhs.get(i);
ElementType b = rhs.get(i);
ElementType c = a - b - borrow;
borrow = (a < b) || (borrow && a == b);
res.set(i, c);
}
if (borrow != 0)
{
throw std::overflow_error("Subtraction underflow");
}
}
template <AnyMpInt TLhs, AnyMpInt TRhs, AnyMpInt TRes>
requires ElementTypesMatch<TLhs, TRes> && ElementTypesMatch<TRhs, TRes>
inline void sub_ignore_sign(const TLhs& lhs, const TRhs& rhs, TRes& res)
{
if (less_ignore_sign(lhs, rhs))
{
sub_ignore_underflow(rhs, lhs, res);
res.set_negative(true);
}
else
{
sub_ignore_underflow(lhs, rhs, res);
res.set_negative(false);
}
}
template <AnyMpInt TLhs, AnyMpInt TRhs, AnyMpInt TRes>
requires ElementTypesMatch<TLhs, TRes> && ElementTypesMatch<TRhs, TRes>
inline void add_ignore_sign(const TLhs& lhs, const TRhs& rhs, TRes& res)
{
using ElementType = typename TRes::ElementType;
res.zero();
ElementType carry = 0;
size_t end = std::max(lhs.size_elems(), rhs.size_elems());
for (size_t i = 0; i < end; ++i)
{
ElementType a = lhs.get(i);
ElementType b = rhs.get(i);
ElementType c = carry + a;
carry = (c < a) ? 1 : 0;
c += b;
if (c < b)
carry = 1;
res.set(i, c);
}
res.set(end, carry);
}
struct Addition
{
template <AnyMpInt TLhs, AnyMpInt TRhs, AnyMpInt TRes>
requires ElementTypesMatch<TLhs, TRes> && ElementTypesMatch<TRhs, TRes>
static void invoke(const TLhs& lhs, const TRhs& rhs, TRes& res)
{
const bool lhs_neg = lhs.negative();
const bool rhs_neg = rhs.negative();
if (lhs_neg == rhs_neg)
{
add_ignore_sign(lhs, rhs, res);
res.set_negative(lhs_neg);
}
else if (lhs_neg)
{
// (-a) + b == b - a
sub_ignore_sign(rhs, lhs, res);
}
else
{
// a + (-b) == a - b
sub_ignore_sign(lhs, rhs, res);
}
}
};
struct Subtraction
{
template <AnyMpInt TLhs, AnyMpInt TRhs, AnyMpInt TRes>
requires ElementTypesMatch<TLhs, TRes> && ElementTypesMatch<TRhs, TRes>
static void invoke(const TLhs& lhs, const TRhs& rhs, TRes& res)
{
const bool lhs_neg = lhs.negative();
const bool rhs_neg = rhs.negative();
if (!lhs_neg && !rhs_neg)
{
// a - b
sub_ignore_sign(lhs, rhs, res);
}
else if (lhs_neg && rhs_neg)
{
// (-a) - (-b) == b - a
sub_ignore_sign(rhs, lhs, res);
}
else if (lhs_neg && !rhs_neg)
{
// (-a) - b == -(a + b)
add_ignore_sign(lhs, rhs, res);
res.set_negative(true);
}
else
{
// a - (-b) == a + b
add_ignore_sign(lhs, rhs, res);
res.set_negative(false);
}
}
};
struct Multiplication
{
template <AnyMpInt TLhs, AnyMpInt TRhs, AnyMpInt TRes>
requires ElementTypesMatch<TLhs, TRes> && ElementTypesMatch<TRhs, TRes>
static void invoke(const TLhs& lhs, const TRhs& rhs, TRes& res)
{
using ElementType = typename TRes::ElementType;
using LongerType = DoubleWidthType<ElementType>;
res.zero();
res.set_negative(lhs.negative() != rhs.negative());
bool overflow = false;
const size_t n = lhs.size_elems();
const size_t m = rhs.size_elems();
for (size_t i = 0; i < n; i++)
{
LongerType carry = 0;
ElementType a = lhs[i];
for (size_t j = 0; j < m; j++)
{
ElementType b = rhs[j];
LongerType t = static_cast<LongerType>(a) * b + res.get(i + j) + carry;
overflow |= !res.try_set(i + j, static_cast<ElementType>(t));
carry = t >> (sizeof(ElementType) * 8);
}
overflow |= !res.try_set(i + m, static_cast<ElementType>(carry));
}
if (overflow)
{
throw std::overflow_error("Multiplication overflow");
}
}
};
struct Division
{
template <AnyMpInt TLhs, AnyMpInt TRhs, AnyMpInt TRes>
requires ElementTypesMatch<TLhs, TRes> && ElementTypesMatch<TRhs, TRes>
static void invoke(const TLhs& lhs, const TRhs& rhs, TRes& res)
{
using ElementType = typename TRes::ElementType;
using LongerType = DoubleWidthType<ElementType>;
if (rhs == 0U)
{
throw std::runtime_error("Division by zero");
}
const size_t n = rhs.size_elems();
const size_t m = (lhs.size_elems() >= n) ? (lhs.size_elems() - n) : 0;
// If divisor larger than dividend => quotient = 0
if (lhs.size_elems() < n)
{
res.zero();
return;
}
// Base and mask
constexpr unsigned int W = sizeof(ElementType) * 8;
const LongerType BASE = (static_cast<LongerType>(1) << W);
const LongerType MASK = BASE - 1;
// Prepare u (normalized dividend) length m + n + 1
BasicInt<ElementType, calculate_sum_bytes(calculate_sum_bytes(TLhs::MAX_BYTES, TRhs::MAX_BYTES), sizeof(ElementType))> u;
u.zero();
for (size_t i = 0; i < lhs.size_elems(); ++i)
u.set(i, lhs.get(i));
// Prepare v (normalized divisor) length n
BasicInt<ElementType, calculate_sum_bytes(TRhs::MAX_BYTES, sizeof(ElementType))> v;
v.zero();
for (size_t i = 0; i < n; ++i)
v.set(i, rhs.get(i));
// Normalization factor d = BASE / (v[n-1] + 1)
ElementType d = 1;
if (v.get(n - 1) + 1 != 0) // defensive
{
d = static_cast<ElementType>(BASE / (static_cast<LongerType>(v.get(n - 1)) + 1));
}
if (d != 1)
{
// u = u * d
LongerType carry = 0;
for (size_t i = 0; i < (m + n + 1); ++i)
{
LongerType t = static_cast<LongerType>(u.get(i)) * d + carry;
u.set(i, static_cast<ElementType>(t & MASK));
carry = t >> W;
}
// v = v * d
carry = 0;
for (size_t i = 0; i < n; ++i)
{
LongerType t = static_cast<LongerType>(v.get(i)) * d + carry;
v.set(i, static_cast<ElementType>(t & MASK));
carry = t >> W;
}
// v[n] implicitly 0
}
// Prepare quotient
res.zero();
res.set_negative(lhs.negative() != rhs.negative());
// Main loop j = m .. 0
for (int j = static_cast<int>(m); j >= 0; --j)
{
// u[j + n] might be zero or >0
const LongerType uj_n = static_cast<LongerType>(u.get(j + n));
const LongerType uj_n1 = static_cast<LongerType>(u.get(j + n - 1));
const LongerType vn_1 = static_cast<LongerType>(v.get(n - 1));
// Estimate q_hat = (u[j+n]*BASE + u[j+n-1]) / v[n-1]
LongerType numerator = (uj_n * BASE) + uj_n1;
LongerType qhat = numerator / vn_1;
LongerType rhat = numerator % vn_1;
if (qhat >= BASE)
qhat = BASE - 1;
// Correction loop (only if n >= 2)
if (n >= 2)
{
const LongerType vn_2 = static_cast<LongerType>(v.get(n - 2));
while (qhat * vn_2 > (rhat * BASE + static_cast<LongerType>(u.get(j + n - 2))))
{
qhat -= 1;
rhat += vn_1;
if (rhat >= BASE)
break;
}
}
// Multiply v by qhat and subtract from u segment starting at j
LongerType carry_mul = 0;
unsigned int borrow = 0;
for (size_t i = 0; i < n; ++i)
{
// p = qhat * v[i] + carry_mul
LongerType p = qhat * static_cast<LongerType>(v.get(i)) + carry_mul;
ElementType p_low = static_cast<ElementType>(p & MASK);
carry_mul = p >> W;
ElementType uval = u.get(j + i);
LongerType sub = static_cast<LongerType>(uval);
LongerType needed = static_cast<LongerType>(p_low) + borrow;
bool under = (sub < needed);
ElementType new_u = static_cast<ElementType>((sub + BASE - needed) & MASK);
u.set(j + i, new_u);
borrow = under ? 1u : 0u;
}
// Subtract carry_mul and borrow from u[j+n]
{
ElementType uval = u.get(j + n);
LongerType sub = static_cast<LongerType>(uval);
LongerType needed = carry_mul + borrow;
bool under = (sub < needed);
ElementType new_u = static_cast<ElementType>((sub + BASE - needed) & MASK);
u.set(j + n, new_u);
borrow = under ? 1u : 0u;
}
if (borrow != 0)
{
// qhat was too large, decrement and add v back
qhat -= 1;
LongerType carry_add = 0;
for (size_t i = 0; i < n; ++i)
{
LongerType sum =
static_cast<LongerType>(u.get(j + i)) + static_cast<LongerType>(v.get(i)) + carry_add;
u.set(j + i, static_cast<ElementType>(sum & MASK));
carry_add = sum >> W;
}
// add carry_add to u[j+n]
u.set(j + n, static_cast<ElementType>((static_cast<LongerType>(u.get(j + n)) + carry_add) & MASK));
}
// store quotient digit
res.set(static_cast<size_t>(j), static_cast<ElementType>(qhat & MASK));
}
// Note: remainder unnormalization is not required for quotient; we ignore remainder.
}
};
// Binary operation
template <typename TOp, AnyMpInt TLhs, AnyInt TRhs>
OpResultType<TLhs, TRhs> binary_op(const TLhs& lhs, const TRhs& rhs)
{
OpResultType<TLhs, TRhs> res{};
try
{
TOp::invoke(lhs, to_mp_int<typename TLhs::ElementType>(rhs), res);
}
catch (const std::overflow_error&)
{
throw OverflowErrorOf<TLhs, TRhs>(std::move(res));
}
return res;
}
// Compound assignment binary operation
template <typename TOp, AnyMpInt TLhs, AnyInt TRhs>
TLhs& ca_binary_op(TLhs& lhs, const TRhs& rhs)
{
TLhs res{};
try
{
TOp::invoke(lhs, to_mp_int<typename TLhs::ElementType>(rhs), res);
}
catch (const std::overflow_error&)
{
throw OverflowErrorOf<TLhs, TRhs>(std::move(res));
}
lhs = std::move(res);
return lhs;
}
// Addition operators
template <AnyMpInt TLhs, AnyInt TRhs>
OpResultType<TLhs, TRhs> operator+(const TLhs& lhs, const TRhs& rhs)
{
return binary_op<Addition>(lhs, rhs);
}
template <AnyMpInt TLhs, AnyInt TRhs>
TLhs& operator+=(TLhs& lhs, const TRhs& rhs)
{
return ca_binary_op<Addition>(lhs, rhs);
}
// Subtraction operators
template <AnyMpInt TLhs, AnyInt TRhs>
OpResultType<TLhs, TRhs> operator-(const TLhs& lhs, const TRhs& rhs)
{
return binary_op<Subtraction>(lhs, rhs);
}
template <AnyMpInt TLhs, AnyInt TRhs>
TLhs& operator-=(TLhs& lhs, const TRhs& rhs)
{
return ca_binary_op<Subtraction>(lhs, rhs);
}
// Multiplication operators
template <AnyMpInt TLhs, AnyInt TRhs>
OpResultType<TLhs, TRhs> operator*(const TLhs& lhs, const TRhs& rhs)
{
return binary_op<Multiplication>(lhs, rhs);
}
template <AnyMpInt TLhs, AnyInt TRhs>
TLhs& operator*=(TLhs& lhs, const TRhs& rhs)
{
return ca_binary_op<Multiplication>(lhs, rhs);
}
// Division operators
template <AnyMpInt TLhs, AnyInt TRhs>
OpResultType<TLhs, TRhs> operator/(const TLhs& lhs, const TRhs& rhs)
{
return binary_op<Division>(lhs, rhs);
}
template <AnyMpInt TLhs, AnyInt TRhs>
TLhs& operator/=(TLhs& lhs, const TRhs& rhs)
{
return ca_binary_op<Division>(lhs, rhs);
}
} // namespace mp