aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--Makefile2
-rw-r--r--include/mcl/bn.hpp13
-rw-r--r--include/mcl/conversion.hpp2
-rw-r--r--include/mcl/fp.hpp173
-rw-r--r--include/mcl/gmp_util.hpp162
-rw-r--r--include/mcl/operator.hpp2
-rw-r--r--include/mcl/vint.hpp237
-rw-r--r--src/fp.cpp55
-rw-r--r--test/fp_test.cpp1
-rw-r--r--test/gmp_test.cpp39
10 files changed, 450 insertions, 236 deletions
diff --git a/Makefile b/Makefile
index c4c5eba..7cb811a 100644
--- a/Makefile
+++ b/Makefile
@@ -261,7 +261,7 @@ endif
emcc -o $@ src/fp.cpp src/she_c384.cpp $(EMCC_OPT) -DMCL_MAX_BIT_SIZE=384 -s TOTAL_MEMORY=67108864 -s DISABLE_EXCEPTION_CATCHING=0
../mcl-wasm/mcl_c.js: src/bn_c256.cpp $(MCL_C_DEP)
- emcc -o $@ src/fp.cpp src/bn_c256.cpp $(EMCC_OPT) -DMCL_MAX_BIT_SIZE=256 -DMCL_USE_WEB_CRYPTO_API -s DISABLE_EXCEPTION_CATCHING=1
+ emcc -o $@ src/fp.cpp src/bn_c256.cpp $(EMCC_OPT) -DMCL_MAX_BIT_SIZE=256 -DMCL_USE_WEB_CRYPTO_API -s DISABLE_EXCEPTION_CATCHING=1 #-DCYBOZU_DONT_USE_EXCEPTION -DCYBOZU_DONT_USE_STRING
../mcl-wasm/mcl_c512.js: src/bn_c512.cpp $(MCL_C_DEP)
emcc -o $@ src/fp.cpp src/bn_c512.cpp $(EMCC_OPT) -DMCL_MAX_BIT_SIZE=512 -DMCL_USE_WEB_CRYPTO_API -s DISABLE_EXCEPTION_CATCHING=1
diff --git a/include/mcl/bn.hpp b/include/mcl/bn.hpp
index a686999..98020ac 100644
--- a/include/mcl/bn.hpp
+++ b/include/mcl/bn.hpp
@@ -890,7 +890,8 @@ struct Param {
{
this->cp = cp;
isBLS12 = cp.curveType == MCL_BLS12_381;
- z = mpz_class(cp.z);
+ gmp::setStr(pb, z, cp.z);
+ if (!*pb) return;
isNegative = z < 0;
if (isNegative) {
abs_z = -z;
@@ -970,12 +971,14 @@ struct Param {
glv2.init(r, z, isBLS12);
*pb = true;
}
+#ifndef CYBOZU_DONT_EXCEPTION
void init(const mcl::CurveParam& cp, fp::Mode mode)
{
bool b;
init(&b, cp, mode);
if (!b) throw cybozu::Exception("Param:init");
}
+#endif
};
template<size_t dummyImpl = 0>
@@ -1828,6 +1831,7 @@ inline void precomputedMillerLoop2(Fp12& f, const G1& P1, const mcl::Array<Fp6>&
}
inline void mapToG1(bool *pb, G1& P, const Fp& x) { *pb = BN::param.mapTo.calcG1(P, x); }
inline void mapToG2(bool *pb, G2& P, const Fp2& x) { *pb = BN::param.mapTo.calcG2(P, x); }
+#ifndef CYBOZU_DONT_EXCEPTION
inline void mapToG1(G1& P, const Fp& x)
{
bool b;
@@ -1840,6 +1844,7 @@ inline void mapToG2(G2& P, const Fp2& x)
mapToG2(&b, P, x);
if (!b) throw cybozu::Exception("mapToG2:bad value") << x;
}
+#endif
inline void hashAndMapToG1(G1& P, const void *buf, size_t bufSize)
{
Fp t;
@@ -1861,6 +1866,7 @@ inline void hashAndMapToG2(G2& P, const void *buf, size_t bufSize)
assert(b);
(void)b;
}
+#ifndef CYBOZU_DONT_USE_STRING
inline void hashAndMapToG1(G1& P, const std::string& str)
{
hashAndMapToG1(P, str.c_str(), str.size());
@@ -1869,6 +1875,7 @@ inline void hashAndMapToG2(G2& P, const std::string& str)
{
hashAndMapToG2(P, str.c_str(), str.size());
}
+#endif
inline void verifyOrderG1(bool doVerify)
{
if (BN::param.isBLS12) {
@@ -1936,12 +1943,14 @@ inline void init(bool *pb, const mcl::CurveParam& cp = mcl::BN254, fp::Mode mode
*pb = true;
}
+#ifndef CYBOZU_DONT_EXCEPTION
inline void init(const mcl::CurveParam& cp = mcl::BN254, fp::Mode mode = fp::FP_AUTO)
{
bool b;
init(&b, cp, mode);
if (!b) throw cybozu::Exception("BN:init");
}
+#endif
} // mcl::bn::BN
@@ -1950,12 +1959,14 @@ inline void initPairing(bool *pb, const mcl::CurveParam& cp = mcl::BN254, fp::Mo
BN::init(pb, cp, mode);
}
+#ifndef CYBOZU_DONT_EXCEPTION
inline void initPairing(const mcl::CurveParam& cp = mcl::BN254, fp::Mode mode = fp::FP_AUTO)
{
bool b;
BN::init(&b, cp, mode);
if (!b) throw cybozu::Exception("bn:initPairing");
}
+#endif
} } // mcl::bn
diff --git a/include/mcl/conversion.hpp b/include/mcl/conversion.hpp
index 1000257..b5faa50 100644
--- a/include/mcl/conversion.hpp
+++ b/include/mcl/conversion.hpp
@@ -26,6 +26,7 @@ bool skipSpace(char *c, InputStream& is)
}
}
+#ifndef CYBOZU_DONT_USE_STRING
template<class InputStream>
void loadWord(std::string& s, InputStream& is)
{
@@ -39,6 +40,7 @@ void loadWord(std::string& s, InputStream& is)
s += c;
}
}
+#endif
template<class InputStream>
size_t loadWord(char *buf, size_t bufSize, InputStream& is)
diff --git a/include/mcl/fp.hpp b/include/mcl/fp.hpp
index b114a1f..caa67c9 100644
--- a/include/mcl/fp.hpp
+++ b/include/mcl/fp.hpp
@@ -48,7 +48,14 @@ int64_t getInt64(bool *pb, fp::Block& b, const fp::Op& op);
const char *ModeToStr(Mode mode);
-Mode StrToMode(const std::string& s);
+Mode StrToMode(const char *s);
+
+#ifndef CYBOZU_DONT_USE_STRING
+inline Mode StrToMode(const std::string& s)
+{
+ return StrToMode(s.c_str());
+}
+#endif
inline void dumpUnit(Unit x)
{
@@ -124,36 +131,14 @@ public:
static inline void init(bool *pb, const char *mstr, fp::Mode mode = fp::FP_AUTO)
{
mpz_class p;
- gmp::setStr(pb, p, mstr, strlen(mstr));
+ gmp::setStr(pb, p, mstr);
if (!*pb) return;
init(pb, p, mode);
}
- static inline void init(const mpz_class& _p, fp::Mode mode = fp::FP_AUTO)
- {
- bool b;
- init(&b, _p, mode);
- if (!b) throw cybozu::Exception("Fp:init");
- }
- static inline void init(const std::string& mstr, fp::Mode mode = fp::FP_AUTO)
- {
- bool b;
- init(&b, mstr.c_str(), mode);
- if (!b) throw cybozu::Exception("Fp:init");
- }
static inline size_t getModulo(char *buf, size_t bufSize)
{
return gmp::getStr(buf, bufSize, op_.mp);
}
- static inline void getModulo(std::string& pstr)
- {
- gmp::getStr(pstr, op_.mp);
- }
- static std::string getModulo()
- {
- std::string s;
- getModulo(s);
- return s;
- }
static inline bool isFullBit() { return op_.isFullBit; }
/*
binary patter of p
@@ -176,8 +161,8 @@ public:
x.getMpz(mx);
bool b = op_.sq.get(my, mx);
if (!b) return false;
- y.setMpz(my);
- return true;
+ y.setMpz(&b, my);
+ return b;
}
FpT() {}
FpT(const FpT& x)
@@ -194,10 +179,6 @@ public:
op_.fp_clear(v_);
}
FpT(int64_t x) { operator=(x); }
- explicit FpT(const std::string& str, int base = 0)
- {
- Serializer::setStr(str, base);
- }
FpT& operator=(int64_t x)
{
if (x == 1) {
@@ -290,20 +271,6 @@ public:
}
cybozu::write(pb, os, buf + sizeof(buf) - len, len);
}
- template<class OutputStream>
- void save(OutputStream& os, int ioMode = IoSerialize) const
- {
- bool b;
- save(&b, os, ioMode);
- if (!b) throw cybozu::Exception("fp:save") << ioMode;
- }
- template<class InputStream>
- void load(InputStream& is, int ioMode = IoSerialize)
- {
- bool b;
- load(&b, is, ioMode);
- if (!b) throw cybozu::Exception("fp:load") << ioMode;
- }
template<class S>
void setArray(bool *pb, const S *x, size_t n)
{
@@ -311,16 +278,6 @@ public:
toMont();
}
/*
- throw exception if x >= p
- */
- template<class S>
- void setArray(const S *x, size_t n)
- {
- bool b;
- setArray(&b, x, n);
- if (!b) throw cybozu::Exception("Fp:setArray");
- }
- /*
mask x with (1 << bitLen) and subtract p if x >= p
*/
template<class S>
@@ -368,10 +325,6 @@ public:
uint32_t size = op_.hash(buf, static_cast<uint32_t>(sizeof(buf)), msg, static_cast<uint32_t>(msgSize));
setArrayMask(buf, size);
}
- void setHashOf(const std::string& msg)
- {
- setHashOf(msg.data(), msg.size());
- }
void getMpz(mpz_class& x) const
{
fp::Block b;
@@ -392,12 +345,6 @@ public:
}
setArray(pb, gmp::getUnit(x), gmp::getUnitSize(x));
}
- void setMpz(const mpz_class& x)
- {
- bool b;
- setMpz(&b, x);
- if (!b) throw cybozu::Exception("Fp:setMpz:neg");
- }
static inline void add(FpT& z, const FpT& x, const FpT& y) { op_.fp_add(z.v_, x.v_, y.v_, op_.p); }
static inline void sub(FpT& z, const FpT& x, const FpT& y) { op_.fp_sub(z.v_, x.v_, y.v_, op_.p); }
static inline void addPre(FpT& z, const FpT& x, const FpT& y) { op_.fp_addPre(z.v_, x.v_, y.v_); }
@@ -458,20 +405,6 @@ public:
getBlock(b);
return fp::getInt64(pb, b, op_);
}
- uint64_t getUint64() const
- {
- bool b;
- uint64_t v = getUint64(&b);
- if (!b) throw cybozu::Exception("Fp:getUint64:large value");
- return v;
- }
- int64_t getInt64() const
- {
- bool b;
- int64_t v = getInt64(&b);
- if (!b) throw cybozu::Exception("Fp:getInt64:large value");
- return v;
- }
bool operator==(const FpT& rhs) const { return fp::isEqualArray(v_, rhs.v_, op_.N); }
bool operator!=(const FpT& rhs) const { return !operator==(rhs); }
friend inline std::ostream& operator<<(std::ostream& os, const FpT& self)
@@ -526,16 +459,94 @@ public:
ioMode_ = ioMode;
}
static inline int getIoMode() { return ioMode_; }
+ static inline size_t getModBitLen() { return getBitSize(); }
+ static inline void setHashFunc(uint32_t hash(void *out, uint32_t maxOutSize, const void *msg, uint32_t msgSize))
+ {
+ op_.hash = hash;
+ }
+#ifndef CYBOZU_DONT_USE_STRING
+ explicit FpT(const std::string& str, int base = 0)
+ {
+ Serializer::setStr(str, base);
+ }
+ static inline void getModulo(std::string& pstr)
+ {
+ gmp::getStr(pstr, op_.mp);
+ }
+ static std::string getModulo()
+ {
+ std::string s;
+ getModulo(s);
+ return s;
+ }
+ void setHashOf(const std::string& msg)
+ {
+ setHashOf(msg.data(), msg.size());
+ }
// backward compatibility
static inline void setModulo(const std::string& mstr, fp::Mode mode = fp::FP_AUTO)
{
init(mstr, mode);
}
- static inline size_t getModBitLen() { return getBitSize(); }
- static inline void setHashFunc(uint32_t hash(void *out, uint32_t maxOutSize, const void *msg, uint32_t msgSize))
+#endif
+#ifndef CYBOZU_DONT_USE_EXCEPTION
+ static inline void init(const mpz_class& _p, fp::Mode mode = fp::FP_AUTO)
{
- op_.hash = hash;
+ bool b;
+ init(&b, _p, mode);
+ if (!b) throw cybozu::Exception("Fp:init");
+ }
+ static inline void init(const std::string& mstr, fp::Mode mode = fp::FP_AUTO)
+ {
+ bool b;
+ init(&b, mstr.c_str(), mode);
+ if (!b) throw cybozu::Exception("Fp:init");
+ }
+ template<class OutputStream>
+ void save(OutputStream& os, int ioMode = IoSerialize) const
+ {
+ bool b;
+ save(&b, os, ioMode);
+ if (!b) throw cybozu::Exception("fp:save") << ioMode;
+ }
+ template<class InputStream>
+ void load(InputStream& is, int ioMode = IoSerialize)
+ {
+ bool b;
+ load(&b, is, ioMode);
+ if (!b) throw cybozu::Exception("fp:load") << ioMode;
+ }
+ /*
+ throw exception if x >= p
+ */
+ template<class S>
+ void setArray(const S *x, size_t n)
+ {
+ bool b;
+ setArray(&b, x, n);
+ if (!b) throw cybozu::Exception("Fp:setArray");
+ }
+ void setMpz(const mpz_class& x)
+ {
+ bool b;
+ setMpz(&b, x);
+ if (!b) throw cybozu::Exception("Fp:setMpz:neg");
+ }
+ uint64_t getUint64() const
+ {
+ bool b;
+ uint64_t v = getUint64(&b);
+ if (!b) throw cybozu::Exception("Fp:getUint64:large value");
+ return v;
+ }
+ int64_t getInt64() const
+ {
+ bool b;
+ int64_t v = getInt64(&b);
+ if (!b) throw cybozu::Exception("Fp:getInt64:large value");
+ return v;
}
+#endif
};
template<class tag, size_t maxBitSize> fp::Op FpT<tag, maxBitSize>::op_;
diff --git a/include/mcl/gmp_util.hpp b/include/mcl/gmp_util.hpp
index 399041c..bb461af 100644
--- a/include/mcl/gmp_util.hpp
+++ b/include/mcl/gmp_util.hpp
@@ -64,12 +64,13 @@ typedef mpz_class ImplType;
// z = [buf[n-1]:..:buf[1]:buf[0]]
// eg. buf[] = {0x12345678, 0xaabbccdd}; => z = 0xaabbccdd12345678;
template<class T>
-void setArray(mpz_class& z, const T *buf, size_t n)
+void setArray(bool *pb, mpz_class& z, const T *buf, size_t n)
{
#ifdef MCL_USE_VINT
- z.setArray(buf, n);
+ z.setArray(pb, buf, n);
#else
mpz_import(z.get_mpz_t(), n, -1, sizeof(*buf), 0, 0, buf);
+ *pb = true;
#endif
}
/*
@@ -78,44 +79,43 @@ void setArray(mpz_class& z, const T *buf, size_t n)
*/
#ifndef MCL_USE_VINT
template<class T>
-void getArray(T *buf, size_t maxSize, const mpz_srcptr x)
+bool getArray_(T *buf, size_t maxSize, const mpz_srcptr x)
{
const size_t bufByteSize = sizeof(T) * maxSize;
const int xn = x->_mp_size;
- if (xn < 0) throw cybozu::Exception("gmp:getArray:x is negative");
+ if (xn < 0) return false;
size_t xByteSize = sizeof(*x->_mp_d) * xn;
- if (xByteSize > bufByteSize) throw cybozu::Exception("gmp:getArray:too small") << xn << maxSize;
+ if (xByteSize > bufByteSize) return false;
memcpy(buf, x->_mp_d, xByteSize);
memset((char*)buf + xByteSize, 0, bufByteSize - xByteSize);
+ return true;
}
#endif
template<class T>
-void getArray(T *buf, size_t maxSize, const mpz_class& x)
+void getArray(bool *pb, T *buf, size_t maxSize, const mpz_class& x)
{
#ifdef MCL_USE_VINT
- x.getArray(buf, maxSize);
+ x.getArray(pb, buf, maxSize);
#else
- getArray(buf, maxSize, x.get_mpz_t());
+ *pb = getArray_(buf, maxSize, x.get_mpz_t());
#endif
}
inline void set(mpz_class& z, uint64_t x)
{
- setArray(z, &x, 1);
+ bool b;
+ setArray(&b, z, &x, 1);
+ assert(b);
+ (void)b;
}
-inline void setStr(bool *pb, mpz_class& z, const char *str, size_t strSize, int base = 0)
+inline void setStr(bool *pb, mpz_class& z, const char *str, int base = 0)
{
#ifdef MCL_USE_VINT
- z.setStr(pb, str, strSize, base);
+ z.setStr(pb, str, base);
#else
- *pb = z.set_str(std::string(str, strSize), base) == 0;
+ *pb = z.set_str(str, base) == 0;
#endif
}
-inline void setStr(mpz_class& z, const std::string& str, int base = 0)
-{
- bool b;
- setStr(&b, z, str.c_str(), str.size(), base);
- if (!b) throw cybozu::Exception("gmp:setStr");
-}
+
/*
set buf with string terminated by '\0'
return strlen(buf) if success else 0
@@ -125,18 +125,19 @@ inline size_t getStr(char *buf, size_t bufSize, const mpz_class& z, int base = 1
#ifdef MCL_USE_VINT
return z.getStr(buf, bufSize, base);
#else
- std::string str = z.get_str(base);
- if (str.size() < bufSize) {
- memcpy(buf, str.c_str(), str.size() + 1);
- return str.size();
- }
- return 0;
+ __gmp_alloc_cstring tmp(mpz_get_str(0, base, z.get_mpz_t()));
+ size_t n = strlen(tmp.str);
+ if (n + 1 > bufSize) return 0;
+ memcpy(buf, tmp.str, n + 1);
+ return n;
#endif
}
+
+#ifndef CYBOZU_DONT_USE_STRING
inline void getStr(std::string& str, const mpz_class& z, int base = 10)
{
#ifdef MCL_USE_VINT
- str = z.getStr(base);
+ z.getStr(str, base);
#else
str = z.get_str(base);
#endif
@@ -144,9 +145,11 @@ inline void getStr(std::string& str, const mpz_class& z, int base = 10)
inline std::string getStr(const mpz_class& z, int base = 10)
{
std::string s;
- getStr(s, z, base);
+ gmp::getStr(s, z, base);
return s;
}
+#endif
+
inline void add(mpz_class& z, const mpz_class& x, const mpz_class& y)
{
#ifdef MCL_USE_VINT
@@ -365,11 +368,12 @@ inline int legendre(const mpz_class& a, const mpz_class& p)
return mpz_legendre(a.get_mpz_t(), p.get_mpz_t());
#endif
}
-inline bool isPrime(const mpz_class& x)
+inline bool isPrime(bool *pb, const mpz_class& x)
{
#ifdef MCL_USE_VINT
- return x.isPrime(32);
+ return x.isPrime(pb, 32);
#else
+ *pb = true;
return mpz_probab_prime_p(x.get_mpz_t(), 32) != 0;
#endif
}
@@ -438,7 +442,7 @@ inline mpz_class abs(const mpz_class& x)
#endif
}
-inline void getRand(mpz_class& z, size_t bitSize, fp::RandGen rg = fp::RandGen())
+inline void getRand(bool *pb, mpz_class& z, size_t bitSize, fp::RandGen rg = fp::RandGen())
{
if (rg.isZero()) rg = fp::RandGen::get();
assert(bitSize > 1);
@@ -447,7 +451,7 @@ inline void getRand(mpz_class& z, size_t bitSize, fp::RandGen rg = fp::RandGen()
uint32_t buf[128];
assert(n <= CYBOZU_NUM_OF_ARRAY(buf));
if (n > CYBOZU_NUM_OF_ARRAY(buf)) {
- z = 0;
+ *pb = false;
return;
}
rg.read(buf, n * sizeof(buf[0]));
@@ -459,22 +463,26 @@ inline void getRand(mpz_class& z, size_t bitSize, fp::RandGen rg = fp::RandGen()
v |= 1U << (rem - 1);
}
buf[n - 1] = v;
- setArray(z, buf, n);
+ setArray(pb, z, buf, n);
}
-inline void getRandPrime(mpz_class& z, size_t bitSize, fp::RandGen rg = fp::RandGen(), bool setSecondBit = false, bool mustBe3mod4 = false)
+inline void getRandPrime(bool *pb, mpz_class& z, size_t bitSize, fp::RandGen rg = fp::RandGen(), bool setSecondBit = false, bool mustBe3mod4 = false)
{
if (rg.isZero()) rg = fp::RandGen::get();
assert(bitSize > 2);
- do {
- getRand(z, bitSize, rg);
+ for (;;) {
+ getRand(pb, z, bitSize, rg);
+ if (!*pb) return;
if (setSecondBit) {
z |= mpz_class(1) << (bitSize - 2);
}
if (mustBe3mod4) {
z |= 3;
}
- } while (!(isPrime(z)));
+ bool ret = isPrime(pb, z);
+ if (!*pb) return;
+ if (ret) return;
+ }
}
inline mpz_class getQuadraticNonResidue(const mpz_class& p)
{
@@ -566,6 +574,49 @@ bool getNAF(Vec& v, const mpz_class& x)
}
}
+#ifndef CYBOZU_DONT_USE_EXCEPTION
+inline void setStr(mpz_class& z, const std::string& str, int base = 0)
+{
+ bool b;
+ setStr(&b, z, str.c_str(), base);
+ if (!b) throw cybozu::Exception("gmp:setStr");
+}
+template<class T>
+void setArray(mpz_class& z, const T *buf, size_t n)
+{
+ bool b;
+ setArray(&b, z, buf, n);
+ if (!b) throw cybozu::Exception("gmp:setArray");
+}
+template<class T>
+void getArray(T *buf, size_t maxSize, const mpz_class& x)
+{
+ bool b;
+ getArray(&b, buf, maxSize, x);
+ if (!b) throw cybozu::Exception("gmp:getArray");
+}
+inline bool isPrime(const mpz_class& x)
+{
+ bool b;
+ bool ret = isPrime(&b, x);
+ if (!b) throw cybozu::Exception("gmp:isPrime");
+ return ret;
+}
+inline void getRand(mpz_class& z, size_t bitSize, fp::RandGen rg = fp::RandGen())
+{
+ bool b;
+ getRand(&b, z, bitSize, rg);
+ if (!b) throw cybozu::Exception("gmp:getRand");
+}
+inline void getRandPrime(mpz_class& z, size_t bitSize, fp::RandGen rg = fp::RandGen(), bool setSecondBit = false, bool mustBe3mod4 = false)
+{
+ bool b;
+ getRandPrime(&b, z, bitSize, rg, setSecondBit, mustBe3mod4);
+ if (!b) throw cybozu::Exception("gmp:getRandPrime");
+}
+#endif
+
+
} // mcl::gmp
/*
@@ -591,12 +642,19 @@ public:
s = 0;
q_add_1_div_2 = 0;
}
- void set(const mpz_class& _p)
+ void set(bool *pb, const mpz_class& _p)
{
p = _p;
- if (p <= 2) throw cybozu::Exception("SquareRoot:bad p") << p;
- isPrime = gmp::isPrime(p);
- if (!isPrime) return; // don't throw until get() is called
+ if (p <= 2) {
+ *pb = false;
+ return;
+ }
+ isPrime = gmp::isPrime(pb, p);
+ if (!*pb) return;
+ if (!isPrime) {
+ *pb = false;
+ return;
+ }
g = gmp::getQuadraticNonResidue(p);
// p - 1 = 2^r q, q is odd
r = 0;
@@ -607,13 +665,18 @@ public:
}
gmp::powMod(s, g, q, p);
q_add_1_div_2 = (q + 1) / 2;
+ *pb = true;
}
/*
solve x^2 = a mod p
*/
- bool get(mpz_class& x, const mpz_class& a) const
+ bool get(bool *pb, mpz_class& x, const mpz_class& a) const
{
- if (!isPrime) throw cybozu::Exception("SquareRoot:get:not prime") << p;
+ if (!isPrime) {
+ *pb = false;
+ return false;
+ }
+ *pb = true;
if (a == 0) {
x = 0;
return true;
@@ -653,7 +716,7 @@ public:
template<class Fp>
bool get(Fp& x, const Fp& a) const
{
- if (Fp::getOp().mp != p) throw cybozu::Exception("bad Fp") << Fp::getOp().mp << p;
+ assert(Fp::getOp().mp == p);
if (a == 0) {
x = 0;
return true;
@@ -691,6 +754,21 @@ public:
}
return true;
}
+#ifndef CYBOZU_DONT_USE_EXCEPTION
+ void set(const mpz_class& _p)
+ {
+ bool b;
+ set(&b, _p);
+ if (!b) throw cybozu::Exception("gmp:SquareRoot:set");
+ }
+ bool get(mpz_class& x, const mpz_class& a) const
+ {
+ bool b;
+ bool ret = get(&b, x, a);
+ if (!b) throw cybozu::Exception("gmp:SquareRoot:get:not prime");
+ return ret;
+ }
+#endif
};
} // mcl
diff --git a/include/mcl/operator.hpp b/include/mcl/operator.hpp
index 7198929..29a66f5 100644
--- a/include/mcl/operator.hpp
+++ b/include/mcl/operator.hpp
@@ -136,6 +136,7 @@ struct Serializable : public E {
buf[n] = '\0';
return n;
}
+#ifndef CYBOZU_DONT_USE_STRING
void setStr(const std::string& str, int ioMode = 0)
{
cybozu::StringInputStream is(str);
@@ -153,6 +154,7 @@ struct Serializable : public E {
getStr(str, ioMode);
return str;
}
+#endif
// return written bytes
size_t serialize(void *buf, size_t maxBufSize, int ioMode = IoSerialize) const
{
diff --git a/include/mcl/vint.hpp b/include/mcl/vint.hpp
index a7e9f7a..8fa24b9 100644
--- a/include/mcl/vint.hpp
+++ b/include/mcl/vint.hpp
@@ -618,6 +618,7 @@ void divNM(T *q, size_t qn, T *r, const T *x, size_t xn, const T *y, size_t yn)
}
}
+#ifndef MCL_VINT_FIXED_BUFFER
template<class T>
class Buffer {
size_t allocSize_;
@@ -651,21 +652,6 @@ public:
std::swap(allocSize_, rhs.allocSize_);
std::swap(ptr_, rhs.ptr_);
}
-#if 0
-#if CYBOZU_CPP_VERSION >= CYBOZU_CPP_VERSION_CPP11
- Buffer(Buffer&& rhs) noexcept
- : allocSize_(0)
- , ptr_(0)
- {
- swap(rhs);
- }
- Buffer& operator=(Buffer&& rhs) noexcept
- {
- swap(rhs);
- return *this;
- }
-#endif
-#endif
void clear()
{
allocSize_ = 0;
@@ -676,17 +662,29 @@ public:
/*
@note extended buffer may be not cleared
*/
- void alloc(size_t n)
+ void alloc(bool *pb, size_t n)
{
if (n > allocSize_) {
T *p = (T*)malloc(n * sizeof(T));
- if (p == 0) throw cybozu::Exception("Buffer:alloc:malloc:") << n;
+ if (p == 0) {
+ *pb = false;
+ return;
+ }
copyN(p, ptr_, allocSize_);
free(ptr_);
ptr_ = p;
allocSize_ = n;
}
+ *pb = true;
}
+#ifndef CYBOZU_DONT_USE_EXCEPTION
+ void alloc(size_t n)
+ {
+ bool b;
+ alloc(&b, n);
+ if (!b) throw cybozu::Exception("Buffer:alloc");
+ }
+#endif
/*
*this = rhs
rhs may be destroyed
@@ -694,6 +692,7 @@ public:
const T& operator[](size_t n) const { return ptr_[n]; }
T& operator[](size_t n) { return ptr_[n]; }
};
+#endif
template<class T, size_t BitLen>
class FixedBuffer {
@@ -721,11 +720,23 @@ public:
return *this;
}
void clear() { size_ = 0; }
- void alloc(size_t n)
+ void alloc(bool *pb, size_t n)
{
- verify(n);
+ if (n > N) {
+ *pb = false;
+ return;
+ }
size_ = n;
+ *pb = true;
}
+#ifndef CYBOZU_DONT_USE_EXCEPTION
+ void alloc(size_t n)
+ {
+ bool b;
+ alloc(&b, n);
+ if (!b) throw cybozu::Exception("FixedBuffer:alloc");
+ }
+#endif
void swap(FixedBuffer& rhs)
{
FixedBuffer *p1 = this;
@@ -745,9 +756,8 @@ public:
// to avoid warning of gcc
void verify(size_t n) const
{
- if (n > N) {
- throw cybozu::Exception("verify:too large size") << n << (int)N;
- }
+ assert(n <= N);
+ (void)n;
}
const T& operator[](size_t n) const { verify(n); return v_[n]; }
T& operator[](size_t n) { verify(n); return v_[n]; }
@@ -946,6 +956,19 @@ private:
}
r.trim(yn);
}
+ /*
+ @param x [inout] x <- d
+ @retval s for x = 2^s d where d is odd
+ */
+ static uint32_t countTrailingZero(VintT& x)
+ {
+ uint32_t s = 0;
+ while (x.isEven()) {
+ x >>= 1;
+ s++;
+ }
+ return s;
+ }
struct MulMod {
const VintT *pm;
void operator()(VintT& z, const VintT& x, const VintT& y) const
@@ -973,11 +996,6 @@ public:
{
*this = x;
}
- explicit VintT(const std::string& str)
- : size_(0)
- {
- setStr(str);
- }
VintT(const VintT& rhs)
: buf_(rhs.buf_)
, size_(rhs.size_)
@@ -1057,12 +1075,13 @@ public:
/*
set [0, max) randomly
*/
- void setRand(const VintT& max)
+ void setRand(bool *pb, const VintT& max)
{
+ assert(max > 0);
fp::RandGen& rg = fp::RandGen::get();
- if (max <= 0) throw cybozu::Exception("Vint:setRand:bad value") << max;
size_t n = max.size();
- buf_.alloc(n);
+ buf_.alloc(pb, n);
+ if (!*pb) return;
rg.read(&buf_[0], n * sizeof(buf_[0]));
trim(n);
*this %= max;
@@ -1073,12 +1092,16 @@ public:
buf_[size, maxSize) with zero
@note assume little endian system
*/
- void getArray(Unit *x, size_t maxSize) const
+ void getArray(bool *pb, Unit *x, size_t maxSize) const
{
size_t n = size();
- if (n > maxSize) throw cybozu::Exception("Vint:getArray:small maxSize") << maxSize << n;
+ if (n > maxSize) {
+ *pb = false;
+ return;
+ }
vint::copyN(x, &buf_[0], n);
vint::clearN(x + n, maxSize - n);
+ *pb = true;
}
void clear() { *this = 0; }
template<class OutputStream>
@@ -1093,13 +1116,6 @@ public:
}
cybozu::write(pb, os, buf + sizeof(buf) - n, n);
}
- template<class OutputStream>
- void save(OutputStream& os, int base = 10) const
- {
- bool b;
- save(&b, os, base);
- if (!b) throw cybozu::Exception("Vint:save");
- }
/*
set buf with string terminated by '\0'
return strlen(buf) if success else 0
@@ -1114,13 +1130,6 @@ public:
buf[n] = '\0';
return n;
}
- std::string getStr(int base = 10) const
- {
- std::string s;
- cybozu::StringOutputStream os(s);
- save(os, base);
- return s;
- }
/*
return bitSize(abs(*this))
@note return 1 if zero
@@ -1138,7 +1147,7 @@ public:
{
size_t q = i / unitBitSize;
size_t r = i % unitBitSize;
- if (q > size()) throw cybozu::Exception("Vint:testBit:large i") << q << size();
+ assert(q <= size());
Unit mask = Unit(1) << r;
return (buf_[q] & mask) != 0;
}
@@ -1146,7 +1155,7 @@ public:
{
size_t q = i / unitBitSize;
size_t r = i % unitBitSize;
- if (q > size()) throw cybozu::Exception("Vint:setBit:large i") << q << size();
+ assert(q <= size());
buf_.alloc(q + 1);
Unit mask = Unit(1) << r;
if (v) {
@@ -1162,23 +1171,19 @@ public:
"0b..." => base = 2
otherwise => base = 10
*/
- void setStr(bool *pb, const char *str, size_t strSize, int base = 0)
+ void setStr(bool *pb, const char *str, int base = 0)
{
const size_t maxN = MCL_MAX_BIT_SIZE / (sizeof(MCL_SIZEOF_UNIT) * 8);
- buf_.alloc(maxN);
+ buf_.alloc(pb, maxN);
+ if (!*pb) return;
*pb = false;
isNeg_ = false;
- size_t n = fp::strToArray(&isNeg_, &buf_[0], maxN, str, strSize, base);
+ size_t len = strlen(str);
+ size_t n = fp::strToArray(&isNeg_, &buf_[0], maxN, str, len, base);
if (n == 0) return;
trim(n);
*pb = true;
}
- void setStr(std::string str, int base = 0)
- {
- bool b;
- setStr(&b, str.c_str(), str.size(), base);
- if (!b) throw cybozu::Exception("Vint:setStr") << str;
- }
static int compare(const VintT& x, const VintT& y)
{
if (x.isNeg_ ^ y.isNeg_) {
@@ -1372,10 +1377,6 @@ public:
usub(r, yy.buf_, yy.size(), r.buf_, r.size());
}
}
- inline friend std::ostream& operator<<(std::ostream& os, const VintT& x)
- {
- return os << x.getStr(os.flags() & std::ios_base::hex ? 16 : 10);
- }
template<class InputStream>
void load(bool *pb, InputStream& is, int ioMode)
{
@@ -1391,18 +1392,6 @@ public:
trim(n);
*pb = true;
}
- template<class InputStream>
- void load(InputStream& is, int ioMode = 0)
- {
- bool b;
- load(&b, is, ioMode);
- if (!b) throw cybozu::Exception("Vint:load");
- }
- inline friend std::istream& operator>>(std::istream& is, VintT& x)
- {
- x.load(is);
- return is;
- }
// logical left shift (copy sign)
static void shl(VintT& y, const VintT& x, size_t shiftBit)
{
@@ -1575,26 +1564,12 @@ public:
b -= a * q;
}
}
-private:
- /*
- @param x [inout] x <- d
- @retval s for x = 2^s d where d is odd
- */
- static uint32_t countTrailingZero(VintT& x)
- {
- uint32_t s = 0;
- while (x.isEven()) {
- x >>= 1;
- s++;
- }
- return s;
- }
-public:
/*
Miller-Rabin
*/
- static bool isPrime(const VintT& n, int tryNum = 32)
+ static bool isPrime(bool *pb, const VintT& n, int tryNum = 32)
{
+ *pb = true;
if (n <= 1) return false;
if (n == 2 || n == 3) return true;
if (n.isEven()) return false;
@@ -1604,7 +1579,8 @@ public:
// n - 1 = 2^r d
VintT a, x;
for (int i = 0; i < tryNum; i++) {
- a.setRand(n - 3);
+ a.setRand(pb, n - 3);
+ if (!*pb) return false;
a += 2; // a in [2, n - 2]
powMod(x, a, d, n);
if (x == 1 || x == nm1) {
@@ -1621,9 +1597,9 @@ public:
}
return true;
}
- bool isPrime(int tryNum = 32) const
+ bool isPrime(bool *pb, int tryNum = 32) const
{
- return isPrime(*this, tryNum);
+ return isPrime(pb, *this, tryNum);
}
static void gcd(VintT& z, VintT x, VintT y)
{
@@ -1665,7 +1641,7 @@ public:
*/
static int jacobi(VintT m, VintT n)
{
- if (n.isEven()) throw cybozu::Exception();
+ assert(n.isOdd());
if (n == 1) return 1;
if (m < 0 || m > n) {
quotRem(0, m, m, n); // m = m mod n
@@ -1693,6 +1669,81 @@ public:
}
return j;
}
+#ifndef CYBOZU_DONT_USE_STRING
+ explicit VintT(const std::string& str)
+ : size_(0)
+ {
+ setStr(str);
+ }
+ void getStr(std::string& s, int base = 10) const
+ {
+ cybozu::StringOutputStream os(s);
+ save(os, base);
+ }
+ std::string getStr(int base = 10) const
+ {
+ std::string s;
+ getStr(s, base);
+ return s;
+ }
+ inline friend std::ostream& operator<<(std::ostream& os, const VintT& x)
+ {
+ return os << x.getStr(os.flags() & std::ios_base::hex ? 16 : 10);
+ }
+ inline friend std::istream& operator>>(std::istream& is, VintT& x)
+ {
+ x.load(is);
+ return is;
+ }
+#endif
+#ifndef CYBOZU_DONT_USE_EXCEPTION
+ void setStr(const std::string& str, int base = 0)
+ {
+ bool b;
+ setStr(&b, str.c_str(), base);
+ if (!b) throw cybozu::Exception("Vint:setStr") << str;
+ }
+ void setRand(const VintT& max)
+ {
+ bool b;
+ setRand(&b, max);
+ if (!b) throw cybozu::Exception("Vint:setRand");
+ }
+ void getArray(Unit *x, size_t maxSize) const
+ {
+ bool b;
+ getArray(&b, x, maxSize);
+ if (!b) throw cybozu::Exception("Vint:getArray");
+ }
+ template<class InputStream>
+ void load(InputStream& is, int ioMode = 0)
+ {
+ bool b;
+ load(&b, is, ioMode);
+ if (!b) throw cybozu::Exception("Vint:load");
+ }
+ template<class OutputStream>
+ void save(OutputStream& os, int base = 10) const
+ {
+ bool b;
+ save(&b, os, base);
+ if (!b) throw cybozu::Exception("Vint:save");
+ }
+ static bool isPrime(const VintT& n, int tryNum = 32)
+ {
+ bool b;
+ bool ret = isPrime(&b, n, tryNum);
+ if (!b) throw cybozu::Exception("Vint:isPrime");
+ return ret;
+ }
+ bool isPrime(int tryNum = 32) const
+ {
+ bool b;
+ bool ret = isPrime(&b, *this, tryNum);
+ if (!b) throw cybozu::Exception("Vint:isPrime");
+ return ret;
+ }
+#endif
VintT& operator++() { adds1(*this, *this, 1); return *this; }
VintT& operator--() { subs1(*this, *this, 1); return *this; }
VintT operator++(int) { VintT c = *this; adds1(*this, *this, 1); return c; }
diff --git a/src/fp.cpp b/src/fp.cpp
index e45217b..ba9e484 100644
--- a/src/fp.cpp
+++ b/src/fp.cpp
@@ -77,7 +77,7 @@ const char *ModeToStr(Mode mode)
}
}
-Mode StrToMode(const std::string& s)
+Mode StrToMode(const char *s)
{
static const struct {
const char *s;
@@ -91,7 +91,7 @@ Mode StrToMode(const std::string& s)
{ "xbyak", FP_XBYAK },
};
for (size_t i = 0; i < CYBOZU_NUM_OF_ARRAY(tbl); i++) {
- if (s == tbl[i].s) return tbl[i].mode;
+ if (strcmp(s, tbl[i].s) == 0) return tbl[i].mode;
}
return FP_AUTO;
}
@@ -176,19 +176,24 @@ static inline void set_mpz_t(mpz_t& z, const Unit* p, int n)
static inline void fp_invOpC(Unit *y, const Unit *x, const Op& op)
{
const int N = (int)op.N;
+ bool b;
#ifdef MCL_USE_VINT
Vint vx, vy, vp;
- vx.setArray(x, N);
- vp.setArray(op.p, N);
+ vx.setArray(&b, x, N);
+ assert(b);
+ vp.setArray(&b, op.p, N);
+ assert(b);
Vint::invMod(vy, vx, vp);
- vy.getArray(y, N);
+ vy.getArray(&b, y, N);
+ assert(b);
#else
mpz_class my;
mpz_t mx, mp;
set_mpz_t(mx, x, N);
set_mpz_t(mp, op.p, N);
mpz_invert(my.get_mpz_t(), mx, mp);
- gmp::getArray(y, N, my);
+ gmp::getArray(&b, y, N, my);
+ assert(b);
#endif
}
@@ -323,20 +328,24 @@ static void initInvTbl(Op& op)
}
#endif
-static void initForMont(Op& op, const Unit *p, Mode mode)
+static bool initForMont(Op& op, const Unit *p, Mode mode)
{
const size_t N = op.N;
+ bool b;
{
mpz_class t = 1, R;
- gmp::getArray(op.one, N, t);
+ gmp::getArray(&b, op.one, N, t);
+ if (!b) return false;
R = (t << (N * UnitBitSize)) % op.mp;
t = (R * R) % op.mp;
- gmp::getArray(op.R2, N, t);
+ gmp::getArray(&b, op.R2, N, t);
+ if (!b) return false;
t = (t * R) % op.mp;
- gmp::getArray(op.R3, N, t);
+ gmp::getArray(&b, op.R3, N, t);
+ if (!b) return false;
}
op.rp = getMontgomeryCoeff(p[0]);
- if (mode != FP_XBYAK) return;
+ if (mode != FP_XBYAK) return true;
#ifdef MCL_USE_XBYAK
if (op.fg == 0) op.fg = Op::createFpGenerator();
op.fg->init(op);
@@ -346,6 +355,7 @@ static void initForMont(Op& op, const Unit *p, Mode mode)
initInvTbl(op);
}
#endif
+ return true;
}
bool Op::init(const mpz_class& _p, size_t maxBitSize, Mode mode, size_t mclMaxBitSize)
@@ -359,11 +369,13 @@ bool Op::init(const mpz_class& _p, size_t maxBitSize, Mode mode, size_t mclMaxBi
if (maxBitSize > MCL_MAX_BIT_SIZE) return false;
if (_p <= 0) return false;
clear();
+ bool b;
{
const size_t maxN = (maxBitSize + fp::UnitBitSize - 1) / fp::UnitBitSize;
N = gmp::getUnitSize(_p);
if (N > maxN) return false;
- gmp::getArray(p, N, _p);
+ gmp::getArray(&b, p, N, _p);
+ if (!b) return false;
mp = _p;
}
bitSize = gmp::getBitSize(mp);
@@ -417,10 +429,16 @@ bool Op::init(const mpz_class& _p, size_t maxBitSize, Mode mode, size_t mclMaxBi
}
#endif
#if defined(MCL_USE_VINT) && MCL_SIZEOF_UNIT == 8
- if (mp == mpz_class("0xfffffffffffffffffffffffffffffffffffffffffffffffffffffffefffffc2f")) {
- primeMode = PM_SECP256K1;
- isMont = false;
- isFastMod = true;
+ {
+ const char *secp256k1Str = "0xfffffffffffffffffffffffffffffffffffffffffffffffffffffffefffffc2f";
+ bool b;
+ mpz_class secp256k1;
+ gmp::setStr(&b, secp256k1, secp256k1Str);
+ if (b && mp == secp256k1) {
+ primeMode = PM_SECP256K1;
+ isMont = false;
+ isFastMod = true;
+ }
}
#endif
switch (N) {
@@ -477,8 +495,9 @@ bool Op::init(const mpz_class& _p, size_t maxBitSize, Mode mode, size_t mclMaxBi
fpDbl_mod = &mcl::vint::mcl_fpDbl_mod_SECP256K1;
}
#endif
- fp::initForMont(*this, p, mode);
- sq.set(mp);
+ if (!fp::initForMont(*this, p, mode)) return false;
+ sq.set(&b, mp);
+ if (!b) return false;
if (N * UnitBitSize <= 256) {
hash = sha256;
} else {
diff --git a/test/fp_test.cpp b/test/fp_test.cpp
index f81e0c8..71d6986 100644
--- a/test/fp_test.cpp
+++ b/test/fp_test.cpp
@@ -351,6 +351,7 @@ void compareTest()
void moduloTest(const char *pStr)
{
+std::cout << std::hex;
std::string str;
Fp::getModulo(str);
CYBOZU_TEST_EQUAL(str, mcl::gmp::getStr(mpz_class(pStr)));
diff --git a/test/gmp_test.cpp b/test/gmp_test.cpp
index 22c80dd..1fe9d4e 100644
--- a/test/gmp_test.cpp
+++ b/test/gmp_test.cpp
@@ -21,6 +21,45 @@ CYBOZU_TEST_AUTO(testBit)
}
}
+CYBOZU_TEST_AUTO(getStr)
+{
+ const struct {
+ int x;
+ const char *dec;
+ const char *hex;
+ } tbl[] = {
+ { 0, "0", "0" },
+ { 1, "1", "1" },
+ { 10, "10", "a" },
+ { 16, "16", "10" },
+ { 123456789, "123456789", "75bcd15" },
+ { -1, "-1", "-1" },
+ { -10, "-10", "-a" },
+ { -16, "-16", "-10" },
+ { -100000000, "-100000000", "-5f5e100" },
+ { -987654321, "-987654321", "-3ade68b1" },
+ { -2147483647, "-2147483647", "-7fffffff" },
+ };
+ for (size_t i = 0; i < CYBOZU_NUM_OF_ARRAY(tbl); i++) {
+ mpz_class x = tbl[i].x;
+ char buf[32];
+ size_t n, len;
+ len = strlen(tbl[i].dec);
+ n = mcl::gmp::getStr(buf, len, x, 10);
+ CYBOZU_TEST_EQUAL(n, 0);
+ n = mcl::gmp::getStr(buf, len + 1, x, 10);
+ CYBOZU_TEST_EQUAL(n, len);
+ CYBOZU_TEST_EQUAL_ARRAY(buf, tbl[i].dec, n);
+
+ len = strlen(tbl[i].hex);
+ n = mcl::gmp::getStr(buf, len, x, 16);
+ CYBOZU_TEST_EQUAL(n, 0);
+ n = mcl::gmp::getStr(buf, len + 1, x, 16);
+ CYBOZU_TEST_EQUAL(n, len);
+ CYBOZU_TEST_EQUAL_ARRAY(buf, tbl[i].hex, n);
+ }
+}
+
CYBOZU_TEST_AUTO(getRandPrime)
{
for (int i = 0; i < 10; i++) {