Somewhat working but with vibecoded division
This commit is contained in:
parent
7e0b3cb9fa
commit
753dfe8eef
1
.gitignore
vendored
1
.gitignore
vendored
@ -1,3 +1,4 @@
|
|||||||
.vscode/
|
.vscode/
|
||||||
.vs/
|
.vs/
|
||||||
|
.cache/
|
||||||
build/
|
build/
|
||||||
|
|||||||
42
main.cpp
42
main.cpp
@ -1,9 +1,10 @@
|
|||||||
|
#include <algorithm>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
|
|
||||||
//#include "mp.hpp"
|
//#include "mp.hpp"
|
||||||
#include "mp/int.hpp"
|
#include "mp/int.hpp"
|
||||||
#include "mp/storage.hpp"
|
|
||||||
#include "mp/math.hpp"
|
#include "mp/math.hpp"
|
||||||
|
#include "mp/lib.hpp"
|
||||||
|
|
||||||
template <class T>
|
template <class T>
|
||||||
static void PrintInt(const char* name, const T& val)
|
static void PrintInt(const char* name, const T& val)
|
||||||
@ -11,6 +12,11 @@ static void PrintInt(const char* name, const T& val)
|
|||||||
std::cout << name << " = " << mp::to_hex_string(val) << std::endl;
|
std::cout << name << " = " << mp::to_hex_string(val) << std::endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <mp::AnyMpInt T>
|
||||||
|
static void PrintDec(const char* name, const T& val){
|
||||||
|
std::cout << name << " = " << mp::to_string(val) << std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
int main()
|
int main()
|
||||||
{
|
{
|
||||||
// mp::Int a{0xDEADBEEFDEADF154, 0x0123456789ABCDEF, 0x1111222233334444};
|
// mp::Int a{0xDEADBEEFDEADF154, 0x0123456789ABCDEF, 0x1111222233334444};
|
||||||
@ -68,13 +74,13 @@ int main()
|
|||||||
while (true)
|
while (true)
|
||||||
{
|
{
|
||||||
acc *= mp::Int<1>{10};
|
acc *= mp::Int<1>{10};
|
||||||
PrintInt("acc", acc);
|
PrintDec("acc", acc);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
catch (const mp::OverflowErrorOf<decltype(acc), mp::Int<1>>& e)
|
catch (const mp::OverflowErrorOf<decltype(acc), mp::Int<1>>& e)
|
||||||
{
|
{
|
||||||
std::cout << "overflow" << std::endl;
|
std::cout << "overflow" << std::endl;
|
||||||
PrintInt("value", e.value());
|
PrintDec("value", e.value());
|
||||||
}
|
}
|
||||||
catch (const std::exception& e)
|
catch (const std::exception& e)
|
||||||
{
|
{
|
||||||
@ -105,15 +111,29 @@ int main()
|
|||||||
// }
|
// }
|
||||||
//}
|
//}
|
||||||
|
|
||||||
|
// {
|
||||||
|
// mp::Int<32> a{0xA0000000, 0x6D7217CA, 0x431E0FAE, 0x1};
|
||||||
|
// mp::Int<16> b{0x6FC10000, 0x2386F2};
|
||||||
|
|
||||||
|
// PrintDec("a", a);
|
||||||
|
// PrintDec("b", b);
|
||||||
|
|
||||||
|
// auto c = a / b;
|
||||||
|
// PrintDec("c", c);
|
||||||
|
|
||||||
|
// }
|
||||||
|
|
||||||
{
|
{
|
||||||
mp::Int<32> a{0xA0000000, 0x6D7217CA, 0x431E0FAE, 0x1};
|
mp::Int<32> a;
|
||||||
mp::Int<16> b{0x6FC10000, 0x2386F2};
|
mp::Int<16> b;
|
||||||
|
mp::Int<32> c;
|
||||||
PrintInt("a", a);
|
|
||||||
PrintInt("b", b);
|
|
||||||
|
|
||||||
auto c = a / b;
|
|
||||||
PrintInt("c", c);
|
|
||||||
|
|
||||||
|
mp::parse_string("690000100000000000000000000000000000", a);
|
||||||
|
mp::parse_string("10690010000000000000000", b);
|
||||||
|
c = a / b;
|
||||||
|
PrintDec("a", a);
|
||||||
|
PrintDec("b", b);
|
||||||
|
PrintDec("c", c);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
@ -1,7 +1,5 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <ranges>
|
|
||||||
|
|
||||||
#include "utils.hpp"
|
#include "utils.hpp"
|
||||||
#include "storage.hpp"
|
#include "storage.hpp"
|
||||||
|
|
||||||
|
|||||||
100
mp/lib.hpp
Normal file
100
mp/lib.hpp
Normal file
@ -0,0 +1,100 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <string>
|
||||||
|
#include <stdexcept>
|
||||||
|
|
||||||
|
#include "utils.hpp"
|
||||||
|
|
||||||
|
namespace mp
|
||||||
|
{
|
||||||
|
|
||||||
|
inline char hex_digit(uint8_t bits)
|
||||||
|
{
|
||||||
|
return bits > 9 ? 'a' + (bits - 10) : '0' + bits;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <AnyConstMpInt T>
|
||||||
|
inline std::string to_hex_string(const T& number)
|
||||||
|
{
|
||||||
|
constexpr size_t ELEMENT_DIGITS = T::ELEMENT_BYTES * 2;
|
||||||
|
|
||||||
|
std::string str(number.size_elems() * ELEMENT_DIGITS, '-');
|
||||||
|
|
||||||
|
for (size_t elem = 0; elem < number.size_elems(); ++elem)
|
||||||
|
{
|
||||||
|
auto v = number[elem];
|
||||||
|
for (size_t digit = 0; digit < ELEMENT_DIGITS; ++digit)
|
||||||
|
{
|
||||||
|
str[str.size() - 1 - (elem * ELEMENT_DIGITS) - digit] = hex_digit((v >> (digit * 4)) & 0xF);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auto first = str.find_first_not_of('0');
|
||||||
|
if (first != std::string::npos)
|
||||||
|
{
|
||||||
|
return "0x" + str.substr(first);
|
||||||
|
}
|
||||||
|
|
||||||
|
return "0x0";
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
template <AnyMpInt T>
|
||||||
|
inline typename T::ElementType div_mod(T& value, const typename T::ElementType divisor)
|
||||||
|
{
|
||||||
|
using ElementType = typename T::ElementType;
|
||||||
|
|
||||||
|
ElementType remainder = 0;
|
||||||
|
for (size_t i = value.size_elems(); i-- > 0; )
|
||||||
|
{
|
||||||
|
using LongerType = DoubleWidthType<ElementType>;
|
||||||
|
|
||||||
|
LongerType acc = (static_cast<LongerType>(remainder) << (T::ELEMENT_BYTES * 8)) | value[i];
|
||||||
|
value[i] = static_cast<ElementType>(acc / divisor);
|
||||||
|
remainder = static_cast<ElementType>(acc % divisor);
|
||||||
|
}
|
||||||
|
|
||||||
|
return remainder;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <AnyConstMpInt T>
|
||||||
|
inline std::string to_string(const T& number)
|
||||||
|
{
|
||||||
|
std::string str;
|
||||||
|
T temp = number;
|
||||||
|
do {
|
||||||
|
auto rem = div_mod(temp, 10U);
|
||||||
|
str.push_back(static_cast<char>('0' + rem));
|
||||||
|
} while (temp != 0U);
|
||||||
|
|
||||||
|
std::reverse(str.begin(), str.end());
|
||||||
|
return str;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <AnyMpInt T>
|
||||||
|
inline void parse_string(const char* str, T& number)
|
||||||
|
{
|
||||||
|
number.zero();
|
||||||
|
|
||||||
|
for (; *str; ++str)
|
||||||
|
{
|
||||||
|
if (*str < '0' || *str > '9')
|
||||||
|
{
|
||||||
|
throw std::invalid_argument("Invalid character in input string");
|
||||||
|
}
|
||||||
|
|
||||||
|
number = number * 10U + static_cast<typename T::ElementType>(*str - '0');
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// template <AnyConstInt T>
|
||||||
|
// inline T parse_string(const char* str)
|
||||||
|
// {
|
||||||
|
// T number;
|
||||||
|
// parse_string(str, number);
|
||||||
|
// return number;
|
||||||
|
// }
|
||||||
|
|
||||||
|
}
|
||||||
172
mp/math.hpp
172
mp/math.hpp
@ -1,7 +1,6 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "int.hpp"
|
#include "int.hpp"
|
||||||
#include "storage.hpp"
|
|
||||||
#include "utils.hpp"
|
#include "utils.hpp"
|
||||||
|
|
||||||
namespace mp
|
namespace mp
|
||||||
@ -229,20 +228,6 @@ struct Multiplication
|
|||||||
|
|
||||||
struct Division
|
struct Division
|
||||||
{
|
{
|
||||||
template <AnyMpInt T>
|
|
||||||
static void shift_left(T& value)
|
|
||||||
{
|
|
||||||
using ElementType = typename T::ElementType;
|
|
||||||
|
|
||||||
for (size_t i = value.size_elems(); i-- > 0;)
|
|
||||||
{
|
|
||||||
ElementType new_val = value.get(i - 1);
|
|
||||||
value.set(i, new_val);
|
|
||||||
}
|
|
||||||
|
|
||||||
value.set(0, ElementType{0});
|
|
||||||
}
|
|
||||||
|
|
||||||
template <AnyConstMpInt TLhs, AnyConstMpInt TRhs, AnyMpInt TRes>
|
template <AnyConstMpInt TLhs, AnyConstMpInt TRhs, AnyMpInt TRes>
|
||||||
requires ElementTypesMatch<TLhs, TRes> && ElementTypesMatch<TRhs, TRes>
|
requires ElementTypesMatch<TLhs, TRes> && ElementTypesMatch<TRhs, TRes>
|
||||||
static void invoke(const TLhs& lhs, const TRhs& rhs, TRes& res)
|
static void invoke(const TLhs& lhs, const TRhs& rhs, TRes& res)
|
||||||
@ -255,31 +240,144 @@ struct Division
|
|||||||
throw std::runtime_error("Division by zero");
|
throw std::runtime_error("Division by zero");
|
||||||
}
|
}
|
||||||
|
|
||||||
BasicInt<ElementType, TRes::MAX_BYTES> remainder = lhs;
|
const size_t n = rhs.size_elems();
|
||||||
res.zero();
|
const size_t m = (lhs.size_elems() >= n) ? (lhs.size_elems() - n) : 0;
|
||||||
BasicInt<ElementType, TRhs::MAX_BYTES> temp;
|
|
||||||
|
|
||||||
while (remainder >= rhs)
|
// If divisor larger than dividend => quotient = 0
|
||||||
|
if (lhs.size_elems() < n)
|
||||||
{
|
{
|
||||||
int shift = remainder.size_elems() - rhs.size_elems();
|
res.zero();
|
||||||
|
return;
|
||||||
temp = rhs;
|
|
||||||
for (int i = 0; i < shift; ++i)
|
|
||||||
{
|
|
||||||
shift_left(temp);
|
|
||||||
}
|
|
||||||
|
|
||||||
ElementType q_digit = 0;
|
|
||||||
|
|
||||||
if (remainder >= temp)
|
|
||||||
{
|
|
||||||
remainder -= temp;
|
|
||||||
q_digit++;
|
|
||||||
}
|
|
||||||
|
|
||||||
res.set(shift, q_digit);
|
|
||||||
}
|
}
|
||||||
};
|
|
||||||
|
// Base and mask
|
||||||
|
constexpr unsigned int W = sizeof(ElementType) * 8;
|
||||||
|
const LongerType BASE = (static_cast<LongerType>(1) << W);
|
||||||
|
const LongerType MASK = BASE - 1;
|
||||||
|
|
||||||
|
// Prepare u (normalized dividend) length m + n + 1
|
||||||
|
BasicInt<ElementType, TLhs::MAX_BYTES + TRhs::MAX_BYTES + sizeof(ElementType)> u;
|
||||||
|
u.zero();
|
||||||
|
for (size_t i = 0; i < lhs.size_elems(); ++i)
|
||||||
|
u.set(i, lhs.get(i));
|
||||||
|
|
||||||
|
// Prepare v (normalized divisor) length n
|
||||||
|
BasicInt<ElementType, TRhs::MAX_BYTES + sizeof(ElementType)> v;
|
||||||
|
v.zero();
|
||||||
|
for (size_t i = 0; i < n; ++i)
|
||||||
|
v.set(i, rhs.get(i));
|
||||||
|
|
||||||
|
// Normalization factor d = BASE / (v[n-1] + 1)
|
||||||
|
ElementType d = 1;
|
||||||
|
if (v.get(n - 1) + 1 != 0) // defensive
|
||||||
|
{
|
||||||
|
d = static_cast<ElementType>(BASE / (static_cast<LongerType>(v.get(n - 1)) + 1));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (d != 1)
|
||||||
|
{
|
||||||
|
// u = u * d
|
||||||
|
LongerType carry = 0;
|
||||||
|
for (size_t i = 0; i < (m + n + 1); ++i)
|
||||||
|
{
|
||||||
|
LongerType t = static_cast<LongerType>(u.get(i)) * d + carry;
|
||||||
|
u.set(i, static_cast<ElementType>(t & MASK));
|
||||||
|
carry = t >> W;
|
||||||
|
}
|
||||||
|
// v = v * d
|
||||||
|
carry = 0;
|
||||||
|
for (size_t i = 0; i < n; ++i)
|
||||||
|
{
|
||||||
|
LongerType t = static_cast<LongerType>(v.get(i)) * d + carry;
|
||||||
|
v.set(i, static_cast<ElementType>(t & MASK));
|
||||||
|
carry = t >> W;
|
||||||
|
}
|
||||||
|
// v[n] implicitly 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prepare quotient
|
||||||
|
res.zero();
|
||||||
|
|
||||||
|
// Main loop j = m .. 0
|
||||||
|
for (int j = static_cast<int>(m); j >= 0; --j)
|
||||||
|
{
|
||||||
|
// u[j + n] might be zero or >0
|
||||||
|
const LongerType uj_n = static_cast<LongerType>(u.get(j + n));
|
||||||
|
const LongerType uj_n1 = static_cast<LongerType>(u.get(j + n - 1));
|
||||||
|
const LongerType vn_1 = static_cast<LongerType>(v.get(n - 1));
|
||||||
|
|
||||||
|
// Estimate q_hat = (u[j+n]*BASE + u[j+n-1]) / v[n-1]
|
||||||
|
LongerType numerator = (uj_n * BASE) + uj_n1;
|
||||||
|
LongerType qhat = numerator / vn_1;
|
||||||
|
LongerType rhat = numerator % vn_1;
|
||||||
|
|
||||||
|
if (qhat >= BASE)
|
||||||
|
qhat = BASE - 1;
|
||||||
|
|
||||||
|
// Correction loop (only if n >= 2)
|
||||||
|
if (n >= 2)
|
||||||
|
{
|
||||||
|
const LongerType vn_2 = static_cast<LongerType>(v.get(n - 2));
|
||||||
|
while (qhat * vn_2 > (rhat * BASE + static_cast<LongerType>(u.get(j + n - 2))))
|
||||||
|
{
|
||||||
|
qhat -= 1;
|
||||||
|
rhat += vn_1;
|
||||||
|
if (rhat >= BASE)
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Multiply v by qhat and subtract from u segment starting at j
|
||||||
|
LongerType carry_mul = 0;
|
||||||
|
unsigned int borrow = 0;
|
||||||
|
for (size_t i = 0; i < n; ++i)
|
||||||
|
{
|
||||||
|
// p = qhat * v[i] + carry_mul
|
||||||
|
LongerType p = qhat * static_cast<LongerType>(v.get(i)) + carry_mul;
|
||||||
|
ElementType p_low = static_cast<ElementType>(p & MASK);
|
||||||
|
carry_mul = p >> W;
|
||||||
|
|
||||||
|
ElementType uval = u.get(j + i);
|
||||||
|
LongerType sub = static_cast<LongerType>(uval);
|
||||||
|
LongerType needed = static_cast<LongerType>(p_low) + borrow;
|
||||||
|
bool under = (sub < needed);
|
||||||
|
ElementType new_u = static_cast<ElementType>((sub + BASE - needed) & MASK);
|
||||||
|
u.set(j + i, new_u);
|
||||||
|
borrow = under ? 1u : 0u;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Subtract carry_mul and borrow from u[j+n]
|
||||||
|
{
|
||||||
|
ElementType uval = u.get(j + n);
|
||||||
|
LongerType sub = static_cast<LongerType>(uval);
|
||||||
|
LongerType needed = carry_mul + borrow;
|
||||||
|
bool under = (sub < needed);
|
||||||
|
ElementType new_u = static_cast<ElementType>((sub + BASE - needed) & MASK);
|
||||||
|
u.set(j + n, new_u);
|
||||||
|
borrow = under ? 1u : 0u;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (borrow != 0)
|
||||||
|
{
|
||||||
|
// qhat was too large, decrement and add v back
|
||||||
|
qhat -= 1;
|
||||||
|
LongerType carry_add = 0;
|
||||||
|
for (size_t i = 0; i < n; ++i)
|
||||||
|
{
|
||||||
|
LongerType sum = static_cast<LongerType>(u.get(j + i)) + static_cast<LongerType>(v.get(i)) + carry_add;
|
||||||
|
u.set(j + i, static_cast<ElementType>(sum & MASK));
|
||||||
|
carry_add = sum >> W;
|
||||||
|
}
|
||||||
|
// add carry_add to u[j+n]
|
||||||
|
u.set(j + n, static_cast<ElementType>((static_cast<LongerType>(u.get(j + n)) + carry_add) & MASK));
|
||||||
|
}
|
||||||
|
|
||||||
|
// store quotient digit
|
||||||
|
res.set(static_cast<size_t>(j), static_cast<ElementType>(qhat & MASK));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Note: remainder unnormalization is not required for quotient; we ignore remainder.
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Binary operation
|
// Binary operation
|
||||||
|
|||||||
@ -1,14 +1,10 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <array>
|
#include <array>
|
||||||
#include <concepts>
|
|
||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
#include <span>
|
|
||||||
#include <stdexcept>
|
#include <stdexcept>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "utils.hpp"
|
|
||||||
|
|
||||||
namespace mp
|
namespace mp
|
||||||
{
|
{
|
||||||
|
|
||||||
|
|||||||
36
mp/utils.hpp
36
mp/utils.hpp
@ -1,6 +1,7 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <concepts>
|
#include <concepts>
|
||||||
|
#include <cstddef>
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
#include <type_traits>
|
#include <type_traits>
|
||||||
|
|
||||||
@ -30,7 +31,7 @@ struct DoubleWidth<uint32_t>
|
|||||||
using type = uint64_t;
|
using type = uint64_t;
|
||||||
};
|
};
|
||||||
|
|
||||||
#ifdef __SIZEOF_INT128__
|
#if USE_UINT128
|
||||||
template <>
|
template <>
|
||||||
struct DoubleWidth<uint64_t>
|
struct DoubleWidth<uint64_t>
|
||||||
{
|
{
|
||||||
@ -98,7 +99,7 @@ concept AnyConstInt = AnyRegularInt<T> || AnyConstMpInt<T>;
|
|||||||
|
|
||||||
//======================== Utils =======================//
|
//======================== Utils =======================//
|
||||||
|
|
||||||
#ifdef __SIZEOF_INT128__
|
#ifdef USE_UINT128
|
||||||
using LongestElementSuitableType = uint64_t;
|
using LongestElementSuitableType = uint64_t;
|
||||||
#else
|
#else
|
||||||
using LongestElementSuitableType = HalfWidthType<size_t>;
|
using LongestElementSuitableType = HalfWidthType<size_t>;
|
||||||
@ -126,35 +127,4 @@ constexpr TElem calculate_last_elem_mask()
|
|||||||
return mask;
|
return mask;
|
||||||
}
|
}
|
||||||
|
|
||||||
inline char hex_digit(uint8_t bits)
|
|
||||||
{
|
|
||||||
return bits > 9 ? 'a' + (bits - 10) : '0' + bits;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <AnyConstMpInt T>
|
|
||||||
std::string to_hex_string(const T& number)
|
|
||||||
{
|
|
||||||
constexpr size_t ELEMENT_DIGITS = T::ELEMENT_BYTES * 2;
|
|
||||||
|
|
||||||
std::string str(number.size_elems() * ELEMENT_DIGITS, '-');
|
|
||||||
|
|
||||||
for (size_t elem = 0; elem < number.size_elems(); ++elem)
|
|
||||||
{
|
|
||||||
auto v = number[elem];
|
|
||||||
for (size_t digit = 0; digit < ELEMENT_DIGITS; ++digit)
|
|
||||||
{
|
|
||||||
str[str.size() - 1 - (elem * ELEMENT_DIGITS) - digit] = hex_digit((v >> (digit * 4)) & 0xF);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
auto first = str.find_first_not_of('0');
|
|
||||||
if (first != std::string::npos)
|
|
||||||
{
|
|
||||||
return "0x" + str.substr(first);
|
|
||||||
}
|
|
||||||
|
|
||||||
return "0x0";
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
Loading…
x
Reference in New Issue
Block a user