aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--Makefile75
-rw-r--r--accounts/accounts_test.go2
-rw-r--r--cmd/geth/js.go2
-rw-r--r--cmd/geth/main.go7
-rw-r--r--cmd/utils/flags.go32
-rw-r--r--common/types.go4
-rw-r--r--crypto/crypto.go26
-rw-r--r--crypto/crypto_test.go8
-rw-r--r--crypto/ecies/asn1.go7
-rw-r--r--crypto/ecies/ecies.go1
-rw-r--r--crypto/ecies/ecies_test.go121
-rw-r--r--crypto/ecies/params.go14
-rw-r--r--crypto/key.go5
-rw-r--r--crypto/secp256k1/curve.go (renamed from crypto/curve.go)168
-rw-r--r--crypto/secp256k1/curve_test.go39
-rw-r--r--crypto/secp256k1/pubkey_scalar_mul.h56
-rw-r--r--crypto/secp256k1/secp256.go24
-rw-r--r--crypto/secp256k1/secp256_test.go33
-rw-r--r--eth/backend.go2
-rw-r--r--p2p/discover/node.go2
-rw-r--r--p2p/rlpx.go4
-rw-r--r--p2p/rlpx_test.go2
-rw-r--r--trie/arc.go12
-rw-r--r--trie/errors.go41
-rw-r--r--trie/iterator.go19
-rw-r--r--trie/proof.go11
-rw-r--r--trie/secure_trie.go49
-rw-r--r--trie/trie.go204
-rw-r--r--trie/trie_test.go75
-rw-r--r--whisper/message_test.go9
30 files changed, 743 insertions, 311 deletions
diff --git a/Makefile b/Makefile
index 5fc5e9a9d..76cb819b6 100644
--- a/Makefile
+++ b/Makefile
@@ -3,11 +3,12 @@
# don't need to bother with make.
.PHONY: geth geth-cross evm all test travis-test-with-coverage xgo clean
-.PHONY: geth-linux geth-linux-arm geth-linux-386 geth-linux-amd64
+.PHONY: geth-linux geth-linux-386 geth-linux-amd64
+.PHONY: geth-linux-arm geth-linux-arm-5 geth-linux-arm-6 geth-linux-arm-7 geth-linux-arm64
.PHONY: geth-darwin geth-darwin-386 geth-darwin-amd64
.PHONY: geth-windows geth-windows-386 geth-windows-amd64
-.PHONY: geth-android geth-android-16 geth-android-21
-.PHONY: geth-ios geth-ios-5.0 geth-ios-8.1
+.PHONY: geth-android
+.PHONY: geth-ios geth-ios-arm-7 geth-ios-arm64
GOBIN = build/bin
@@ -20,19 +21,14 @@ geth:
@echo "Done building."
@echo "Run \"$(GOBIN)/geth\" to launch geth."
-geth-cross: geth-linux geth-darwin geth-windows geth-android
+geth-cross: geth-linux geth-darwin geth-windows geth-android geth-ios
@echo "Full cross compilation done:"
@ls -l $(GOBIN)/geth-*
-geth-linux: xgo geth-linux-arm geth-linux-386 geth-linux-amd64
+geth-linux: geth-linux-386 geth-linux-amd64 geth-linux-arm
@echo "Linux cross compilation done:"
@ls -l $(GOBIN)/geth-linux-*
-geth-linux-arm: xgo
- build/env.sh $(GOBIN)/xgo --go=$(GO) --buildmode=$(MODE) --dest=$(GOBIN) --deps=$(CROSSDEPS) --targets=linux/arm -v $(shell build/flags.sh) ./cmd/geth
- @echo "Linux ARM cross compilation done:"
- @ls -l $(GOBIN)/geth-linux-* | grep arm
-
geth-linux-386: xgo
build/env.sh $(GOBIN)/xgo --go=$(GO) --buildmode=$(MODE) --dest=$(GOBIN) --deps=$(CROSSDEPS) --targets=linux/386 -v $(shell build/flags.sh) ./cmd/geth
@echo "Linux 386 cross compilation done:"
@@ -43,7 +39,31 @@ geth-linux-amd64: xgo
@echo "Linux amd64 cross compilation done:"
@ls -l $(GOBIN)/geth-linux-* | grep amd64
-geth-darwin: xgo geth-darwin-386 geth-darwin-amd64
+geth-linux-arm: geth-linux-arm-5 geth-linux-arm-6 geth-linux-arm-7 geth-linux-arm64
+ @echo "Linux ARM cross compilation done:"
+ @ls -l $(GOBIN)/geth-linux-* | grep arm
+
+geth-linux-arm-5: xgo
+ build/env.sh $(GOBIN)/xgo --go=$(GO) --buildmode=$(MODE) --dest=$(GOBIN) --deps=$(CROSSDEPS) --targets=linux/arm-5 -v $(shell build/flags.sh) ./cmd/geth
+ @echo "Linux ARMv5 cross compilation done:"
+ @ls -l $(GOBIN)/geth-linux-* | grep arm-5
+
+geth-linux-arm-6: xgo
+ build/env.sh $(GOBIN)/xgo --go=$(GO) --buildmode=$(MODE) --dest=$(GOBIN) --deps=$(CROSSDEPS) --targets=linux/arm-6 -v $(shell build/flags.sh) ./cmd/geth
+ @echo "Linux ARMv6 cross compilation done:"
+ @ls -l $(GOBIN)/geth-linux-* | grep arm-6
+
+geth-linux-arm-7: xgo
+ build/env.sh $(GOBIN)/xgo --go=$(GO) --buildmode=$(MODE) --dest=$(GOBIN) --deps=$(CROSSDEPS) --targets=linux/arm-7 -v $(shell build/flags.sh) ./cmd/geth
+ @echo "Linux ARMv7 cross compilation done:"
+ @ls -l $(GOBIN)/geth-linux-* | grep arm-7
+
+geth-linux-arm64: xgo
+ build/env.sh $(GOBIN)/xgo --go=$(GO) --buildmode=$(MODE) --dest=$(GOBIN) --deps=$(CROSSDEPS) --targets=linux/arm64 -v $(shell build/flags.sh) ./cmd/geth
+ @echo "Linux ARM64 cross compilation done:"
+ @ls -l $(GOBIN)/geth-linux-* | grep arm64
+
+geth-darwin: geth-darwin-386 geth-darwin-amd64
@echo "Darwin cross compilation done:"
@ls -l $(GOBIN)/geth-darwin-*
@@ -57,7 +77,7 @@ geth-darwin-amd64: xgo
@echo "Darwin amd64 cross compilation done:"
@ls -l $(GOBIN)/geth-darwin-* | grep amd64
-geth-windows: xgo geth-windows-386 geth-windows-amd64
+geth-windows: geth-windows-386 geth-windows-amd64
@echo "Windows cross compilation done:"
@ls -l $(GOBIN)/geth-windows-*
@@ -71,33 +91,24 @@ geth-windows-amd64: xgo
@echo "Windows amd64 cross compilation done:"
@ls -l $(GOBIN)/geth-windows-* | grep amd64
-geth-android: xgo geth-android-16 geth-android-21
+geth-android: xgo
+ build/env.sh $(GOBIN)/xgo --go=$(GO) --buildmode=$(MODE) --dest=$(GOBIN) --deps=$(CROSSDEPS) --targets=android/* -v $(shell build/flags.sh) ./cmd/geth
@echo "Android cross compilation done:"
@ls -l $(GOBIN)/geth-android-*
-geth-android-16: xgo
- build/env.sh $(GOBIN)/xgo --go=$(GO) --buildmode=$(MODE) --dest=$(GOBIN) --deps=$(CROSSDEPS) --targets=android-16/* -v $(shell build/flags.sh) ./cmd/geth
- @echo "Android 16 cross compilation done:"
- @ls -l $(GOBIN)/geth-android-16-*
-
-geth-android-21: xgo
- build/env.sh $(GOBIN)/xgo --go=$(GO) --buildmode=$(MODE) --dest=$(GOBIN) --deps=$(CROSSDEPS) --targets=android-21/* -v $(shell build/flags.sh) ./cmd/geth
- @echo "Android 21 cross compilation done:"
- @ls -l $(GOBIN)/geth-android-21-*
-
-geth-ios: xgo geth-ios-5.0 geth-ios-8.1
+geth-ios: geth-ios-arm-7 geth-ios-arm64
@echo "iOS cross compilation done:"
@ls -l $(GOBIN)/geth-ios-*
-geth-ios-5.0:
- build/env.sh $(GOBIN)/xgo --go=$(GO) --buildmode=$(MODE) --dest=$(GOBIN) --deps=$(CROSSDEPS) --depsargs=--disable-assembly --targets=ios-5.0/* -v $(shell build/flags.sh) ./cmd/geth
- @echo "iOS 5.0 cross compilation done:"
- @ls -l $(GOBIN)/geth-ios-5.0-*
+geth-ios-arm-7: xgo
+ build/env.sh $(GOBIN)/xgo --go=$(GO) --buildmode=$(MODE) --dest=$(GOBIN) --deps=$(CROSSDEPS) --depsargs=--disable-assembly --targets=ios/arm-7 -v $(shell build/flags.sh) ./cmd/geth
+ @echo "iOS ARMv7 cross compilation done:"
+ @ls -l $(GOBIN)/geth-ios-* | grep arm-7
-geth-ios-8.1:
- build/env.sh $(GOBIN)/xgo --go=$(GO) --buildmode=$(MODE) --dest=$(GOBIN) --deps=$(CROSSDEPS) --depsargs=--disable-assembly --targets=ios-8.1/* -v $(shell build/flags.sh) ./cmd/geth
- @echo "iOS 8.1 cross compilation done:"
- @ls -l $(GOBIN)/geth-ios-8.1-*
+geth-ios-arm64: xgo
+ build/env.sh $(GOBIN)/xgo --go=$(GO) --buildmode=$(MODE) --dest=$(GOBIN) --deps=$(CROSSDEPS) --depsargs=--disable-assembly --targets=ios-7.0/arm64 -v $(shell build/flags.sh) ./cmd/geth
+ @echo "iOS ARM64 cross compilation done:"
+ @ls -l $(GOBIN)/geth-ios-* | grep arm64
evm:
build/env.sh $(GOROOT)/bin/go install -v $(shell build/flags.sh) ./cmd/evm
diff --git a/accounts/accounts_test.go b/accounts/accounts_test.go
index d7a8a2b85..55ddecdea 100644
--- a/accounts/accounts_test.go
+++ b/accounts/accounts_test.go
@@ -68,7 +68,7 @@ func TestTimedUnlock(t *testing.T) {
}
// Signing fails again after automatic locking
- time.Sleep(150 * time.Millisecond)
+ time.Sleep(350 * time.Millisecond)
_, err = am.Sign(a1, testSigData)
if err != ErrLocked {
t.Fatal("Signing should've failed with ErrLocked timeout expired, got ", err)
diff --git a/cmd/geth/js.go b/cmd/geth/js.go
index 196f3af59..843c9a5b5 100644
--- a/cmd/geth/js.go
+++ b/cmd/geth/js.go
@@ -245,7 +245,7 @@ func (self *jsre) batch(statement string) {
func (self *jsre) welcome() {
self.re.Run(`
(function () {
- console.log('instance: ' + web3.version.client);
+ console.log('instance: ' + web3.version.node);
console.log(' datadir: ' + admin.datadir);
console.log("coinbase: " + eth.coinbase);
var ts = 1000 * eth.getBlock(eth.blockNumber).timestamp;
diff --git a/cmd/geth/main.go b/cmd/geth/main.go
index 3a5471845..6ec30cebc 100644
--- a/cmd/geth/main.go
+++ b/cmd/geth/main.go
@@ -464,9 +464,12 @@ func execScripts(ctx *cli.Context) {
node.Stop()
}
+// tries unlocking the specified account a few times.
func unlockAccount(ctx *cli.Context, accman *accounts.Manager, address string, i int, passwords []string) (common.Address, string) {
- // Try to unlock the specified account a few times
- account := utils.MakeAddress(accman, address)
+ account, err := utils.MakeAddress(accman, address)
+ if err != nil {
+ utils.Fatalf("Unlock error: %v", err)
+ }
for trials := 0; trials < 3; trials++ {
prompt := fmt.Sprintf("Unlocking account %s | Attempt %d/%d", address, trials+1, 3)
diff --git a/cmd/utils/flags.go b/cmd/utils/flags.go
index 53126f9e5..839ec3f02 100644
--- a/cmd/utils/flags.go
+++ b/cmd/utils/flags.go
@@ -518,47 +518,41 @@ func MakeAccountManager(ctx *cli.Context) *accounts.Manager {
// MakeAddress converts an account specified directly as a hex encoded string or
// a key index in the key store to an internal account representation.
-func MakeAddress(accman *accounts.Manager, account string) common.Address {
+func MakeAddress(accman *accounts.Manager, account string) (a common.Address, err error) {
// If the specified account is a valid address, return it
if common.IsHexAddress(account) {
- return common.HexToAddress(account)
+ return common.HexToAddress(account), nil
}
// Otherwise try to interpret the account as a keystore index
index, err := strconv.Atoi(account)
if err != nil {
- Fatalf("Invalid account address or index: '%s'", account)
+ return a, fmt.Errorf("invalid account address or index %q", account)
}
hex, err := accman.AddressByIndex(index)
if err != nil {
- Fatalf("Failed to retrieve requested account #%d: %v", index, err)
+ return a, fmt.Errorf("can't get account #%d (%v)", index, err)
}
- return common.HexToAddress(hex)
+ return common.HexToAddress(hex), nil
}
// MakeEtherbase retrieves the etherbase either from the directly specified
// command line flags or from the keystore if CLI indexed.
func MakeEtherbase(accman *accounts.Manager, ctx *cli.Context) common.Address {
- // If the specified etherbase is a valid address, return it
- etherbase := ctx.GlobalString(EtherbaseFlag.Name)
- if common.IsHexAddress(etherbase) {
- return common.HexToAddress(etherbase)
- }
- // If no etherbase was specified and no accounts are known, bail out
accounts, _ := accman.Accounts()
- if etherbase == "" && len(accounts) == 0 {
+ if !ctx.GlobalIsSet(EtherbaseFlag.Name) && len(accounts) == 0 {
glog.V(logger.Error).Infoln("WARNING: No etherbase set and no accounts found as default")
return common.Address{}
}
- // Otherwise try to interpret the parameter as a keystore index
- index, err := strconv.Atoi(etherbase)
- if err != nil {
- Fatalf("Invalid account address or index: '%s'", etherbase)
+ etherbase := ctx.GlobalString(EtherbaseFlag.Name)
+ if etherbase == "" {
+ return common.Address{}
}
- hex, err := accman.AddressByIndex(index)
+ // If the specified etherbase is a valid address, return it
+ addr, err := MakeAddress(accman, etherbase)
if err != nil {
- Fatalf("Failed to set requested account #%d as etherbase: %v", index, err)
+ Fatalf("Option %q: %v", EtherbaseFlag.Name, err)
}
- return common.HexToAddress(hex)
+ return addr
}
// MakeMinerExtra resolves extradata for the miner from the set command line flags
diff --git a/common/types.go b/common/types.go
index ea5838188..acbd5b28d 100644
--- a/common/types.go
+++ b/common/types.go
@@ -95,10 +95,10 @@ func HexToAddress(s string) Address { return BytesToAddress(FromHex(s)) }
// IsHexAddress verifies whether a string can represent a valid hex-encoded
// Ethereum address or not.
func IsHexAddress(s string) bool {
- if len(s) == 2+2*AddressLength && IsHex(s[2:]) {
+ if len(s) == 2+2*AddressLength && IsHex(s) {
return true
}
- if len(s) == 2*AddressLength && IsHex(s) {
+ if len(s) == 2*AddressLength && IsHex("0x"+s) {
return true
}
return false
diff --git a/crypto/crypto.go b/crypto/crypto.go
index 8685d62d3..7d7623753 100644
--- a/crypto/crypto.go
+++ b/crypto/crypto.go
@@ -43,14 +43,6 @@ import (
"golang.org/x/crypto/ripemd160"
)
-var secp256k1n *big.Int
-
-func init() {
- // specify the params for the s256 curve
- ecies.AddParamsForCurve(S256(), ecies.ECIES_AES128_SHA256)
- secp256k1n = common.String2Big("0xfffffffffffffffffffffffffffffffebaaedce6af48a03bbfd25e8cd0364141")
-}
-
func Sha3(data ...[]byte) []byte {
d := sha3.NewKeccak256()
for _, b := range data {
@@ -99,9 +91,9 @@ func ToECDSA(prv []byte) *ecdsa.PrivateKey {
}
priv := new(ecdsa.PrivateKey)
- priv.PublicKey.Curve = S256()
+ priv.PublicKey.Curve = secp256k1.S256()
priv.D = common.BigD(prv)
- priv.PublicKey.X, priv.PublicKey.Y = S256().ScalarBaseMult(prv)
+ priv.PublicKey.X, priv.PublicKey.Y = secp256k1.S256().ScalarBaseMult(prv)
return priv
}
@@ -116,15 +108,15 @@ func ToECDSAPub(pub []byte) *ecdsa.PublicKey {
if len(pub) == 0 {
return nil
}
- x, y := elliptic.Unmarshal(S256(), pub)
- return &ecdsa.PublicKey{S256(), x, y}
+ x, y := elliptic.Unmarshal(secp256k1.S256(), pub)
+ return &ecdsa.PublicKey{secp256k1.S256(), x, y}
}
func FromECDSAPub(pub *ecdsa.PublicKey) []byte {
if pub == nil || pub.X == nil || pub.Y == nil {
return nil
}
- return elliptic.Marshal(S256(), pub.X, pub.Y)
+ return elliptic.Marshal(secp256k1.S256(), pub.X, pub.Y)
}
// HexToECDSA parses a secp256k1 private key.
@@ -168,7 +160,7 @@ func SaveECDSA(file string, key *ecdsa.PrivateKey) error {
}
func GenerateKey() (*ecdsa.PrivateKey, error) {
- return ecdsa.GenerateKey(S256(), rand.Reader)
+ return ecdsa.GenerateKey(secp256k1.S256(), rand.Reader)
}
func ValidateSignatureValues(v byte, r, s *big.Int) bool {
@@ -176,7 +168,7 @@ func ValidateSignatureValues(v byte, r, s *big.Int) bool {
return false
}
vint := uint32(v)
- if r.Cmp(secp256k1n) < 0 && s.Cmp(secp256k1n) < 0 && (vint == 27 || vint == 28) {
+ if r.Cmp(secp256k1.N) < 0 && s.Cmp(secp256k1.N) < 0 && (vint == 27 || vint == 28) {
return true
} else {
return false
@@ -189,8 +181,8 @@ func SigToPub(hash, sig []byte) (*ecdsa.PublicKey, error) {
return nil, err
}
- x, y := elliptic.Unmarshal(S256(), s)
- return &ecdsa.PublicKey{S256(), x, y}, nil
+ x, y := elliptic.Unmarshal(secp256k1.S256(), s)
+ return &ecdsa.PublicKey{secp256k1.S256(), x, y}, nil
}
func Sign(hash []byte, prv *ecdsa.PrivateKey) (sig []byte, err error) {
diff --git a/crypto/crypto_test.go b/crypto/crypto_test.go
index fdd9c1ee8..d5e19a4bb 100644
--- a/crypto/crypto_test.go
+++ b/crypto/crypto_test.go
@@ -181,7 +181,7 @@ func TestValidateSignatureValues(t *testing.T) {
minusOne := big.NewInt(-1)
one := common.Big1
zero := common.Big0
- secp256k1nMinus1 := new(big.Int).Sub(secp256k1n, common.Big1)
+ secp256k1nMinus1 := new(big.Int).Sub(secp256k1.N, common.Big1)
// correct v,r,s
check(true, 27, one, one)
@@ -208,9 +208,9 @@ func TestValidateSignatureValues(t *testing.T) {
// correct sig with max r,s
check(true, 27, secp256k1nMinus1, secp256k1nMinus1)
// correct v, combinations of incorrect r,s at upper limit
- check(false, 27, secp256k1n, secp256k1nMinus1)
- check(false, 27, secp256k1nMinus1, secp256k1n)
- check(false, 27, secp256k1n, secp256k1n)
+ check(false, 27, secp256k1.N, secp256k1nMinus1)
+ check(false, 27, secp256k1nMinus1, secp256k1.N)
+ check(false, 27, secp256k1.N, secp256k1.N)
// current callers ensures r,s cannot be negative, but let's test for that too
// as crypto package could be used stand-alone
diff --git a/crypto/ecies/asn1.go b/crypto/ecies/asn1.go
index 6eaf3d2ca..40dabd329 100644
--- a/crypto/ecies/asn1.go
+++ b/crypto/ecies/asn1.go
@@ -41,6 +41,8 @@ import (
"fmt"
"hash"
"math/big"
+
+ "github.com/ethereum/go-ethereum/crypto/secp256k1"
)
var (
@@ -81,6 +83,7 @@ func doScheme(base, v []int) asn1.ObjectIdentifier {
type secgNamedCurve asn1.ObjectIdentifier
var (
+ secgNamedCurveS256 = secgNamedCurve{1, 3, 132, 0, 10}
secgNamedCurveP256 = secgNamedCurve{1, 2, 840, 10045, 3, 1, 7}
secgNamedCurveP384 = secgNamedCurve{1, 3, 132, 0, 34}
secgNamedCurveP521 = secgNamedCurve{1, 3, 132, 0, 35}
@@ -116,6 +119,8 @@ func (curve secgNamedCurve) Equal(curve2 secgNamedCurve) bool {
func namedCurveFromOID(curve secgNamedCurve) elliptic.Curve {
switch {
+ case curve.Equal(secgNamedCurveS256):
+ return secp256k1.S256()
case curve.Equal(secgNamedCurveP256):
return elliptic.P256()
case curve.Equal(secgNamedCurveP384):
@@ -134,6 +139,8 @@ func oidFromNamedCurve(curve elliptic.Curve) (secgNamedCurve, bool) {
return secgNamedCurveP384, true
case elliptic.P521():
return secgNamedCurveP521, true
+ case secp256k1.S256():
+ return secgNamedCurveS256, true
}
return nil, false
diff --git a/crypto/ecies/ecies.go b/crypto/ecies/ecies.go
index a3b520dd5..65dc5b38b 100644
--- a/crypto/ecies/ecies.go
+++ b/crypto/ecies/ecies.go
@@ -125,6 +125,7 @@ func (prv *PrivateKey) GenerateShared(pub *PublicKey, skLen, macLen int) (sk []b
if skLen+macLen > MaxSharedKeyLength(pub) {
return nil, ErrSharedKeyTooBig
}
+
x, _ := pub.Curve.ScalarMult(pub.X, pub.Y, prv.D.Bytes())
if x == nil {
return nil, ErrSharedKeyIsPointAtInfinity
diff --git a/crypto/ecies/ecies_test.go b/crypto/ecies/ecies_test.go
index 1c391f938..6a0ea3f02 100644
--- a/crypto/ecies/ecies_test.go
+++ b/crypto/ecies/ecies_test.go
@@ -31,13 +31,18 @@ package ecies
import (
"bytes"
+ "crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/sha256"
+ "encoding/hex"
"flag"
"fmt"
"io/ioutil"
+ "math/big"
"testing"
+
+ "github.com/ethereum/go-ethereum/crypto/secp256k1"
)
var dumpEnc bool
@@ -65,7 +70,6 @@ func TestKDF(t *testing.T) {
}
}
-var skLen int
var ErrBadSharedKeys = fmt.Errorf("ecies: shared keys don't match")
// cmpParams compares a set of ECIES parameters. We assume, as per the
@@ -117,7 +121,7 @@ func TestSharedKey(t *testing.T) {
fmt.Println(err.Error())
t.FailNow()
}
- skLen = MaxSharedKeyLength(&prv1.PublicKey) / 2
+ skLen := MaxSharedKeyLength(&prv1.PublicKey) / 2
prv2, err := GenerateKey(rand.Reader, DefaultCurve, nil)
if err != nil {
@@ -143,6 +147,44 @@ func TestSharedKey(t *testing.T) {
}
}
+func TestSharedKeyPadding(t *testing.T) {
+ // sanity checks
+ prv0 := hexKey("1adf5c18167d96a1f9a0b1ef63be8aa27eaf6032c233b2b38f7850cf5b859fd9")
+ prv1 := hexKey("97a076fc7fcd9208240668e31c9abee952cbb6e375d1b8febc7499d6e16f1a")
+ x0, _ := new(big.Int).SetString("1a8ed022ff7aec59dc1b440446bdda5ff6bcb3509a8b109077282b361efffbd8", 16)
+ x1, _ := new(big.Int).SetString("6ab3ac374251f638d0abb3ef596d1dc67955b507c104e5f2009724812dc027b8", 16)
+ y0, _ := new(big.Int).SetString("e040bd480b1deccc3bc40bd5b1fdcb7bfd352500b477cb9471366dbd4493f923", 16)
+ y1, _ := new(big.Int).SetString("8ad915f2b503a8be6facab6588731fefeb584fd2dfa9a77a5e0bba1ec439e4fa", 16)
+
+ if prv0.PublicKey.X.Cmp(x0) != 0 {
+ t.Errorf("mismatched prv0.X:\nhave: %x\nwant: %x\n", prv0.PublicKey.X.Bytes(), x0.Bytes())
+ }
+ if prv0.PublicKey.Y.Cmp(y0) != 0 {
+ t.Errorf("mismatched prv0.Y:\nhave: %x\nwant: %x\n", prv0.PublicKey.Y.Bytes(), y0.Bytes())
+ }
+ if prv1.PublicKey.X.Cmp(x1) != 0 {
+ t.Errorf("mismatched prv1.X:\nhave: %x\nwant: %x\n", prv1.PublicKey.X.Bytes(), x1.Bytes())
+ }
+ if prv1.PublicKey.Y.Cmp(y1) != 0 {
+ t.Errorf("mismatched prv1.Y:\nhave: %x\nwant: %x\n", prv1.PublicKey.Y.Bytes(), y1.Bytes())
+ }
+
+ // test shared secret generation
+ sk1, err := prv0.GenerateShared(&prv1.PublicKey, 16, 16)
+ if err != nil {
+ fmt.Println(err.Error())
+ }
+
+ sk2, err := prv1.GenerateShared(&prv0.PublicKey, 16, 16)
+ if err != nil {
+ t.Fatal(err.Error())
+ }
+
+ if !bytes.Equal(sk1, sk2) {
+ t.Fatal(ErrBadSharedKeys.Error())
+ }
+}
+
// Verify that the key generation code fails when too much key data is
// requested.
func TestTooBigSharedKey(t *testing.T) {
@@ -158,13 +200,13 @@ func TestTooBigSharedKey(t *testing.T) {
t.FailNow()
}
- _, err = prv1.GenerateShared(&prv2.PublicKey, skLen*2, skLen*2)
+ _, err = prv1.GenerateShared(&prv2.PublicKey, 32, 32)
if err != ErrSharedKeyTooBig {
fmt.Println("ecdh: shared key should be too large for curve")
t.FailNow()
}
- _, err = prv2.GenerateShared(&prv1.PublicKey, skLen*2, skLen*2)
+ _, err = prv2.GenerateShared(&prv1.PublicKey, 32, 32)
if err != ErrSharedKeyTooBig {
fmt.Println("ecdh: shared key should be too large for curve")
t.FailNow()
@@ -176,25 +218,21 @@ func TestTooBigSharedKey(t *testing.T) {
func TestMarshalPublic(t *testing.T) {
prv, err := GenerateKey(rand.Reader, DefaultCurve, nil)
if err != nil {
- fmt.Println(err.Error())
- t.FailNow()
+ t.Fatalf("GenerateKey error: %s", err)
}
out, err := MarshalPublic(&prv.PublicKey)
if err != nil {
- fmt.Println(err.Error())
- t.FailNow()
+ t.Fatalf("MarshalPublic error: %s", err)
}
pub, err := UnmarshalPublic(out)
if err != nil {
- fmt.Println(err.Error())
- t.FailNow()
+ t.Fatalf("UnmarshalPublic error: %s", err)
}
if !cmpPublic(prv.PublicKey, *pub) {
- fmt.Println("ecies: failed to unmarshal public key")
- t.FailNow()
+ t.Fatal("ecies: failed to unmarshal public key")
}
}
@@ -304,9 +342,26 @@ func BenchmarkGenSharedKeyP256(b *testing.B) {
fmt.Println(err.Error())
b.FailNow()
}
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ _, err := prv.GenerateShared(&prv.PublicKey, 16, 16)
+ if err != nil {
+ fmt.Println(err.Error())
+ b.FailNow()
+ }
+ }
+}
+// Benchmark the generation of S256 shared keys.
+func BenchmarkGenSharedKeyS256(b *testing.B) {
+ prv, err := GenerateKey(rand.Reader, secp256k1.S256(), nil)
+ if err != nil {
+ fmt.Println(err.Error())
+ b.FailNow()
+ }
+ b.ResetTimer()
for i := 0; i < b.N; i++ {
- _, err := prv.GenerateShared(&prv.PublicKey, skLen, skLen)
+ _, err := prv.GenerateShared(&prv.PublicKey, 16, 16)
if err != nil {
fmt.Println(err.Error())
b.FailNow()
@@ -511,3 +566,43 @@ func TestBasicKeyValidation(t *testing.T) {
}
}
}
+
+// Verify GenerateShared against static values - useful when
+// debugging changes in underlying libs
+func TestSharedKeyStatic(t *testing.T) {
+ prv1 := hexKey("7ebbc6a8358bc76dd73ebc557056702c8cfc34e5cfcd90eb83af0347575fd2ad")
+ prv2 := hexKey("6a3d6396903245bba5837752b9e0348874e72db0c4e11e9c485a81b4ea4353b9")
+
+ skLen := MaxSharedKeyLength(&prv1.PublicKey) / 2
+
+ sk1, err := prv1.GenerateShared(&prv2.PublicKey, skLen, skLen)
+ if err != nil {
+ fmt.Println(err.Error())
+ t.FailNow()
+ }
+
+ sk2, err := prv2.GenerateShared(&prv1.PublicKey, skLen, skLen)
+ if err != nil {
+ fmt.Println(err.Error())
+ t.FailNow()
+ }
+
+ if !bytes.Equal(sk1, sk2) {
+ fmt.Println(ErrBadSharedKeys.Error())
+ t.FailNow()
+ }
+
+ sk, _ := hex.DecodeString("167ccc13ac5e8a26b131c3446030c60fbfac6aa8e31149d0869f93626a4cdf62")
+ if !bytes.Equal(sk1, sk) {
+ t.Fatalf("shared secret mismatch: want: %x have: %x", sk, sk1)
+ }
+}
+
+// TODO: remove after refactoring packages crypto and crypto/ecies
+func hexKey(prv string) *PrivateKey {
+ priv := new(ecdsa.PrivateKey)
+ priv.PublicKey.Curve = secp256k1.S256()
+ priv.D, _ = new(big.Int).SetString(prv, 16)
+ priv.PublicKey.X, priv.PublicKey.Y = secp256k1.S256().ScalarBaseMult(priv.D.Bytes())
+ return ImportECDSA(priv)
+}
diff --git a/crypto/ecies/params.go b/crypto/ecies/params.go
index 97ddb0973..511c53ebc 100644
--- a/crypto/ecies/params.go
+++ b/crypto/ecies/params.go
@@ -41,13 +41,12 @@ import (
"crypto/sha512"
"fmt"
"hash"
-)
-// The default curve for this package is the NIST P256 curve, which
-// provides security equivalent to AES-128.
-var DefaultCurve = elliptic.P256()
+ "github.com/ethereum/go-ethereum/crypto/secp256k1"
+)
var (
+ DefaultCurve = secp256k1.S256()
ErrUnsupportedECDHAlgorithm = fmt.Errorf("ecies: unsupported ECDH algorithm")
ErrUnsupportedECIESParameters = fmt.Errorf("ecies: unsupported ECIES parameters")
)
@@ -101,9 +100,10 @@ var (
)
var paramsFromCurve = map[elliptic.Curve]*ECIESParams{
- elliptic.P256(): ECIES_AES128_SHA256,
- elliptic.P384(): ECIES_AES256_SHA384,
- elliptic.P521(): ECIES_AES256_SHA512,
+ secp256k1.S256(): ECIES_AES128_SHA256,
+ elliptic.P256(): ECIES_AES128_SHA256,
+ elliptic.P384(): ECIES_AES256_SHA384,
+ elliptic.P521(): ECIES_AES256_SHA512,
}
func AddParamsForCurve(curve elliptic.Curve, params *ECIESParams) {
diff --git a/crypto/key.go b/crypto/key.go
index 4ec43dfd7..ccf284ad8 100644
--- a/crypto/key.go
+++ b/crypto/key.go
@@ -25,6 +25,7 @@ import (
"strings"
"github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/crypto/secp256k1"
"github.com/pborman/uuid"
)
@@ -137,7 +138,7 @@ func NewKey(rand io.Reader) *Key {
panic("key generation: could not read from random source: " + err.Error())
}
reader := bytes.NewReader(randBytes)
- privateKeyECDSA, err := ecdsa.GenerateKey(S256(), reader)
+ privateKeyECDSA, err := ecdsa.GenerateKey(secp256k1.S256(), reader)
if err != nil {
panic("key generation: ecdsa.GenerateKey failed: " + err.Error())
}
@@ -155,7 +156,7 @@ func NewKeyForDirectICAP(rand io.Reader) *Key {
panic("key generation: could not read from random source: " + err.Error())
}
reader := bytes.NewReader(randBytes)
- privateKeyECDSA, err := ecdsa.GenerateKey(S256(), reader)
+ privateKeyECDSA, err := ecdsa.GenerateKey(secp256k1.S256(), reader)
if err != nil {
panic("key generation: ecdsa.GenerateKey failed: " + err.Error())
}
diff --git a/crypto/curve.go b/crypto/secp256k1/curve.go
index 48f3f5e9c..6e44a6771 100644
--- a/crypto/curve.go
+++ b/crypto/secp256k1/curve.go
@@ -29,15 +29,22 @@
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-package crypto
+package secp256k1
import (
"crypto/elliptic"
"io"
"math/big"
"sync"
+ "unsafe"
)
+/*
+#include "libsecp256k1/include/secp256k1.h"
+extern int secp256k1_pubkey_scalar_mul(const secp256k1_context* ctx, const unsigned char *point, const unsigned char *scalar);
+*/
+import "C"
+
// This code is from https://github.com/ThePiachu/GoBit and implements
// several Koblitz elliptic curves over prime fields.
//
@@ -211,44 +218,37 @@ func (BitCurve *BitCurve) doubleJacobian(x, y, z *big.Int) (*big.Int, *big.Int,
return x3, y3, z3
}
-//TODO: double check if it is okay
-// ScalarMult returns k*(Bx,By) where k is a number in big-endian form.
-func (BitCurve *BitCurve) ScalarMult(Bx, By *big.Int, k []byte) (*big.Int, *big.Int) {
- // We have a slight problem in that the identity of the group (the
- // point at infinity) cannot be represented in (x, y) form on a finite
- // machine. Thus the standard add/double algorithm has to be tweaked
- // slightly: our initial state is not the identity, but x, and we
- // ignore the first true bit in |k|. If we don't find any true bits in
- // |k|, then we return nil, nil, because we cannot return the identity
- // element.
-
- Bz := new(big.Int).SetInt64(1)
- x := Bx
- y := By
- z := Bz
-
- seenFirstTrue := false
- for _, byte := range k {
- for bitNum := 0; bitNum < 8; bitNum++ {
- if seenFirstTrue {
- x, y, z = BitCurve.doubleJacobian(x, y, z)
- }
- if byte&0x80 == 0x80 {
- if !seenFirstTrue {
- seenFirstTrue = true
- } else {
- x, y, z = BitCurve.addJacobian(Bx, By, Bz, x, y, z)
- }
- }
- byte <<= 1
- }
+func (BitCurve *BitCurve) ScalarMult(Bx, By *big.Int, scalar []byte) (*big.Int, *big.Int) {
+ // Ensure scalar is exactly 32 bytes. We pad always, even if
+ // scalar is 32 bytes long, to avoid a timing side channel.
+ if len(scalar) > 32 {
+ panic("can't handle scalars > 256 bits")
}
-
- if !seenFirstTrue {
+ padded := make([]byte, 32)
+ copy(padded[32-len(scalar):], scalar)
+ scalar = padded
+
+ // Do the multiplication in C, updating point.
+ point := make([]byte, 64)
+ readBits(point[:32], Bx)
+ readBits(point[32:], By)
+ pointPtr := (*C.uchar)(unsafe.Pointer(&point[0]))
+ scalarPtr := (*C.uchar)(unsafe.Pointer(&scalar[0]))
+ res := C.secp256k1_pubkey_scalar_mul(context, pointPtr, scalarPtr)
+
+ // Unpack the result and clear temporaries.
+ x := new(big.Int).SetBytes(point[:32])
+ y := new(big.Int).SetBytes(point[32:])
+ for i := range point {
+ point[i] = 0
+ }
+ for i := range padded {
+ scalar[i] = 0
+ }
+ if res != 1 {
return nil, nil
}
-
- return BitCurve.affineFromJacobian(x, y, z)
+ return x, y
}
// ScalarBaseMult returns k*G, where G is the base point of the group and k is
@@ -312,86 +312,24 @@ func (BitCurve *BitCurve) Unmarshal(data []byte) (x, y *big.Int) {
return
}
-//curve parameters taken from:
-//http://www.secg.org/collateral/sec2_final.pdf
-
-var initonce sync.Once
-var ecp160k1 *BitCurve
-var ecp192k1 *BitCurve
-var ecp224k1 *BitCurve
-var ecp256k1 *BitCurve
-
-func initAll() {
- initS160()
- initS192()
- initS224()
- initS256()
-}
-
-func initS160() {
- // See SEC 2 section 2.4.1
- ecp160k1 = new(BitCurve)
- ecp160k1.P, _ = new(big.Int).SetString("FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFAC73", 16)
- ecp160k1.N, _ = new(big.Int).SetString("0100000000000000000001B8FA16DFAB9ACA16B6B3", 16)
- ecp160k1.B, _ = new(big.Int).SetString("0000000000000000000000000000000000000007", 16)
- ecp160k1.Gx, _ = new(big.Int).SetString("3B4C382CE37AA192A4019E763036F4F5DD4D7EBB", 16)
- ecp160k1.Gy, _ = new(big.Int).SetString("938CF935318FDCED6BC28286531733C3F03C4FEE", 16)
- ecp160k1.BitSize = 160
-}
-
-func initS192() {
- // See SEC 2 section 2.5.1
- ecp192k1 = new(BitCurve)
- ecp192k1.P, _ = new(big.Int).SetString("FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFEE37", 16)
- ecp192k1.N, _ = new(big.Int).SetString("FFFFFFFFFFFFFFFFFFFFFFFE26F2FC170F69466A74DEFD8D", 16)
- ecp192k1.B, _ = new(big.Int).SetString("000000000000000000000000000000000000000000000003", 16)
- ecp192k1.Gx, _ = new(big.Int).SetString("DB4FF10EC057E9AE26B07D0280B7F4341DA5D1B1EAE06C7D", 16)
- ecp192k1.Gy, _ = new(big.Int).SetString("9B2F2F6D9C5628A7844163D015BE86344082AA88D95E2F9D", 16)
- ecp192k1.BitSize = 192
-}
-
-func initS224() {
- // See SEC 2 section 2.6.1
- ecp224k1 = new(BitCurve)
- ecp224k1.P, _ = new(big.Int).SetString("FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFE56D", 16)
- ecp224k1.N, _ = new(big.Int).SetString("010000000000000000000000000001DCE8D2EC6184CAF0A971769FB1F7", 16)
- ecp224k1.B, _ = new(big.Int).SetString("00000000000000000000000000000000000000000000000000000005", 16)
- ecp224k1.Gx, _ = new(big.Int).SetString("A1455B334DF099DF30FC28A169A467E9E47075A90F7E650EB6B7A45C", 16)
- ecp224k1.Gy, _ = new(big.Int).SetString("7E089FED7FBA344282CAFBD6F7E319F7C0B0BD59E2CA4BDB556D61A5", 16)
- ecp224k1.BitSize = 224
-}
-
-func initS256() {
- // See SEC 2 section 2.7.1
- ecp256k1 = new(BitCurve)
- ecp256k1.P, _ = new(big.Int).SetString("FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFFC2F", 16)
- ecp256k1.N, _ = new(big.Int).SetString("FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141", 16)
- ecp256k1.B, _ = new(big.Int).SetString("0000000000000000000000000000000000000000000000000000000000000007", 16)
- ecp256k1.Gx, _ = new(big.Int).SetString("79BE667EF9DCBBAC55A06295CE870B07029BFCDB2DCE28D959F2815B16F81798", 16)
- ecp256k1.Gy, _ = new(big.Int).SetString("483ADA7726A3C4655DA4FBFC0E1108A8FD17B448A68554199C47D08FFB10D4B8", 16)
- ecp256k1.BitSize = 256
-}
-
-// S160 returns a BitCurve which implements secp160k1 (see SEC 2 section 2.4.1)
-func S160() *BitCurve {
- initonce.Do(initAll)
- return ecp160k1
-}
-
-// S192 returns a BitCurve which implements secp192k1 (see SEC 2 section 2.5.1)
-func S192() *BitCurve {
- initonce.Do(initAll)
- return ecp192k1
-}
-
-// S224 returns a BitCurve which implements secp224k1 (see SEC 2 section 2.6.1)
-func S224() *BitCurve {
- initonce.Do(initAll)
- return ecp224k1
-}
+var (
+ initonce sync.Once
+ theCurve *BitCurve
+)
// S256 returns a BitCurve which implements secp256k1 (see SEC 2 section 2.7.1)
func S256() *BitCurve {
- initonce.Do(initAll)
- return ecp256k1
+ initonce.Do(func() {
+ // See SEC 2 section 2.7.1
+ // curve parameters taken from:
+ // http://www.secg.org/collateral/sec2_final.pdf
+ theCurve = new(BitCurve)
+ theCurve.P, _ = new(big.Int).SetString("FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFFC2F", 16)
+ theCurve.N, _ = new(big.Int).SetString("FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141", 16)
+ theCurve.B, _ = new(big.Int).SetString("0000000000000000000000000000000000000000000000000000000000000007", 16)
+ theCurve.Gx, _ = new(big.Int).SetString("79BE667EF9DCBBAC55A06295CE870B07029BFCDB2DCE28D959F2815B16F81798", 16)
+ theCurve.Gy, _ = new(big.Int).SetString("483ADA7726A3C4655DA4FBFC0E1108A8FD17B448A68554199C47D08FFB10D4B8", 16)
+ theCurve.BitSize = 256
+ })
+ return theCurve
}
diff --git a/crypto/secp256k1/curve_test.go b/crypto/secp256k1/curve_test.go
new file mode 100644
index 000000000..d915ee852
--- /dev/null
+++ b/crypto/secp256k1/curve_test.go
@@ -0,0 +1,39 @@
+// Copyright 2015 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
+
+package secp256k1
+
+import (
+ "bytes"
+ "encoding/hex"
+ "math/big"
+ "testing"
+)
+
+func TestReadBits(t *testing.T) {
+ check := func(input string) {
+ want, _ := hex.DecodeString(input)
+ int, _ := new(big.Int).SetString(input, 16)
+ buf := make([]byte, len(want))
+ readBits(buf, int)
+ if !bytes.Equal(buf, want) {
+ t.Errorf("have: %x\nwant: %x", buf, want)
+ }
+ }
+ check("000000000000000000000000000000000000000000000000000000FEFCF3F8F0")
+ check("0000000000012345000000000000000000000000000000000000FEFCF3F8F0")
+ check("18F8F8F1000111000110011100222004330052300000000000000000FEFCF3F8F0")
+}
diff --git a/crypto/secp256k1/pubkey_scalar_mul.h b/crypto/secp256k1/pubkey_scalar_mul.h
new file mode 100644
index 000000000..0511545ec
--- /dev/null
+++ b/crypto/secp256k1/pubkey_scalar_mul.h
@@ -0,0 +1,56 @@
+// Copyright 2015 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
+
+/** Multiply point by scalar in constant time.
+ * Returns: 1: multiplication was successful
+ * 0: scalar was invalid (zero or overflow)
+ * Args: ctx: pointer to a context object (cannot be NULL)
+ * Out: point: the multiplied point (usually secret)
+ * In: point: pointer to a 64-byte bytepublic point,
+ encoded as two 256bit big-endian numbers.
+ * scalar: a 32-byte scalar with which to multiply the point
+ */
+int secp256k1_pubkey_scalar_mul(const secp256k1_context* ctx, unsigned char *point, const unsigned char *scalar) {
+ int ret = 0;
+ int overflow = 0;
+ secp256k1_fe feX, feY;
+ secp256k1_gej res;
+ secp256k1_ge ge;
+ secp256k1_scalar s;
+ ARG_CHECK(point != NULL);
+ ARG_CHECK(scalar != NULL);
+ (void)ctx;
+
+ secp256k1_fe_set_b32(&feX, point);
+ secp256k1_fe_set_b32(&feY, point+32);
+ secp256k1_ge_set_xy(&ge, &feX, &feY);
+ secp256k1_scalar_set_b32(&s, scalar, &overflow);
+ if (overflow || secp256k1_scalar_is_zero(&s)) {
+ ret = 0;
+ } else {
+ secp256k1_ecmult_const(&res, &ge, &s);
+ secp256k1_ge_set_gej(&ge, &res);
+ /* Note: can't use secp256k1_pubkey_save here because it is not constant time. */
+ secp256k1_fe_normalize(&ge.x);
+ secp256k1_fe_normalize(&ge.y);
+ secp256k1_fe_get_b32(point, &ge.x);
+ secp256k1_fe_get_b32(point+32, &ge.y);
+ ret = 1;
+ }
+ secp256k1_scalar_clear(&s);
+ return ret;
+}
+
diff --git a/crypto/secp256k1/secp256.go b/crypto/secp256k1/secp256.go
index 41a5608a5..8dc248145 100644
--- a/crypto/secp256k1/secp256.go
+++ b/crypto/secp256k1/secp256.go
@@ -20,6 +20,7 @@ package secp256k1
/*
#cgo CFLAGS: -I./libsecp256k1
+#cgo CFLAGS: -I./libsecp256k1/src/
#cgo darwin CFLAGS: -I/usr/local/include
#cgo freebsd CFLAGS: -I/usr/local/include
#cgo linux,arm CFLAGS: -I/usr/local/arm/include
@@ -35,6 +36,7 @@ package secp256k1
#define NDEBUG
#include "./libsecp256k1/src/secp256k1.c"
#include "./libsecp256k1/src/modules/recovery/main_impl.h"
+#include "pubkey_scalar_mul.h"
typedef void (*callbackFunc) (const char* msg, void* data);
extern void secp256k1GoPanicIllegal(const char* msg, void* data);
@@ -44,6 +46,7 @@ import "C"
import (
"errors"
+ "math/big"
"unsafe"
"github.com/ethereum/go-ethereum/crypto/randentropy"
@@ -56,13 +59,16 @@ import (
> store private keys in buffer and shuffle (deters persistance on swap disc)
> byte permutation (changing)
> xor with chaning random block (to deter scanning memory for 0x63) (stream cipher?)
- > on disk: store keys in wallets
*/
// holds ptr to secp256k1_context_struct (see secp256k1/include/secp256k1.h)
-var context *C.secp256k1_context
+var (
+ context *C.secp256k1_context
+ N *big.Int
+)
func init() {
+ N, _ = new(big.Int).SetString("fffffffffffffffffffffffffffffffebaaedce6af48a03bbfd25e8cd0364141", 16)
// around 20 ms on a modern CPU.
context = C.secp256k1_context_create(3) // SECP256K1_START_SIGN | SECP256K1_START_VERIFY
C.secp256k1_context_set_illegal_callback(context, C.callbackFunc(C.secp256k1GoPanicIllegal), nil)
@@ -78,7 +84,6 @@ var (
func GenerateKeyPair() ([]byte, []byte) {
var seckey []byte = randentropy.GetEntropyCSPRNG(32)
var seckey_ptr *C.uchar = (*C.uchar)(unsafe.Pointer(&seckey[0]))
-
var pubkey64 []byte = make([]byte, 64) // secp256k1_pubkey
var pubkey65 []byte = make([]byte, 65) // 65 byte uncompressed pubkey
pubkey64_ptr := (*C.secp256k1_pubkey)(unsafe.Pointer(&pubkey64[0]))
@@ -254,3 +259,16 @@ func checkSignature(sig []byte) error {
}
return nil
}
+
+// reads num into buf as big-endian bytes.
+func readBits(buf []byte, num *big.Int) {
+ const wordLen = int(unsafe.Sizeof(big.Word(0)))
+ i := len(buf)
+ for _, d := range num.Bits() {
+ for j := 0; j < wordLen && i > 0; j++ {
+ i--
+ buf[i] = byte(d)
+ d >>= 8
+ }
+ }
+}
diff --git a/crypto/secp256k1/secp256_test.go b/crypto/secp256k1/secp256_test.go
index cb71ea5e7..fc6fc9b32 100644
--- a/crypto/secp256k1/secp256_test.go
+++ b/crypto/secp256k1/secp256_test.go
@@ -24,7 +24,7 @@ import (
"github.com/ethereum/go-ethereum/crypto/randentropy"
)
-const TestCount = 10000
+const TestCount = 1000
func TestPrivkeyGenerate(t *testing.T) {
_, seckey := GenerateKeyPair()
@@ -86,10 +86,7 @@ func TestSignAndRecover(t *testing.T) {
func TestRandomMessagesWithSameKey(t *testing.T) {
pubkey, seckey := GenerateKeyPair()
keys := func() ([]byte, []byte) {
- // Sign function zeroes the privkey so we need a new one in each call
- newkey := make([]byte, len(seckey))
- copy(newkey, seckey)
- return pubkey, newkey
+ return pubkey, seckey
}
signAndRecoverWithRandomMessages(t, keys)
}
@@ -209,30 +206,32 @@ func compactSigCheck(t *testing.T, sig []byte) {
}
}
-// godep go test -v -run=XXX -bench=BenchmarkSignRandomInputEachRound
+// godep go test -v -run=XXX -bench=BenchmarkSign
// add -benchtime=10s to benchmark longer for more accurate average
-func BenchmarkSignRandomInputEachRound(b *testing.B) {
+
+// to avoid compiler optimizing the benchmarked function call
+var err error
+
+func BenchmarkSign(b *testing.B) {
for i := 0; i < b.N; i++ {
- b.StopTimer()
_, seckey := GenerateKeyPair()
msg := randentropy.GetEntropyCSPRNG(32)
b.StartTimer()
- if _, err := Sign(msg, seckey); err != nil {
- b.Fatal(err)
- }
+ _, e := Sign(msg, seckey)
+ err = e
+ b.StopTimer()
}
}
-//godep go test -v -run=XXX -bench=BenchmarkRecoverRandomInputEachRound
-func BenchmarkRecoverRandomInputEachRound(b *testing.B) {
+//godep go test -v -run=XXX -bench=BenchmarkECRec
+func BenchmarkRecover(b *testing.B) {
for i := 0; i < b.N; i++ {
- b.StopTimer()
_, seckey := GenerateKeyPair()
msg := randentropy.GetEntropyCSPRNG(32)
sig, _ := Sign(msg, seckey)
b.StartTimer()
- if _, err := RecoverPubkey(msg, sig); err != nil {
- b.Fatal(err)
- }
+ _, e := RecoverPubkey(msg, sig)
+ err = e
+ b.StopTimer()
}
}
diff --git a/eth/backend.go b/eth/backend.go
index 0369f6afd..91f02db72 100644
--- a/eth/backend.go
+++ b/eth/backend.go
@@ -191,7 +191,7 @@ func New(ctx *node.ServiceContext, config *Config) (*Ethereum, error) {
shutdownChan: make(chan bool),
chainDb: chainDb,
dappDb: dappDb,
- eventMux: &event.TypeMux{},
+ eventMux: ctx.EventMux,
accountManager: config.AccountManager,
etherbase: config.Etherbase,
netVersionId: config.NetworkId,
diff --git a/p2p/discover/node.go b/p2p/discover/node.go
index a14f29424..dd19df3a2 100644
--- a/p2p/discover/node.go
+++ b/p2p/discover/node.go
@@ -210,7 +210,7 @@ func PubkeyID(pub *ecdsa.PublicKey) NodeID {
// Pubkey returns the public key represented by the node ID.
// It returns an error if the ID is not a point on the curve.
func (id NodeID) Pubkey() (*ecdsa.PublicKey, error) {
- p := &ecdsa.PublicKey{Curve: crypto.S256(), X: new(big.Int), Y: new(big.Int)}
+ p := &ecdsa.PublicKey{Curve: secp256k1.S256(), X: new(big.Int), Y: new(big.Int)}
half := len(id) / 2
p.X.SetBytes(id[:half])
p.Y.SetBytes(id[half:])
diff --git a/p2p/rlpx.go b/p2p/rlpx.go
index aaa733854..8f429d6ec 100644
--- a/p2p/rlpx.go
+++ b/p2p/rlpx.go
@@ -277,7 +277,7 @@ func newInitiatorHandshake(remoteID discover.NodeID) (*encHandshake, error) {
return nil, err
}
// generate random keypair to use for signing
- randpriv, err := ecies.GenerateKey(rand.Reader, crypto.S256(), nil)
+ randpriv, err := ecies.GenerateKey(rand.Reader, secp256k1.S256(), nil)
if err != nil {
return nil, err
}
@@ -376,7 +376,7 @@ func decodeAuthMsg(prv *ecdsa.PrivateKey, token []byte, auth []byte) (*encHandsh
var err error
h := new(encHandshake)
// generate random keypair for session
- h.randomPrivKey, err = ecies.GenerateKey(rand.Reader, crypto.S256(), nil)
+ h.randomPrivKey, err = ecies.GenerateKey(rand.Reader, secp256k1.S256(), nil)
if err != nil {
return nil, err
}
diff --git a/p2p/rlpx_test.go b/p2p/rlpx_test.go
index 900353f0e..7cc7548e2 100644
--- a/p2p/rlpx_test.go
+++ b/p2p/rlpx_test.go
@@ -93,6 +93,7 @@ func testEncHandshake(token []byte) error {
go func() {
r := result{side: "initiator"}
defer func() { output <- r }()
+ defer fd0.Close()
dest := &discover.Node{ID: discover.PubkeyID(&prv1.PublicKey)}
r.id, r.err = c0.doEncHandshake(prv0, dest)
@@ -107,6 +108,7 @@ func testEncHandshake(token []byte) error {
go func() {
r := result{side: "receiver"}
defer func() { output <- r }()
+ defer fd1.Close()
r.id, r.err = c1.doEncHandshake(prv1, nil)
if r.err != nil {
diff --git a/trie/arc.go b/trie/arc.go
index 9da012e16..fc7a3259f 100644
--- a/trie/arc.go
+++ b/trie/arc.go
@@ -62,6 +62,18 @@ func newARC(c int) *arc {
}
}
+// Clear clears the cache
+func (a *arc) Clear() {
+ a.mutex.Lock()
+ defer a.mutex.Unlock()
+ a.p = 0
+ a.t1 = list.New()
+ a.b1 = list.New()
+ a.t2 = list.New()
+ a.b2 = list.New()
+ a.cache = make(map[string]*entry, a.c)
+}
+
// Put inserts a new key-value pair into the cache.
// This optimizes future access to this entry (side effect).
func (a *arc) Put(key hashNode, value node) bool {
diff --git a/trie/errors.go b/trie/errors.go
new file mode 100644
index 000000000..a0f58f28f
--- /dev/null
+++ b/trie/errors.go
@@ -0,0 +1,41 @@
+// Copyright 2014 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
+
+package trie
+
+import (
+ "fmt"
+
+ "github.com/ethereum/go-ethereum/common"
+)
+
+// MissingNodeError is returned by the trie functions (TryGet, TryUpdate, TryDelete)
+// in the case where a trie node is not present in the local database. Contains
+// information necessary for retrieving the missing node through an ODR service.
+//
+// NodeHash is the hash of the missing node
+// RootHash is the original root of the trie that contains the node
+// KeyPrefix is the prefix that leads from the root to the missing node (hex encoded)
+// KeySuffix (optional) contains the rest of the key we were looking for, gives a
+// hint on which further nodes should also be retrieved (hex encoded)
+type MissingNodeError struct {
+ RootHash, NodeHash common.Hash
+ KeyPrefix, KeySuffix []byte
+}
+
+func (err *MissingNodeError) Error() string {
+ return fmt.Sprintf("Missing trie node %064x", err.NodeHash)
+}
diff --git a/trie/iterator.go b/trie/iterator.go
index 38555fe08..5f205e081 100644
--- a/trie/iterator.go
+++ b/trie/iterator.go
@@ -16,7 +16,12 @@
package trie
-import "bytes"
+import (
+ "bytes"
+
+ "github.com/ethereum/go-ethereum/logger"
+ "github.com/ethereum/go-ethereum/logger/glog"
+)
type Iterator struct {
trie *Trie
@@ -100,7 +105,11 @@ func (self *Iterator) next(node interface{}, key []byte, isIterStart bool) []byt
}
case hashNode:
- return self.next(self.trie.resolveHash(node), key, isIterStart)
+ rn, err := self.trie.resolveHash(node, nil, nil)
+ if err != nil && glog.V(logger.Error) {
+ glog.Errorf("Unhandled trie error: %v", err)
+ }
+ return self.next(rn, key, isIterStart)
}
return nil
}
@@ -127,7 +136,11 @@ func (self *Iterator) key(node interface{}) []byte {
}
}
case hashNode:
- return self.key(self.trie.resolveHash(node))
+ rn, err := self.trie.resolveHash(node, nil, nil)
+ if err != nil && glog.V(logger.Error) {
+ glog.Errorf("Unhandled trie error: %v", err)
+ }
+ return self.key(rn)
}
return nil
diff --git a/trie/proof.go b/trie/proof.go
index a705c49db..2e88bb50b 100644
--- a/trie/proof.go
+++ b/trie/proof.go
@@ -7,6 +7,8 @@ import (
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/crypto/sha3"
+ "github.com/ethereum/go-ethereum/logger"
+ "github.com/ethereum/go-ethereum/logger/glog"
"github.com/ethereum/go-ethereum/rlp"
)
@@ -39,7 +41,14 @@ func (t *Trie) Prove(key []byte) []rlp.RawValue {
case nil:
return nil
case hashNode:
- tn = t.resolveHash(n)
+ var err error
+ tn, err = t.resolveHash(n, nil, nil)
+ if err != nil {
+ if glog.V(logger.Error) {
+ glog.Errorf("Unhandled trie error: %v", err)
+ }
+ return nil
+ }
default:
panic(fmt.Sprintf("%T: invalid node: %v", tn, tn))
}
diff --git a/trie/secure_trie.go b/trie/secure_trie.go
index 47d1934d0..caeef3c3a 100644
--- a/trie/secure_trie.go
+++ b/trie/secure_trie.go
@@ -21,6 +21,8 @@ import (
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/crypto/sha3"
+ "github.com/ethereum/go-ethereum/logger"
+ "github.com/ethereum/go-ethereum/logger/glog"
)
var secureKeyPrefix = []byte("secure-key-")
@@ -46,8 +48,8 @@ type SecureTrie struct {
// NewSecure creates a trie with an existing root node from db.
//
// If root is the zero hash or the sha3 hash of an empty string, the
-// trie is initially empty. Otherwise, New will panics if db is nil
-// and returns ErrMissingRoot if the root node cannpt be found.
+// trie is initially empty. Otherwise, New will panic if db is nil
+// and returns MissingNodeError if the root node cannot be found.
// Accessing the trie loads nodes from db on demand.
func NewSecure(root common.Hash, db Database) (*SecureTrie, error) {
if db == nil {
@@ -63,7 +65,18 @@ func NewSecure(root common.Hash, db Database) (*SecureTrie, error) {
// Get returns the value for key stored in the trie.
// The value bytes must not be modified by the caller.
func (t *SecureTrie) Get(key []byte) []byte {
- return t.Trie.Get(t.hashKey(key))
+ res, err := t.TryGet(key)
+ if err != nil && glog.V(logger.Error) {
+ glog.Errorf("Unhandled trie error: %v", err)
+ }
+ return res
+}
+
+// TryGet returns the value for key stored in the trie.
+// The value bytes must not be modified by the caller.
+// If a node was not found in the database, a MissingNodeError is returned.
+func (t *SecureTrie) TryGet(key []byte) ([]byte, error) {
+ return t.Trie.TryGet(t.hashKey(key))
}
// Update associates key with value in the trie. Subsequent calls to
@@ -73,14 +86,40 @@ func (t *SecureTrie) Get(key []byte) []byte {
// The value bytes must not be modified by the caller while they are
// stored in the trie.
func (t *SecureTrie) Update(key, value []byte) {
+ if err := t.TryUpdate(key, value); err != nil && glog.V(logger.Error) {
+ glog.Errorf("Unhandled trie error: %v", err)
+ }
+}
+
+// TryUpdate associates key with value in the trie. Subsequent calls to
+// Get will return value. If value has length zero, any existing value
+// is deleted from the trie and calls to Get will return nil.
+//
+// The value bytes must not be modified by the caller while they are
+// stored in the trie.
+//
+// If a node was not found in the database, a MissingNodeError is returned.
+func (t *SecureTrie) TryUpdate(key, value []byte) error {
hk := t.hashKey(key)
- t.Trie.Update(hk, value)
+ err := t.Trie.TryUpdate(hk, value)
+ if err != nil {
+ return err
+ }
t.Trie.db.Put(t.secKey(hk), key)
+ return nil
}
// Delete removes any existing value for key from the trie.
func (t *SecureTrie) Delete(key []byte) {
- t.Trie.Delete(t.hashKey(key))
+ if err := t.TryDelete(key); err != nil && glog.V(logger.Error) {
+ glog.Errorf("Unhandled trie error: %v", err)
+ }
+}
+
+// TryDelete removes any existing value for key from the trie.
+// If a node was not found in the database, a MissingNodeError is returned.
+func (t *SecureTrie) TryDelete(key []byte) error {
+ return t.Trie.TryDelete(t.hashKey(key))
}
// GetKey returns the sha3 preimage of a hashed key that was
diff --git a/trie/trie.go b/trie/trie.go
index a3a383fb5..717296e27 100644
--- a/trie/trie.go
+++ b/trie/trie.go
@@ -19,7 +19,6 @@ package trie
import (
"bytes"
- "errors"
"fmt"
"hash"
@@ -44,7 +43,10 @@ var (
emptyState = crypto.Sha3Hash(nil)
)
-var ErrMissingRoot = errors.New("missing root node")
+// ClearGlobalCache clears the global trie cache
+func ClearGlobalCache() {
+ globalCache.Clear()
+}
// Database must be implemented by backing stores for the trie.
type Database interface {
@@ -67,8 +69,9 @@ type DatabaseWriter interface {
//
// Trie is not safe for concurrent use.
type Trie struct {
- root node
- db Database
+ root node
+ db Database
+ originalRoot common.Hash
*hasher
}
@@ -76,16 +79,19 @@ type Trie struct {
//
// If root is the zero hash or the sha3 hash of an empty string, the
// trie is initially empty and does not require a database. Otherwise,
-// New will panics if db is nil or root does not exist in the
-// database. Accessing the trie loads nodes from db on demand.
+// New will panic if db is nil and returns a MissingNodeError if root does
+// not exist in the database. Accessing the trie loads nodes from db on demand.
func New(root common.Hash, db Database) (*Trie, error) {
- trie := &Trie{db: db}
+ trie := &Trie{db: db, originalRoot: root}
if (root != common.Hash{}) && root != emptyRoot {
if db == nil {
panic("trie.New: cannot use existing root without a database")
}
if v, _ := trie.db.Get(root[:]); len(v) == 0 {
- return nil, ErrMissingRoot
+ return nil, &MissingNodeError{
+ RootHash: root,
+ NodeHash: root,
+ }
}
trie.root = hashNode(root.Bytes())
}
@@ -100,28 +106,44 @@ func (t *Trie) Iterator() *Iterator {
// Get returns the value for key stored in the trie.
// The value bytes must not be modified by the caller.
func (t *Trie) Get(key []byte) []byte {
+ res, err := t.TryGet(key)
+ if err != nil && glog.V(logger.Error) {
+ glog.Errorf("Unhandled trie error: %v", err)
+ }
+ return res
+}
+
+// TryGet returns the value for key stored in the trie.
+// The value bytes must not be modified by the caller.
+// If a node was not found in the database, a MissingNodeError is returned.
+func (t *Trie) TryGet(key []byte) ([]byte, error) {
key = compactHexDecode(key)
+ pos := 0
tn := t.root
- for len(key) > 0 {
+ for pos < len(key) {
switch n := tn.(type) {
case shortNode:
- if len(key) < len(n.Key) || !bytes.Equal(n.Key, key[:len(n.Key)]) {
- return nil
+ if len(key)-pos < len(n.Key) || !bytes.Equal(n.Key, key[pos:pos+len(n.Key)]) {
+ return nil, nil
}
tn = n.Val
- key = key[len(n.Key):]
+ pos += len(n.Key)
case fullNode:
- tn = n[key[0]]
- key = key[1:]
+ tn = n[key[pos]]
+ pos++
case nil:
- return nil
+ return nil, nil
case hashNode:
- tn = t.resolveHash(n)
+ var err error
+ tn, err = t.resolveHash(n, key[:pos], key[pos:])
+ if err != nil {
+ return nil, err
+ }
default:
panic(fmt.Sprintf("%T: invalid node: %v", tn, tn))
}
}
- return tn.(valueNode)
+ return tn.(valueNode), nil
}
// Update associates key with value in the trie. Subsequent calls to
@@ -131,17 +153,40 @@ func (t *Trie) Get(key []byte) []byte {
// The value bytes must not be modified by the caller while they are
// stored in the trie.
func (t *Trie) Update(key, value []byte) {
+ if err := t.TryUpdate(key, value); err != nil && glog.V(logger.Error) {
+ glog.Errorf("Unhandled trie error: %v", err)
+ }
+}
+
+// TryUpdate associates key with value in the trie. Subsequent calls to
+// Get will return value. If value has length zero, any existing value
+// is deleted from the trie and calls to Get will return nil.
+//
+// The value bytes must not be modified by the caller while they are
+// stored in the trie.
+//
+// If a node was not found in the database, a MissingNodeError is returned.
+func (t *Trie) TryUpdate(key, value []byte) error {
k := compactHexDecode(key)
if len(value) != 0 {
- t.root = t.insert(t.root, k, valueNode(value))
+ n, err := t.insert(t.root, nil, k, valueNode(value))
+ if err != nil {
+ return err
+ }
+ t.root = n
} else {
- t.root = t.delete(t.root, k)
+ n, err := t.delete(t.root, nil, k)
+ if err != nil {
+ return err
+ }
+ t.root = n
}
+ return nil
}
-func (t *Trie) insert(n node, key []byte, value node) node {
+func (t *Trie) insert(n node, prefix, key []byte, value node) (node, error) {
if len(key) == 0 {
- return value
+ return value, nil
}
switch n := n.(type) {
case shortNode:
@@ -149,25 +194,40 @@ func (t *Trie) insert(n node, key []byte, value node) node {
// If the whole key matches, keep this short node as is
// and only update the value.
if matchlen == len(n.Key) {
- return shortNode{n.Key, t.insert(n.Val, key[matchlen:], value)}
+ nn, err := t.insert(n.Val, append(prefix, key[:matchlen]...), key[matchlen:], value)
+ if err != nil {
+ return nil, err
+ }
+ return shortNode{n.Key, nn}, nil
}
// Otherwise branch out at the index where they differ.
var branch fullNode
- branch[n.Key[matchlen]] = t.insert(nil, n.Key[matchlen+1:], n.Val)
- branch[key[matchlen]] = t.insert(nil, key[matchlen+1:], value)
+ var err error
+ branch[n.Key[matchlen]], err = t.insert(nil, append(prefix, n.Key[:matchlen+1]...), n.Key[matchlen+1:], n.Val)
+ if err != nil {
+ return nil, err
+ }
+ branch[key[matchlen]], err = t.insert(nil, append(prefix, key[:matchlen+1]...), key[matchlen+1:], value)
+ if err != nil {
+ return nil, err
+ }
// Replace this shortNode with the branch if it occurs at index 0.
if matchlen == 0 {
- return branch
+ return branch, nil
}
// Otherwise, replace it with a short node leading up to the branch.
- return shortNode{key[:matchlen], branch}
+ return shortNode{key[:matchlen], branch}, nil
case fullNode:
- n[key[0]] = t.insert(n[key[0]], key[1:], value)
- return n
+ nn, err := t.insert(n[key[0]], append(prefix, key[0]), key[1:], value)
+ if err != nil {
+ return nil, err
+ }
+ n[key[0]] = nn
+ return n, nil
case nil:
- return shortNode{key, value}
+ return shortNode{key, value}, nil
case hashNode:
// We've hit a part of the trie that isn't loaded yet. Load
@@ -176,7 +236,11 @@ func (t *Trie) insert(n node, key []byte, value node) node {
//
// TODO: track whether insertion changed the value and keep
// n as a hash node if it didn't.
- return t.insert(t.resolveHash(n), key, value)
+ rn, err := t.resolveHash(n, prefix, key)
+ if err != nil {
+ return nil, err
+ }
+ return t.insert(rn, prefix, key, value)
default:
panic(fmt.Sprintf("%T: invalid node: %v", n, n))
@@ -185,28 +249,44 @@ func (t *Trie) insert(n node, key []byte, value node) node {
// Delete removes any existing value for key from the trie.
func (t *Trie) Delete(key []byte) {
+ if err := t.TryDelete(key); err != nil && glog.V(logger.Error) {
+ glog.Errorf("Unhandled trie error: %v", err)
+ }
+}
+
+// TryDelete removes any existing value for key from the trie.
+// If a node was not found in the database, a MissingNodeError is returned.
+func (t *Trie) TryDelete(key []byte) error {
k := compactHexDecode(key)
- t.root = t.delete(t.root, k)
+ n, err := t.delete(t.root, nil, k)
+ if err != nil {
+ return err
+ }
+ t.root = n
+ return nil
}
// delete returns the new root of the trie with key deleted.
// It reduces the trie to minimal form by simplifying
// nodes on the way up after deleting recursively.
-func (t *Trie) delete(n node, key []byte) node {
+func (t *Trie) delete(n node, prefix, key []byte) (node, error) {
switch n := n.(type) {
case shortNode:
matchlen := prefixLen(key, n.Key)
if matchlen < len(n.Key) {
- return n // don't replace n on mismatch
+ return n, nil // don't replace n on mismatch
}
if matchlen == len(key) {
- return nil // remove n entirely for whole matches
+ return nil, nil // remove n entirely for whole matches
}
// The key is longer than n.Key. Remove the remaining suffix
// from the subtrie. Child can never be nil here since the
// subtrie must contain at least two other values with keys
// longer than n.Key.
- child := t.delete(n.Val, key[len(n.Key):])
+ child, err := t.delete(n.Val, append(prefix, key[:len(n.Key)]...), key[len(n.Key):])
+ if err != nil {
+ return nil, err
+ }
switch child := child.(type) {
case shortNode:
// Deleting from the subtrie reduced it to another
@@ -215,13 +295,17 @@ func (t *Trie) delete(n node, key []byte) node {
// always creates a new slice) instead of append to
// avoid modifying n.Key since it might be shared with
// other nodes.
- return shortNode{concat(n.Key, child.Key...), child.Val}
+ return shortNode{concat(n.Key, child.Key...), child.Val}, nil
default:
- return shortNode{n.Key, child}
+ return shortNode{n.Key, child}, nil
}
case fullNode:
- n[key[0]] = t.delete(n[key[0]], key[1:])
+ nn, err := t.delete(n[key[0]], append(prefix, key[0]), key[1:])
+ if err != nil {
+ return nil, err
+ }
+ n[key[0]] = nn
// Check how many non-nil entries are left after deleting and
// reduce the full node to a short node if only one entry is
// left. Since n must've contained at least two children
@@ -250,21 +334,24 @@ func (t *Trie) delete(n node, key []byte) node {
// shortNode{..., shortNode{...}}. Since the entry
// might not be loaded yet, resolve it just for this
// check.
- cnode := t.resolve(n[pos])
+ cnode, err := t.resolve(n[pos], prefix, []byte{byte(pos)})
+ if err != nil {
+ return nil, err
+ }
if cnode, ok := cnode.(shortNode); ok {
k := append([]byte{byte(pos)}, cnode.Key...)
- return shortNode{k, cnode.Val}
+ return shortNode{k, cnode.Val}, nil
}
}
// Otherwise, n is replaced by a one-nibble short node
// containing the child.
- return shortNode{[]byte{byte(pos)}, n[pos]}
+ return shortNode{[]byte{byte(pos)}, n[pos]}, nil
}
// n still contains at least two values and cannot be reduced.
- return n
+ return n, nil
case nil:
- return nil
+ return nil, nil
case hashNode:
// We've hit a part of the trie that isn't loaded yet. Load
@@ -273,7 +360,11 @@ func (t *Trie) delete(n node, key []byte) node {
//
// TODO: track whether deletion actually hit a key and keep
// n as a hash node if it didn't.
- return t.delete(t.resolveHash(n), key)
+ rn, err := t.resolveHash(n, prefix, key)
+ if err != nil {
+ return nil, err
+ }
+ return t.delete(rn, prefix, key)
default:
panic(fmt.Sprintf("%T: invalid node: %v (%v)", n, n, key))
@@ -287,34 +378,31 @@ func concat(s1 []byte, s2 ...byte) []byte {
return r
}
-func (t *Trie) resolve(n node) node {
+func (t *Trie) resolve(n node, prefix, suffix []byte) (node, error) {
if n, ok := n.(hashNode); ok {
- return t.resolveHash(n)
+ return t.resolveHash(n, prefix, suffix)
}
- return n
+ return n, nil
}
-func (t *Trie) resolveHash(n hashNode) node {
+func (t *Trie) resolveHash(n hashNode, prefix, suffix []byte) (node, error) {
if v, ok := globalCache.Get(n); ok {
- return v
+ return v, nil
}
enc, err := t.db.Get(n)
if err != nil || enc == nil {
- // TODO: This needs to be improved to properly distinguish errors.
- // Disk I/O errors shouldn't produce nil (and cause a
- // consensus failure or weird crash), but it is unclear how
- // they could be handled because the entire stack above the trie isn't
- // prepared to cope with missing state nodes.
- if glog.V(logger.Error) {
- glog.Errorf("Dangling hash node ref %x: %v", n, err)
+ return nil, &MissingNodeError{
+ RootHash: t.originalRoot,
+ NodeHash: common.BytesToHash(n),
+ KeyPrefix: prefix,
+ KeySuffix: suffix,
}
- return nil
}
dec := mustDecodeNode(n, enc)
if dec != nil {
globalCache.Put(n, dec)
}
- return dec
+ return dec, nil
}
// Root returns the root hash of the trie.
diff --git a/trie/trie_test.go b/trie/trie_test.go
index c96861bed..35d043cdf 100644
--- a/trie/trie_test.go
+++ b/trie/trie_test.go
@@ -64,11 +64,84 @@ func TestMissingRoot(t *testing.T) {
if trie != nil {
t.Error("New returned non-nil trie for invalid root")
}
- if err != ErrMissingRoot {
+ if _, ok := err.(*MissingNodeError); !ok {
t.Error("New returned wrong error: %v", err)
}
}
+func TestMissingNode(t *testing.T) {
+ db, _ := ethdb.NewMemDatabase()
+ trie, _ := New(common.Hash{}, db)
+ updateString(trie, "120000", "qwerqwerqwerqwerqwerqwerqwerqwer")
+ updateString(trie, "123456", "asdfasdfasdfasdfasdfasdfasdfasdf")
+ root, _ := trie.Commit()
+
+ ClearGlobalCache()
+
+ trie, _ = New(root, db)
+ _, err := trie.TryGet([]byte("120000"))
+ if err != nil {
+ t.Errorf("Unexpected error: %v", err)
+ }
+
+ trie, _ = New(root, db)
+ _, err = trie.TryGet([]byte("120099"))
+ if err != nil {
+ t.Errorf("Unexpected error: %v", err)
+ }
+
+ trie, _ = New(root, db)
+ _, err = trie.TryGet([]byte("123456"))
+ if err != nil {
+ t.Errorf("Unexpected error: %v", err)
+ }
+
+ trie, _ = New(root, db)
+ err = trie.TryUpdate([]byte("120099"), []byte("zxcvzxcvzxcvzxcvzxcvzxcvzxcvzxcv"))
+ if err != nil {
+ t.Errorf("Unexpected error: %v", err)
+ }
+
+ trie, _ = New(root, db)
+ err = trie.TryDelete([]byte("123456"))
+ if err != nil {
+ t.Errorf("Unexpected error: %v", err)
+ }
+
+ db.Delete(common.FromHex("e1d943cc8f061a0c0b98162830b970395ac9315654824bf21b73b891365262f9"))
+ ClearGlobalCache()
+
+ trie, _ = New(root, db)
+ _, err = trie.TryGet([]byte("120000"))
+ if _, ok := err.(*MissingNodeError); !ok {
+ t.Errorf("Wrong error: %v", err)
+ }
+
+ trie, _ = New(root, db)
+ _, err = trie.TryGet([]byte("120099"))
+ if _, ok := err.(*MissingNodeError); !ok {
+ t.Errorf("Wrong error: %v", err)
+ }
+
+ trie, _ = New(root, db)
+ _, err = trie.TryGet([]byte("123456"))
+ if err != nil {
+ t.Errorf("Unexpected error: %v", err)
+ }
+
+ trie, _ = New(root, db)
+ err = trie.TryUpdate([]byte("120099"), []byte("zxcv"))
+ if _, ok := err.(*MissingNodeError); !ok {
+ t.Errorf("Wrong error: %v", err)
+ }
+
+ trie, _ = New(root, db)
+ err = trie.TryDelete([]byte("123456"))
+ if _, ok := err.(*MissingNodeError); !ok {
+ t.Errorf("Wrong error: %v", err)
+ }
+}
+
func TestInsert(t *testing.T) {
trie := newEmpty()
diff --git a/whisper/message_test.go b/whisper/message_test.go
index 6ff95efff..d70da40a4 100644
--- a/whisper/message_test.go
+++ b/whisper/message_test.go
@@ -23,6 +23,7 @@ import (
"time"
"github.com/ethereum/go-ethereum/crypto"
+ "github.com/ethereum/go-ethereum/crypto/secp256k1"
)
// Tests whether a message can be wrapped without any identity or encryption.
@@ -72,8 +73,8 @@ func TestMessageCleartextSignRecover(t *testing.T) {
if pubKey == nil {
t.Fatalf("failed to recover public key")
}
- p1 := elliptic.Marshal(crypto.S256(), key.PublicKey.X, key.PublicKey.Y)
- p2 := elliptic.Marshal(crypto.S256(), pubKey.X, pubKey.Y)
+ p1 := elliptic.Marshal(secp256k1.S256(), key.PublicKey.X, key.PublicKey.Y)
+ p2 := elliptic.Marshal(secp256k1.S256(), pubKey.X, pubKey.Y)
if !bytes.Equal(p1, p2) {
t.Fatalf("public key mismatch: have 0x%x, want 0x%x", p2, p1)
}
@@ -150,8 +151,8 @@ func TestMessageFullCrypto(t *testing.T) {
if pubKey == nil {
t.Fatalf("failed to recover public key")
}
- p1 := elliptic.Marshal(crypto.S256(), fromKey.PublicKey.X, fromKey.PublicKey.Y)
- p2 := elliptic.Marshal(crypto.S256(), pubKey.X, pubKey.Y)
+ p1 := elliptic.Marshal(secp256k1.S256(), fromKey.PublicKey.X, fromKey.PublicKey.Y)
+ p2 := elliptic.Marshal(secp256k1.S256(), pubKey.X, pubKey.Y)
if !bytes.Equal(p1, p2) {
t.Fatalf("public key mismatch: have 0x%x, want 0x%x", p2, p1)
}