Somewhat working but with vibecoded division

This commit is contained in:
tovjemam 2025-11-26 16:38:26 +01:00
parent 7e0b3cb9fa
commit 753dfe8eef
7 changed files with 265 additions and 82 deletions

1
.gitignore vendored
View File

@ -1,3 +1,4 @@
.vscode/ .vscode/
.vs/ .vs/
.cache/
build/ build/

View File

@ -1,9 +1,10 @@
#include <algorithm>
#include <iostream> #include <iostream>
//#include "mp.hpp" //#include "mp.hpp"
#include "mp/int.hpp" #include "mp/int.hpp"
#include "mp/storage.hpp"
#include "mp/math.hpp" #include "mp/math.hpp"
#include "mp/lib.hpp"
template <class T> template <class T>
static void PrintInt(const char* name, const T& val) static void PrintInt(const char* name, const T& val)
@ -11,6 +12,11 @@ static void PrintInt(const char* name, const T& val)
std::cout << name << " = " << mp::to_hex_string(val) << std::endl; std::cout << name << " = " << mp::to_hex_string(val) << std::endl;
} }
template <mp::AnyMpInt T>
static void PrintDec(const char* name, const T& val){
std::cout << name << " = " << mp::to_string(val) << std::endl;
}
int main() int main()
{ {
// mp::Int a{0xDEADBEEFDEADF154, 0x0123456789ABCDEF, 0x1111222233334444}; // mp::Int a{0xDEADBEEFDEADF154, 0x0123456789ABCDEF, 0x1111222233334444};
@ -68,13 +74,13 @@ int main()
while (true) while (true)
{ {
acc *= mp::Int<1>{10}; acc *= mp::Int<1>{10};
PrintInt("acc", acc); PrintDec("acc", acc);
} }
} }
catch (const mp::OverflowErrorOf<decltype(acc), mp::Int<1>>& e) catch (const mp::OverflowErrorOf<decltype(acc), mp::Int<1>>& e)
{ {
std::cout << "overflow" << std::endl; std::cout << "overflow" << std::endl;
PrintInt("value", e.value()); PrintDec("value", e.value());
} }
catch (const std::exception& e) catch (const std::exception& e)
{ {
@ -105,15 +111,29 @@ int main()
// } // }
//} //}
// {
// mp::Int<32> a{0xA0000000, 0x6D7217CA, 0x431E0FAE, 0x1};
// mp::Int<16> b{0x6FC10000, 0x2386F2};
// PrintDec("a", a);
// PrintDec("b", b);
// auto c = a / b;
// PrintDec("c", c);
// }
{ {
mp::Int<32> a{0xA0000000, 0x6D7217CA, 0x431E0FAE, 0x1}; mp::Int<32> a;
mp::Int<16> b{0x6FC10000, 0x2386F2}; mp::Int<16> b;
mp::Int<32> c;
PrintInt("a", a);
PrintInt("b", b);
auto c = a / b;
PrintInt("c", c);
mp::parse_string("690000100000000000000000000000000000", a);
mp::parse_string("10690010000000000000000", b);
c = a / b;
PrintDec("a", a);
PrintDec("b", b);
PrintDec("c", c);
} }
} }

View File

@ -1,7 +1,5 @@
#pragma once #pragma once
#include <ranges>
#include "utils.hpp" #include "utils.hpp"
#include "storage.hpp" #include "storage.hpp"

100
mp/lib.hpp Normal file
View File

@ -0,0 +1,100 @@
#pragma once
#include <algorithm>
#include <string>
#include <stdexcept>
#include "utils.hpp"
namespace mp
{
inline char hex_digit(uint8_t bits)
{
return bits > 9 ? 'a' + (bits - 10) : '0' + bits;
}
template <AnyConstMpInt T>
inline std::string to_hex_string(const T& number)
{
constexpr size_t ELEMENT_DIGITS = T::ELEMENT_BYTES * 2;
std::string str(number.size_elems() * ELEMENT_DIGITS, '-');
for (size_t elem = 0; elem < number.size_elems(); ++elem)
{
auto v = number[elem];
for (size_t digit = 0; digit < ELEMENT_DIGITS; ++digit)
{
str[str.size() - 1 - (elem * ELEMENT_DIGITS) - digit] = hex_digit((v >> (digit * 4)) & 0xF);
}
}
auto first = str.find_first_not_of('0');
if (first != std::string::npos)
{
return "0x" + str.substr(first);
}
return "0x0";
}
template <AnyMpInt T>
inline typename T::ElementType div_mod(T& value, const typename T::ElementType divisor)
{
using ElementType = typename T::ElementType;
ElementType remainder = 0;
for (size_t i = value.size_elems(); i-- > 0; )
{
using LongerType = DoubleWidthType<ElementType>;
LongerType acc = (static_cast<LongerType>(remainder) << (T::ELEMENT_BYTES * 8)) | value[i];
value[i] = static_cast<ElementType>(acc / divisor);
remainder = static_cast<ElementType>(acc % divisor);
}
return remainder;
}
template <AnyConstMpInt T>
inline std::string to_string(const T& number)
{
std::string str;
T temp = number;
do {
auto rem = div_mod(temp, 10U);
str.push_back(static_cast<char>('0' + rem));
} while (temp != 0U);
std::reverse(str.begin(), str.end());
return str;
}
template <AnyMpInt T>
inline void parse_string(const char* str, T& number)
{
number.zero();
for (; *str; ++str)
{
if (*str < '0' || *str > '9')
{
throw std::invalid_argument("Invalid character in input string");
}
number = number * 10U + static_cast<typename T::ElementType>(*str - '0');
}
}
// template <AnyConstInt T>
// inline T parse_string(const char* str)
// {
// T number;
// parse_string(str, number);
// return number;
// }
}

