From 753dfe8eefe2adb950fff54a2d65287cd924d1c5 Mon Sep 17 00:00:00 2001 From: tovjemam Date: Wed, 26 Nov 2025 16:38:26 +0100 Subject: [PATCH] Somewhat working but with vibecoded division --- .gitignore | 1 + main.cpp | 42 +++++++++---- mp/int.hpp | 2 - mp/lib.hpp | 100 ++++++++++++++++++++++++++++++ mp/math.hpp | 162 +++++++++++++++++++++++++++++++++++++++---------- mp/storage.hpp | 4 -- mp/utils.hpp | 36 +---------- 7 files changed, 265 insertions(+), 82 deletions(-) create mode 100644 mp/lib.hpp diff --git a/.gitignore b/.gitignore index 24e741c..836a95c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ .vscode/ .vs/ +.cache/ build/ diff --git a/main.cpp b/main.cpp index cb559eb..69a004f 100644 --- a/main.cpp +++ b/main.cpp @@ -1,9 +1,10 @@ +#include #include //#include "mp.hpp" #include "mp/int.hpp" -#include "mp/storage.hpp" #include "mp/math.hpp" +#include "mp/lib.hpp" template static void PrintInt(const char* name, const T& val) @@ -11,6 +12,11 @@ static void PrintInt(const char* name, const T& val) std::cout << name << " = " << mp::to_hex_string(val) << std::endl; } +template +static void PrintDec(const char* name, const T& val){ + std::cout << name << " = " << mp::to_string(val) << std::endl; +} + int main() { // mp::Int a{0xDEADBEEFDEADF154, 0x0123456789ABCDEF, 0x1111222233334444}; @@ -68,13 +74,13 @@ int main() while (true) { acc *= mp::Int<1>{10}; - PrintInt("acc", acc); + PrintDec("acc", acc); } } catch (const mp::OverflowErrorOf>& e) { std::cout << "overflow" << std::endl; - PrintInt("value", e.value()); + PrintDec("value", e.value()); } catch (const std::exception& e) { @@ -105,15 +111,29 @@ int main() // } //} + // { + // mp::Int<32> a{0xA0000000, 0x6D7217CA, 0x431E0FAE, 0x1}; + // mp::Int<16> b{0x6FC10000, 0x2386F2}; + + // PrintDec("a", a); + // PrintDec("b", b); + + // auto c = a / b; + // PrintDec("c", c); + + // } + { - mp::Int<32> a{0xA0000000, 0x6D7217CA, 0x431E0FAE, 0x1}; - mp::Int<16> b{0x6FC10000, 0x2386F2}; - - PrintInt("a", a); - PrintInt("b", b); - - auto c = a / b; - PrintInt("c", c); + mp::Int<32> a; + mp::Int<16> b; + mp::Int<32> c; + mp::parse_string("690000100000000000000000000000000000", a); + mp::parse_string("10690010000000000000000", b); + c = a / b; + PrintDec("a", a); + PrintDec("b", b); + PrintDec("c", c); } + } \ No newline at end of file diff --git a/mp/int.hpp b/mp/int.hpp index 3ae9193..f82efdb 100644 --- a/mp/int.hpp +++ b/mp/int.hpp @@ -1,7 +1,5 @@ #pragma once -#include - #include "utils.hpp" #include "storage.hpp" diff --git a/mp/lib.hpp b/mp/lib.hpp new file mode 100644 index 0000000..790be82 --- /dev/null +++ b/mp/lib.hpp @@ -0,0 +1,100 @@ +#pragma once + +#include +#include +#include + +#include "utils.hpp" + +namespace mp +{ + +inline char hex_digit(uint8_t bits) +{ + return bits > 9 ? 'a' + (bits - 10) : '0' + bits; +} + +template +inline std::string to_hex_string(const T& number) +{ + constexpr size_t ELEMENT_DIGITS = T::ELEMENT_BYTES * 2; + + std::string str(number.size_elems() * ELEMENT_DIGITS, '-'); + + for (size_t elem = 0; elem < number.size_elems(); ++elem) + { + auto v = number[elem]; + for (size_t digit = 0; digit < ELEMENT_DIGITS; ++digit) + { + str[str.size() - 1 - (elem * ELEMENT_DIGITS) - digit] = hex_digit((v >> (digit * 4)) & 0xF); + } + } + + auto first = str.find_first_not_of('0'); + if (first != std::string::npos) + { + return "0x" + str.substr(first); + } + + return "0x0"; + +} + +template +inline typename T::ElementType div_mod(T& value, const typename T::ElementType divisor) +{ + using ElementType = typename T::ElementType; + + ElementType remainder = 0; + for (size_t i = value.size_elems(); i-- > 0; ) + { + using LongerType = DoubleWidthType; + + LongerType acc = (static_cast(remainder) << (T::ELEMENT_BYTES * 8)) | value[i]; + value[i] = static_cast(acc / divisor); + remainder = static_cast(acc % divisor); + } + + return remainder; +} + +template +inline std::string to_string(const T& number) +{ + std::string str; + T temp = number; + do { + auto rem = div_mod(temp, 10U); + str.push_back(static_cast('0' + rem)); + } while (temp != 0U); + + std::reverse(str.begin(), str.end()); + return str; +} + +template +inline void parse_string(const char* str, T& number) +{ + number.zero(); + + for (; *str; ++str) + { + if (*str < '0' || *str > '9') + { + throw std::invalid_argument("Invalid character in input string"); + } + + number = number * 10U + static_cast(*str - '0'); + } +} + + +// template +// inline T parse_string(const char* str) +// { +// T number; +// parse_string(str, number); +// return number; +// } + +} \ No newline at end of file diff --git a/mp/math.hpp b/mp/math.hpp index 4102313..5fb4ec7 100644 --- a/mp/math.hpp +++ b/mp/math.hpp @@ -1,7 +1,6 @@ #pragma once #include "int.hpp" -#include "storage.hpp" #include "utils.hpp" namespace mp @@ -229,20 +228,6 @@ struct Multiplication struct Division { - template - static void shift_left(T& value) - { - using ElementType = typename T::ElementType; - - for (size_t i = value.size_elems(); i-- > 0;) - { - ElementType new_val = value.get(i - 1); - value.set(i, new_val); - } - - value.set(0, ElementType{0}); - } - template requires ElementTypesMatch && ElementTypesMatch static void invoke(const TLhs& lhs, const TRhs& rhs, TRes& res) @@ -255,31 +240,144 @@ struct Division throw std::runtime_error("Division by zero"); } - BasicInt remainder = lhs; - res.zero(); - BasicInt temp; + const size_t n = rhs.size_elems(); + const size_t m = (lhs.size_elems() >= n) ? (lhs.size_elems() - n) : 0; - while (remainder >= rhs) + // If divisor larger than dividend => quotient = 0 + if (lhs.size_elems() < n) { - int shift = remainder.size_elems() - rhs.size_elems(); - - temp = rhs; - for (int i = 0; i < shift; ++i) + res.zero(); + return; + } + + // Base and mask + constexpr unsigned int W = sizeof(ElementType) * 8; + const LongerType BASE = (static_cast(1) << W); + const LongerType MASK = BASE - 1; + + // Prepare u (normalized dividend) length m + n + 1 + BasicInt u; + u.zero(); + for (size_t i = 0; i < lhs.size_elems(); ++i) + u.set(i, lhs.get(i)); + + // Prepare v (normalized divisor) length n + BasicInt v; + v.zero(); + for (size_t i = 0; i < n; ++i) + v.set(i, rhs.get(i)); + + // Normalization factor d = BASE / (v[n-1] + 1) + ElementType d = 1; + if (v.get(n - 1) + 1 != 0) // defensive + { + d = static_cast(BASE / (static_cast(v.get(n - 1)) + 1)); + } + + if (d != 1) + { + // u = u * d + LongerType carry = 0; + for (size_t i = 0; i < (m + n + 1); ++i) { - shift_left(temp); + LongerType t = static_cast(u.get(i)) * d + carry; + u.set(i, static_cast(t & MASK)); + carry = t >> W; + } + // v = v * d + carry = 0; + for (size_t i = 0; i < n; ++i) + { + LongerType t = static_cast(v.get(i)) * d + carry; + v.set(i, static_cast(t & MASK)); + carry = t >> W; + } + // v[n] implicitly 0 + } + + // Prepare quotient + res.zero(); + + // Main loop j = m .. 0 + for (int j = static_cast(m); j >= 0; --j) + { + // u[j + n] might be zero or >0 + const LongerType uj_n = static_cast(u.get(j + n)); + const LongerType uj_n1 = static_cast(u.get(j + n - 1)); + const LongerType vn_1 = static_cast(v.get(n - 1)); + + // Estimate q_hat = (u[j+n]*BASE + u[j+n-1]) / v[n-1] + LongerType numerator = (uj_n * BASE) + uj_n1; + LongerType qhat = numerator / vn_1; + LongerType rhat = numerator % vn_1; + + if (qhat >= BASE) + qhat = BASE - 1; + + // Correction loop (only if n >= 2) + if (n >= 2) + { + const LongerType vn_2 = static_cast(v.get(n - 2)); + while (qhat * vn_2 > (rhat * BASE + static_cast(u.get(j + n - 2)))) + { + qhat -= 1; + rhat += vn_1; + if (rhat >= BASE) + break; + } } - ElementType q_digit = 0; - - if (remainder >= temp) + // Multiply v by qhat and subtract from u segment starting at j + LongerType carry_mul = 0; + unsigned int borrow = 0; + for (size_t i = 0; i < n; ++i) { - remainder -= temp; - q_digit++; + // p = qhat * v[i] + carry_mul + LongerType p = qhat * static_cast(v.get(i)) + carry_mul; + ElementType p_low = static_cast(p & MASK); + carry_mul = p >> W; + + ElementType uval = u.get(j + i); + LongerType sub = static_cast(uval); + LongerType needed = static_cast(p_low) + borrow; + bool under = (sub < needed); + ElementType new_u = static_cast((sub + BASE - needed) & MASK); + u.set(j + i, new_u); + borrow = under ? 1u : 0u; } - res.set(shift, q_digit); - } - }; + // Subtract carry_mul and borrow from u[j+n] + { + ElementType uval = u.get(j + n); + LongerType sub = static_cast(uval); + LongerType needed = carry_mul + borrow; + bool under = (sub < needed); + ElementType new_u = static_cast((sub + BASE - needed) & MASK); + u.set(j + n, new_u); + borrow = under ? 1u : 0u; + } + + if (borrow != 0) + { + // qhat was too large, decrement and add v back + qhat -= 1; + LongerType carry_add = 0; + for (size_t i = 0; i < n; ++i) + { + LongerType sum = static_cast(u.get(j + i)) + static_cast(v.get(i)) + carry_add; + u.set(j + i, static_cast(sum & MASK)); + carry_add = sum >> W; + } + // add carry_add to u[j+n] + u.set(j + n, static_cast((static_cast(u.get(j + n)) + carry_add) & MASK)); + } + + // store quotient digit + res.set(static_cast(j), static_cast(qhat & MASK)); + } + + // Note: remainder unnormalization is not required for quotient; we ignore remainder. + } }; // Binary operation diff --git a/mp/storage.hpp b/mp/storage.hpp index 697166c..d452969 100644 --- a/mp/storage.hpp +++ b/mp/storage.hpp @@ -1,14 +1,10 @@ #pragma once #include -#include #include -#include #include #include -#include "utils.hpp" - namespace mp { diff --git a/mp/utils.hpp b/mp/utils.hpp index 3751247..cca9972 100644 --- a/mp/utils.hpp +++ b/mp/utils.hpp @@ -1,6 +1,7 @@ #pragma once #include +#include #include #include @@ -30,7 +31,7 @@ struct DoubleWidth using type = uint64_t; }; -#ifdef __SIZEOF_INT128__ +#if USE_UINT128 template <> struct DoubleWidth { @@ -98,7 +99,7 @@ concept AnyConstInt = AnyRegularInt || AnyConstMpInt; //======================== Utils =======================// -#ifdef __SIZEOF_INT128__ +#ifdef USE_UINT128 using LongestElementSuitableType = uint64_t; #else using LongestElementSuitableType = HalfWidthType; @@ -126,35 +127,4 @@ constexpr TElem calculate_last_elem_mask() return mask; } -inline char hex_digit(uint8_t bits) -{ - return bits > 9 ? 'a' + (bits - 10) : '0' + bits; -} - -template -std::string to_hex_string(const T& number) -{ - constexpr size_t ELEMENT_DIGITS = T::ELEMENT_BYTES * 2; - - std::string str(number.size_elems() * ELEMENT_DIGITS, '-'); - - for (size_t elem = 0; elem < number.size_elems(); ++elem) - { - auto v = number[elem]; - for (size_t digit = 0; digit < ELEMENT_DIGITS; ++digit) - { - str[str.size() - 1 - (elem * ELEMENT_DIGITS) - digit] = hex_digit((v >> (digit * 4)) & 0xF); - } - } - - auto first = str.find_first_not_of('0'); - if (first != std::string::npos) - { - return "0x" + str.substr(first); - } - - return "0x0"; - -} - } \ No newline at end of file