aboutsummaryrefslogtreecommitdiffstats
path: root/include/mcl/ec.hpp
blob: 8b70b70dc99ab1887833729a88c9474d897c2477 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
#pragma once
/**
    @file
    @brief elliptic curve
    @author MITSUNARI Shigeo(@herumi)
    @license modified new BSD license
    http://opensource.org/licenses/BSD-3-Clause
*/
#include <sstream>
#include <cybozu/exception.hpp>
#include <cybozu/bitvector.hpp>
#include <mcl/operator.hpp>
#include <mcl/power.hpp>
#include <mcl/gmp_util.hpp>

namespace mcl {

#define MCL_EC_USE_AFFINE 0
#define MCL_EC_USE_PROJ 1
#define MCL_EC_USE_JACOBI 2

//#define MCL_EC_COORD MCL_EC_USE_JACOBI
//#define MCL_EC_COORD MCL_EC_USE_PROJ
#ifndef MCL_EC_COORD
    #define MCL_EC_COORD MCL_EC_USE_PROJ
#endif
/*
    elliptic curve
    y^2 = x^3 + ax + b (affine)
    y^2 = x^3 + az^4 + bz^6 (Jacobi) x = X/Z^2, y = Y/Z^3
*/
template<class _Fp>
class EcT : public ope::addsub<EcT<_Fp>,
    ope::comparable<EcT<_Fp>,
    ope::hasNegative<EcT<_Fp> > > > {
    enum {
        zero,
        minus3,
        generic
    };
public:
    typedef _Fp Fp;
    typedef typename Fp::BlockType BlockType;
#if MCL_EC_COORD == MCL_EC_USE_AFFINE
    Fp x, y;
    bool inf_;
#else
    mutable Fp x, y, z;
#endif
    static Fp a_;
    static Fp b_;
    static int specialA_;
    static bool compressedExpression_;
#if MCL_EC_COORD == MCL_EC_USE_AFFINE
    EcT() : inf_(true) {}
#else
    EcT() { z.clear(); }
#endif
    EcT(const Fp& _x, const Fp& _y)
    {
        set(_x, _y);
    }
    void normalize() const
    {
#if MCL_EC_COORD == MCL_EC_USE_JACOBI
        if (isZero() || z == 1) return;
        Fp rz, rz2;
        Fp::inv(rz, z);
        rz2 = rz * rz;
        x *= rz2;
        y *= rz2 * rz;
        z = 1;
#elif MCL_EC_COORD == MCL_EC_USE_PROJ
        if (isZero() || z == 1) return;
        Fp rz;
        Fp::inv(rz, z);
        x *= rz;
        y *= rz;
        z = 1;
#endif
    }

    static inline void setParam(const std::string& astr, const std::string& bstr)
    {
        a_.fromStr(astr);
        b_.fromStr(bstr);
        if (a_.isZero()) {
            specialA_ = zero;
        } else if (a_ == -3) {
            specialA_ = minus3;
        } else {
            specialA_ = generic;
        }
    }
    static inline bool isValid(const Fp& _x, const Fp& _y)
    {
        return _y * _y == (_x * _x + a_) * _x + b_;
    }
    void set(const Fp& _x, const Fp& _y, bool verify = true)
    {
        if (verify && !isValid(_x, _y)) throw cybozu::Exception("ec:EcT:set") << _x << _y;
        x = _x; y = _y;
#if MCL_EC_COORD == MCL_EC_USE_AFFINE
        inf_ = false;
#else
        z = 1;
#endif
    }
    void clear()
    {
#if MCL_EC_COORD == MCL_EC_USE_AFFINE
        inf_ = true;
#else
        z = 0;
#endif
        x.clear();
        y.clear();
    }

    static inline void dbl(EcT& R, const EcT& P, bool verifyInf = true)
    {
        if (verifyInf) {
            if (P.isZero()) {
                R.clear(); return;
            }
        }
#if MCL_EC_COORD == MCL_EC_USE_JACOBI
        Fp S, M, t, y2;
        Fp::square(y2, P.y);
        Fp::mul(S, P.x, y2);
        S += S;
        S += S;
        Fp::square(M, P.x);
        switch (specialA_) {
        case zero:
            Fp::add(t, M, M);
            M += t;
            break;
        case minus3:
            Fp::square(t, P.z);
            Fp::square(t, t);
            M -= t;
            Fp::add(t, M, M);
            M += t;
            break;
        case generic:
        default:
            Fp::square(t, P.z);
            Fp::square(t, t);
            t *= a_;
            t += M;
            M += M;
            M += t;
            break;
        }
        Fp::square(R.x, M);
        R.x -= S;
        R.x -= S;
        Fp::mul(R.z, P.y, P.z);
        R.z += R.z;
        Fp::square(y2, y2);
        y2 += y2;
        y2 += y2;
        y2 += y2;
        Fp::sub(R.y, S, R.x);
        R.y *= M;
        R.y -= y2;
#elif MCL_EC_COORD == MCL_EC_USE_PROJ
        Fp w, t, h;
        switch (specialA_) {
        case zero:
            Fp::square(w, P.x);
            Fp::add(t, w, w);
            w += t;
            break;
        case minus3:
            Fp::square(w, P.x);
            Fp::square(t, P.z);
            w -= t;
            Fp::add(t, w, w);
            w += t;
            break;
        case generic:
        default:
            Fp::square(w, P.z);
            w *= a_;
            Fp::square(t, P.x);
            w += t;
            w += t;
            w += t; // w = a z^2 + 3x^2
            break;
        }
        Fp::mul(R.z, P.y, P.z); // s = yz
        Fp::mul(t, R.z, P.x);
        t *= P.y; // xys
        t += t;
        t += t; // 4(xys) ; 4B
        Fp::square(h, w);
        h -= t;
        h -= t; // w^2 - 8B
        Fp::mul(R.x, h, R.z);
        t -= h; // h is free
        t *= w;
        Fp::square(w, P.y);
        R.x += R.x;
        R.z += R.z;
        Fp::square(h, R.z);
        w *= h;
        R.z *= h;
        Fp::sub(R.y, t, w);
        R.y -= w;
#else
        Fp t, s;
        Fp::square(t, P.x);
        Fp::add(s, t, t);
        t += s;
        t += a_;
        Fp::add(s, P.y, P.y);
        t /= s;
        Fp::square(s, t);
        s -= P.x;
        Fp x3;
        Fp::sub(x3, s, P.x);
        Fp::sub(s, P.x, x3);
        s *= t;
        Fp::sub(R.y, s, P.y);
        R.x = x3;
        R.inf_ = false;
#endif
    }
    static inline void add(EcT& R, const EcT& P, const EcT& Q)
    {
        if (P.isZero()) { R = Q; return; }
        if (Q.isZero()) { R = P; return; }
#if MCL_EC_COORD == MCL_EC_USE_JACOBI
        Fp r, U1, S1, H, H3;
        Fp::square(r, P.z);
        Fp::square(S1, Q.z);
        Fp::mul(U1, P.x, S1);
        Fp::mul(H, Q.x, r);
        H -= U1;
        r *= P.z;
        S1 *= Q.z;
        S1 *= P.y;
        Fp::mul(r, Q.y, r);
        r -= S1;
        if (H.isZero()) {
            if (r.isZero()) {
                dbl(R, P, false);
            } else {
                R.clear();
            }
            return;
        }
        Fp::mul(R.z, P.z, Q.z);
        R.z *= H;
        Fp::square(H3, H); // H^2
        Fp::square(R.y, r); // r^2
        U1 *= H3; // U1 H^2
        H3 *= H; // H^3
        R.y -= U1;
        R.y -= U1;
        Fp::sub(R.x, R.y, H3);
        U1 -= R.x;
        U1 *= r;
        H3 *= S1;
        Fp::sub(R.y, U1, H3);
#elif MCL_EC_COORD == MCL_EC_USE_PROJ
        Fp r, PyQz, v, A, vv;
        Fp::mul(r, P.x, Q.z);
        Fp::mul(PyQz, P.y, Q.z);
        Fp::mul(A, Q.y, P.z);
        Fp::mul(v, Q.x, P.z);
        v -= r;
        if (v.isZero()) {
            Fp::add(vv, A, PyQz);
            if (vv.isZero()) {
                R.clear();
            } else {
                dbl(R, P, false);
            }
            return;
        }
        Fp::sub(R.y, A, PyQz);
        Fp::square(A, R.y);
        Fp::square(vv, v);
        r *= vv;
        vv *= v;
        Fp::mul(R.z, P.z, Q.z);
        A *= R.z;
        R.z *= vv;
        A -= vv;
        vv *= PyQz;
        A -= r;
        A -= r;
        Fp::mul(R.x, v, A);
        r -= A;
        R.y *= r;
        R.y -= vv;
#else
        Fp t;
        Fp::neg(t, Q.y);
        if (P.y == t) { R.clear(); return; }
        Fp::sub(t, Q.x, P.x);
        if (t.isZero()) {
            dbl(R, P, false);
            return;
        }
        Fp s;
        Fp::sub(s, Q.y, P.y);
        Fp::div(t, s, t);
        R.inf_ = false;
        Fp x3;
        Fp::square(x3, t);
        x3 -= P.x;
        x3 -= Q.x;
        Fp::sub(s, P.x, x3);
        s *= t;
        Fp::sub(R.y, s, P.y);
        R.x = x3;
#endif
    }
    static inline void sub(EcT& R, const EcT& P, const EcT& Q)
    {
#if 0
        if (P.inf_) { neg(R, Q); return; }
        if (Q.inf_) { R = P; return; }
        if (P.y == Q.y) { R.clear(); return; }
        Fp t;
        Fp::sub(t, Q.x, P.x);
        if (t.isZero()) {
            dbl(R, P, false);
            return;
        }
        Fp s;
        Fp::add(s, Q.y, P.y);
        Fp::neg(s, s);
        Fp::div(t, s, t);
        R.inf_ = false;
        Fp x3;
        Fp::mul(x3, t, t);
        x3 -= P.x;
        x3 -= Q.x;
        Fp::sub(s, P.x, x3);
        s *= t;
        Fp::sub(R.y, s, P.y);
        R.x = x3;
#else
        EcT nQ;
        neg(nQ, Q);
        add(R, P, nQ);
#endif
    }
    static inline void neg(EcT& R, const EcT& P)
    {
        if (P.isZero()) {
            R.clear();
            return;
        }
#if MCL_EC_COORD == MCL_EC_USE_AFFINE
        R.inf_ = false;
        R.x = P.x;
        Fp::neg(R.y, P.y);
#else
        R.x = P.x;
        Fp::neg(R.y, P.y);
        R.z = P.z;
#endif
    }
    template<class N>
    static inline void power(EcT& z, const EcT& x, const N& y)
    {
        power_impl::power(z, x, y);
    }
    /*
        0 <= P for any P
        (Px, Py) <= (P'x, P'y) iff Px < P'x or Px == P'x and Py <= P'y
    */
    static inline int compare(const EcT& P, const EcT& Q)
    {
        P.normalize();
        Q.normalize();
        if (P.isZero()) {
            if (Q.isZero()) return 0;
            return -1;
        }
        if (Q.isZero()) return 1;
        int c = _Fp::compare(P.x, Q.x);
        if (c > 0) return 1;
        if (c < 0) return -1;
        return _Fp::compare(P.y, Q.y);
    }
    bool isZero() const
    {
#if MCL_EC_COORD == MCL_EC_USE_AFFINE
        return inf_;
#else
        return z.isZero();
#endif
    }
    friend inline std::ostream& operator<<(std::ostream& os, const EcT& self)
    {
        if (self.isZero()) {
            return os << '0';
        } else {
            self.normalize();
            os << self.x.toStr(16) << '_';
            if (compressedExpression_) {
                return os << Fp::isYodd(self.y);
            } else {
                return os << self.y.toStr(16);
            }
        }
    }
    friend inline std::istream& operator>>(std::istream& is, EcT& self)
    {
        std::string str;
        is >> str;
        if (str == "0") {
            self.clear();
        } else {
#if MCL_EC_COORD == MCL_EC_USE_AFFINE
            self.inf_ = false;
#else
            self.z = 1;
#endif
            size_t pos = str.find('_');
            if (pos == std::string::npos) throw cybozu::Exception("EcT:operator>>:bad format") << str;
            str[pos] = '\0';
            self.x.fromStr(&str[0], 16);
            if (compressedExpression_) {
                const char c = str[pos + 1];
                if ((c == '0' || c == '1') && str.size() == pos + 2) {
                    getYfromX(self.y, self.x, c == '1');
                } else {
                    str[pos] = '_';
                    throw cybozu::Exception("EcT:operator>>:bad y") << str;
                }
            } else {
                self.y.fromStr(&str[pos + 1], 16);
            }
        }
        return is;
    }
    static inline void setCompressedExpression(bool compressedExpression)
    {
        compressedExpression_ = compressedExpression;
    }
    /*
        append to bv(not clear bv)
    */
    void appendToBitVec(cybozu::BitVector& bv) const
    {
#if MCL_EC_COORD == MCL_EC_USE_AFFINE
        #error "not implemented"
#else
        normalize();
        const size_t bitLen = _Fp::getModBitLen();
        /*
                elem |x|y|z|
                size  n n 1 if not compressed
                size  n 1 1 if compressed
        */
        const size_t maxBitLen = compressedExpression_ ? (bitLen + 1 + 1) : (bitLen * 2 + 1);
        if (isZero()) {
            bv.resize(bv.size() + maxBitLen);
            return;
        }
        x.appendToBitVec(bv);
        if (compressedExpression_) {
            bv.append(Fp::isYodd(y), 1);
        } else {
            y.appendToBitVec(bv);
        }
        bv.append(1, 1); // z = 1
#endif
    }
    void fromBitVec(const cybozu::BitVector& bv)
    {
#if MCL_EC_COORD == MCL_EC_USE_AFFINE
        #error "not implemented"
#else
        const size_t bitLen = _Fp::getModBitLen();
        const size_t maxBitLen = compressedExpression_ ? (bitLen + 1 + 1) : (bitLen * 2 + 1);
        if (bv.size() != maxBitLen) {
            throw cybozu::Exception("EcT:fromBitVec:bad size") << bv.size() << maxBitLen;
        }
        if (!bv.get(maxBitLen - 1)) { // if z = 0
            clear();
            return;
        }
        cybozu::BitVector t;
        bv.extract(t, 0, bitLen);
        x.fromBitVec(t);
        if (compressedExpression_) {
            bool odd = bv.get(bitLen); // y
            getYfromX(y, x, odd);
        } else {
            bv.extract(t, bitLen, bitLen);
            y.fromBitVec(t);
        }
        z = 1;
#endif
    }
    static inline size_t getBitVecSize()
    {
        const size_t bitLen = _Fp::getModBitLen();
        if (compressedExpression_) {
            return bitLen + 2;
        } else {
            return bitLen * 2 + 1;;
        }
    }
    static inline void getYfromX(Fp& y, const Fp& x, bool isYodd)
    {
        Fp t;
        Fp::square(t, x);
        t += a_;
        t *= x;
        t += b_;
        Fp::squareRoot(y, t);
        if (Fp::isYodd(y) ^ isYodd) {
            Fp::neg(y, y);
        }
    }
};

template<class T>
struct TagMultiGr<EcT<T> > {
    static void square(EcT<T>& z, const EcT<T>& x)
    {
        EcT<T>::dbl(z, x);
    }
    static void mul(EcT<T>& z, const EcT<T>& x, const EcT<T>& y)
    {
        EcT<T>::add(z, x, y);
    }
    static void inv(EcT<T>& z, const EcT<T>& x)
    {
        EcT<T>::neg(z, x);
    }
    static void div(EcT<T>& z, const EcT<T>& x, const EcT<T>& y)
    {
        EcT<T>::sub(z, x, y);
    }
    static void init(EcT<T>& x)
    {
        x.clear();
    }
};

template<class _Fp> _Fp EcT<_Fp>::a_;
template<class _Fp> _Fp EcT<_Fp>::b_;
template<class _Fp> int EcT<_Fp>::specialA_;
template<class _Fp> bool EcT<_Fp>::compressedExpression_;

struct EcParam {
    const char *name;
    const char *p;
    const char *a;
    const char *b;
    const char *gx;
    const char *gy;
    const char *n;
    size_t bitLen; // bit length of p
};

} // mcl

namespace std { CYBOZU_NAMESPACE_TR1_BEGIN
template<class T> struct hash;

template<class _Fp>
struct hash<mcl::EcT<_Fp> > {
    size_t operator()(const mcl::EcT<_Fp>& P) const
    {
        if (P.isZero()) return 0;
        P.normalize();
        uint64_t v = hash<_Fp>()(P.x);
        v = hash<_Fp>()(P.y, v);
        return static_cast<size_t>(v);
    }
};

CYBOZU_NAMESPACE_TR1_END } // std