View File

@ -1,7 +1,6 @@
#pragma once #pragma once
#include "int.hpp" #include "int.hpp"
#include "storage.hpp"
#include "utils.hpp" #include "utils.hpp"
namespace mp namespace mp
@ -229,20 +228,6 @@ struct Multiplication
struct Division struct Division
{ {
template <AnyMpInt T>
static void shift_left(T& value)
{
using ElementType = typename T::ElementType;
for (size_t i = value.size_elems(); i-- > 0;)
{
ElementType new_val = value.get(i - 1);
value.set(i, new_val);
}
value.set(0, ElementType{0});
}
template <AnyConstMpInt TLhs, AnyConstMpInt TRhs, AnyMpInt TRes> template <AnyConstMpInt TLhs, AnyConstMpInt TRhs, AnyMpInt TRes>
requires ElementTypesMatch<TLhs, TRes> && ElementTypesMatch<TRhs, TRes> requires ElementTypesMatch<TLhs, TRes> && ElementTypesMatch<TRhs, TRes>
static void invoke(const TLhs& lhs, const TRhs& rhs, TRes& res) static void invoke(const TLhs& lhs, const TRhs& rhs, TRes& res)
@ -255,31 +240,144 @@ struct Division
throw std::runtime_error("Division by zero"); throw std::runtime_error("Division by zero");
} }
BasicInt<ElementType, TRes::MAX_BYTES> remainder = lhs; const size_t n = rhs.size_elems();
res.zero(); const size_t m = (lhs.size_elems() >= n) ? (lhs.size_elems() - n) : 0;
BasicInt<ElementType, TRhs::MAX_BYTES> temp;
while (remainder >= rhs) // If divisor larger than dividend => quotient = 0
if (lhs.size_elems() < n)
{ {
int shift = remainder.size_elems() - rhs.size_elems(); res.zero();
return;
temp = rhs;
for (int i = 0; i < shift; ++i)
{
shift_left(temp);
}
ElementType q_digit = 0;
if (remainder >= temp)
{
remainder -= temp;
q_digit++;
}
res.set(shift, q_digit);
} }
};
// Base and mask
constexpr unsigned int W = sizeof(ElementType) * 8;
const LongerType BASE = (static_cast<LongerType>(1) << W);
const LongerType MASK = BASE - 1;
// Prepare u (normalized dividend) length m + n + 1
BasicInt<ElementType, TLhs::MAX_BYTES + TRhs::MAX_BYTES + sizeof(ElementType)> u;
u.zero();
for (size_t i = 0; i < lhs.size_elems(); ++i)
u.set(i, lhs.get(i));
// Prepare v (normalized divisor) length n
BasicInt<ElementType, TRhs::MAX_BYTES + sizeof(ElementType)> v;
v.zero();
for (size_t i = 0; i < n; ++i)
v.set(i, rhs.get(i));
// Normalization factor d = BASE / (v[n-1] + 1)
ElementType d = 1;
if (v.get(n - 1) + 1 != 0) // defensive
{
d = static_cast<ElementType>(BASE / (static_cast<LongerType>(v.get(n - 1)) + 1));
}
if (d != 1)
{
// u = u * d
LongerType carry = 0;
for (size_t i = 0; i < (m + n + 1); ++i)
{
LongerType t = static_cast<LongerType>(u.get(i)) * d + carry;
u.set(i, static_cast<ElementType>(t & MASK));
carry = t >> W;
}
// v = v * d
carry = 0;
for (size_t i = 0; i < n; ++i)
{
LongerType t = static_cast<LongerType>(v.get(i)) * d + carry;
v.set(i, static_cast<ElementType>(t & MASK));
carry = t >> W;
}
// v[n] implicitly 0
}
// Prepare quotient
res.zero();
// Main loop j = m .. 0
for (int j = static_cast<int>(m); j >= 0; --j)
{
// u[j + n] might be zero or >0
const LongerType uj_n = static_cast<LongerType>(u.get(j + n));
const LongerType uj_n1 = static_cast<LongerType>(u.get(j + n - 1));
const LongerType vn_1 = static_cast<LongerType>(v.get(n - 1));
// Estimate q_hat = (u[j+n]*BASE + u[j+n-1]) / v[n-1]
LongerType numerator = (uj_n * BASE) + uj_n1;
LongerType qhat = numerator / vn_1;
LongerType rhat = numerator % vn_1;
if (qhat >= BASE)
qhat = BASE - 1;
// Correction loop (only if n >= 2)
if (n >= 2)
{
const LongerType vn_2 = static_cast<LongerType>(v.get(n - 2));
while (qhat * vn_2 > (rhat * BASE + static_cast<LongerType>(u.get(j + n - 2))))
{
qhat -= 1;
rhat += vn_1;
if (rhat >= BASE)
break;
}
}
// Multiply v by qhat and subtract from u segment starting at j
LongerType carry_mul = 0;
unsigned int borrow = 0;
for (size_t i = 0; i < n; ++i)
{
// p = qhat * v[i] + carry_mul
LongerType p = qhat * static_cast<LongerType>(v.get(i)) + carry_mul;
ElementType p_low = static_cast<ElementType>(p & MASK);
carry_mul = p >> W;
ElementType uval = u.get(j + i);
LongerType sub = static_cast<LongerType>(uval);
LongerType needed = static_cast<LongerType>(p_low) + borrow;
bool under = (sub < needed);
ElementType new_u = static_cast<ElementType>((sub + BASE - needed) & MASK);
u.set(j + i, new_u);
borrow = under ? 1u : 0u;
}
// Subtract carry_mul and borrow from u[j+n]
{
ElementType uval = u.get(j + n);
LongerType sub = static_cast<LongerType>(uval);
LongerType needed = carry_mul + borrow;
bool under = (sub < needed);
ElementType new_u = static_cast<ElementType>((sub + BASE - needed) & MASK);
u.set(j + n, new_u);
borrow = under ? 1u : 0u;
}
if (borrow != 0)
{
// qhat was too large, decrement and add v back
qhat -= 1;
LongerType carry_add = 0;
for (size_t i = 0; i < n; ++i)
{
LongerType sum = static_cast<LongerType>(u.get(j + i)) + static_cast<LongerType>(v.get(i)) + carry_add;
u.set(j + i, static_cast<ElementType>(sum & MASK));
carry_add = sum >> W;
}
// add carry_add to u[j+n]
u.set(j + n, static_cast<ElementType>((static_cast<LongerType>(u.get(j + n)) + carry_add) & MASK));
}
// store quotient digit
res.set(static_cast<size_t>(j), static_cast<ElementType>(qhat & MASK));
}
// Note: remainder unnormalization is not required for quotient; we ignore remainder.
}
}; };
// Binary operation // Binary operation

