172 lines
4.1 KiB
C++
172 lines
4.1 KiB
C++
#pragma once
|
|
|
|
#include "int.hpp"
|
|
#include "storage.hpp"
|
|
#include "utils.hpp"
|
|
|
|
namespace mp
|
|
{
|
|
|
|
template <size_t SizeA, size_t SizeB>
|
|
constexpr size_t ResultMaxSize = std::max(SizeA, SizeB);
|
|
|
|
template <AnyInt TLhs, AnyInt TRhs>
|
|
requires ElementTypesMatch<TLhs, TRhs>
|
|
using OpResult = BasicInt<typename TLhs::ElementType, ResultMaxSize<TLhs::MAX_BYTES, TRhs::MAX_BYTES>>;
|
|
|
|
template <AnyInt T>
|
|
class OverflowError : public std::runtime_error
|
|
{
|
|
public:
|
|
OverflowError(T&& value) : std::runtime_error("Overflow"), m_value(std::move(value)) {}
|
|
|
|
const T& value() const { return m_value; }
|
|
|
|
private:
|
|
T m_value;
|
|
};
|
|
|
|
template <AnyInt TLhs, AnyInt TRhs>
|
|
using OverflowErrorOf = OverflowError<OpResult<TLhs, TRhs>>;
|
|
|
|
struct Addition
|
|
{
|
|
template <AnyInt TLhs, AnyInt TRhs, AnyInt TRes>
|
|
requires ElementTypesMatch<TLhs, TRes> && ElementTypesMatch<TRhs, TRes>
|
|
static void invoke(const TLhs& lhs, const TRhs& rhs, TRes& res)
|
|
{
|
|
using ElementType = typename TRes::ElementType;
|
|
|
|
res.zero();
|
|
|
|
ElementType carry = 0;
|
|
size_t end = std::max(lhs.size_elems(), rhs.size_elems());
|
|
|
|
for (size_t i = 0; i < end; ++i)
|
|
{
|
|
ElementType a = lhs.get(i);
|
|
ElementType b = rhs.get(i);
|
|
|
|
ElementType c = carry + a;
|
|
carry = (c < a) ? 1 : 0;
|
|
c += b;
|
|
if (c < b)
|
|
carry = 1;
|
|
|
|
res.set(i, c);
|
|
}
|
|
|
|
res.set(end, carry);
|
|
}
|
|
};
|
|
|
|
struct Multiplication
|
|
{
|
|
template <AnyInt TLhs, AnyInt TRhs, AnyInt TRes>
|
|
requires ElementTypesMatch<TLhs, TRes> && ElementTypesMatch<TRhs, TRes>
|
|
static void invoke(const TLhs& lhs, const TRhs& rhs, TRes& res)
|
|
{
|
|
using ElementType = typename TRes::ElementType;
|
|
using DoubleType = DoubleWidthType<ElementType>;
|
|
|
|
res.zero();
|
|
|
|
bool overflow = false;
|
|
|
|
const size_t n = lhs.size_elems();
|
|
const size_t m = rhs.size_elems();
|
|
|
|
for (size_t i = 0; i < n; i++)
|
|
{
|
|
DoubleType carry = 0;
|
|
ElementType a = lhs[i];
|
|
for (size_t j = 0; j < m; j++)
|
|
{
|
|
ElementType b = rhs[j];
|
|
DoubleType t = static_cast<DoubleType>(a) * b + res.get(i + j) + carry;
|
|
|
|
overflow |= !res.try_set(i + j, static_cast<ElementType>(t));
|
|
carry = t >> (sizeof(ElementType) * 8);
|
|
}
|
|
overflow |= !res.try_set(i + m, static_cast<ElementType>(carry));
|
|
}
|
|
|
|
if (overflow)
|
|
{
|
|
throw std::overflow_error("Multiplication overflow");
|
|
}
|
|
};
|
|
};
|
|
|
|
// Binary operation
|
|
template <typename TOp, AnyInt TLhs, AnyInt TRhs>
|
|
requires ElementTypesMatch<TLhs, TRhs>
|
|
OpResult<TLhs, TRhs> binary_op(const TLhs& lhs, const TRhs& rhs)
|
|
{
|
|
OpResult<TLhs, TRhs> res{};
|
|
|
|
try
|
|
{
|
|
TOp::invoke(lhs, rhs, res);
|
|
}
|
|
catch (const std::overflow_error&)
|
|
{
|
|
throw OverflowErrorOf<TLhs, TRhs>(std::move(res));
|
|
}
|
|
|
|
return res;
|
|
}
|
|
|
|
// Compound assignment binary operation
|
|
template <typename TOp, AnyInt TLhs, AnyInt TRhs>
|
|
requires ElementTypesMatch<TLhs, TRhs>
|
|
TLhs& ca_binary_op(TLhs& lhs, const TRhs& rhs)
|
|
{
|
|
TLhs res{};
|
|
|
|
try
|
|
{
|
|
TOp::invoke(lhs, rhs, res);
|
|
}
|
|
catch (const std::overflow_error&)
|
|
{
|
|
throw OverflowErrorOf<TLhs, TRhs>(std::move(res));
|
|
}
|
|
|
|
lhs = std::move(res);
|
|
return lhs;
|
|
}
|
|
|
|
// Addition operators
|
|
|
|
template <AnyInt TLhs, AnyInt TRhs>
|
|
requires ElementTypesMatch<TLhs, TRhs>
|
|
OpResult<TLhs, TRhs> operator+(const TLhs& lhs, const TRhs& rhs)
|
|
{
|
|
return binary_op<Addition>(lhs, rhs);
|
|
}
|
|
|
|
template <AnyInt TLhs, AnyInt TRhs>
|
|
requires ElementTypesMatch<TLhs, TRhs>
|
|
TLhs& operator+=(TLhs& lhs, const TRhs& rhs)
|
|
{
|
|
return ca_binary_op<Addition>(lhs, rhs);
|
|
}
|
|
|
|
// Multiplication operators
|
|
|
|
template <AnyInt TLhs, AnyInt TRhs>
|
|
requires ElementTypesMatch<TLhs, TRhs>
|
|
OpResult<TLhs, TRhs> operator*(const TLhs& lhs, const TRhs& rhs)
|
|
{
|
|
return binary_op<Multiplication>(lhs, rhs);
|
|
}
|
|
|
|
template <AnyInt TLhs, AnyInt TRhs>
|
|
requires ElementTypesMatch<TLhs, TRhs>
|
|
TLhs& operator*=(TLhs& lhs, const TRhs& rhs)
|
|
{
|
|
return ca_binary_op<Multiplication>(lhs, rhs);
|
|
}
|
|
|
|
} // namespace mp
|