diff options
Diffstat (limited to 'vendor/github.com/dexon-foundation/mcl/src/fp_generator.hpp')
-rw-r--r-- | vendor/github.com/dexon-foundation/mcl/src/fp_generator.hpp | 1460 |
1 files changed, 985 insertions, 475 deletions
diff --git a/vendor/github.com/dexon-foundation/mcl/src/fp_generator.hpp b/vendor/github.com/dexon-foundation/mcl/src/fp_generator.hpp index 581a0de87..b496bc4d4 100644 --- a/vendor/github.com/dexon-foundation/mcl/src/fp_generator.hpp +++ b/vendor/github.com/dexon-foundation/mcl/src/fp_generator.hpp @@ -8,6 +8,7 @@ */ #if CYBOZU_HOST == CYBOZU_HOST_INTEL #define XBYAK_NO_OP_NAMES +#define XBYAK_DISABLE_AVX512 #include "xbyak/xbyak_util.h" #if MCL_SIZEOF_UNIT == 8 @@ -126,6 +127,71 @@ if (rm.isReg()) { \ namespace fp { +struct Profiler { + FILE *fp_; + const char *suf_; + const uint8_t *prev_; + Profiler() + : fp_(0) + , suf_(0) + , prev_(0) + { + } + void init(const char *suf, const uint8_t *prev) + { +#ifdef __linux__ + close(); + const char *s = getenv("MCL_PERF"); + if (s == 0 || strcmp(s, "1") != 0) return; + fprintf(stderr, "use perf suf=%s\n", suf); + suf_ = suf; + const int pid = getpid(); + char name[128]; + snprintf(name, sizeof(name), "/tmp/perf-%d.map", pid); + fp_ = fopen(name, "wb"); + if (fp_ == 0) throw cybozu::Exception("PerMap") << name; + prev_ = prev; +#else + (void)suf; + (void)prev; +#endif + } + ~Profiler() + { + close(); + } + void close() + { +#ifdef __linux__ + if (fp_ == 0) return; + fclose(fp_); + fp_ = 0; + prev_ = 0; +#endif + } + void set(const uint8_t *p, size_t n, const char *name) const + { +#ifdef __linux__ + if (fp_ == 0) return; + fprintf(fp_, "%llx %zx %s%s\n", (long long)p, n, name, suf_); +#else + (void)p; + (void)n; + (void)name; +#endif + } + void set(const char *name, const uint8_t *cur) + { +#ifdef __linux__ + set(prev_, cur - prev_, name); + prev_ = cur; +#else + (void)name; + (void)cur; +#endif + } +}; + struct FpGenerator : Xbyak::CodeGenerator { typedef Xbyak::RegExp RegExp; typedef Xbyak::Reg64 Reg64; @@ -192,28 +258,17 @@ struct FpGenerator : Xbyak::CodeGenerator { const Reg64& gt8; const Reg64& gt9; const mcl::fp::Op *op_; - Label *pL_; // valid only in init_inner + Label pL_; // pointer to p + // the following labels assume sf(this, 3, 10 | UseRDX) + Label mulPreL; + Label fpDbl_modL; + Label fp_mulL; const uint64_t *p_; uint64_t rp_; int pn_; int FpByte_; bool isFullBit_; - // add/sub without carry. return true if overflow - typedef bool (*bool3op)(uint64_t*, const uint64_t*, const uint64_t*); - - // add/sub with mod -// typedef void (*void3op)(uint64_t*, const uint64_t*, const uint64_t*); - - // mul without carry. return top of z - typedef uint64_t (*uint3opI)(uint64_t*, const uint64_t*, uint64_t); - - // neg - typedef void (*void2op)(uint64_t*, const uint64_t*); - - // preInv - typedef int (*int2op)(uint64_t*, const uint64_t*); - void4u mul_; -// uint3opI mulUnit_; + Profiler prof_; /* @param op [in] ; use op.p, op.N, op.isFullBit @@ -248,34 +303,25 @@ struct FpGenerator : Xbyak::CodeGenerator { , rp_(0) , pn_(0) , FpByte_(0) - , mul_(0) -// , mulUnit_(0) { useMulx_ = cpu_.has(Xbyak::util::Cpu::tBMI2); useAdx_ = cpu_.has(Xbyak::util::Cpu::tADX); } - void init(Op& op) + bool init(Op& op) { + if (!cpu_.has(Xbyak::util::Cpu::tAVX)) return false; reset(); // reset jit code for reuse setProtectModeRW(); // read/write memory init_inner(op); +// printf("code size=%d\n", (int)getSize()); setProtectModeRE(); // set read/exec memory + return true; } private: void init_inner(Op& op) { - // the following labels assume sf(this, 3, 10 | UseRDX) - Label mulPreL; - Label fpDbl_modL; - Label fp_mulL; - Label pL; // label to p_ op_ = &op; - pL_ = &pL; - /* - first 4096-byte is data area - remain is code area - */ - L(pL); + L(pL_); p_ = reinterpret_cast<const uint64_t*>(getCurr()); for (size_t i = 0; i < op.N; i++) { dq(op.p[i]); @@ -285,167 +331,101 @@ private: FpByte_ = int(op.maxN * sizeof(uint64_t)); isFullBit_ = op.isFullBit; // printf("p=%p, pn_=%d, isFullBit_=%d\n", p_, pn_, isFullBit_); - // code from here - setSize(4096); - assert((getCurr<size_t>() & 4095) == 0); - op.fp_addPre = getCurr<u3u>(); - gen_addSubPre(true, pn_); - align(16); - op.fp_subPre = getCurr<u3u>(); - gen_addSubPre(false, pn_); - align(16); - op.fp_sub = getCurr<void4u>(); - op.fp_subA_ = getCurr<void3u>(); - gen_fp_sub(); - align(16); - op.fp_add = getCurr<void4u>(); - op.fp_addA_ = getCurr<void3u>(); - gen_fp_add(); + static char suf[] = "_0"; + prof_.init(suf, getCurr()); + suf[1]++; - align(16); - op.fp_shr1 = getCurr<void2u>(); - gen_shr1(); + op.fp_addPre = gen_addSubPre(true, pn_); + prof_.set("Fp_addPre", getCurr()); - align(16); - op.fp_negA_ = getCurr<void2u>(); - gen_fp_neg(); + op.fp_subPre = gen_addSubPre(false, pn_); + prof_.set("Fp_subPre", getCurr()); - // setup fp_tower - op.fp2_mulNF = 0; - if (pn_ <= 4 || (pn_ == 6 && !isFullBit_)) { - align(16); - op.fpDbl_addA_ = getCurr<void3u>(); - gen_fpDbl_add(); - align(16); - op.fpDbl_subA_ = getCurr<void3u>(); - gen_fpDbl_sub(); - } - if (op.isFullBit) { - op.fpDbl_addPre = 0; - op.fpDbl_subPre = 0; - } else { - align(16); - op.fpDbl_addPreA_ = getCurr<void3u>(); - gen_addSubPre(true, pn_ * 2); - align(16); - op.fpDbl_subPreA_ = getCurr<void3u>(); - gen_addSubPre(false, pn_ * 2); - } - if ((useMulx_ && op.N == 2) || op.N == 3 || op.N == 4 || (useAdx_ && op.N == 6)) { - align(16); - op.fpDbl_mulPreA_ = getCurr<void3u>(); - if (op.N == 4) { - /* - fpDbl_mulPre is available as C function - this function calls mulPreL directly. - */ - StackFrame sf(this, 3, 10 | UseRDX, 0, false); - mulPre4(gp0, gp1, gp2, sf.t); - sf.close(); // make epilog - L(mulPreL); // called only from asm code - mulPre4(gp0, gp1, gp2, sf.t); - ret(); - } else if (op.N == 6 && useAdx_) { -#if 1 - StackFrame sf(this, 3, 7 | UseRDX, 0, false); - mulPre6(gp0, gp1, gp2, sf.t); - sf.close(); // make epilog - L(mulPreL); // called only from asm code - mulPre6(gp0, gp1, gp2, sf.t); - ret(); -#else - { - StackFrame sf(this, 3, 7 | UseRDX); - mulPre6(gp0, gp1, gp2, sf.t); - } - { - StackFrame sf(this, 3, 10 | UseRDX, 0, false); - L(mulPreL); // called only from asm code - mulPre6(gp0, gp1, gp2, sf.t); - ret(); - } -#endif - } else { - gen_fpDbl_mulPre(); - } - } - if (op.N == 2 || op.N == 3 || op.N == 4 || (op.N == 6 && !isFullBit_ && useAdx_)) { - align(16); - op.fpDbl_modA_ = getCurr<void2u>(); - if (op.N == 4) { - StackFrame sf(this, 3, 10 | UseRDX, 0, false); - call(fpDbl_modL); - sf.close(); - L(fpDbl_modL); - gen_fpDbl_mod4(gp0, gp1, sf.t, gp2); - ret(); - } else if (op.N == 6 && !isFullBit_ && useAdx_) { - StackFrame sf(this, 3, 10 | UseRDX, 0, false); - call(fpDbl_modL); - sf.close(); - L(fpDbl_modL); - Pack t = sf.t; - t.append(gp2); - gen_fpDbl_mod6(gp0, gp1, t); - ret(); - } else { - gen_fpDbl_mod(op); - } - } - if (op.N > 4) return; - if ((useMulx_ && op.N == 2) || op.N == 3 || op.N == 4) { - align(16); - op.fpDbl_sqrPreA_ = getCurr<void2u>(); - gen_fpDbl_sqrPre(op); + op.fp_addA_ = gen_fp_add(); + prof_.set("Fp_add", getCurr()); + + op.fp_subA_ = gen_fp_sub(); + prof_.set("Fp_sub", getCurr()); + + op.fp_shr1 = gen_shr1(); + prof_.set("Fp_shr1", getCurr()); + + op.fp_negA_ = gen_fp_neg(); + prof_.set("Fp_neg", getCurr()); + + op.fpDbl_addA_ = gen_fpDbl_add(); + prof_.set("FpDbl_add", getCurr()); + + op.fpDbl_subA_ = gen_fpDbl_sub(); + prof_.set("FpDbl_sub", getCurr()); + + op.fpDbl_addPre = gen_addSubPre(true, pn_ * 2); + prof_.set("FpDbl_addPre", getCurr()); + + op.fpDbl_subPre = gen_addSubPre(false, pn_ * 2); + prof_.set("FpDbl_subPre", getCurr()); + + op.fpDbl_mulPreA_ = gen_fpDbl_mulPre(); + prof_.set("FpDbl_mulPre", getCurr()); + + op.fpDbl_sqrPreA_ = gen_fpDbl_sqrPre(); + prof_.set("FpDbl_sqrPre", getCurr()); + + op.fpDbl_modA_ = gen_fpDbl_mod(op); + prof_.set("FpDbl_mod", getCurr()); + + op.fp_mulA_ = gen_mul(); + prof_.set("Fp_mul", getCurr()); + if (op.fp_mulA_) { + op.fp_mul = fp::func_ptr_cast<void4u>(op.fp_mulA_); // used in toMont/fromMont } - align(16); - op.fp_mul = getCurr<void4u>(); // used in toMont/fromMont - op.fp_mulA_ = getCurr<void3u>(); - gen_mul(fp_mulL); -// if (op.N > 4) return; - align(16); - op.fp_sqrA_ = getCurr<void2u>(); - gen_sqr(); + op.fp_sqrA_ = gen_sqr(); + prof_.set("Fp_sqr", getCurr()); + if (op.primeMode != PM_NIST_P192 && op.N <= 4) { // support general op.N but not fast for op.N > 4 align(16); op.fp_preInv = getCurr<int2u>(); gen_preInv(); + prof_.set("preInv", getCurr()); } - if (op.N == 4 && !isFullBit_) { - align(16); - op.fp2_addA_ = getCurr<void3u>(); - gen_fp2_add4(); - align(16); - op.fp2_subA_ = getCurr<void3u>(); - gen_fp2_sub4(); - align(16); - op.fp2_negA_ = getCurr<void2u>(); - gen_fp2_neg4(); - align(16); - op.fp2Dbl_mulPreA_ = getCurr<void3u>(); - gen_fp2Dbl_mulPre(mulPreL); - align(16); - op.fp2Dbl_sqrPreA_ = getCurr<void2u>(); - gen_fp2Dbl_sqrPre(mulPreL); - align(16); - op.fp2_mulA_ = getCurr<void3u>(); - gen_fp2_mul4(fpDbl_modL); - align(16); - op.fp2_sqrA_ = getCurr<void2u>(); - gen_fp2_sqr4(fp_mulL); - align(16); - op.fp2_mul_xiA_ = getCurr<void2u>(); - gen_fp2_mul_xi4(); - } + if (op.xi_a == 0) return; // Fp2 is not used + op.fp2_addA_ = gen_fp2_add(); + prof_.set("Fp2_add", getCurr()); + + op.fp2_subA_ = gen_fp2_sub(); + prof_.set("Fp2_sub", getCurr()); + + op.fp2_negA_ = gen_fp2_neg(); + prof_.set("Fp2_neg", getCurr()); + + op.fp2_mulNF = 0; + op.fp2Dbl_mulPreA_ = gen_fp2Dbl_mulPre(); + prof_.set("Fp2Dbl_mulPre", getCurr()); + + op.fp2Dbl_sqrPreA_ = gen_fp2Dbl_sqrPre(); + prof_.set("Fp2Dbl_sqrPre", getCurr()); + + op.fp2_mulA_ = gen_fp2_mul(); + prof_.set("Fp2_mul", getCurr()); + + op.fp2_sqrA_ = gen_fp2_sqr(); + prof_.set("Fp2_sqr", getCurr()); + + op.fp2_mul_xiA_ = gen_fp2_mul_xi(); + prof_.set("Fp2_mul_xi", getCurr()); } - void gen_addSubPre(bool isAdd, int n) + u3u gen_addSubPre(bool isAdd, int n) { +// if (isFullBit_) return 0; + align(16); + u3u func = getCurr<u3u>(); StackFrame sf(this, 3); if (isAdd) { gen_raw_add(sf.p[0], sf.p[1], sf.p[2], rax, n); } else { gen_raw_sub(sf.p[0], sf.p[1], sf.p[2], rax, n); } + return func; } /* pz[] = px[] + py[] @@ -498,7 +478,7 @@ private: } jmp(exit); L(nonZero); - mov(rax, (size_t)p_); + mov(rax, pL_); for (size_t i = 0; i < t.size(); i++) { mov(rdx, ptr [rax + i * 8]); if (i == 0) { @@ -626,7 +606,7 @@ private: mov(*fullReg, 0); adc(*fullReg, 0); } - mov(rax, (size_t)p_); + mov(rax, pL_); sub_rm(p1, rax); if (fullReg) { sbb(*fullReg, 0); @@ -646,7 +626,7 @@ private: const Pack& p1 = t.sub(pn_, pn_); load_rm(p0, px); sub_rm(p0, py, withCarry); - mov(rax, (size_t)p_); + mov(rax, pL_); load_rm(p1, rax); sbb(rax, rax); // rax = (x > y) ? 0 : -1 for (size_t i = 0; i < p1.size(); i++) { @@ -676,29 +656,29 @@ private: gen_raw_fp_sub(pz, px, py, sf.t, false); } /* - add(pz + offset, px + offset, py + offset); + add(pz, px, py); size of t1, t2 == 6 destroy t0, t1 */ - void gen_raw_fp_add6(const Reg64& pz, const Reg64& px, const Reg64& py, int offset, const Pack& t1, const Pack& t2, bool withCarry) + void gen_raw_fp_add6(const RegExp& pz, const RegExp& px, const RegExp& py, const Pack& t1, const Pack& t2, bool withCarry) { - load_rm(t1, px + offset); - add_rm(t1, py + offset, withCarry); + load_rm(t1, px); + add_rm(t1, py, withCarry); Label exit; if (isFullBit_) { jnc("@f"); - mov(t2[0], *pL_); // t2 is not used + mov(t2[0], pL_); // t2 is not used sub_rm(t1, t2[0]); jmp(exit); L("@@"); } mov_rr(t2, t1); - sub_rm(t2, rip + *pL_); + sub_rm(t2, rip + pL_); for (int i = 0; i < 6; i++) { cmovnc(t1[i], t2[i]); } L(exit); - store_mr(pz + offset, t1); + store_mr(pz, t1); } void gen_fp_add6() { @@ -713,17 +693,19 @@ private: Pack t2 = sf.t.sub(6); t2.append(rax); t2.append(px); // destory after used - gen_raw_fp_add6(pz, px, py, 0, t1, t2, false); + gen_raw_fp_add6(pz, px, py, t1, t2, false); } - void gen_fp_add() + void3u gen_fp_add() { + align(16); + void3u func = getCurr<void3u>(); if (pn_ <= 4) { gen_fp_add_le4(); - return; + return func; } if (pn_ == 6) { gen_fp_add6(); - return; + return func; } StackFrame sf(this, 3, 0, pn_ * 8); const Reg64& pz = sf.p[0]; @@ -733,7 +715,7 @@ private: inLocalLabel(); gen_raw_add(pz, px, py, rax, pn_); - mov(px, (size_t)p_); // destroy px + mov(px, pL_); // destroy px if (isFullBit_) { jc(".over", jmpMode); } @@ -759,9 +741,12 @@ private: L(".exit"); #endif outLocalLabel(); + return func; } - void gen_fpDbl_add() + void3u gen_fpDbl_add() { + align(16); + void3u func = getCurr<void3u>(); if (pn_ <= 4) { int tn = pn_ * 2 + (isFullBit_ ? 1 : 0); StackFrame sf(this, 3, tn); @@ -770,6 +755,7 @@ private: const Reg64& py = sf.p[2]; gen_raw_add(pz, px, py, rax, pn_); gen_raw_fp_add(pz + 8 * pn_, px + 8 * pn_, py + 8 * pn_, sf.t, true); + return func; } else if (pn_ == 6 && !isFullBit_) { StackFrame sf(this, 3, 10); const Reg64& pz = sf.p[0]; @@ -780,14 +766,15 @@ private: Pack t2 = sf.t.sub(6); t2.append(rax); t2.append(py); - gen_raw_fp_add6(pz, px, py, pn_ * 8, t1, t2, true); - } else { - assert(0); - exit(1); + gen_raw_fp_add6(pz + pn_ * 8, px + pn_ * 8, py + pn_ * 8, t1, t2, true); + return func; } + return 0; } - void gen_fpDbl_sub() + void3u gen_fpDbl_sub() { + align(16); + void3u func = getCurr<void3u>(); if (pn_ <= 4) { int tn = pn_ * 2; StackFrame sf(this, 3, tn); @@ -796,6 +783,7 @@ private: const Reg64& py = sf.p[2]; gen_raw_sub(pz, px, py, rax, pn_); gen_raw_fp_sub(pz + 8 * pn_, px + 8 * pn_, py + 8 * pn_, sf.t, true); + return func; } else if (pn_ == 6) { StackFrame sf(this, 3, 4); const Reg64& pz = sf.p[0]; @@ -806,12 +794,11 @@ private: t.append(rax); t.append(px); gen_raw_fp_sub6(pz, px, py, pn_ * 8, t, true); - } else { - assert(0); - exit(1); + return func; } + return 0; } - void gen_raw_fp_sub6(const Reg64& pz, const Reg64& px, const Reg64& py, int offset, const Pack& t, bool withCarry) + void gen_raw_fp_sub6(const RegExp& pz, const RegExp& px, const RegExp& py, int offset, const Pack& t, bool withCarry) { load_rm(t, px + offset); sub_rm(t, py + offset, withCarry); @@ -819,7 +806,7 @@ private: jmp is faster than and-mask without jmp */ jnc("@f"); - add_rm(t, rip + *pL_); + add_rm(t, rip + pL_); L("@@"); store_mr(pz + offset, t); } @@ -834,15 +821,17 @@ private: t.append(px); // |t| = 6 gen_raw_fp_sub6(pz, px, py, 0, t, false); } - void gen_fp_sub() + void3u gen_fp_sub() { + align(16); + void3u func = getCurr<void3u>(); if (pn_ <= 4) { gen_fp_sub_le4(); - return; + return func; } if (pn_ == 6) { gen_fp_sub6(); - return; + return func; } StackFrame sf(this, 3); const Reg64& pz = sf.p[0]; @@ -852,17 +841,23 @@ private: Label exit; gen_raw_sub(pz, px, py, rax, pn_); jnc(exit, jmpMode); - mov(px, (size_t)p_); + mov(px, pL_); gen_raw_add(pz, pz, px, rax, pn_); L(exit); + return func; } - void gen_fp_neg() + void2u gen_fp_neg() { + align(16); + void2u func = getCurr<void2u>(); StackFrame sf(this, 2, UseRDX | pn_); gen_raw_neg(sf.p[0], sf.p[1], sf.t); + return func; } - void gen_shr1() + void2u gen_shr1() { + align(16); + void2u func = getCurr<void2u>(); const int c = 1; StackFrame sf(this, 2, 1); const Reg64 *t0 = &rax; @@ -878,25 +873,54 @@ private: } shr(*t0, c); mov(ptr [pz + (pn_ - 1) * 8], *t0); + return func; } - void gen_mul(Label& fp_mulL) + void3u gen_mul() { + align(16); + void3u func = getCurr<void3u>(); if (op_->primeMode == PM_NIST_P192) { StackFrame sf(this, 3, 10 | UseRDX, 8 * 6); mulPre3(rsp, sf.p[1], sf.p[2], sf.t); fpDbl_mod_NIST_P192(sf.p[0], rsp, sf.t); + return func; } if (pn_ == 3) { - gen_montMul3(p_, rp_); - } else if (pn_ == 4) { - gen_montMul4(fp_mulL, p_, rp_); -// } else if (pn_ == 6 && useAdx_) { -// gen_montMul6(fp_mulL, p_, rp_); - } else if (pn_ <= 9) { + gen_montMul3(); + return func; + } + if (pn_ == 4) { + gen_montMul4(); + return func; + } + if (pn_ == 6 && !isFullBit_ && useMulx_ && useAdx_) { +#if 1 + // a little faster + gen_montMul6(); +#else + if (mulPreL.getAddress() == 0 || fpDbl_modL.getAddress() == 0) return 0; + StackFrame sf(this, 3, 10 | UseRDX, 12 * 8); + /* + use xm3 + rsp + [0, ..12 * 8) ; mul(x, y) + */ + vmovq(xm3, gp0); + mov(gp0, rsp); + call(mulPreL); // gp0, x, y + vmovq(gp0, xm3); + mov(gp1, rsp); + call(fpDbl_modL); +#endif + return func; + } +#if 0 + if (pn_ <= 9) { gen_montMulN(p_, rp_, pn_); - } else { - throw cybozu::Exception("mcl:FpGenerator:gen_mul:not implemented for") << pn_; + return func; } +#endif + return 0; } /* @input (z, xy) @@ -926,7 +950,7 @@ private: mov(a, rp_); mul(t6); - mov(t0, (uint64_t)p_); + mov(t0, pL_); mov(t7, a); // q // [d:t7:t1] = p * q @@ -995,7 +1019,7 @@ private: mov(a, rp_); mul(t10); - mov(t0, (uint64_t)p_); + mov(t0, pL_); mov(t7, a); // q // [d:t7:t2:t1] = p * q @@ -1070,12 +1094,12 @@ private: const Reg64& a = rax; const Reg64& d = rdx; - movq(xm0, z); + vmovq(xm0, z); mov(z, ptr [xy + 8 * 0]); mov(a, rp_); mul(z); - mov(t0, (uint64_t)p_); + mov(t0, pL_); mov(t7, a); // q // [d:t7:t3:t2:t1] = p * q @@ -1097,7 +1121,7 @@ private: if (isFullBit_) { mov(t5, 0); adc(t5, 0); - movq(xm2, t5); + vmovq(xm2, t5); } // free z, t0, t1, t5, t6, xy @@ -1106,18 +1130,18 @@ private: mul(t2); mov(z, a); // q - movq(xm1, t10); + vmovq(xm1, t10); // [d:z:t5:t6:xy] = p * q mul4x1(t0, z, t1, t5, t6, xy, t10); - movq(t10, xm1); + vmovq(t10, xm1); add_rr(Pack(t8, t4, t7, t3, t2), Pack(d, z, t5, t6, xy)); adc(t9, 0); adc(t10, 0); // [t10:t9:t8:t4:t7:t3] if (isFullBit_) { - movq(t5, xm2); + vmovq(t5, xm2); adc(t5, 0); - movq(xm2, t5); + vmovq(xm2, t5); } // free z, t0, t1, t2, t5, t6, xy @@ -1132,7 +1156,7 @@ private: add_rr(Pack(t9, t8, t4, t7, t3), Pack(d, z, t5, xy, t6)); adc(t10, 0); // c' = [t10:t9:t8:t4:t7] if (isFullBit_) { - movq(t3, xm2); + vmovq(t3, xm2); adc(t3, 0); } @@ -1161,60 +1185,112 @@ private: cmovc(t9, t2); cmovc(t10, t6); - movq(z, xm0); + vmovq(z, xm0); store_mr(z, Pack(t10, t9, t8, t4)); } - void gen_fpDbl_mod(const mcl::fp::Op& op) + void2u gen_fpDbl_mod(const fp::Op& op) { + align(16); + void2u func = getCurr<void2u>(); if (op.primeMode == PM_NIST_P192) { StackFrame sf(this, 2, 6 | UseRDX); fpDbl_mod_NIST_P192(sf.p[0], sf.p[1], sf.t); - return; + return func; } #if 0 if (op.primeMode == PM_NIST_P521) { StackFrame sf(this, 2, 8 | UseRDX); fpDbl_mod_NIST_P521(sf.p[0], sf.p[1], sf.t); - return; + return func; } #endif - switch (pn_) { - case 2: + if (pn_ == 2) { gen_fpDbl_mod2(); - break; - case 3: + return func; + } + if (pn_ == 3) { gen_fpDbl_mod3(); - break; -#if 0 - case 4: - { - StackFrame sf(this, 3, 10 | UseRDX); - gen_fpDbl_mod4(gp0, gp1, sf.t, gp2); - } - break; -#endif - default: - throw cybozu::Exception("gen_fpDbl_mod:not support") << pn_; + return func; + } + if (pn_ == 4) { + StackFrame sf(this, 3, 10 | UseRDX, 0, false); + call(fpDbl_modL); + sf.close(); + L(fpDbl_modL); + gen_fpDbl_mod4(gp0, gp1, sf.t, gp2); + ret(); + return func; + } + if (pn_ == 6 && !isFullBit_ && useMulx_ && useAdx_) { + StackFrame sf(this, 3, 10 | UseRDX, 0, false); + call(fpDbl_modL); + sf.close(); + L(fpDbl_modL); + Pack t = sf.t; + t.append(gp2); + gen_fpDbl_mod6(gp0, gp1, t); + ret(); + return func; } + return 0; } - void gen_sqr() + void2u gen_sqr() { + align(16); + void2u func = getCurr<void2u>(); if (op_->primeMode == PM_NIST_P192) { - StackFrame sf(this, 2, 10 | UseRDX | UseRCX, 8 * 6); - sqrPre3(rsp, sf.p[1], sf.t); + StackFrame sf(this, 3, 10 | UseRDX, 6 * 8); + Pack t = sf.t; + t.append(sf.p[2]); + sqrPre3(rsp, sf.p[1], t); fpDbl_mod_NIST_P192(sf.p[0], rsp, sf.t); + return func; } if (pn_ == 3) { - gen_montSqr3(p_, rp_); - return; + gen_montSqr3(); + return func; } - // sqr(y, x) = mul(y, x, x) + if (pn_ == 4 && useMulx_) { +#if 1 + // sqr(y, x) = mul(y, x, x) #ifdef XBYAK64_WIN - mov(r8, rdx); + mov(r8, rdx); #else - mov(rdx, rsi); + mov(rdx, rsi); +#endif + jmp((const void*)op_->fp_mulA_); +#else // (sqrPre + mod) is slower than mul + StackFrame sf(this, 3, 10 | UseRDX, 8 * 8); + Pack t = sf.t; + t.append(sf.p[2]); + sqrPre4(rsp, sf.p[1], t); + mov(gp0, sf.p[0]); + mov(gp1, rsp); + call(fpDbl_modL); #endif - jmp((const void*)op_->fp_mulA_); + return func; + } + if (pn_ == 6 && !isFullBit_ && useMulx_ && useAdx_) { + if (fpDbl_modL.getAddress() == 0) return 0; + StackFrame sf(this, 3, 10 | UseRDX, (12 + 6) * 8); + /* + use xm3 + rsp + [6 * 8, (12 + 6) * 8) ; sqrPre(x, x) + [0..6 * 8) ; stack for sqrPre6 + */ + vmovq(xm3, gp0); + Pack t = sf.t; + t.append(sf.p[2]); + // sqrPre6 uses 6 * 8 bytes stack + sqrPre6(rsp + 6 * 8, sf.p[1], t); + mov(gp0, ptr[rsp + (12 + 6) * 8]); + vmovq(gp0, xm3); + lea(gp1, ptr[rsp + 6 * 8]); + call(fpDbl_modL); + return func; + } + return 0; } /* input (pz[], px[], py[]) @@ -1259,7 +1335,7 @@ private: z[0..3] <- montgomery(x[0..3], y[0..3]) destroy gt0, ..., gt9, xm0, xm1, p2 */ - void gen_montMul4(Label& fp_mulL, const uint64_t *p, uint64_t pp) + void gen_montMul4() { StackFrame sf(this, 3, 10 | UseRDX, 0, false); call(fp_mulL); @@ -1280,23 +1356,23 @@ private: const Reg64& t9 = sf.t[9]; L(fp_mulL); - movq(xm0, p0); // save p0 - mov(p0, (uint64_t)p); - movq(xm1, p2); + vmovq(xm0, p0); // save p0 + mov(p0, pL_); + vmovq(xm1, p2); mov(p2, ptr [p2]); - montgomery4_1(pp, t0, t7, t3, t2, t1, p1, p2, p0, t4, t5, t6, t8, t9, true, xm2); + montgomery4_1(rp_, t0, t7, t3, t2, t1, p1, p2, p0, t4, t5, t6, t8, t9, true, xm2); - movq(p2, xm1); + vmovq(p2, xm1); mov(p2, ptr [p2 + 8]); - montgomery4_1(pp, t1, t0, t7, t3, t2, p1, p2, p0, t4, t5, t6, t8, t9, false, xm2); + montgomery4_1(rp_, t1, t0, t7, t3, t2, p1, p2, p0, t4, t5, t6, t8, t9, false, xm2); - movq(p2, xm1); + vmovq(p2, xm1); mov(p2, ptr [p2 + 16]); - montgomery4_1(pp, t2, t1, t0, t7, t3, p1, p2, p0, t4, t5, t6, t8, t9, false, xm2); + montgomery4_1(rp_, t2, t1, t0, t7, t3, p1, p2, p0, t4, t5, t6, t8, t9, false, xm2); - movq(p2, xm1); + vmovq(p2, xm1); mov(p2, ptr [p2 + 24]); - montgomery4_1(pp, t3, t2, t1, t0, t7, p1, p2, p0, t4, t5, t6, t8, t9, false, xm2); + montgomery4_1(rp_, t3, t2, t1, t0, t7, p1, p2, p0, t4, t5, t6, t8, t9, false, xm2); // [t7:t3:t2:t1:t0] mov(t4, t0); @@ -1310,16 +1386,131 @@ private: cmovc(t2, t6); cmovc(t3, rdx); - movq(p0, xm0); // load p0 + vmovq(p0, xm0); // load p0 store_mr(p0, Pack(t3, t2, t1, t0)); ret(); } /* + c[n+2] = c[n+1] + px[n] * rdx + use rax + */ + void mulAdd(const Pack& c, int n, const RegExp& px) + { + const Reg64& a = rax; + xor_(a, a); + for (int i = 0; i < n; i++) { + mulx(c[n + 1], a, ptr [px + i * 8]); + adox(c[i], a); + adcx(c[i + 1], c[n + 1]); + } + mov(a, 0); + mov(c[n + 1], a); + adox(c[n], a); + adcx(c[n + 1], a); + adox(c[n + 1], a); + } + /* + input + c[6..0] + rdx = yi + use rax, rdx + output + c[7..1] + + if first: + c = x[5..0] * rdx + else: + c += x[5..0] * rdx + q = uint64_t(c0 * rp) + c += p * q + c >>= 64 + */ + void montgomery6_1(const Pack& c, const RegExp& px, const Reg64& t0, const Reg64& t1, bool isFirst) + { + const int n = 6; + const Reg64& a = rax; + const Reg64& d = rdx; + if (isFirst) { + const Reg64 *pt0 = &a; + const Reg64 *pt1 = &t0; + // c[6..0] = px[5..0] * rdx + mulx(*pt0, c[0], ptr [px + 0 * 8]); + for (int i = 1; i < n; i++) { + mulx(*pt1, c[i], ptr[px + i * 8]); + if (i == 1) { + add(c[i], *pt0); + } else { + adc(c[i], *pt0); + } + std::swap(pt0, pt1); + } + mov(c[n], 0); + adc(c[n], *pt0); + } else { + // c[7..0] = c[6..0] + px[5..0] * rdx + mulAdd(c, 6, px); + } + mov(a, rp_); + mul(c[0]); // q = a + mov(d, a); + mov(t1, pL_); + // c += p * q + mulAdd(c, 6, t1); + } + /* + input (z, x, y) = (p0, p1, p2) + z[0..5] <- montgomery(x[0..5], y[0..5]) + destroy t0, ..., t9, rax, rdx + */ + void gen_montMul6() + { + assert(!isFullBit_ && useMulx_ && useAdx_); + StackFrame sf(this, 3, 10 | UseRDX, 0, false); + call(fp_mulL); + sf.close(); + const Reg64& pz = sf.p[0]; + const Reg64& px = sf.p[1]; + const Reg64& py = sf.p[2]; + + const Reg64& t0 = sf.t[0]; + const Reg64& t1 = sf.t[1]; + const Reg64& t2 = sf.t[2]; + const Reg64& t3 = sf.t[3]; + const Reg64& t4 = sf.t[4]; + const Reg64& t5 = sf.t[5]; + const Reg64& t6 = sf.t[6]; + const Reg64& t7 = sf.t[7]; + const Reg64& t8 = sf.t[8]; + const Reg64& t9 = sf.t[9]; + L(fp_mulL); + mov(rdx, ptr [py + 0 * 8]); + montgomery6_1(Pack(t7, t6, t5, t4, t3, t2, t1, t0), px, t8, t9, true); + mov(rdx, ptr [py + 1 * 8]); + montgomery6_1(Pack(t0, t7, t6, t5, t4, t3, t2, t1), px, t8, t9, false); + mov(rdx, ptr [py + 2 * 8]); + montgomery6_1(Pack(t1, t0, t7, t6, t5, t4, t3, t2), px, t8, t9, false); + mov(rdx, ptr [py + 3 * 8]); + montgomery6_1(Pack(t2, t1, t0, t7, t6, t5, t4, t3), px, t8, t9, false); + mov(rdx, ptr [py + 4 * 8]); + montgomery6_1(Pack(t3, t2, t1, t0, t7, t6, t5, t4), px, t8, t9, false); + mov(rdx, ptr [py + 5 * 8]); + montgomery6_1(Pack(t4, t3, t2, t1, t0, t7, t6, t5), px, t8, t9, false); + // [t4:t3:t2:t1:t0:t7:t6] + const Pack z = Pack(t3, t2, t1, t0, t7, t6); + const Pack keep = Pack(rdx, rax, px, py, t8, t9); + mov_rr(keep, z); + mov(t5, pL_); + sub_rm(z, t5); + cmovc_rr(z, keep); + store_mr(pz, z); + ret(); + } + /* input (z, x, y) = (p0, p1, p2) z[0..2] <- montgomery(x[0..2], y[0..2]) destroy gt0, ..., gt9, xm0, xm1, p2 */ - void gen_montMul3(const uint64_t *p, uint64_t pp) + void gen_montMul3() { StackFrame sf(this, 3, 10 | UseRDX); const Reg64& p0 = sf.p[0]; @@ -1337,16 +1528,16 @@ private: const Reg64& t8 = sf.t[8]; const Reg64& t9 = sf.t[9]; - movq(xm0, p0); // save p0 - mov(t7, (uint64_t)p); + vmovq(xm0, p0); // save p0 + mov(t7, pL_); mov(t9, ptr [p2]); // c3, c2, c1, c0, px, y, p, - montgomery3_1(pp, t0, t3, t2, t1, p1, t9, t7, t4, t5, t6, t8, p0, true); + montgomery3_1(rp_, t0, t3, t2, t1, p1, t9, t7, t4, t5, t6, t8, p0, true); mov(t9, ptr [p2 + 8]); - montgomery3_1(pp, t1, t0, t3, t2, p1, t9, t7, t4, t5, t6, t8, p0, false); + montgomery3_1(rp_, t1, t0, t3, t2, p1, t9, t7, t4, t5, t6, t8, p0, false); mov(t9, ptr [p2 + 16]); - montgomery3_1(pp, t2, t1, t0, t3, p1, t9, t7, t4, t5, t6, t8, p0, false); + montgomery3_1(rp_, t2, t1, t0, t3, p1, t9, t7, t4, t5, t6, t8, p0, false); // [(t3):t2:t1:t0] mov(t4, t0); @@ -1357,7 +1548,7 @@ private: cmovc(t0, t4); cmovc(t1, t5); cmovc(t2, t6); - movq(p0, xm0); + vmovq(p0, xm0); store_mr(p0, Pack(t2, t1, t0)); } /* @@ -1365,7 +1556,7 @@ private: z[0..2] <- montgomery(px[0..2], px[0..2]) destroy gt0, ..., gt9, xm0, xm1, p2 */ - void gen_montSqr3(const uint64_t *p, uint64_t pp) + void gen_montSqr3() { StackFrame sf(this, 3, 10 | UseRDX, 16 * 3); const Reg64& pz = sf.p[0]; @@ -1383,24 +1574,24 @@ private: const Reg64& t8 = sf.t[8]; const Reg64& t9 = sf.t[9]; - movq(xm0, pz); // save pz - mov(t7, (uint64_t)p); + vmovq(xm0, pz); // save pz + mov(t7, pL_); mov(t9, ptr [px]); mul3x1_sqr1(px, t9, t3, t2, t1, t0); mov(t0, rdx); - montgomery3_sub(pp, t0, t9, t2, t1, px, t3, t7, t4, t5, t6, t8, pz, true); + montgomery3_sub(rp_, t0, t9, t2, t1, px, t3, t7, t4, t5, t6, t8, pz, true); mov(t3, ptr [px + 8]); mul3x1_sqr2(px, t3, t6, t5, t4); add_rr(Pack(t1, t0, t9, t2), Pack(rdx, rax, t5, t4)); if (isFullBit_) setc(pz.cvt8()); - montgomery3_sub(pp, t1, t3, t9, t2, px, t0, t7, t4, t5, t6, t8, pz, false); + montgomery3_sub(rp_, t1, t3, t9, t2, px, t0, t7, t4, t5, t6, t8, pz, false); mov(t0, ptr [px + 16]); mul3x1_sqr3(t0, t5, t4); add_rr(Pack(t2, t1, t3, t9), Pack(rdx, rax, t5, t4)); if (isFullBit_) setc(pz.cvt8()); - montgomery3_sub(pp, t2, t0, t3, t9, px, t1, t7, t4, t5, t6, t8, pz, false); + montgomery3_sub(rp_, t2, t0, t3, t9, px, t1, t7, t4, t5, t6, t8, pz, false); // [t9:t2:t0:t3] mov(t4, t3); @@ -1411,18 +1602,17 @@ private: cmovc(t3, t4); cmovc(t0, t5); cmovc(t2, t6); - movq(pz, xm0); + vmovq(pz, xm0); store_mr(pz, Pack(t2, t0, t3)); } /* py[5..0] <- px[2..0]^2 - @note use rax, rdx, rcx! + @note use rax, rdx */ void sqrPre3(const RegExp& py, const RegExp& px, const Pack& t) { const Reg64& a = rax; const Reg64& d = rdx; - const Reg64& c = rcx; const Reg64& t0 = t[0]; const Reg64& t1 = t[1]; const Reg64& t2 = t[2]; @@ -1433,6 +1623,7 @@ private: const Reg64& t7 = t[7]; const Reg64& t8 = t[8]; const Reg64& t9 = t[9]; + const Reg64& t10 = t[10]; if (useMulx_) { mov(d, ptr [px + 8 * 0]); @@ -1453,7 +1644,7 @@ private: mov(d, t7); mulx(t8, t7, d); - mulx(c, t9, t9); + mulx(t10, t9, t9); } else { mov(t9, ptr [px + 8 * 0]); mov(a, t9); @@ -1484,11 +1675,11 @@ private: mov(a, ptr [px + 8 * 2]); mul(t9); mov(t9, a); - mov(c, d); + mov(t10, d); } add(t2, t7); adc(t8, t9); - mov(t7, c); + mov(t7, t10); adc(t7, 0); // [t7:t8:t2:t1] add(t0, t1); @@ -1500,7 +1691,7 @@ private: mov(a, ptr [px + 8 * 2]); mul(a); add(t4, t9); - adc(a, c); + adc(a, t10); adc(d, 0); // [d:a:t4:t3] add(t2, t3); @@ -1537,13 +1728,13 @@ private: const Reg64& a = rax; const Reg64& d = rdx; mov(d, ptr [px]); - mulx(hi, a, ptr [py + 8 * 0]); - adox(pd[0], a); - mov(ptr [pz], pd[0]); - for (size_t i = 1; i < pd.size(); i++) { - adcx(pd[i], hi); - mulx(hi, a, ptr [py + 8 * i]); + xor_(a, a); + for (size_t i = 0; i < pd.size(); i++) { + mulx(hi, a, ptr [py + i * 8]); adox(pd[i], a); + if (i == 0) mov(ptr[pz], pd[0]); + if (i == pd.size() - 1) break; + adcx(pd[i + 1], hi); } mov(d, 0); adcx(hi, d); @@ -1611,6 +1802,7 @@ private: if (useMulx_) { mulPack(pz, px, py, Pack(t2, t1, t0)); +#if 0 // a little slow if (useAdx_) { // [t2:t1:t0] mulPackAdd(pz + 8 * 1, px + 8 * 1, py, t3, Pack(t2, t1, t0)); @@ -1620,6 +1812,7 @@ private: store_mr(pz + 8 * 3, Pack(t4, t3, t2)); return; } +#endif } else { mov(t5, ptr [px]); mov(a, ptr [py + 8 * 0]); @@ -1749,6 +1942,7 @@ private: } /* py[7..0] = px[3..0] ^ 2 + use xmm0 */ void sqrPre4(const RegExp& py, const RegExp& px, const Pack& t) { @@ -1762,30 +1956,76 @@ private: const Reg64& t7 = t[7]; const Reg64& t8 = t[8]; const Reg64& t9 = t[9]; + const Reg64& t10 = t[10]; + const Reg64& a = rax; + const Reg64& d = rdx; - // (AN + B)^2 = A^2N^2 + 2AB + B^2 - - mul2x2(px + 8 * 0, px + 8 * 2, t4, t3, t2, t1, t0); - // [t3:t2:t1:t0] = AB - xor_(t4, t4); - add_rr(Pack(t4, t3, t2, t1, t0), Pack(t4, t3, t2, t1, t0)); - // [t4:t3:t2:t1:t0] = 2AB - store_mr(py + 8 * 2, Pack(t4, t3, t2, t1, t0)); - - mov(t8, ptr [px + 8 * 0]); - mov(t9, ptr [px + 8 * 1]); - sqr2(t1, t0, t7, t6, t9, t8, rax, rcx); - // B^2 = [t1:t0:t7:t6] - store_mr(py + 8 * 0, Pack(t7, t6)); - // [t1:t0] - - mov(t8, ptr [px + 8 * 2]); - mov(t9, ptr [px + 8 * 3]); - sqr2(t5, t4, t3, t2, t9, t8, rax, rcx); - // [t5:t4:t3:t2] - add_rm(Pack(t4, t3, t2, t1, t0), py + 8 * 2); - adc(t5, 0); - store_mr(py + 8 * 2, Pack(t5, t4, t3, t2, t1, t0)); + /* + (aN + b)^2 = a^2 N^2 + 2ab N + b^2 + */ + load_rm(Pack(t9, t8), px); + sqr2(t3, t2, t1, t0, t9, t8, t7, t6); + // [t3:t2:t1:t0] = b^2 + store_mr(py, Pack(t1, t0)); + vmovq(xm0, t2); + mul2x2(px, px + 2 * 8, t6, t5, t4, t1, t0); + // [t5:t4:t1:t0] = ab + xor_(t6, t6); + add_rr(Pack(t6, t5, t4, t1, t0), Pack(t6, t5, t4, t1, t0)); + // [t6:t5:t4:t1:t0] = 2ab + load_rm(Pack(t8, t7), px + 2 * 8); + // free t10, t9, rax, rdx + /* + [d:t8:t10:t9] = [t8:t7]^2 + */ + mov(d, t7); + mulx(t10, t9, t7); // [t10:t9] = t7^2 + mulx(t7, t2, t8); // [t7:t2] = t7 t8 + xor_(a, a); + add_rr(Pack(a, t7, t2), Pack(a, t7, t2)); + // [a:t7:t2] = 2 t7 t8 + mov(d, t8); + mulx(d, t8, t8); // [d:t8] = t8^2 + add_rr(Pack(d, t8, t10), Pack(a, t7, t2)); + // [d:t8:t10:t9] = [t8:t7]^2 + vmovq(t2, xm0); + add_rr(Pack(t8, t10, t9, t3, t2), Pack(t6, t5, t4, t1, t0)); + adc(d, 0); + store_mr(py + 2 * 8, Pack(d, t8, t10, t9, t3, t2)); + } + /* + py[11..0] = px[5..0] ^ 2 + use rax, rdx, stack[6 * 8] + */ + void sqrPre6(const RegExp& py, const RegExp& px, const Pack& t) + { + const Reg64& t0 = t[0]; + const Reg64& t1 = t[1]; + const Reg64& t2 = t[2]; + /* + (aN + b)^2 = a^2 N^2 + 2ab N + b^2 + */ + sqrPre3(py, px, t); // [py] <- b^2 + sqrPre3(py + 6 * 8, px + 3 * 8, t); // [py + 6 * 8] <- a^2 + mulPre3(rsp, px, px + 3 * 8, t); // ab + Pack ab = t.sub(0, 6); + load_rm(ab, rsp); + xor_(rax, rax); + for (int i = 0; i < 6; i++) { + if (i == 0) { + add(ab[i], ab[i]); + } else { + adc(ab[i], ab[i]); + } + } + adc(rax, rax); + add_rm(ab, py + 3 * 8); + store_mr(py + 3 * 8, ab); + load_rm(Pack(t2, t1, t0), py + 9 * 8); + adc(t0, rax); + adc(t1, 0); + adc(t2, 0); + store_mr(py + 9 * 8, Pack(t2, t1, t0)); } /* pz[7..0] <- px[3..0] * py[3..0] @@ -1805,6 +2045,16 @@ private: const Reg64& t8 = t[8]; const Reg64& t9 = t[9]; +#if 0 // a little slower + if (useMulx_ && useAdx_) { + mulPack(pz, px, py, Pack(t3, t2, t1, t0)); + mulPackAdd(pz + 8 * 1, px + 8 * 1, py, t4, Pack(t3, t2, t1, t0)); + mulPackAdd(pz + 8 * 2, px + 8 * 2, py, t0, Pack(t4, t3, t2, t1)); + mulPackAdd(pz + 8 * 3, px + 8 * 3, py, t1, Pack(t0, t4, t3, t2)); + store_mr(pz + 8 * 4, Pack(t1, t0, t4, t3)); + return; + } +#endif #if 0 // a little slower if (!useMulx_) { @@ -1818,28 +2068,17 @@ private: mul2x2(px + 8 * 0, py + 8 * 0, t9, t8, t7, t6, t5); store_mr(pz, Pack(t6, t5)); // [t8:t7] - movq(xm0, t7); - movq(xm1, t8); + vmovq(xm0, t7); + vmovq(xm1, t8); mul2x2(px + 8 * 2, py + 8 * 2, t8, t7, t9, t6, t5); - movq(a, xm0); - movq(d, xm1); + vmovq(a, xm0); + vmovq(d, xm1); add_rr(Pack(t4, t3, t2, t1, t0), Pack(t9, t6, t5, d, a)); adc(t7, 0); store_mr(pz + 8 * 2, Pack(t7, t4, t3, t2, t1, t0)); #else if (useMulx_) { mulPack(pz, px, py, Pack(t3, t2, t1, t0)); - if (0 && useAdx_) { // a little slower? - // [t3:t2:t1:t0] - mulPackAdd(pz + 8 * 1, px + 8 * 1, py, t4, Pack(t3, t2, t1, t0)); - // [t4:t3:t2:t1] - mulPackAdd(pz + 8 * 2, px + 8 * 2, py, t5, Pack(t4, t3, t2, t1)); - // [t5:t4:t3:t2] - mulPackAdd(pz + 8 * 3, px + 8 * 3, py, t0, Pack(t5, t4, t3, t2)); - // [t0:t5:t4:t3] - store_mr(pz + 8 * 4, Pack(t0, t5, t4, t3)); - return; - } } else { mov(t5, ptr [px]); mov(a, ptr [py + 8 * 0]); @@ -1894,12 +2133,111 @@ private: mov(ptr [pz + 8 * 7], d); #endif } - void mulPre6(const RegExp& pz, const RegExp& px, const RegExp& py, const Pack& t) + // [gp0] <- [gp1] * [gp2] + void mulPre6(const Pack& t) { + const Reg64& pz = gp0; + const Reg64& px = gp1; + const Reg64& py = gp2; const Reg64& t0 = t[0]; const Reg64& t1 = t[1]; const Reg64& t2 = t[2]; const Reg64& t3 = t[3]; +#if 0 // slower than basic multiplication(56clk -> 67clk) +// const Reg64& t7 = t[7]; +// const Reg64& t8 = t[8]; +// const Reg64& t9 = t[9]; + const Reg64& a = rax; + const Reg64& d = rdx; + const int stackSize = (3 + 3 + 6 + 1 + 1 + 1) * 8; // a+b, c+d, (a+b)(c+d), x, y, z + const int abPos = 0; + const int cdPos = abPos + 3 * 8; + const int abcdPos = cdPos + 3 * 8; + const int zPos = abcdPos + 6 * 8; + const int yPos = zPos + 8; + const int xPos = yPos + 8; + + sub(rsp, stackSize); + mov(ptr[rsp + zPos], pz); + mov(ptr[rsp + xPos], px); + mov(ptr[rsp + yPos], py); + /* + x = aN + b, y = cN + d + xy = abN^2 + ((a+b)(c+d) - ac - bd)N + bd + */ + xor_(a, a); + load_rm(Pack(t2, t1, t0), px); // b + add_rm(Pack(t2, t1, t0), px + 3 * 8); // a + b + adc(a, 0); + store_mr(pz, Pack(t2, t1, t0)); + vmovq(xm0, a); // carry1 + + xor_(a, a); + load_rm(Pack(t2, t1, t0), py); // d + add_rm(Pack(t2, t1, t0), py + 3 * 8); // c + d + adc(a, 0); + store_mr(pz + 3 * 8, Pack(t2, t1, t0)); + vmovq(xm1, a); // carry2 + + mulPre3(rsp + abcdPos, pz, pz + 3 * 8, t); // (a+b)(c+d) + + vmovq(a, xm0); + vmovq(d, xm1); + mov(t3, a); + and_(t3, d); // t3 = carry1 & carry2 + Label doNothing; + je(doNothing); + load_rm(Pack(t2, t1, t0), rsp + abcdPos + 3 * 8); + test(a, a); + je("@f"); + // add (c+d) + add_rm(Pack(t2, t1, t0), pz + 3 * 8); + adc(t3, 0); + L("@@"); + test(d, d); + je("@f"); + // add(a+b) + add_rm(Pack(t2, t1, t0), pz); + adc(t3, 0); + L("@@"); + store_mr(rsp + abcdPos + 3 * 8, Pack(t2, t1, t0)); + L(doNothing); + vmovq(xm0, t3); // save new carry + + + mov(gp0, ptr [rsp + zPos]); + mov(gp1, ptr [rsp + xPos]); + mov(gp2, ptr [rsp + yPos]); + mulPre3(gp0, gp1, gp2, t); // [rsp] <- bd + + mov(gp0, ptr [rsp + zPos]); + mov(gp1, ptr [rsp + xPos]); + mov(gp2, ptr [rsp + yPos]); + mulPre3(gp0 + 6 * 8, gp1 + 3 * 8, gp2 + 3 * 8, t); // [rsp + 6 * 8] <- ac + + mov(pz, ptr[rsp + zPos]); + vmovq(d, xm0); + for (int i = 0; i < 6; i++) { + mov(a, ptr[pz + (3 + i) * 8]); + if (i == 0) { + add(a, ptr[rsp + abcdPos + i * 8]); + } else { + adc(a, ptr[rsp + abcdPos + i * 8]); + } + mov(ptr[pz + (3 + i) * 8], a); + } + mov(a, ptr[pz + 9 * 8]); + adc(a, d); + mov(ptr[pz + 9 * 8], a); + jnc("@f"); + for (int i = 10; i < 12; i++) { + mov(a, ptr[pz + i * 8]); + adc(a, 0); + mov(ptr[pz + i * 8], a); + } + L("@@"); + add(rsp, stackSize); +#else const Reg64& t4 = t[4]; const Reg64& t5 = t[5]; const Reg64& t6 = t[6]; @@ -1911,6 +2249,7 @@ private: mulPackAdd(pz + 8 * 4, px + 8 * 4, py, t2, Pack(t1, t0, t6, t5, t4, t3)); // [t2:t1:t0:t6:t5:t4] mulPackAdd(pz + 8 * 5, px + 8 * 5, py, t3, Pack(t2, t1, t0, t6, t5, t4)); // [t3:t2:t1:t0:t6:t5] store_mr(pz + 8 * 6, Pack(t3, t2, t1, t0, t6, t5)); +#endif } /* @input (z, xy) @@ -1934,11 +2273,11 @@ private: const Reg64& a = rax; const Reg64& d = rdx; - movq(xm0, z); + vmovq(xm0, z); mov(z, ptr [xy + 0 * 8]); mov(a, rp_); mul(z); - lea(t0, ptr [rip + *pL_]); + lea(t0, ptr [rip + pL_]); load_rm(Pack(t7, t6, t5, t4, t3, t2, t1), xy); mov(d, a); // q mulPackAddShr(Pack(t7, t6, t5, t4, t3, t2, t1), t0, t10); @@ -1951,32 +2290,32 @@ private: // z = [t1:t0:t10:t9:t8:t7:t6:t5:t4:t3:t2] mov(a, rp_); mul(t2); - movq(xm1, t0); // save - lea(t0, ptr [rip + *pL_]); + vmovq(xm1, t0); // save + lea(t0, ptr [rip + pL_]); mov(d, a); - movq(xm2, t10); + vmovq(xm2, t10); mulPackAddShr(Pack(t8, t7, t6, t5, t4, t3, t2), t0, t10); - movq(t10, xm2); + vmovq(t10, xm2); adc(t9, rax); adc(t10, rax); - movq(t0, xm1); // load + vmovq(t0, xm1); // load adc(t0, rax); adc(t1, rax); // z = [t1:t0:t10:t9:t8:t7:t6:t5:t4:t3] mov(a, rp_); mul(t3); - lea(t2, ptr [rip + *pL_]); + lea(t2, ptr [rip + pL_]); mov(d, a); - movq(xm2, t10); + vmovq(xm2, t10); mulPackAddShr(Pack(t9, t8, t7, t6, t5, t4, t3), t2, t10); - movq(t10, xm2); + vmovq(t10, xm2); adc(t10, rax); adc(t0, rax); adc(t1, rax); // z = [t1:t0:t10:t9:t8:t7:t6:t5:t4] mov(a, rp_); mul(t4); - lea(t2, ptr [rip + *pL_]); + lea(t2, ptr [rip + pL_]); mov(d, a); mulPackAddShr(Pack(t10, t9, t8, t7, t6, t5, t4), t2, t3); adc(t0, rax); @@ -1984,14 +2323,14 @@ private: // z = [t1:t0:t10:t9:t8:t7:t6:t5] mov(a, rp_); mul(t5); - lea(t2, ptr [rip + *pL_]); + lea(t2, ptr [rip + pL_]); mov(d, a); mulPackAddShr(Pack(t0, t10, t9, t8, t7, t6, t5), t2, t3); adc(t1, a); // z = [t1:t0:t10:t9:t8:t7:t6] mov(a, rp_); mul(t6); - lea(t2, ptr [rip + *pL_]); + lea(t2, ptr [rip + pL_]); mov(d, a); mulPackAddShr(Pack(t1, t0, t10, t9, t8, t7, t6), t2, t3, true); // z = [t1:t0:t10:t9:t8:t7] @@ -2000,47 +2339,87 @@ private: mov_rr(keep, zp); sub_rm(zp, t2); // z -= p cmovc_rr(zp, keep); - movq(z, xm0); + vmovq(z, xm0); store_mr(z, zp); } - void gen_fpDbl_sqrPre(mcl::fp::Op& op) + void2u gen_fpDbl_sqrPre() { - if (useMulx_ && pn_ == 2) { + align(16); + void2u func = getCurr<void2u>(); + if (pn_ == 2 && useMulx_) { StackFrame sf(this, 2, 7 | UseRDX); sqrPre2(sf.p[0], sf.p[1], sf.t); - return; + return func; } if (pn_ == 3) { - StackFrame sf(this, 2, 10 | UseRDX | UseRCX); - sqrPre3(sf.p[0], sf.p[1], sf.t); - return; + StackFrame sf(this, 3, 10 | UseRDX); + Pack t = sf.t; + t.append(sf.p[2]); + sqrPre3(sf.p[0], sf.p[1], t); + return func; } - if (useMulx_ && pn_ == 4) { - StackFrame sf(this, 2, 10 | UseRDX | UseRCX); - sqrPre4(sf.p[0], sf.p[1], sf.t); - return; + if (pn_ == 4 && useMulx_) { + StackFrame sf(this, 3, 10 | UseRDX); + Pack t = sf.t; + t.append(sf.p[2]); + sqrPre4(sf.p[0], sf.p[1], t); + return func; + } + if (pn_ == 6 && useMulx_ && useAdx_) { + StackFrame sf(this, 3, 10 | UseRDX, 6 * 8); + Pack t = sf.t; + t.append(sf.p[2]); + sqrPre6(sf.p[0], sf.p[1], t); + return func; } + return 0; +#if 0 #ifdef XBYAK64_WIN mov(r8, rdx); #else mov(rdx, rsi); #endif jmp((void*)op.fpDbl_mulPreA_); + return func; +#endif } - void gen_fpDbl_mulPre() + void3u gen_fpDbl_mulPre() { - if (useMulx_ && pn_ == 2) { + align(16); + void3u func = getCurr<void3u>(); + if (pn_ == 2 && useMulx_) { StackFrame sf(this, 3, 5 | UseRDX); mulPre2(sf.p[0], sf.p[1], sf.p[2], sf.t); - return; + return func; } if (pn_ == 3) { StackFrame sf(this, 3, 10 | UseRDX); mulPre3(sf.p[0], sf.p[1], sf.p[2], sf.t); - return; + return func; } - assert(0); - exit(1); + if (pn_ == 4) { + /* + fpDbl_mulPre is available as C function + this function calls mulPreL directly. + */ + StackFrame sf(this, 3, 10 | UseRDX, 0, false); + mulPre4(gp0, gp1, gp2, sf.t); + sf.close(); // make epilog + L(mulPreL); // called only from asm code + mulPre4(gp0, gp1, gp2, sf.t); + ret(); + return func; + } + if (pn_ == 6 && useAdx_) { + StackFrame sf(this, 3, 10 | UseRDX, 0, false); + call(mulPreL); + sf.close(); // make epilog + L(mulPreL); // called only from asm code + mulPre6(sf.t); + ret(); + return func; + } + return 0; } static inline void debug_put_inner(const uint64_t *ptr, int n) { @@ -2250,10 +2629,10 @@ private: { if (n >= 10) exit(1); static uint64_t buf[10]; - movq(xm0, rax); + vmovq(xm0, rax); mov(rax, (size_t)buf); store_mp(rax, mp, t); - movq(rax, xm0); + vmovq(rax, xm0); push(rax); mov(rax, (size_t)buf); debug_put(rax, n); @@ -2270,11 +2649,8 @@ private: */ void gen_preInv() { - assert(pn_ >= 1); + assert(1 <= pn_ && pn_ <= 4); const int freeRegNum = 13; - if (pn_ > 9) { - throw cybozu::Exception("mcl:FpGenerator:gen_preInv:large pn_") << pn_; - } StackFrame sf(this, 2, 10 | UseRDX | UseRCX, (std::max<int>(0, pn_ * 5 - freeRegNum) + 1 + (isFullBit_ ? 1 : 0)) * 8); const Reg64& pr = sf.p[0]; const Reg64& px = sf.p[1]; @@ -2307,7 +2683,7 @@ private: mov(rax, px); // px is free frome here load_mp(vv, rax, t); // v = x - mov(rax, (size_t)p_); + mov(rax, pL_); load_mp(uu, rax, t); // u = p_ // k = 0 xor_(rax, rax); @@ -2324,46 +2700,6 @@ private: } else { mov(qword [ss.getMem(0)], 1); } -#if 0 - L(".lp"); - or_mp(vv, t); - jz(".exit", T_NEAR); - - g_test(uu[0], 1); - jz(".u_even", T_NEAR); - g_test(vv[0], 1); - jz(".v_even", T_NEAR); - for (int i = pn_ - 1; i >= 0; i--) { - g_cmp(vv[i], uu[i], t); - jc(".v_lt_u", T_NEAR); - if (i > 0) jnz(".v_ge_u", T_NEAR); - } - - L(".v_ge_u"); - sub_mp(vv, uu, t); - add_mp(ss, rr, t); - L(".v_even"); - shr_mp(vv, 1, t); - twice_mp(rr, t); - if (isFullBit_) { - sbb(t, t); - mov(ptr [rTop], t); - } - inc(rax); - jmp(".lp", T_NEAR); - L(".v_lt_u"); - sub_mp(uu, vv, t); - add_mp(rr, ss, t); - if (isFullBit_) { - sbb(t, t); - mov(ptr [rTop], t); - } - L(".u_even"); - shr_mp(uu, 1, t); - twice_mp(ss, t); - inc(rax); - jmp(".lp", T_NEAR); -#else for (int cn = pn_; cn > 0; cn--) { const std::string _lp = mkLabel(".lp", cn); const std::string _u_v_odd = mkLabel(".u_v_odd", cn); @@ -2420,13 +2756,12 @@ private: uu.removeLast(); } } -#endif L(".exit"); assert(ss.isReg(0)); const Reg64& t2 = ss.getReg(0); const Reg64& t3 = rdx; - mov(t2, (size_t)p_); + mov(t2, pL_); if (isFullBit_) { mov(t, ptr [rTop]); test(t, t); @@ -3057,7 +3392,7 @@ private: mul4x1(px, y, t3, t2, t1, t0, t4); // [rdx:y:t2:t1:t0] = px[3..0] * y if (isFullBit_) { - movq(xt, px); + vmovq(xt, px); xor_(px, px); } add_rr(Pack(c4, y, c2, c1, c0), Pack(rdx, c3, t2, t1, t0)); @@ -3081,13 +3416,19 @@ private: adc(c0, 0); } else { adc(c0, px); - movq(px, xt); + vmovq(px, xt); } } } - void gen_fp2Dbl_mulPre(Label& mulPreL) + void3u gen_fp2Dbl_mulPre() { - assert(!isFullBit_); + if (isFullBit_) return 0; +// if (pn_ != 4 && !(pn_ == 6 && useMulx_ && useAdx_)) return 0; + // almost same for pn_ == 6 + if (pn_ != 4) return 0; + align(16); + void3u func = getCurr<void3u>(); + const RegExp z = rsp + 0 * 8; const RegExp x = rsp + 1 * 8; const RegExp y = rsp + 2 * 8; @@ -3100,9 +3441,9 @@ private: mov(ptr [x], gp1); mov(ptr [y], gp2); // s = a + b - gen_raw_add(s, gp1, gp1 + FpByte_, rax, 4); + gen_raw_add(s, gp1, gp1 + FpByte_, rax, pn_); // t = c + d - gen_raw_add(t, gp2, gp2 + FpByte_, rax, 4); + gen_raw_add(t, gp2, gp2 + FpByte_, rax, pn_); // d1 = (a + b)(c + d) mov(gp0, ptr [z]); add(gp0, FpByte_ * 2); // d1 @@ -3127,20 +3468,33 @@ private: add(gp0, FpByte_ * 2); // d1 mov(gp1, gp0); mov(gp2, ptr [z]); - gen_raw_sub(gp0, gp1, gp2, rax, 8); + gen_raw_sub(gp0, gp1, gp2, rax, pn_ * 2); lea(gp2, ptr [d2]); - gen_raw_sub(gp0, gp1, gp2, rax, 8); + gen_raw_sub(gp0, gp1, gp2, rax, pn_ * 2); mov(gp0, ptr [z]); mov(gp1, gp0); lea(gp2, ptr [d2]); - gen_raw_sub(gp0, gp1, gp2, rax, 4); - gen_raw_fp_sub(gp0 + 8 * 4, gp1 + 8 * 4, gp2 + 8 * 4, Pack(gt0, gt1, gt2, gt3, gt4, gt5, gt6, gt7), true); + gen_raw_sub(gp0, gp1, gp2, rax, pn_); + if (pn_ == 4) { + gen_raw_fp_sub(gp0 + pn_ * 8, gp1 + pn_ * 8, gp2 + pn_ * 8, Pack(gt0, gt1, gt2, gt3, gt4, gt5, gt6, gt7), true); + } else { + assert(pn_ == 6); + gen_raw_fp_sub6(gp0, gp1, gp2, pn_ * 8, sf.t.sub(0, 6), true); + } + return func; } - void gen_fp2Dbl_sqrPre(Label& mulPreL) + void2u gen_fp2Dbl_sqrPre() { - assert(!isFullBit_); + if (isFullBit_) return 0; +// if (pn_ != 4 && !(pn_ == 6 && useMulx_ && useAdx_)) return 0; + // almost same for pn_ == 6 + if (pn_ != 4) return 0; + align(16); + void2u func = getCurr<void2u>(); + // almost same for pn_ == 6 + if (pn_ != 4) return 0; const RegExp y = rsp + 0 * 8; const RegExp x = rsp + 1 * 8; const Ext1 t1(FpByte_, rsp, 2 * 8); @@ -3149,10 +3503,15 @@ private: StackFrame sf(this, 3 /* not 2 */, 10 | UseRDX, t2.next); mov(ptr [y], gp0); mov(ptr [x], gp1); - const Pack a = sf.t.sub(0, 4); - const Pack b = sf.t.sub(4, 4); + Pack t = sf.t; + if (pn_ == 6) { + t.append(rax); + t.append(rdx); + } + const Pack a = t.sub(0, pn_); + const Pack b = t.sub(pn_, pn_); load_rm(b, gp1 + FpByte_); - for (int i = 0; i < 4; i++) { + for (int i = 0; i < pn_; i++) { mov(rax, b[i]); if (i == 0) { add(rax, rax); @@ -3170,11 +3529,17 @@ private: mov(gp2, ptr [x]); call(mulPreL); mov(gp0, ptr [x]); - gen_raw_fp_sub(t1, gp0, gp0 + FpByte_, sf.t, false); + if (pn_ == 4) { + gen_raw_fp_sub(t1, gp0, gp0 + FpByte_, sf.t, false); + } else { + assert(pn_ == 6); + gen_raw_fp_sub6(t1, gp0, gp0, FpByte_, a, false); + } mov(gp0, ptr [y]); lea(gp1, ptr [t1]); lea(gp2, ptr [t2]); call(mulPreL); + return func; } void gen_fp2_add4() { @@ -3183,6 +3548,61 @@ private: gen_raw_fp_add(sf.p[0], sf.p[1], sf.p[2], sf.t, false); gen_raw_fp_add(sf.p[0] + FpByte_, sf.p[1] + FpByte_, sf.p[2] + FpByte_, sf.t, false); } + void gen_fp2_add6() + { + assert(!isFullBit_); + StackFrame sf(this, 3, 10); + const Reg64& pz = sf.p[0]; + const Reg64& px = sf.p[1]; + const Reg64& py = sf.p[2]; + Pack t1 = sf.t.sub(0, 6); + Pack t2 = sf.t.sub(6); + t2.append(rax); + t2.append(px); // destory after used + vmovq(xm0, px); + gen_raw_fp_add6(pz, px, py, t1, t2, false); + vmovq(px, xm0); + gen_raw_fp_add6(pz + FpByte_, px + FpByte_, py + FpByte_, t1, t2, false); + } + void gen_fp2_sub6() + { + StackFrame sf(this, 3, 5); + const Reg64& pz = sf.p[0]; + const Reg64& px = sf.p[1]; + const Reg64& py = sf.p[2]; + Pack t = sf.t; + t.append(rax); + gen_raw_fp_sub6(pz, px, py, 0, t, false); + gen_raw_fp_sub6(pz, px, py, FpByte_, t, false); + } + void3u gen_fp2_add() + { + align(16); + void3u func = getCurr<void3u>(); + if (pn_ == 4 && !isFullBit_) { + gen_fp2_add4(); + return func; + } + if (pn_ == 6 && !isFullBit_) { + gen_fp2_add6(); + return func; + } + return 0; + } + void3u gen_fp2_sub() + { + align(16); + void3u func = getCurr<void3u>(); + if (pn_ == 4 && !isFullBit_) { + gen_fp2_sub4(); + return func; + } + if (pn_ == 6 && !isFullBit_) { + gen_fp2_sub6(); + return func; + } + return 0; + } void gen_fp2_sub4() { assert(!isFullBit_); @@ -3198,14 +3618,17 @@ private: void gen_fp2_mul_xi4() { assert(!isFullBit_); -#if 0 - StackFrame sf(this, 2, 10 | UseRDX | UseRCX); + StackFrame sf(this, 2, 11 | UseRDX); + const Reg64& py = sf.p[0]; + const Reg64& px = sf.p[1]; Pack a = sf.t.sub(0, 4); Pack b = sf.t.sub(4, 4); - Pack t(rdx, rcx, sf.t[8], sf.t[9]); - load_rm(a, sf.p[1]); - load_rm(b, sf.p[1] + FpByte_); - for (int i = 0; i < 4; i++) { + Pack t = sf.t.sub(8); + t.append(rdx); + assert(t.size() == 4); + load_rm(a, px); + load_rm(b, px + FpByte_); + for (int i = 0; i < pn_; i++) { mov(t[i], a[i]); if (i == 0) { add(t[i], b[i]); @@ -3214,41 +3637,78 @@ private: } } sub_rr(a, b); - mov(rax, (size_t)p_); + mov(rax, pL_); load_rm(b, rax); sbb(rax, rax); - for (int i = 0; i < 4; i++) { + for (int i = 0; i < pn_; i++) { and_(b[i], rax); } add_rr(a, b); - store_mr(sf.p[0], a); - mov(rax, (size_t)p_); + store_mr(py, a); + mov(rax, pL_); mov_rr(a, t); sub_rm(t, rax); - for (int i = 0; i < 4; i++) { - cmovc(t[i], a[i]); + cmovc_rr(t, a); + store_mr(py + FpByte_, t); + } + void gen_fp2_mul_xi6() + { + assert(!isFullBit_); + StackFrame sf(this, 2, 12); + const Reg64& py = sf.p[0]; + const Reg64& px = sf.p[1]; + Pack a = sf.t.sub(0, 6); + Pack b = sf.t.sub(6); + load_rm(a, px); + mov_rr(b, a); + add_rm(b, px + FpByte_); + sub_rm(a, px + FpByte_); + mov(rax, pL_); + jnc("@f"); + add_rm(a, rax); + L("@@"); + store_mr(py, a); + mov_rr(a, b); + sub_rm(b, rax); + cmovc_rr(b, a); + store_mr(py + FpByte_, b); + } + void2u gen_fp2_mul_xi() + { + if (isFullBit_) return 0; + if (op_->xi_a != 1) return 0; + align(16); + void2u func = getCurr<void2u>(); + if (pn_ == 4) { + gen_fp2_mul_xi4(); + return func; } - store_mr(sf.p[0] + FpByte_, t); -#else - StackFrame sf(this, 2, 8, 8 * 4); - gen_raw_fp_add(rsp, sf.p[1], sf.p[1] + FpByte_, sf.t, false); - gen_raw_fp_sub(sf.p[0], sf.p[1], sf.p[1] + FpByte_, sf.t, false); - for (int i = 0; i < 4; i++) { - mov(rax, ptr [rsp + i * 8]); - mov(ptr[sf.p[0] + FpByte_ + i * 8], rax); + if (pn_ == 6) { + gen_fp2_mul_xi6(); + return func; } -#endif + return 0; } - void gen_fp2_neg4() + void2u gen_fp2_neg() { - assert(!isFullBit_); - StackFrame sf(this, 2, UseRDX | pn_); - gen_raw_neg(sf.p[0], sf.p[1], sf.t); - gen_raw_neg(sf.p[0] + FpByte_, sf.p[1] + FpByte_, sf.t); + align(16); + void2u func = getCurr<void2u>(); + if (pn_ <= 6) { + StackFrame sf(this, 2, UseRDX | pn_); + gen_raw_neg(sf.p[0], sf.p[1], sf.t); + gen_raw_neg(sf.p[0] + FpByte_, sf.p[1] + FpByte_, sf.t); + return func; + } + return 0; } - void gen_fp2_mul4(Label& fpDbl_modL) + void3u gen_fp2_mul() { - assert(!isFullBit_); + if (isFullBit_) return 0; + if (pn_ != 4 && !(pn_ == 6 && useMulx_ && useAdx_)) return 0; + align(16); + void3u func = getCurr<void3u>(); + bool embedded = pn_ == 4; + const RegExp z = rsp + 0 * 8; const RegExp x = rsp + 1 * 8; const RegExp y = rsp + 2 * 8; @@ -3263,27 +3723,50 @@ private: mov(ptr[x], gp1); mov(ptr[y], gp2); // s = a + b - gen_raw_add(s, gp1, gp1 + FpByte_, rax, 4); + gen_raw_add(s, gp1, gp1 + FpByte_, rax, pn_); // t = c + d - gen_raw_add(t, gp2, gp2 + FpByte_, rax, 4); + gen_raw_add(t, gp2, gp2 + FpByte_, rax, pn_); // d1 = (a + b)(c + d) - mulPre4(d1, s, t, sf.t); + if (embedded) { + mulPre4(d1, s, t, sf.t); + } else { + lea(gp0, ptr [d1]); + lea(gp1, ptr [s]); + lea(gp2, ptr [t]); + call(mulPreL); + } // d0 = a c mov(gp1, ptr [x]); mov(gp2, ptr [y]); - mulPre4(d0, gp1, gp2, sf.t); + if (embedded) { + mulPre4(d0, gp1, gp2, sf.t); + } else { + lea(gp0, ptr [d0]); + call(mulPreL); + } // d2 = b d mov(gp1, ptr [x]); add(gp1, FpByte_); mov(gp2, ptr [y]); add(gp2, FpByte_); - mulPre4(d2, gp1, gp2, sf.t); + if (embedded) { + mulPre4(d2, gp1, gp2, sf.t); + } else { + lea(gp0, ptr [d2]); + call(mulPreL); + } - gen_raw_sub(d1, d1, d0, rax, 8); - gen_raw_sub(d1, d1, d2, rax, 8); + gen_raw_sub(d1, d1, d0, rax, pn_ * 2); + gen_raw_sub(d1, d1, d2, rax, pn_ * 2); - gen_raw_sub(d0, d0, d2, rax, 4); - gen_raw_fp_sub((RegExp)d0 + 8 * 4, (RegExp)d0 + 8 * 4, (RegExp)d2 + 8 * 4, Pack(gt0, gt1, gt2, gt3, gt4, gt5, gt6, gt7), true); + gen_raw_sub(d0, d0, d2, rax, pn_); + if (pn_ == 4) { + gen_raw_fp_sub((RegExp)d0 + pn_ * 8, (RegExp)d0 + pn_ * 8, (RegExp)d2 + pn_ * 8, Pack(gt0, gt1, gt2, gt3, gt4, gt5, gt6, gt7), true); + } else { + lea(gp0, ptr[d0]); + lea(gp2, ptr[d2]); + gen_raw_fp_sub6(gp0, gp0, gp2, pn_ * 8, sf.t.sub(0, 6), true); + } mov(gp0, ptr [z]); lea(gp1, ptr[d0]); @@ -3293,10 +3776,15 @@ private: add(gp0, FpByte_); lea(gp1, ptr[d1]); call(fpDbl_modL); + return func; } - void gen_fp2_sqr4(Label& fp_mulL) + void2u gen_fp2_sqr() { - assert(!isFullBit_); + if (isFullBit_) return 0; + if (pn_ != 4 && !(pn_ == 6 && useMulx_ && useAdx_)) return 0; + align(16); + void2u func = getCurr<void2u>(); + const RegExp y = rsp + 0 * 8; const RegExp x = rsp + 1 * 8; const Ext1 t1(FpByte_, rsp, 2 * 8); @@ -3309,7 +3797,7 @@ private: // t1 = b + b lea(gp0, ptr [t1]); if (nocarry) { - for (int i = 0; i < 4; i++) { + for (int i = 0; i < pn_; i++) { mov(rax, ptr [gp1 + FpByte_ + i * 8]); if (i == 0) { add(rax, rax); @@ -3319,7 +3807,15 @@ private: mov(ptr [gp0 + i * 8], rax); } } else { - gen_raw_fp_add(gp0, gp1 + FpByte_, gp1 + FpByte_, sf.t, false); + if (pn_ == 4) { + gen_raw_fp_add(gp0, gp1 + FpByte_, gp1 + FpByte_, sf.t, false); + } else { + assert(pn_ == 6); + Pack t = sf.t.sub(6, 4); + t.append(rax); + t.append(rdx); + gen_raw_fp_add6(gp0, gp1 + FpByte_, gp1 + FpByte_, sf.t.sub(0, 6), t, false); + } } // t1 = 2ab mov(gp1, gp0); @@ -3327,13 +3823,16 @@ private: call(fp_mulL); if (nocarry) { - Pack a = sf.t.sub(0, 4); - Pack b = sf.t.sub(4, 4); + Pack t = sf.t; + t.append(rdx); + t.append(gp1); + Pack a = t.sub(0, pn_); + Pack b = t.sub(pn_, pn_); mov(gp0, ptr [x]); load_rm(a, gp0); load_rm(b, gp0 + FpByte_); // t2 = a + b - for (int i = 0; i < 4; i++) { + for (int i = 0; i < pn_; i++) { mov(rax, a[i]); if (i == 0) { add(rax, b[i]); @@ -3343,14 +3842,24 @@ private: mov(ptr [(RegExp)t2 + i * 8], rax); } // t3 = a + p - b - mov(gp1, (size_t)p_); - add_rm(a, gp1); + mov(rax, pL_); + add_rm(a, rax); sub_rr(a, b); store_mr(t3, a); } else { mov(gp0, ptr [x]); - gen_raw_fp_add(t2, gp0, gp0 + FpByte_, sf.t, false); - gen_raw_fp_sub(t3, gp0, gp0 + FpByte_, sf.t, false); + if (pn_ == 4) { + gen_raw_fp_add(t2, gp0, gp0 + FpByte_, sf.t, false); + gen_raw_fp_sub(t3, gp0, gp0 + FpByte_, sf.t, false); + } else { + assert(pn_ == 6); + Pack p1 = sf.t.sub(0, 6); + Pack p2 = sf.t.sub(6, 4); + p2.append(rax); + p2.append(rdx); + gen_raw_fp_add6(t2, gp0, gp0 + FpByte_, p1, p2, false); + gen_raw_fp_sub6(t3, gp0, gp0 + FpByte_, 0, p1, false); + } } mov(gp0, ptr [y]); @@ -3358,10 +3867,11 @@ private: lea(gp2, ptr [t3]); call(fp_mulL); mov(gp0, ptr [y]); - for (int i = 0; i < 4; i++) { + for (int i = 0; i < pn_; i++) { mov(rax, ptr [(RegExp)t1 + i * 8]); mov(ptr [gp0 + FpByte_ + i * 8], rax); } + return func; } }; |