diff options
Diffstat (limited to 'include/mcl/vint.hpp')
-rw-r--r-- | include/mcl/vint.hpp | 258 |
1 files changed, 186 insertions, 72 deletions
diff --git a/include/mcl/vint.hpp b/include/mcl/vint.hpp index d909923..f791367 100644 --- a/include/mcl/vint.hpp +++ b/include/mcl/vint.hpp @@ -142,6 +142,45 @@ inline void decStr2Int(T& x, const std::string& s) } } +inline uint32_t bin2uint32(const char *s, size_t n) +{ + uint32_t x = 0; + for (size_t i = 0; i < n; i++) { + x <<= 1; + char c = s[i]; + if (c != '0' && c != '1') throw cybozu::Exception("bin2uint32:bad char") << std::string(s, n); + if (c == '1') { + x |= 1; + } + } + return x; +} + +template<class T> +inline void binStr2Int(T& x, const std::string& s) +{ + const size_t width = 32; + size_t size = s.size(); + size_t q = size / width; + size_t r = size % width; + + const char *p = s.c_str(); + uint32_t v; + x = 0; + if (r) { + v = bin2uint32(p, r); + p += r; + T::addu1(x, x, v); + } + while (q) { + v = bin2uint32(p, width); + p += width; + x <<= width; + T::addu1(x, x, v); + q--; + } +} + /* compare x[] and y[] @retval positive if x > y @@ -198,7 +237,7 @@ T addN(T *z, const T *x, const T *y, size_t n) z[] = x[] + y */ template<class T> -T add1(T *z, const T *x, size_t n, T y) +T adds1(T *z, const T *x, size_t n, T y) { assert(n > 0); T t = x[0] + y; @@ -236,7 +275,7 @@ T addNM(T *z, const T *x, size_t xn, const T *y, size_t yn) size_t min = yn; T c = vint::addN(z, x, y, min); if (max > min) { - c = vint::add1(z + min, x + min, max - min, c); + c = vint::adds1(z + min, x + min, max - min, c); } return c; } @@ -267,7 +306,7 @@ T subN(T *z, const T *x, const T *y, size_t n) out[] = x[n] - y */ template<class T> -T sub1(T *z, const T *x, size_t n, T y) +T subs1(T *z, const T *x, size_t n, T y) { assert(n > 0); #if 0 @@ -309,7 +348,7 @@ EXIT_0: @note accept z == x */ template<class T> -T mul1(T *z, const T *x, size_t n, T y) +T mulu1(T *z, const T *x, size_t n, T y) { assert(n > 0); T H = 0; @@ -346,12 +385,12 @@ static inline void mulNM(T *z, const T *x, size_t xn, const T *y, size_t yn) copyN(p, y, yn); y = p; } - z[xn] = mul1(&z[0], x, xn, y[0]); + z[xn] = vint::mulu1(&z[0], x, xn, y[0]); clearN(z + xn + 1, yn - 1); T *t2 = (T*)CYBOZU_ALLOCA(sizeof(T) * (xn + 1)); for (size_t i = 1; i < yn; i++) { - t2[xn] = vint::mul1(&t2[0], x, xn, y[i]); + t2[xn] = vint::mulu1(&t2[0], x, xn, y[i]); vint::addN(&z[i], &z[i], &t2[0], xn + 1); } } @@ -371,7 +410,7 @@ static inline void sqrN(T *y, const T *x, size_t xn) accept q == x */ template<class T> -T div1(T *q, const T *x, size_t n, T y) +T divu1(T *q, const T *x, size_t n, T y) { T r = 0; for (int i = (int)n - 1; i >= 0; i--) { @@ -384,7 +423,7 @@ T div1(T *q, const T *x, size_t n, T y) @retval r = x[] % y */ template<class T> -T mod1(const T *x, size_t n, T y) +T modu1(const T *x, size_t n, T y) { T r = 0; for (int i = (int)n - 1; i >= 0; i--) { @@ -438,74 +477,96 @@ static inline double GetApp(const T *x, size_t xn, bool up) return t; } +template<class T> +size_t getRealSize(const T *x, size_t xn) +{ + int i = (int)xn - 1; + for (; i > 0; i--) { + if (x[i]) { + return i + 1; + } + } + return 1; +} + /* - q[] = x[xn] / y[yn] ; size of q = xn - yn + 1 if q - r[] = x[xn] % y[yn] ; size of r = xn + q[qn] = x[xn] / y[yn] ; qn == xn - yn + 1 if xn >= yn if q + r[rn] = x[xn] % y[yn] ; rn = yn before getRealSiz */ template<class T> -void divNM(T *q, T *r, const T *x, size_t xn, const T *y, size_t yn) +void divNM(T *q, size_t qn, T *r, const T *x, size_t xn, const T *y, size_t yn) { assert(xn > 0 && yn > 0); - if (x == q || x == r) { + assert(xn < yn || (q == 0 || qn == xn - yn + 1)); + const size_t rn = yn; + xn = getRealSize(x, xn); + yn = getRealSize(y, yn); + if (x == q) { T *p = (T*)CYBOZU_ALLOCA(sizeof(T) * xn); copyN(p, x, xn); x = p; } - if (y == q || y == r) { + if (y == q) { T *p = (T*)CYBOZU_ALLOCA(sizeof(T) * yn); copyN(p, y, yn); y = p; } if (q) { - clearN(q, xn - yn + 1); + clearN(q, qn); } if (yn > xn) { + /* + if y > x then q = 0 and r = x + */ copyN(r, x, xn); + clearN(r + xn, rn - xn); return; } if (yn == 1) { T t; if (q) { - t = div1(q, x, xn, y[0]); + t = divu1(q, x, xn, y[0]); } else { - t = mod1(x, xn, y[0]); + t = modu1(x, xn, y[0]); } r[0] = t; - clearN(r + 1, xn - 1); + clearN(r + 1, rn - 1); return; } -// assert(xn >= yn && yn >= 2); + assert(yn >= 2); if (x == y) { assert(xn == yn); - clearN(r, xn); + clearN(r, rn); if (q) { q[0] = 1; } return; } - copyN(r, x, xn); + T *rr = (T*)CYBOZU_ALLOCA(sizeof(T) * xn); + copyN(rr, x, xn); T *t = (T*)CYBOZU_ALLOCA(sizeof(T) * (yn + 1)); double yt = GetApp(y, yn, true); - while (vint::compareNM(r, xn, y, yn) >= 0) { + while (vint::compareNM(rr, xn, y, yn) >= 0) { size_t len = yn; - double xt = GetApp(r, xn, false); - if (vint::compareNM(&r[xn - len], yn, y, yn) < 0) { + double xt = GetApp(rr, xn, false); + if (vint::compareNM(&rr[xn - len], yn, y, yn) < 0) { xt *= double(1ULL << (sizeof(T) * 8 - 1)) * 2; len++; } T qt = T(xt / yt); if (qt == 0) qt = 1; - t[yn] = vint::mul1(&t[0], y, yn, qt); - T b = vint::subN(&r[xn - len], &r[xn - len], &t[0], len); + t[yn] = vint::mulu1(&t[0], y, yn, qt); + T b = vint::subN(&rr[xn - len], &rr[xn - len], &t[0], len); if (b) { assert(!b); } if (q) q[xn - len] += qt; - while (xn >= yn && r[xn - 1] == 0) { + while (xn >= yn && rr[xn - 1] == 0) { xn--; } } + copyN(r, rr, rn); } /* @@ -797,14 +858,14 @@ private: { size_t zn = xn + 1; z.buf_.alloc(zn); - z.buf_[zn - 1] = vint::add1(&z.buf_[0], &x[0], xn, y); + z.buf_[zn - 1] = vint::adds1(&z.buf_[0], &x[0], xn, y); z.trim(zn); } static void usub1(VintT& z, const Buffer& x, size_t xn, Unit y) { size_t zn = xn; z.buf_.alloc(zn); - Unit c = vint::sub1(&z.buf_[0], &x[0], xn, y); + Unit c = vint::subs1(&z.buf_[0], &x[0], xn, y); (void)c; assert(!c); z.trim(zn); @@ -815,7 +876,7 @@ private: z.buf_.alloc(xn); Unit c = vint::subN(&z.buf_[0], &x[0], &y[0], yn); if (xn > yn) { - c = vint::sub1(&z.buf_[yn], &x[yn], xn - yn, c); + c = vint::subs1(&z.buf_[yn], &x[yn], xn - yn, c); } assert(!c); z.trim(xn); @@ -837,7 +898,7 @@ private: z.isNeg_ = yNeg; } } - static void _add1(VintT& z, const VintT& x, bool xNeg, int y, bool yNeg) + static void _adds1(VintT& z, const VintT& x, bool xNeg, int y, bool yNeg) { assert(y >= 0); if ((xNeg ^ yNeg) == 0) { @@ -854,6 +915,22 @@ private: z.isNeg_ = yNeg; } } + static void _addu1(VintT& z, const VintT& x, bool xNeg, Unit y) + { + if (!xNeg) { + // same sign + uadd1(z, x.buf_, x.size(), y); + z.isNeg_ = xNeg; + return; + } + if (x.size() > 1 || x.buf_[0] >= y) { + usub1(z, x.buf_, x.size(), y); + z.isNeg_ = xNeg; + } else { + z = y - x.buf_[0]; + z.isNeg_ = false; + } + } /** @param q [out] x / y if q != 0 @param r [out] x % y @@ -871,12 +948,12 @@ private: if (q) { q->buf_.alloc(qn); } - r.buf_.alloc(xn); - vint::divNM(q ? &q->buf_[0] : 0, &r.buf_[0], &x[0], xn, &y[0], yn); + r.buf_.alloc(yn); + vint::divNM(q ? &q->buf_[0] : 0, qn, &r.buf_[0], &x[0], xn, &y[0], yn); if (q) { q->trim(qn); } - r.trim(xn); + r.trim(yn); } struct MulMod { const VintT *pm; @@ -951,6 +1028,18 @@ public: std::swap(size_, rhs.size_); std::swap(isNeg_, rhs.isNeg_); } + void dump() const + { + printf("size_=%d ", (int)size_); + for (size_t i = 0; i < size_; i++) { +#if CYBOZU_OS_BIT == 32 + printf("%08x", (uint32_t)buf_[size_ - 1 - i]); +#else + printf("%016llx", (unsigned long long)buf_[size_ - 1 - i]); +#endif + } + printf("\n"); + } /* set positive value @note assume little endian system @@ -1012,7 +1101,7 @@ public: std::vector<uint32_t> t; while (!x.isZero()) { - uint32_t r = divMod1(&x, x, i1e9); + uint32_t r = udivModu1(&x, x, i1e9); t.push_back(r); } if (t.empty()) { @@ -1026,7 +1115,7 @@ public: break; case 16: { - os << "0x" << std::hex; + os << std::hex; const size_t n = size(); os << getUnit()[n - 1]; for (size_t i = 1; i < n; i++) { @@ -1087,16 +1176,14 @@ public: neg = true; str = str.substr(1); } - if (str.size() >= 2 && str[0] == '0') { - switch (str[1]) { - case 'x': - if (base != 0 && base != 16) throw cybozu::Exception("bad base in setStr(str)") << base; - base = 16; - str = str.substr(2); - break; - default: - throw cybozu::Exception("not support base in setStr(str) 0") << str[1]; - } + if (str.size() >= 2 && str[0] == '0' && str[1] == 'x') { + if (base != 0 && base != 16) throw cybozu::Exception("Vint:setStr bad base 0x)") << str << base; + base = 16; + str = str.substr(2); + } else if (str.size() >= 2 && str[0] == '0' && str[1] == 'x') { + if (base != 0 && base != 2) throw cybozu::Exception("Vint:setStr bad base 0b") << str << base; + base = 2; + str = str.substr(2); } if (base == 0) { base = 10; @@ -1118,6 +1205,9 @@ public: setArray(&x[0], x.size()); } break; + case 2: + binStr2Int(*this, str); + break; default: case 10: decStr2Int(*this, str); @@ -1150,6 +1240,8 @@ public: uint32_t getLow32bit() const { return (uint32_t)buf_[0]; } bool isOdd() const { return (buf_[0] & 1) == 1; } bool isEven() const { return !isOdd(); } + const Unit *getUnit() const { return &buf_[0]; } + size_t getUnitSize() const { return size_; } static void add(VintT& z, const VintT& x, const VintT& y) { _add(z, x, x.isNeg_, y, y.isNeg_); @@ -1172,36 +1264,48 @@ public: { mul(y, x, x); } - static void add1(VintT& z, const VintT& x, int y) + static void addu1(VintT& z, const VintT& x, Unit y) { - if (y == invalidVar) throw cybozu::Exception("VintT:add1:bad y"); - _add1(z, x, x.isNeg_, std::abs(y), y < 0); + _addu1(z, x, x.isNeg_, y); } - static void sub1(VintT& z, const VintT& x, int y) + static void subu1(VintT& z, const VintT& x, Unit y) { - if (y == invalidVar) throw cybozu::Exception("VintT:sub1:bad y"); - _add1(z, x, x.isNeg_, std::abs(y), !(y < 0)); + _addu1(z, x, x.isNeg_, y); } - static void mul1(VintT& z, const VintT& x, int y) + static void mulu1(VintT& z, const VintT& x, Unit y) { - if (y == invalidVar) throw cybozu::Exception("VintT:mul1:bad y"); size_t xn = x.size(); size_t zn = xn + 1; - Unit absY = std::abs(y); z.buf_.alloc(zn); - z.buf_[zn - 1] = vint::mul1(&z.buf_[0], &x.buf_[0], xn, absY); - z.isNeg_ = x.isNeg_ ^ (y < 0); + z.buf_[zn - 1] = vint::mulu1(&z.buf_[0], &x.buf_[0], xn, y); + z.isNeg_ = x.isNeg_; z.trim(zn); } + static void adds1(VintT& z, const VintT& x, int y) + { + if (y == invalidVar) throw cybozu::Exception("VintT:adds1:bad y"); + _adds1(z, x, x.isNeg_, std::abs(y), y < 0); + } + static void subs1(VintT& z, const VintT& x, int y) + { + if (y == invalidVar) throw cybozu::Exception("VintT:subs1:bad y"); + _adds1(z, x, x.isNeg_, std::abs(y), !(y < 0)); + } + static void muls1(VintT& z, const VintT& x, int y) + { + if (y == invalidVar) throw cybozu::Exception("VintT:muls1:bad y"); + mulu1(z, x, std::abs(y)); + z.isNeg_ ^= (y < 0); + } /* @param q [out] q = x / y if q is not zero @param x [in] @param y [in] must be not zero return x % y */ - static int divMod1(VintT *q, const VintT& x, int y) + static int divMods1(VintT *q, const VintT& x, int y) { - if (y == invalidVar) throw cybozu::Exception("VintT:divMod1:bad y"); + if (y == invalidVar) throw cybozu::Exception("VintT:divMods1:bad y"); bool xNeg = x.isNeg_; bool yNeg = y < 0; Unit absY = std::abs(y); @@ -1210,10 +1314,10 @@ public: if (q) { q->isNeg_ = xNeg ^ yNeg; q->buf_.alloc(xn); - r = vint::div1(&q->buf_[0], &x.buf_[0], xn, absY); + r = vint::divu1(&q->buf_[0], &x.buf_[0], xn, absY); q->trim(xn); } else { - r = vint::mod1(&x.buf_[0], xn, absY); + r = vint::modu1(&x.buf_[0], xn, absY); } return xNeg ? -r : r; } @@ -1240,19 +1344,31 @@ public: { divMod(0, r, x, y); } - static void div1(VintT& q, const VintT& x, int y) + static void divs1(VintT& q, const VintT& x, int y) { - divMod1(&q, x, y); + divMods1(&q, x, y); } - static void mod1(VintT& r, const VintT& x, int y) + static void mods1(VintT& r, const VintT& x, int y) { bool xNeg = x.isNeg_; - r = divMod1(0, x, y); + r = divMods1(0, x, y); r.isNeg_ = xNeg; } + static Unit udivModu1(VintT *q, const VintT& x, Unit y) + { + if (x.isNeg_) throw cybozu::Exception("VintT:udivu1:x is not negative") << x; + size_t xn = x.size(); + if (q) q->buf_.alloc(xn); + Unit r = vint::divu1(q ? &q->buf_[0] : 0, &x.buf_[0], xn, y); + if (q) { + q->trim(xn); + q->isNeg_ = false; + } + return r; + } /* like Python - 13 / 5 = 3 ... 2 + 13 / 5 = 2 ... 3 13 / -5 = -3 ... -2 -13 / 5 = -3 ... 2 -13 / -5 = 2 ... -3 @@ -1554,8 +1670,6 @@ public: VintT& operator--() { sub(*this, *this, 1); return *this; } VintT operator++(int) { VintT c = *this; add(*this, *this, 1); return c; } VintT operator--(int) { VintT c = *this; sub(*this, *this, 1); return c; } - const Unit *getUnit() const { return &buf_[0]; } - size_t getUnitSize() const { return size_; } friend bool operator<(const VintT& x, const VintT& y) { return compare(x, y) < 0; } friend bool operator>=(const VintT& x, const VintT& y) { return !operator<(x, y); } friend bool operator>(const VintT& x, const VintT& y) { return compare(x, y) > 0; } @@ -1569,11 +1683,11 @@ public: VintT& operator%=(const VintT& rhs) { mod(*this, *this, rhs); return *this; } VintT& operator&=(const VintT& rhs) { andBit(*this, *this, rhs); return *this; } VintT& operator|=(const VintT& rhs) { orBit(*this, *this, rhs); return *this; } - VintT& operator+=(int rhs) { add1(*this, *this, rhs); return *this; } - VintT& operator-=(int rhs) { sub1(*this, *this, rhs); return *this; } - VintT& operator*=(int rhs) { mul1(*this, *this, rhs); return *this; } - VintT& operator/=(int rhs) { div1(*this, *this, rhs); return *this; } - VintT& operator%=(int rhs) { mod1(*this, *this, rhs); return *this; } + VintT& operator+=(int rhs) { adds1(*this, *this, rhs); return *this; } + VintT& operator-=(int rhs) { subs1(*this, *this, rhs); return *this; } + VintT& operator*=(int rhs) { muls1(*this, *this, rhs); return *this; } + VintT& operator/=(int rhs) { divs1(*this, *this, rhs); return *this; } + VintT& operator%=(int rhs) { mods1(*this, *this, rhs); return *this; } friend VintT operator+(const VintT& a, const VintT& b) { VintT c; add(c, a, b); return c; } friend VintT operator-(const VintT& a, const VintT& b) { VintT c; sub(c, a, b); return c; } friend VintT operator*(const VintT& a, const VintT& b) { VintT c; mul(c, a, b); return c; } |