aboutsummaryrefslogtreecommitdiffstats
path: root/include/mcl/vint.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'include/mcl/vint.hpp')
-rw-r--r--include/mcl/vint.hpp258
1 files changed, 186 insertions, 72 deletions
diff --git a/include/mcl/vint.hpp b/include/mcl/vint.hpp
index d909923..f791367 100644
--- a/include/mcl/vint.hpp
+++ b/include/mcl/vint.hpp
@@ -142,6 +142,45 @@ inline void decStr2Int(T& x, const std::string& s)
}
}
+inline uint32_t bin2uint32(const char *s, size_t n)
+{
+ uint32_t x = 0;
+ for (size_t i = 0; i < n; i++) {
+ x <<= 1;
+ char c = s[i];
+ if (c != '0' && c != '1') throw cybozu::Exception("bin2uint32:bad char") << std::string(s, n);
+ if (c == '1') {
+ x |= 1;
+ }
+ }
+ return x;
+}
+
+template<class T>
+inline void binStr2Int(T& x, const std::string& s)
+{
+ const size_t width = 32;
+ size_t size = s.size();
+ size_t q = size / width;
+ size_t r = size % width;
+
+ const char *p = s.c_str();
+ uint32_t v;
+ x = 0;
+ if (r) {
+ v = bin2uint32(p, r);
+ p += r;
+ T::addu1(x, x, v);
+ }
+ while (q) {
+ v = bin2uint32(p, width);
+ p += width;
+ x <<= width;
+ T::addu1(x, x, v);
+ q--;
+ }
+}
+
/*
compare x[] and y[]
@retval positive if x > y
@@ -198,7 +237,7 @@ T addN(T *z, const T *x, const T *y, size_t n)
z[] = x[] + y
*/
template<class T>
-T add1(T *z, const T *x, size_t n, T y)
+T adds1(T *z, const T *x, size_t n, T y)
{
assert(n > 0);
T t = x[0] + y;
@@ -236,7 +275,7 @@ T addNM(T *z, const T *x, size_t xn, const T *y, size_t yn)
size_t min = yn;
T c = vint::addN(z, x, y, min);
if (max > min) {
- c = vint::add1(z + min, x + min, max - min, c);
+ c = vint::adds1(z + min, x + min, max - min, c);
}
return c;
}
@@ -267,7 +306,7 @@ T subN(T *z, const T *x, const T *y, size_t n)
out[] = x[n] - y
*/
template<class T>
-T sub1(T *z, const T *x, size_t n, T y)
+T subs1(T *z, const T *x, size_t n, T y)
{
assert(n > 0);
#if 0
@@ -309,7 +348,7 @@ EXIT_0:
@note accept z == x
*/
template<class T>
-T mul1(T *z, const T *x, size_t n, T y)
+T mulu1(T *z, const T *x, size_t n, T y)
{
assert(n > 0);
T H = 0;
@@ -346,12 +385,12 @@ static inline void mulNM(T *z, const T *x, size_t xn, const T *y, size_t yn)
copyN(p, y, yn);
y = p;
}
- z[xn] = mul1(&z[0], x, xn, y[0]);
+ z[xn] = vint::mulu1(&z[0], x, xn, y[0]);
clearN(z + xn + 1, yn - 1);
T *t2 = (T*)CYBOZU_ALLOCA(sizeof(T) * (xn + 1));
for (size_t i = 1; i < yn; i++) {
- t2[xn] = vint::mul1(&t2[0], x, xn, y[i]);
+ t2[xn] = vint::mulu1(&t2[0], x, xn, y[i]);
vint::addN(&z[i], &z[i], &t2[0], xn + 1);
}
}
@@ -371,7 +410,7 @@ static inline void sqrN(T *y, const T *x, size_t xn)
accept q == x
*/
template<class T>
-T div1(T *q, const T *x, size_t n, T y)
+T divu1(T *q, const T *x, size_t n, T y)
{
T r = 0;
for (int i = (int)n - 1; i >= 0; i--) {
@@ -384,7 +423,7 @@ T div1(T *q, const T *x, size_t n, T y)
@retval r = x[] % y
*/
template<class T>
-T mod1(const T *x, size_t n, T y)
+T modu1(const T *x, size_t n, T y)
{
T r = 0;
for (int i = (int)n - 1; i >= 0; i--) {
@@ -438,74 +477,96 @@ static inline double GetApp(const T *x, size_t xn, bool up)
return t;
}
+template<class T>
+size_t getRealSize(const T *x, size_t xn)
+{
+ int i = (int)xn - 1;
+ for (; i > 0; i--) {
+ if (x[i]) {
+ return i + 1;
+ }
+ }
+ return 1;
+}
+
/*
- q[] = x[xn] / y[yn] ; size of q = xn - yn + 1 if q
- r[] = x[xn] % y[yn] ; size of r = xn
+ q[qn] = x[xn] / y[yn] ; qn == xn - yn + 1 if xn >= yn if q
+ r[rn] = x[xn] % y[yn] ; rn = yn before getRealSiz
*/
template<class T>
-void divNM(T *q, T *r, const T *x, size_t xn, const T *y, size_t yn)
+void divNM(T *q, size_t qn, T *r, const T *x, size_t xn, const T *y, size_t yn)
{
assert(xn > 0 && yn > 0);
- if (x == q || x == r) {
+ assert(xn < yn || (q == 0 || qn == xn - yn + 1));
+ const size_t rn = yn;
+ xn = getRealSize(x, xn);
+ yn = getRealSize(y, yn);
+ if (x == q) {
T *p = (T*)CYBOZU_ALLOCA(sizeof(T) * xn);
copyN(p, x, xn);
x = p;
}
- if (y == q || y == r) {
+ if (y == q) {
T *p = (T*)CYBOZU_ALLOCA(sizeof(T) * yn);
copyN(p, y, yn);
y = p;
}
if (q) {
- clearN(q, xn - yn + 1);
+ clearN(q, qn);
}
if (yn > xn) {
+ /*
+ if y > x then q = 0 and r = x
+ */
copyN(r, x, xn);
+ clearN(r + xn, rn - xn);
return;
}
if (yn == 1) {
T t;
if (q) {
- t = div1(q, x, xn, y[0]);
+ t = divu1(q, x, xn, y[0]);
} else {
- t = mod1(x, xn, y[0]);
+ t = modu1(x, xn, y[0]);
}
r[0] = t;
- clearN(r + 1, xn - 1);
+ clearN(r + 1, rn - 1);
return;
}
-// assert(xn >= yn && yn >= 2);
+ assert(yn >= 2);
if (x == y) {
assert(xn == yn);
- clearN(r, xn);
+ clearN(r, rn);
if (q) {
q[0] = 1;
}
return;
}
- copyN(r, x, xn);
+ T *rr = (T*)CYBOZU_ALLOCA(sizeof(T) * xn);
+ copyN(rr, x, xn);
T *t = (T*)CYBOZU_ALLOCA(sizeof(T) * (yn + 1));
double yt = GetApp(y, yn, true);
- while (vint::compareNM(r, xn, y, yn) >= 0) {
+ while (vint::compareNM(rr, xn, y, yn) >= 0) {
size_t len = yn;
- double xt = GetApp(r, xn, false);
- if (vint::compareNM(&r[xn - len], yn, y, yn) < 0) {
+ double xt = GetApp(rr, xn, false);
+ if (vint::compareNM(&rr[xn - len], yn, y, yn) < 0) {
xt *= double(1ULL << (sizeof(T) * 8 - 1)) * 2;
len++;
}
T qt = T(xt / yt);
if (qt == 0) qt = 1;
- t[yn] = vint::mul1(&t[0], y, yn, qt);
- T b = vint::subN(&r[xn - len], &r[xn - len], &t[0], len);
+ t[yn] = vint::mulu1(&t[0], y, yn, qt);
+ T b = vint::subN(&rr[xn - len], &rr[xn - len], &t[0], len);
if (b) {
assert(!b);
}
if (q) q[xn - len] += qt;
- while (xn >= yn && r[xn - 1] == 0) {
+ while (xn >= yn && rr[xn - 1] == 0) {
xn--;
}
}
+ copyN(r, rr, rn);
}
/*
@@ -797,14 +858,14 @@ private:
{
size_t zn = xn + 1;
z.buf_.alloc(zn);
- z.buf_[zn - 1] = vint::add1(&z.buf_[0], &x[0], xn, y);
+ z.buf_[zn - 1] = vint::adds1(&z.buf_[0], &x[0], xn, y);
z.trim(zn);
}
static void usub1(VintT& z, const Buffer& x, size_t xn, Unit y)
{
size_t zn = xn;
z.buf_.alloc(zn);
- Unit c = vint::sub1(&z.buf_[0], &x[0], xn, y);
+ Unit c = vint::subs1(&z.buf_[0], &x[0], xn, y);
(void)c;
assert(!c);
z.trim(zn);
@@ -815,7 +876,7 @@ private:
z.buf_.alloc(xn);
Unit c = vint::subN(&z.buf_[0], &x[0], &y[0], yn);
if (xn > yn) {
- c = vint::sub1(&z.buf_[yn], &x[yn], xn - yn, c);
+ c = vint::subs1(&z.buf_[yn], &x[yn], xn - yn, c);
}
assert(!c);
z.trim(xn);
@@ -837,7 +898,7 @@ private:
z.isNeg_ = yNeg;
}
}
- static void _add1(VintT& z, const VintT& x, bool xNeg, int y, bool yNeg)
+ static void _adds1(VintT& z, const VintT& x, bool xNeg, int y, bool yNeg)
{
assert(y >= 0);
if ((xNeg ^ yNeg) == 0) {
@@ -854,6 +915,22 @@ private:
z.isNeg_ = yNeg;
}
}
+ static void _addu1(VintT& z, const VintT& x, bool xNeg, Unit y)
+ {
+ if (!xNeg) {
+ // same sign
+ uadd1(z, x.buf_, x.size(), y);
+ z.isNeg_ = xNeg;
+ return;
+ }
+ if (x.size() > 1 || x.buf_[0] >= y) {
+ usub1(z, x.buf_, x.size(), y);
+ z.isNeg_ = xNeg;
+ } else {
+ z = y - x.buf_[0];
+ z.isNeg_ = false;
+ }
+ }
/**
@param q [out] x / y if q != 0
@param r [out] x % y
@@ -871,12 +948,12 @@ private:
if (q) {
q->buf_.alloc(qn);
}
- r.buf_.alloc(xn);
- vint::divNM(q ? &q->buf_[0] : 0, &r.buf_[0], &x[0], xn, &y[0], yn);
+ r.buf_.alloc(yn);
+ vint::divNM(q ? &q->buf_[0] : 0, qn, &r.buf_[0], &x[0], xn, &y[0], yn);
if (q) {
q->trim(qn);
}
- r.trim(xn);
+ r.trim(yn);
}
struct MulMod {
const VintT *pm;
@@ -951,6 +1028,18 @@ public:
std::swap(size_, rhs.size_);
std::swap(isNeg_, rhs.isNeg_);
}
+ void dump() const
+ {
+ printf("size_=%d ", (int)size_);
+ for (size_t i = 0; i < size_; i++) {
+#if CYBOZU_OS_BIT == 32
+ printf("%08x", (uint32_t)buf_[size_ - 1 - i]);
+#else
+ printf("%016llx", (unsigned long long)buf_[size_ - 1 - i]);
+#endif
+ }
+ printf("\n");
+ }
/*
set positive value
@note assume little endian system
@@ -1012,7 +1101,7 @@ public:
std::vector<uint32_t> t;
while (!x.isZero()) {
- uint32_t r = divMod1(&x, x, i1e9);
+ uint32_t r = udivModu1(&x, x, i1e9);
t.push_back(r);
}
if (t.empty()) {
@@ -1026,7 +1115,7 @@ public:
break;
case 16:
{
- os << "0x" << std::hex;
+ os << std::hex;
const size_t n = size();
os << getUnit()[n - 1];
for (size_t i = 1; i < n; i++) {
@@ -1087,16 +1176,14 @@ public:
neg = true;
str = str.substr(1);
}
- if (str.size() >= 2 && str[0] == '0') {
- switch (str[1]) {
- case 'x':
- if (base != 0 && base != 16) throw cybozu::Exception("bad base in setStr(str)") << base;
- base = 16;
- str = str.substr(2);
- break;
- default:
- throw cybozu::Exception("not support base in setStr(str) 0") << str[1];
- }
+ if (str.size() >= 2 && str[0] == '0' && str[1] == 'x') {
+ if (base != 0 && base != 16) throw cybozu::Exception("Vint:setStr bad base 0x)") << str << base;
+ base = 16;
+ str = str.substr(2);
+ } else if (str.size() >= 2 && str[0] == '0' && str[1] == 'x') {
+ if (base != 0 && base != 2) throw cybozu::Exception("Vint:setStr bad base 0b") << str << base;
+ base = 2;
+ str = str.substr(2);
}
if (base == 0) {
base = 10;
@@ -1118,6 +1205,9 @@ public:
setArray(&x[0], x.size());
}
break;
+ case 2:
+ binStr2Int(*this, str);
+ break;
default:
case 10:
decStr2Int(*this, str);
@@ -1150,6 +1240,8 @@ public:
uint32_t getLow32bit() const { return (uint32_t)buf_[0]; }
bool isOdd() const { return (buf_[0] & 1) == 1; }
bool isEven() const { return !isOdd(); }
+ const Unit *getUnit() const { return &buf_[0]; }
+ size_t getUnitSize() const { return size_; }
static void add(VintT& z, const VintT& x, const VintT& y)
{
_add(z, x, x.isNeg_, y, y.isNeg_);
@@ -1172,36 +1264,48 @@ public:
{
mul(y, x, x);
}
- static void add1(VintT& z, const VintT& x, int y)
+ static void addu1(VintT& z, const VintT& x, Unit y)
{
- if (y == invalidVar) throw cybozu::Exception("VintT:add1:bad y");
- _add1(z, x, x.isNeg_, std::abs(y), y < 0);
+ _addu1(z, x, x.isNeg_, y);
}
- static void sub1(VintT& z, const VintT& x, int y)
+ static void subu1(VintT& z, const VintT& x, Unit y)
{
- if (y == invalidVar) throw cybozu::Exception("VintT:sub1:bad y");
- _add1(z, x, x.isNeg_, std::abs(y), !(y < 0));
+ _addu1(z, x, x.isNeg_, y);
}
- static void mul1(VintT& z, const VintT& x, int y)
+ static void mulu1(VintT& z, const VintT& x, Unit y)
{
- if (y == invalidVar) throw cybozu::Exception("VintT:mul1:bad y");
size_t xn = x.size();
size_t zn = xn + 1;
- Unit absY = std::abs(y);
z.buf_.alloc(zn);
- z.buf_[zn - 1] = vint::mul1(&z.buf_[0], &x.buf_[0], xn, absY);
- z.isNeg_ = x.isNeg_ ^ (y < 0);
+ z.buf_[zn - 1] = vint::mulu1(&z.buf_[0], &x.buf_[0], xn, y);
+ z.isNeg_ = x.isNeg_;
z.trim(zn);
}
+ static void adds1(VintT& z, const VintT& x, int y)
+ {
+ if (y == invalidVar) throw cybozu::Exception("VintT:adds1:bad y");
+ _adds1(z, x, x.isNeg_, std::abs(y), y < 0);
+ }
+ static void subs1(VintT& z, const VintT& x, int y)
+ {
+ if (y == invalidVar) throw cybozu::Exception("VintT:subs1:bad y");
+ _adds1(z, x, x.isNeg_, std::abs(y), !(y < 0));
+ }
+ static void muls1(VintT& z, const VintT& x, int y)
+ {
+ if (y == invalidVar) throw cybozu::Exception("VintT:muls1:bad y");
+ mulu1(z, x, std::abs(y));
+ z.isNeg_ ^= (y < 0);
+ }
/*
@param q [out] q = x / y if q is not zero
@param x [in]
@param y [in] must be not zero
return x % y
*/
- static int divMod1(VintT *q, const VintT& x, int y)
+ static int divMods1(VintT *q, const VintT& x, int y)
{
- if (y == invalidVar) throw cybozu::Exception("VintT:divMod1:bad y");
+ if (y == invalidVar) throw cybozu::Exception("VintT:divMods1:bad y");
bool xNeg = x.isNeg_;
bool yNeg = y < 0;
Unit absY = std::abs(y);
@@ -1210,10 +1314,10 @@ public:
if (q) {
q->isNeg_ = xNeg ^ yNeg;
q->buf_.alloc(xn);
- r = vint::div1(&q->buf_[0], &x.buf_[0], xn, absY);
+ r = vint::divu1(&q->buf_[0], &x.buf_[0], xn, absY);
q->trim(xn);
} else {
- r = vint::mod1(&x.buf_[0], xn, absY);
+ r = vint::modu1(&x.buf_[0], xn, absY);
}
return xNeg ? -r : r;
}
@@ -1240,19 +1344,31 @@ public:
{
divMod(0, r, x, y);
}
- static void div1(VintT& q, const VintT& x, int y)
+ static void divs1(VintT& q, const VintT& x, int y)
{
- divMod1(&q, x, y);
+ divMods1(&q, x, y);
}
- static void mod1(VintT& r, const VintT& x, int y)
+ static void mods1(VintT& r, const VintT& x, int y)
{
bool xNeg = x.isNeg_;
- r = divMod1(0, x, y);
+ r = divMods1(0, x, y);
r.isNeg_ = xNeg;
}
+ static Unit udivModu1(VintT *q, const VintT& x, Unit y)
+ {
+ if (x.isNeg_) throw cybozu::Exception("VintT:udivu1:x is not negative") << x;
+ size_t xn = x.size();
+ if (q) q->buf_.alloc(xn);
+ Unit r = vint::divu1(q ? &q->buf_[0] : 0, &x.buf_[0], xn, y);
+ if (q) {
+ q->trim(xn);
+ q->isNeg_ = false;
+ }
+ return r;
+ }
/*
like Python
- 13 / 5 = 3 ... 2
+ 13 / 5 = 2 ... 3
13 / -5 = -3 ... -2
-13 / 5 = -3 ... 2
-13 / -5 = 2 ... -3
@@ -1554,8 +1670,6 @@ public:
VintT& operator--() { sub(*this, *this, 1); return *this; }
VintT operator++(int) { VintT c = *this; add(*this, *this, 1); return c; }
VintT operator--(int) { VintT c = *this; sub(*this, *this, 1); return c; }
- const Unit *getUnit() const { return &buf_[0]; }
- size_t getUnitSize() const { return size_; }
friend bool operator<(const VintT& x, const VintT& y) { return compare(x, y) < 0; }
friend bool operator>=(const VintT& x, const VintT& y) { return !operator<(x, y); }
friend bool operator>(const VintT& x, const VintT& y) { return compare(x, y) > 0; }
@@ -1569,11 +1683,11 @@ public:
VintT& operator%=(const VintT& rhs) { mod(*this, *this, rhs); return *this; }
VintT& operator&=(const VintT& rhs) { andBit(*this, *this, rhs); return *this; }
VintT& operator|=(const VintT& rhs) { orBit(*this, *this, rhs); return *this; }
- VintT& operator+=(int rhs) { add1(*this, *this, rhs); return *this; }
- VintT& operator-=(int rhs) { sub1(*this, *this, rhs); return *this; }
- VintT& operator*=(int rhs) { mul1(*this, *this, rhs); return *this; }
- VintT& operator/=(int rhs) { div1(*this, *this, rhs); return *this; }
- VintT& operator%=(int rhs) { mod1(*this, *this, rhs); return *this; }
+ VintT& operator+=(int rhs) { adds1(*this, *this, rhs); return *this; }
+ VintT& operator-=(int rhs) { subs1(*this, *this, rhs); return *this; }
+ VintT& operator*=(int rhs) { muls1(*this, *this, rhs); return *this; }
+ VintT& operator/=(int rhs) { divs1(*this, *this, rhs); return *this; }
+ VintT& operator%=(int rhs) { mods1(*this, *this, rhs); return *this; }
friend VintT operator+(const VintT& a, const VintT& b) { VintT c; add(c, a, b); return c; }
friend VintT operator-(const VintT& a, const VintT& b) { VintT c; sub(c, a, b); return c; }
friend VintT operator*(const VintT& a, const VintT& b) { VintT c; mul(c, a, b); return c; }