View File

@ -1,14 +1,10 @@
#pragma once #pragma once
#include <array> #include <array>
#include <concepts>
#include <cstddef> #include <cstddef>
#include <span>
#include <stdexcept> #include <stdexcept>
#include <vector> #include <vector>
#include "utils.hpp"
namespace mp namespace mp
{ {

View File

@ -1,6 +1,7 @@
#pragma once #pragma once
#include <concepts> #include <concepts>
#include <cstddef>
#include <cstdint> #include <cstdint>
#include <type_traits> #include <type_traits>
@ -30,7 +31,7 @@ struct DoubleWidth<uint32_t>
using type = uint64_t; using type = uint64_t;
}; };
#ifdef __SIZEOF_INT128__ #if USE_UINT128
template <> template <>
struct DoubleWidth<uint64_t> struct DoubleWidth<uint64_t>
{ {
@ -98,7 +99,7 @@ concept AnyConstInt = AnyRegularInt<T> || AnyConstMpInt<T>;
//======================== Utils =======================// //======================== Utils =======================//
#ifdef __SIZEOF_INT128__ #ifdef USE_UINT128
using LongestElementSuitableType = uint64_t; using LongestElementSuitableType = uint64_t;
#else #else
using LongestElementSuitableType = HalfWidthType<size_t>; using LongestElementSuitableType = HalfWidthType<size_t>;
@ -126,35 +127,4 @@ constexpr TElem calculate_last_elem_mask()
return mask; return mask;
} }
inline char hex_digit(uint8_t bits)
{
return bits > 9 ? 'a' + (bits - 10) : '0' + bits;
}
template <AnyConstMpInt T>
std::string to_hex_string(const T& number)
{
constexpr size_t ELEMENT_DIGITS = T::ELEMENT_BYTES * 2;
std::string str(number.size_elems() * ELEMENT_DIGITS, '-');
for (size_t elem = 0; elem < number.size_elems(); ++elem)
{
auto v = number[elem];
for (size_t digit = 0; digit < ELEMENT_DIGITS; ++digit)
{
str[str.size() - 1 - (elem * ELEMENT_DIGITS) - digit] = hex_digit((v >> (digit * 4)) & 0xF);
}
}
auto first = str.find_first_not_of('0');
if (first != std::string::npos)
{
return "0x" + str.substr(first);
}
return "0x0";
}
} }