cpp_mp/mp/math.hpp
2025-11-22 19:36:26 +01:00

172 lines
4.1 KiB
C++

#pragma once
#include "int.hpp"
#include "storage.hpp"
#include "utils.hpp"
namespace mp
{
template <size_t SizeA, size_t SizeB>
constexpr size_t ResultMaxSize = std::max(SizeA, SizeB);
template <AnyInt TLhs, AnyInt TRhs>
requires ElementTypesMatch<TLhs, TRhs>
using OpResult = BasicInt<typename TLhs::ElementType, ResultMaxSize<TLhs::MAX_BYTES, TRhs::MAX_BYTES>>;
template <AnyInt T>
class OverflowError : public std::runtime_error
{
public:
OverflowError(T&& value) : std::runtime_error("Overflow"), m_value(std::move(value)) {}
const T& value() const { return m_value; }
private:
T m_value;
};
template <AnyInt TLhs, AnyInt TRhs>
using OverflowErrorOf = OverflowError<OpResult<TLhs, TRhs>>;
struct Addition
{
template <AnyInt TLhs, AnyInt TRhs, AnyInt TRes>
requires ElementTypesMatch<TLhs, TRes> && ElementTypesMatch<TRhs, TRes>
static void invoke(const TLhs& lhs, const TRhs& rhs, TRes& res)
{
using ElementType = typename TRes::ElementType;
res.zero();
ElementType carry = 0;
size_t end = std::max(lhs.size_elems(), rhs.size_elems());
for (size_t i = 0; i < end; ++i)
{
ElementType a = lhs.get(i);
ElementType b = rhs.get(i);
ElementType c = carry + a;
carry = (c < a) ? 1 : 0;
c += b;
if (c < b)
carry = 1;
res.set(i, c);
}
res.set(end, carry);
}
};
struct Multiplication
{
template <AnyInt TLhs, AnyInt TRhs, AnyInt TRes>
requires ElementTypesMatch<TLhs, TRes> && ElementTypesMatch<TRhs, TRes>
static void invoke(const TLhs& lhs, const TRhs& rhs, TRes& res)
{
using ElementType = typename TRes::ElementType;
using DoubleType = DoubleWidthType<ElementType>;
res.zero();
bool overflow = false;
const size_t n = lhs.size_elems();
const size_t m = rhs.size_elems();
for (size_t i = 0; i < n; i++)
{
DoubleType carry = 0;
ElementType a = lhs[i];
for (size_t j = 0; j < m; j++)
{
ElementType b = rhs[j];
DoubleType t = static_cast<DoubleType>(a) * b + res.get(i + j) + carry;
overflow |= !res.try_set(i + j, static_cast<ElementType>(t));
carry = t >> (sizeof(ElementType) * 8);
}
overflow |= !res.try_set(i + m, static_cast<ElementType>(carry));
}
if (overflow)
{
throw std::overflow_error("Multiplication overflow");
}
};
};
// Binary operation
template <typename TOp, AnyInt TLhs, AnyInt TRhs>
requires ElementTypesMatch<TLhs, TRhs>
OpResult<TLhs, TRhs> binary_op(const TLhs& lhs, const TRhs& rhs)
{
OpResult<TLhs, TRhs> res{};
try
{
TOp::invoke(lhs, rhs, res);
}
catch (const std::overflow_error&)
{
throw OverflowErrorOf<TLhs, TRhs>(std::move(res));
}
return res;
}
// Compound assignment binary operation
template <typename TOp, AnyInt TLhs, AnyInt TRhs>
requires ElementTypesMatch<TLhs, TRhs>
TLhs& ca_binary_op(TLhs& lhs, const TRhs& rhs)
{
TLhs res{};
try
{
TOp::invoke(lhs, rhs, res);
}
catch (const std::overflow_error&)
{
throw OverflowErrorOf<TLhs, TRhs>(std::move(res));
}
lhs = std::move(res);
return lhs;
}
// Addition operators
template <AnyInt TLhs, AnyInt TRhs>
requires ElementTypesMatch<TLhs, TRhs>
OpResult<TLhs, TRhs> operator+(const TLhs& lhs, const TRhs& rhs)
{
return binary_op<Addition>(lhs, rhs);
}
template <AnyInt TLhs, AnyInt TRhs>
requires ElementTypesMatch<TLhs, TRhs>
TLhs& operator+=(TLhs& lhs, const TRhs& rhs)
{
return ca_binary_op<Addition>(lhs, rhs);
}
// Multiplication operators
template <AnyInt TLhs, AnyInt TRhs>
requires ElementTypesMatch<TLhs, TRhs>
OpResult<TLhs, TRhs> operator*(const TLhs& lhs, const TRhs& rhs)
{
return binary_op<Multiplication>(lhs, rhs);
}
template <AnyInt TLhs, AnyInt TRhs>
requires ElementTypesMatch<TLhs, TRhs>
TLhs& operator*=(TLhs& lhs, const TRhs& rhs)
{
return ca_binary_op<Multiplication>(lhs, rhs);
}
} // namespace mp