diff options
-rw-r--r-- | Makefile | 75 | ||||
-rw-r--r-- | accounts/accounts_test.go | 2 | ||||
-rw-r--r-- | cmd/geth/js.go | 2 | ||||
-rw-r--r-- | cmd/geth/main.go | 7 | ||||
-rw-r--r-- | cmd/utils/flags.go | 32 | ||||
-rw-r--r-- | common/types.go | 4 | ||||
-rw-r--r-- | crypto/crypto.go | 26 | ||||
-rw-r--r-- | crypto/crypto_test.go | 8 | ||||
-rw-r--r-- | crypto/ecies/asn1.go | 7 | ||||
-rw-r--r-- | crypto/ecies/ecies.go | 1 | ||||
-rw-r--r-- | crypto/ecies/ecies_test.go | 121 | ||||
-rw-r--r-- | crypto/ecies/params.go | 14 | ||||
-rw-r--r-- | crypto/key.go | 5 | ||||
-rw-r--r-- | crypto/secp256k1/curve.go (renamed from crypto/curve.go) | 168 | ||||
-rw-r--r-- | crypto/secp256k1/curve_test.go | 39 | ||||
-rw-r--r-- | crypto/secp256k1/pubkey_scalar_mul.h | 56 | ||||
-rw-r--r-- | crypto/secp256k1/secp256.go | 24 | ||||
-rw-r--r-- | crypto/secp256k1/secp256_test.go | 33 | ||||
-rw-r--r-- | eth/backend.go | 2 | ||||
-rw-r--r-- | p2p/discover/node.go | 2 | ||||
-rw-r--r-- | p2p/rlpx.go | 4 | ||||
-rw-r--r-- | p2p/rlpx_test.go | 2 | ||||
-rw-r--r-- | trie/arc.go | 12 | ||||
-rw-r--r-- | trie/errors.go | 41 | ||||
-rw-r--r-- | trie/iterator.go | 19 | ||||
-rw-r--r-- | trie/proof.go | 11 | ||||
-rw-r--r-- | trie/secure_trie.go | 49 | ||||
-rw-r--r-- | trie/trie.go | 204 | ||||
-rw-r--r-- | trie/trie_test.go | 75 | ||||
-rw-r--r-- | whisper/message_test.go | 9 |
30 files changed, 743 insertions, 311 deletions
@@ -3,11 +3,12 @@ # don't need to bother with make. .PHONY: geth geth-cross evm all test travis-test-with-coverage xgo clean -.PHONY: geth-linux geth-linux-arm geth-linux-386 geth-linux-amd64 +.PHONY: geth-linux geth-linux-386 geth-linux-amd64 +.PHONY: geth-linux-arm geth-linux-arm-5 geth-linux-arm-6 geth-linux-arm-7 geth-linux-arm64 .PHONY: geth-darwin geth-darwin-386 geth-darwin-amd64 .PHONY: geth-windows geth-windows-386 geth-windows-amd64 -.PHONY: geth-android geth-android-16 geth-android-21 -.PHONY: geth-ios geth-ios-5.0 geth-ios-8.1 +.PHONY: geth-android +.PHONY: geth-ios geth-ios-arm-7 geth-ios-arm64 GOBIN = build/bin @@ -20,19 +21,14 @@ geth: @echo "Done building." @echo "Run \"$(GOBIN)/geth\" to launch geth." -geth-cross: geth-linux geth-darwin geth-windows geth-android +geth-cross: geth-linux geth-darwin geth-windows geth-android geth-ios @echo "Full cross compilation done:" @ls -l $(GOBIN)/geth-* -geth-linux: xgo geth-linux-arm geth-linux-386 geth-linux-amd64 +geth-linux: geth-linux-386 geth-linux-amd64 geth-linux-arm @echo "Linux cross compilation done:" @ls -l $(GOBIN)/geth-linux-* -geth-linux-arm: xgo - build/env.sh $(GOBIN)/xgo --go=$(GO) --buildmode=$(MODE) --dest=$(GOBIN) --deps=$(CROSSDEPS) --targets=linux/arm -v $(shell build/flags.sh) ./cmd/geth - @echo "Linux ARM cross compilation done:" - @ls -l $(GOBIN)/geth-linux-* | grep arm - geth-linux-386: xgo build/env.sh $(GOBIN)/xgo --go=$(GO) --buildmode=$(MODE) --dest=$(GOBIN) --deps=$(CROSSDEPS) --targets=linux/386 -v $(shell build/flags.sh) ./cmd/geth @echo "Linux 386 cross compilation done:" @@ -43,7 +39,31 @@ geth-linux-amd64: xgo @echo "Linux amd64 cross compilation done:" @ls -l $(GOBIN)/geth-linux-* | grep amd64 -geth-darwin: xgo geth-darwin-386 geth-darwin-amd64 +geth-linux-arm: geth-linux-arm-5 geth-linux-arm-6 geth-linux-arm-7 geth-linux-arm64 + @echo "Linux ARM cross compilation done:" + @ls -l $(GOBIN)/geth-linux-* | grep arm + +geth-linux-arm-5: xgo + build/env.sh $(GOBIN)/xgo --go=$(GO) --buildmode=$(MODE) --dest=$(GOBIN) --deps=$(CROSSDEPS) --targets=linux/arm-5 -v $(shell build/flags.sh) ./cmd/geth + @echo "Linux ARMv5 cross compilation done:" + @ls -l $(GOBIN)/geth-linux-* | grep arm-5 + +geth-linux-arm-6: xgo + build/env.sh $(GOBIN)/xgo --go=$(GO) --buildmode=$(MODE) --dest=$(GOBIN) --deps=$(CROSSDEPS) --targets=linux/arm-6 -v $(shell build/flags.sh) ./cmd/geth + @echo "Linux ARMv6 cross compilation done:" + @ls -l $(GOBIN)/geth-linux-* | grep arm-6 + +geth-linux-arm-7: xgo + build/env.sh $(GOBIN)/xgo --go=$(GO) --buildmode=$(MODE) --dest=$(GOBIN) --deps=$(CROSSDEPS) --targets=linux/arm-7 -v $(shell build/flags.sh) ./cmd/geth + @echo "Linux ARMv7 cross compilation done:" + @ls -l $(GOBIN)/geth-linux-* | grep arm-7 + +geth-linux-arm64: xgo + build/env.sh $(GOBIN)/xgo --go=$(GO) --buildmode=$(MODE) --dest=$(GOBIN) --deps=$(CROSSDEPS) --targets=linux/arm64 -v $(shell build/flags.sh) ./cmd/geth + @echo "Linux ARM64 cross compilation done:" + @ls -l $(GOBIN)/geth-linux-* | grep arm64 + +geth-darwin: geth-darwin-386 geth-darwin-amd64 @echo "Darwin cross compilation done:" @ls -l $(GOBIN)/geth-darwin-* @@ -57,7 +77,7 @@ geth-darwin-amd64: xgo @echo "Darwin amd64 cross compilation done:" @ls -l $(GOBIN)/geth-darwin-* | grep amd64 -geth-windows: xgo geth-windows-386 geth-windows-amd64 +geth-windows: geth-windows-386 geth-windows-amd64 @echo "Windows cross compilation done:" @ls -l $(GOBIN)/geth-windows-* @@ -71,33 +91,24 @@ geth-windows-amd64: xgo @echo "Windows amd64 cross compilation done:" @ls -l $(GOBIN)/geth-windows-* | grep amd64 -geth-android: xgo geth-android-16 geth-android-21 +geth-android: xgo + build/env.sh $(GOBIN)/xgo --go=$(GO) --buildmode=$(MODE) --dest=$(GOBIN) --deps=$(CROSSDEPS) --targets=android/* -v $(shell build/flags.sh) ./cmd/geth @echo "Android cross compilation done:" @ls -l $(GOBIN)/geth-android-* -geth-android-16: xgo - build/env.sh $(GOBIN)/xgo --go=$(GO) --buildmode=$(MODE) --dest=$(GOBIN) --deps=$(CROSSDEPS) --targets=android-16/* -v $(shell build/flags.sh) ./cmd/geth - @echo "Android 16 cross compilation done:" - @ls -l $(GOBIN)/geth-android-16-* - -geth-android-21: xgo - build/env.sh $(GOBIN)/xgo --go=$(GO) --buildmode=$(MODE) --dest=$(GOBIN) --deps=$(CROSSDEPS) --targets=android-21/* -v $(shell build/flags.sh) ./cmd/geth - @echo "Android 21 cross compilation done:" - @ls -l $(GOBIN)/geth-android-21-* - -geth-ios: xgo geth-ios-5.0 geth-ios-8.1 +geth-ios: geth-ios-arm-7 geth-ios-arm64 @echo "iOS cross compilation done:" @ls -l $(GOBIN)/geth-ios-* -geth-ios-5.0: - build/env.sh $(GOBIN)/xgo --go=$(GO) --buildmode=$(MODE) --dest=$(GOBIN) --deps=$(CROSSDEPS) --depsargs=--disable-assembly --targets=ios-5.0/* -v $(shell build/flags.sh) ./cmd/geth - @echo "iOS 5.0 cross compilation done:" - @ls -l $(GOBIN)/geth-ios-5.0-* +geth-ios-arm-7: xgo + build/env.sh $(GOBIN)/xgo --go=$(GO) --buildmode=$(MODE) --dest=$(GOBIN) --deps=$(CROSSDEPS) --depsargs=--disable-assembly --targets=ios/arm-7 -v $(shell build/flags.sh) ./cmd/geth + @echo "iOS ARMv7 cross compilation done:" + @ls -l $(GOBIN)/geth-ios-* | grep arm-7 -geth-ios-8.1: - build/env.sh $(GOBIN)/xgo --go=$(GO) --buildmode=$(MODE) --dest=$(GOBIN) --deps=$(CROSSDEPS) --depsargs=--disable-assembly --targets=ios-8.1/* -v $(shell build/flags.sh) ./cmd/geth - @echo "iOS 8.1 cross compilation done:" - @ls -l $(GOBIN)/geth-ios-8.1-* +geth-ios-arm64: xgo + build/env.sh $(GOBIN)/xgo --go=$(GO) --buildmode=$(MODE) --dest=$(GOBIN) --deps=$(CROSSDEPS) --depsargs=--disable-assembly --targets=ios-7.0/arm64 -v $(shell build/flags.sh) ./cmd/geth + @echo "iOS ARM64 cross compilation done:" + @ls -l $(GOBIN)/geth-ios-* | grep arm64 evm: build/env.sh $(GOROOT)/bin/go install -v $(shell build/flags.sh) ./cmd/evm diff --git a/accounts/accounts_test.go b/accounts/accounts_test.go index d7a8a2b85..55ddecdea 100644 --- a/accounts/accounts_test.go +++ b/accounts/accounts_test.go @@ -68,7 +68,7 @@ func TestTimedUnlock(t *testing.T) { } // Signing fails again after automatic locking - time.Sleep(150 * time.Millisecond) + time.Sleep(350 * time.Millisecond) _, err = am.Sign(a1, testSigData) if err != ErrLocked { t.Fatal("Signing should've failed with ErrLocked timeout expired, got ", err) diff --git a/cmd/geth/js.go b/cmd/geth/js.go index 196f3af59..843c9a5b5 100644 --- a/cmd/geth/js.go +++ b/cmd/geth/js.go @@ -245,7 +245,7 @@ func (self *jsre) batch(statement string) { func (self *jsre) welcome() { self.re.Run(` (function () { - console.log('instance: ' + web3.version.client); + console.log('instance: ' + web3.version.node); console.log(' datadir: ' + admin.datadir); console.log("coinbase: " + eth.coinbase); var ts = 1000 * eth.getBlock(eth.blockNumber).timestamp; diff --git a/cmd/geth/main.go b/cmd/geth/main.go index 3a5471845..6ec30cebc 100644 --- a/cmd/geth/main.go +++ b/cmd/geth/main.go @@ -464,9 +464,12 @@ func execScripts(ctx *cli.Context) { node.Stop() } +// tries unlocking the specified account a few times. func unlockAccount(ctx *cli.Context, accman *accounts.Manager, address string, i int, passwords []string) (common.Address, string) { - // Try to unlock the specified account a few times - account := utils.MakeAddress(accman, address) + account, err := utils.MakeAddress(accman, address) + if err != nil { + utils.Fatalf("Unlock error: %v", err) + } for trials := 0; trials < 3; trials++ { prompt := fmt.Sprintf("Unlocking account %s | Attempt %d/%d", address, trials+1, 3) diff --git a/cmd/utils/flags.go b/cmd/utils/flags.go index 53126f9e5..839ec3f02 100644 --- a/cmd/utils/flags.go +++ b/cmd/utils/flags.go @@ -518,47 +518,41 @@ func MakeAccountManager(ctx *cli.Context) *accounts.Manager { // MakeAddress converts an account specified directly as a hex encoded string or // a key index in the key store to an internal account representation. -func MakeAddress(accman *accounts.Manager, account string) common.Address { +func MakeAddress(accman *accounts.Manager, account string) (a common.Address, err error) { // If the specified account is a valid address, return it if common.IsHexAddress(account) { - return common.HexToAddress(account) + return common.HexToAddress(account), nil } // Otherwise try to interpret the account as a keystore index index, err := strconv.Atoi(account) if err != nil { - Fatalf("Invalid account address or index: '%s'", account) + return a, fmt.Errorf("invalid account address or index %q", account) } hex, err := accman.AddressByIndex(index) if err != nil { - Fatalf("Failed to retrieve requested account #%d: %v", index, err) + return a, fmt.Errorf("can't get account #%d (%v)", index, err) } - return common.HexToAddress(hex) + return common.HexToAddress(hex), nil } // MakeEtherbase retrieves the etherbase either from the directly specified // command line flags or from the keystore if CLI indexed. func MakeEtherbase(accman *accounts.Manager, ctx *cli.Context) common.Address { - // If the specified etherbase is a valid address, return it - etherbase := ctx.GlobalString(EtherbaseFlag.Name) - if common.IsHexAddress(etherbase) { - return common.HexToAddress(etherbase) - } - // If no etherbase was specified and no accounts are known, bail out accounts, _ := accman.Accounts() - if etherbase == "" && len(accounts) == 0 { + if !ctx.GlobalIsSet(EtherbaseFlag.Name) && len(accounts) == 0 { glog.V(logger.Error).Infoln("WARNING: No etherbase set and no accounts found as default") return common.Address{} } - // Otherwise try to interpret the parameter as a keystore index - index, err := strconv.Atoi(etherbase) - if err != nil { - Fatalf("Invalid account address or index: '%s'", etherbase) + etherbase := ctx.GlobalString(EtherbaseFlag.Name) + if etherbase == "" { + return common.Address{} } - hex, err := accman.AddressByIndex(index) + // If the specified etherbase is a valid address, return it + addr, err := MakeAddress(accman, etherbase) if err != nil { - Fatalf("Failed to set requested account #%d as etherbase: %v", index, err) + Fatalf("Option %q: %v", EtherbaseFlag.Name, err) } - return common.HexToAddress(hex) + return addr } // MakeMinerExtra resolves extradata for the miner from the set command line flags diff --git a/common/types.go b/common/types.go index ea5838188..acbd5b28d 100644 --- a/common/types.go +++ b/common/types.go @@ -95,10 +95,10 @@ func HexToAddress(s string) Address { return BytesToAddress(FromHex(s)) } // IsHexAddress verifies whether a string can represent a valid hex-encoded // Ethereum address or not. func IsHexAddress(s string) bool { - if len(s) == 2+2*AddressLength && IsHex(s[2:]) { + if len(s) == 2+2*AddressLength && IsHex(s) { return true } - if len(s) == 2*AddressLength && IsHex(s) { + if len(s) == 2*AddressLength && IsHex("0x"+s) { return true } return false diff --git a/crypto/crypto.go b/crypto/crypto.go index 8685d62d3..7d7623753 100644 --- a/crypto/crypto.go +++ b/crypto/crypto.go @@ -43,14 +43,6 @@ import ( "golang.org/x/crypto/ripemd160" ) -var secp256k1n *big.Int - -func init() { - // specify the params for the s256 curve - ecies.AddParamsForCurve(S256(), ecies.ECIES_AES128_SHA256) - secp256k1n = common.String2Big("0xfffffffffffffffffffffffffffffffebaaedce6af48a03bbfd25e8cd0364141") -} - func Sha3(data ...[]byte) []byte { d := sha3.NewKeccak256() for _, b := range data { @@ -99,9 +91,9 @@ func ToECDSA(prv []byte) *ecdsa.PrivateKey { } priv := new(ecdsa.PrivateKey) - priv.PublicKey.Curve = S256() + priv.PublicKey.Curve = secp256k1.S256() priv.D = common.BigD(prv) - priv.PublicKey.X, priv.PublicKey.Y = S256().ScalarBaseMult(prv) + priv.PublicKey.X, priv.PublicKey.Y = secp256k1.S256().ScalarBaseMult(prv) return priv } @@ -116,15 +108,15 @@ func ToECDSAPub(pub []byte) *ecdsa.PublicKey { if len(pub) == 0 { return nil } - x, y := elliptic.Unmarshal(S256(), pub) - return &ecdsa.PublicKey{S256(), x, y} + x, y := elliptic.Unmarshal(secp256k1.S256(), pub) + return &ecdsa.PublicKey{secp256k1.S256(), x, y} } func FromECDSAPub(pub *ecdsa.PublicKey) []byte { if pub == nil || pub.X == nil || pub.Y == nil { return nil } - return elliptic.Marshal(S256(), pub.X, pub.Y) + return elliptic.Marshal(secp256k1.S256(), pub.X, pub.Y) } // HexToECDSA parses a secp256k1 private key. @@ -168,7 +160,7 @@ func SaveECDSA(file string, key *ecdsa.PrivateKey) error { } func GenerateKey() (*ecdsa.PrivateKey, error) { - return ecdsa.GenerateKey(S256(), rand.Reader) + return ecdsa.GenerateKey(secp256k1.S256(), rand.Reader) } func ValidateSignatureValues(v byte, r, s *big.Int) bool { @@ -176,7 +168,7 @@ func ValidateSignatureValues(v byte, r, s *big.Int) bool { return false } vint := uint32(v) - if r.Cmp(secp256k1n) < 0 && s.Cmp(secp256k1n) < 0 && (vint == 27 || vint == 28) { + if r.Cmp(secp256k1.N) < 0 && s.Cmp(secp256k1.N) < 0 && (vint == 27 || vint == 28) { return true } else { return false @@ -189,8 +181,8 @@ func SigToPub(hash, sig []byte) (*ecdsa.PublicKey, error) { return nil, err } - x, y := elliptic.Unmarshal(S256(), s) - return &ecdsa.PublicKey{S256(), x, y}, nil + x, y := elliptic.Unmarshal(secp256k1.S256(), s) + return &ecdsa.PublicKey{secp256k1.S256(), x, y}, nil } func Sign(hash []byte, prv *ecdsa.PrivateKey) (sig []byte, err error) { diff --git a/crypto/crypto_test.go b/crypto/crypto_test.go index fdd9c1ee8..d5e19a4bb 100644 --- a/crypto/crypto_test.go +++ b/crypto/crypto_test.go @@ -181,7 +181,7 @@ func TestValidateSignatureValues(t *testing.T) { minusOne := big.NewInt(-1) one := common.Big1 zero := common.Big0 - secp256k1nMinus1 := new(big.Int).Sub(secp256k1n, common.Big1) + secp256k1nMinus1 := new(big.Int).Sub(secp256k1.N, common.Big1) // correct v,r,s check(true, 27, one, one) @@ -208,9 +208,9 @@ func TestValidateSignatureValues(t *testing.T) { // correct sig with max r,s check(true, 27, secp256k1nMinus1, secp256k1nMinus1) // correct v, combinations of incorrect r,s at upper limit - check(false, 27, secp256k1n, secp256k1nMinus1) - check(false, 27, secp256k1nMinus1, secp256k1n) - check(false, 27, secp256k1n, secp256k1n) + check(false, 27, secp256k1.N, secp256k1nMinus1) + check(false, 27, secp256k1nMinus1, secp256k1.N) + check(false, 27, secp256k1.N, secp256k1.N) // current callers ensures r,s cannot be negative, but let's test for that too // as crypto package could be used stand-alone diff --git a/crypto/ecies/asn1.go b/crypto/ecies/asn1.go index 6eaf3d2ca..40dabd329 100644 --- a/crypto/ecies/asn1.go +++ b/crypto/ecies/asn1.go @@ -41,6 +41,8 @@ import ( "fmt" "hash" "math/big" + + "github.com/ethereum/go-ethereum/crypto/secp256k1" ) var ( @@ -81,6 +83,7 @@ func doScheme(base, v []int) asn1.ObjectIdentifier { type secgNamedCurve asn1.ObjectIdentifier var ( + secgNamedCurveS256 = secgNamedCurve{1, 3, 132, 0, 10} secgNamedCurveP256 = secgNamedCurve{1, 2, 840, 10045, 3, 1, 7} secgNamedCurveP384 = secgNamedCurve{1, 3, 132, 0, 34} secgNamedCurveP521 = secgNamedCurve{1, 3, 132, 0, 35} @@ -116,6 +119,8 @@ func (curve secgNamedCurve) Equal(curve2 secgNamedCurve) bool { func namedCurveFromOID(curve secgNamedCurve) elliptic.Curve { switch { + case curve.Equal(secgNamedCurveS256): + return secp256k1.S256() case curve.Equal(secgNamedCurveP256): return elliptic.P256() case curve.Equal(secgNamedCurveP384): @@ -134,6 +139,8 @@ func oidFromNamedCurve(curve elliptic.Curve) (secgNamedCurve, bool) { return secgNamedCurveP384, true case elliptic.P521(): return secgNamedCurveP521, true + case secp256k1.S256(): + return secgNamedCurveS256, true } return nil, false diff --git a/crypto/ecies/ecies.go b/crypto/ecies/ecies.go index a3b520dd5..65dc5b38b 100644 --- a/crypto/ecies/ecies.go +++ b/crypto/ecies/ecies.go @@ -125,6 +125,7 @@ func (prv *PrivateKey) GenerateShared(pub *PublicKey, skLen, macLen int) (sk []b if skLen+macLen > MaxSharedKeyLength(pub) { return nil, ErrSharedKeyTooBig } + x, _ := pub.Curve.ScalarMult(pub.X, pub.Y, prv.D.Bytes()) if x == nil { return nil, ErrSharedKeyIsPointAtInfinity diff --git a/crypto/ecies/ecies_test.go b/crypto/ecies/ecies_test.go index 1c391f938..6a0ea3f02 100644 --- a/crypto/ecies/ecies_test.go +++ b/crypto/ecies/ecies_test.go @@ -31,13 +31,18 @@ package ecies import ( "bytes" + "crypto/ecdsa" "crypto/elliptic" "crypto/rand" "crypto/sha256" + "encoding/hex" "flag" "fmt" "io/ioutil" + "math/big" "testing" + + "github.com/ethereum/go-ethereum/crypto/secp256k1" ) var dumpEnc bool @@ -65,7 +70,6 @@ func TestKDF(t *testing.T) { } } -var skLen int var ErrBadSharedKeys = fmt.Errorf("ecies: shared keys don't match") // cmpParams compares a set of ECIES parameters. We assume, as per the @@ -117,7 +121,7 @@ func TestSharedKey(t *testing.T) { fmt.Println(err.Error()) t.FailNow() } - skLen = MaxSharedKeyLength(&prv1.PublicKey) / 2 + skLen := MaxSharedKeyLength(&prv1.PublicKey) / 2 prv2, err := GenerateKey(rand.Reader, DefaultCurve, nil) if err != nil { @@ -143,6 +147,44 @@ func TestSharedKey(t *testing.T) { } } +func TestSharedKeyPadding(t *testing.T) { + // sanity checks + prv0 := hexKey("1adf5c18167d96a1f9a0b1ef63be8aa27eaf6032c233b2b38f7850cf5b859fd9") + prv1 := hexKey("97a076fc7fcd9208240668e31c9abee952cbb6e375d1b8febc7499d6e16f1a") + x0, _ := new(big.Int).SetString("1a8ed022ff7aec59dc1b440446bdda5ff6bcb3509a8b109077282b361efffbd8", 16) + x1, _ := new(big.Int).SetString("6ab3ac374251f638d0abb3ef596d1dc67955b507c104e5f2009724812dc027b8", 16) + y0, _ := new(big.Int).SetString("e040bd480b1deccc3bc40bd5b1fdcb7bfd352500b477cb9471366dbd4493f923", 16) + y1, _ := new(big.Int).SetString("8ad915f2b503a8be6facab6588731fefeb584fd2dfa9a77a5e0bba1ec439e4fa", 16) + + if prv0.PublicKey.X.Cmp(x0) != 0 { + t.Errorf("mismatched prv0.X:\nhave: %x\nwant: %x\n", prv0.PublicKey.X.Bytes(), x0.Bytes()) + } + if prv0.PublicKey.Y.Cmp(y0) != 0 { + t.Errorf("mismatched prv0.Y:\nhave: %x\nwant: %x\n", prv0.PublicKey.Y.Bytes(), y0.Bytes()) + } + if prv1.PublicKey.X.Cmp(x1) != 0 { + t.Errorf("mismatched prv1.X:\nhave: %x\nwant: %x\n", prv1.PublicKey.X.Bytes(), x1.Bytes()) + } + if prv1.PublicKey.Y.Cmp(y1) != 0 { + t.Errorf("mismatched prv1.Y:\nhave: %x\nwant: %x\n", prv1.PublicKey.Y.Bytes(), y1.Bytes()) + } + + // test shared secret generation + sk1, err := prv0.GenerateShared(&prv1.PublicKey, 16, 16) + if err != nil { + fmt.Println(err.Error()) + } + + sk2, err := prv1.GenerateShared(&prv0.PublicKey, 16, 16) + if err != nil { + t.Fatal(err.Error()) + } + + if !bytes.Equal(sk1, sk2) { + t.Fatal(ErrBadSharedKeys.Error()) + } +} + // Verify that the key generation code fails when too much key data is // requested. func TestTooBigSharedKey(t *testing.T) { @@ -158,13 +200,13 @@ func TestTooBigSharedKey(t *testing.T) { t.FailNow() } - _, err = prv1.GenerateShared(&prv2.PublicKey, skLen*2, skLen*2) + _, err = prv1.GenerateShared(&prv2.PublicKey, 32, 32) if err != ErrSharedKeyTooBig { fmt.Println("ecdh: shared key should be too large for curve") t.FailNow() } - _, err = prv2.GenerateShared(&prv1.PublicKey, skLen*2, skLen*2) + _, err = prv2.GenerateShared(&prv1.PublicKey, 32, 32) if err != ErrSharedKeyTooBig { fmt.Println("ecdh: shared key should be too large for curve") t.FailNow() @@ -176,25 +218,21 @@ func TestTooBigSharedKey(t *testing.T) { func TestMarshalPublic(t *testing.T) { prv, err := GenerateKey(rand.Reader, DefaultCurve, nil) if err != nil { - fmt.Println(err.Error()) - t.FailNow() + t.Fatalf("GenerateKey error: %s", err) } out, err := MarshalPublic(&prv.PublicKey) if err != nil { - fmt.Println(err.Error()) - t.FailNow() + t.Fatalf("MarshalPublic error: %s", err) } pub, err := UnmarshalPublic(out) if err != nil { - fmt.Println(err.Error()) - t.FailNow() + t.Fatalf("UnmarshalPublic error: %s", err) } if !cmpPublic(prv.PublicKey, *pub) { - fmt.Println("ecies: failed to unmarshal public key") - t.FailNow() + t.Fatal("ecies: failed to unmarshal public key") } } @@ -304,9 +342,26 @@ func BenchmarkGenSharedKeyP256(b *testing.B) { fmt.Println(err.Error()) b.FailNow() } + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := prv.GenerateShared(&prv.PublicKey, 16, 16) + if err != nil { + fmt.Println(err.Error()) + b.FailNow() + } + } +} +// Benchmark the generation of S256 shared keys. +func BenchmarkGenSharedKeyS256(b *testing.B) { + prv, err := GenerateKey(rand.Reader, secp256k1.S256(), nil) + if err != nil { + fmt.Println(err.Error()) + b.FailNow() + } + b.ResetTimer() for i := 0; i < b.N; i++ { - _, err := prv.GenerateShared(&prv.PublicKey, skLen, skLen) + _, err := prv.GenerateShared(&prv.PublicKey, 16, 16) if err != nil { fmt.Println(err.Error()) b.FailNow() @@ -511,3 +566,43 @@ func TestBasicKeyValidation(t *testing.T) { } } } + +// Verify GenerateShared against static values - useful when +// debugging changes in underlying libs +func TestSharedKeyStatic(t *testing.T) { + prv1 := hexKey("7ebbc6a8358bc76dd73ebc557056702c8cfc34e5cfcd90eb83af0347575fd2ad") + prv2 := hexKey("6a3d6396903245bba5837752b9e0348874e72db0c4e11e9c485a81b4ea4353b9") + + skLen := MaxSharedKeyLength(&prv1.PublicKey) / 2 + + sk1, err := prv1.GenerateShared(&prv2.PublicKey, skLen, skLen) + if err != nil { + fmt.Println(err.Error()) + t.FailNow() + } + + sk2, err := prv2.GenerateShared(&prv1.PublicKey, skLen, skLen) + if err != nil { + fmt.Println(err.Error()) + t.FailNow() + } + + if !bytes.Equal(sk1, sk2) { + fmt.Println(ErrBadSharedKeys.Error()) + t.FailNow() + } + + sk, _ := hex.DecodeString("167ccc13ac5e8a26b131c3446030c60fbfac6aa8e31149d0869f93626a4cdf62") + if !bytes.Equal(sk1, sk) { + t.Fatalf("shared secret mismatch: want: %x have: %x", sk, sk1) + } +} + +// TODO: remove after refactoring packages crypto and crypto/ecies +func hexKey(prv string) *PrivateKey { + priv := new(ecdsa.PrivateKey) + priv.PublicKey.Curve = secp256k1.S256() + priv.D, _ = new(big.Int).SetString(prv, 16) + priv.PublicKey.X, priv.PublicKey.Y = secp256k1.S256().ScalarBaseMult(priv.D.Bytes()) + return ImportECDSA(priv) +} diff --git a/crypto/ecies/params.go b/crypto/ecies/params.go index 97ddb0973..511c53ebc 100644 --- a/crypto/ecies/params.go +++ b/crypto/ecies/params.go @@ -41,13 +41,12 @@ import ( "crypto/sha512" "fmt" "hash" -) -// The default curve for this package is the NIST P256 curve, which -// provides security equivalent to AES-128. -var DefaultCurve = elliptic.P256() + "github.com/ethereum/go-ethereum/crypto/secp256k1" +) var ( + DefaultCurve = secp256k1.S256() ErrUnsupportedECDHAlgorithm = fmt.Errorf("ecies: unsupported ECDH algorithm") ErrUnsupportedECIESParameters = fmt.Errorf("ecies: unsupported ECIES parameters") ) @@ -101,9 +100,10 @@ var ( ) var paramsFromCurve = map[elliptic.Curve]*ECIESParams{ - elliptic.P256(): ECIES_AES128_SHA256, - elliptic.P384(): ECIES_AES256_SHA384, - elliptic.P521(): ECIES_AES256_SHA512, + secp256k1.S256(): ECIES_AES128_SHA256, + elliptic.P256(): ECIES_AES128_SHA256, + elliptic.P384(): ECIES_AES256_SHA384, + elliptic.P521(): ECIES_AES256_SHA512, } func AddParamsForCurve(curve elliptic.Curve, params *ECIESParams) { diff --git a/crypto/key.go b/crypto/key.go index 4ec43dfd7..ccf284ad8 100644 --- a/crypto/key.go +++ b/crypto/key.go @@ -25,6 +25,7 @@ import ( "strings" "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/crypto/secp256k1" "github.com/pborman/uuid" ) @@ -137,7 +138,7 @@ func NewKey(rand io.Reader) *Key { panic("key generation: could not read from random source: " + err.Error()) } reader := bytes.NewReader(randBytes) - privateKeyECDSA, err := ecdsa.GenerateKey(S256(), reader) + privateKeyECDSA, err := ecdsa.GenerateKey(secp256k1.S256(), reader) if err != nil { panic("key generation: ecdsa.GenerateKey failed: " + err.Error()) } @@ -155,7 +156,7 @@ func NewKeyForDirectICAP(rand io.Reader) *Key { panic("key generation: could not read from random source: " + err.Error()) } reader := bytes.NewReader(randBytes) - privateKeyECDSA, err := ecdsa.GenerateKey(S256(), reader) + privateKeyECDSA, err := ecdsa.GenerateKey(secp256k1.S256(), reader) if err != nil { panic("key generation: ecdsa.GenerateKey failed: " + err.Error()) } diff --git a/crypto/curve.go b/crypto/secp256k1/curve.go index 48f3f5e9c..6e44a6771 100644 --- a/crypto/curve.go +++ b/crypto/secp256k1/curve.go @@ -29,15 +29,22 @@ // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -package crypto +package secp256k1 import ( "crypto/elliptic" "io" "math/big" "sync" + "unsafe" ) +/* +#include "libsecp256k1/include/secp256k1.h" +extern int secp256k1_pubkey_scalar_mul(const secp256k1_context* ctx, const unsigned char *point, const unsigned char *scalar); +*/ +import "C" + // This code is from https://github.com/ThePiachu/GoBit and implements // several Koblitz elliptic curves over prime fields. // @@ -211,44 +218,37 @@ func (BitCurve *BitCurve) doubleJacobian(x, y, z *big.Int) (*big.Int, *big.Int, return x3, y3, z3 } -//TODO: double check if it is okay -// ScalarMult returns k*(Bx,By) where k is a number in big-endian form. -func (BitCurve *BitCurve) ScalarMult(Bx, By *big.Int, k []byte) (*big.Int, *big.Int) { - // We have a slight problem in that the identity of the group (the - // point at infinity) cannot be represented in (x, y) form on a finite - // machine. Thus the standard add/double algorithm has to be tweaked - // slightly: our initial state is not the identity, but x, and we - // ignore the first true bit in |k|. If we don't find any true bits in - // |k|, then we return nil, nil, because we cannot return the identity - // element. - - Bz := new(big.Int).SetInt64(1) - x := Bx - y := By - z := Bz - - seenFirstTrue := false - for _, byte := range k { - for bitNum := 0; bitNum < 8; bitNum++ { - if seenFirstTrue { - x, y, z = BitCurve.doubleJacobian(x, y, z) - } - if byte&0x80 == 0x80 { - if !seenFirstTrue { - seenFirstTrue = true - } else { - x, y, z = BitCurve.addJacobian(Bx, By, Bz, x, y, z) - } - } - byte <<= 1 - } +func (BitCurve *BitCurve) ScalarMult(Bx, By *big.Int, scalar []byte) (*big.Int, *big.Int) { + // Ensure scalar is exactly 32 bytes. We pad always, even if + // scalar is 32 bytes long, to avoid a timing side channel. + if len(scalar) > 32 { + panic("can't handle scalars > 256 bits") } - - if !seenFirstTrue { + padded := make([]byte, 32) + copy(padded[32-len(scalar):], scalar) + scalar = padded + + // Do the multiplication in C, updating point. + point := make([]byte, 64) + readBits(point[:32], Bx) + readBits(point[32:], By) + pointPtr := (*C.uchar)(unsafe.Pointer(&point[0])) + scalarPtr := (*C.uchar)(unsafe.Pointer(&scalar[0])) + res := C.secp256k1_pubkey_scalar_mul(context, pointPtr, scalarPtr) + + // Unpack the result and clear temporaries. + x := new(big.Int).SetBytes(point[:32]) + y := new(big.Int).SetBytes(point[32:]) + for i := range point { + point[i] = 0 + } + for i := range padded { + scalar[i] = 0 + } + if res != 1 { return nil, nil } - - return BitCurve.affineFromJacobian(x, y, z) + return x, y } // ScalarBaseMult returns k*G, where G is the base point of the group and k is @@ -312,86 +312,24 @@ func (BitCurve *BitCurve) Unmarshal(data []byte) (x, y *big.Int) { return } -//curve parameters taken from: -//http://www.secg.org/collateral/sec2_final.pdf - -var initonce sync.Once -var ecp160k1 *BitCurve -var ecp192k1 *BitCurve -var ecp224k1 *BitCurve -var ecp256k1 *BitCurve - -func initAll() { - initS160() - initS192() - initS224() - initS256() -} - -func initS160() { - // See SEC 2 section 2.4.1 - ecp160k1 = new(BitCurve) - ecp160k1.P, _ = new(big.Int).SetString("FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFAC73", 16) - ecp160k1.N, _ = new(big.Int).SetString("0100000000000000000001B8FA16DFAB9ACA16B6B3", 16) - ecp160k1.B, _ = new(big.Int).SetString("0000000000000000000000000000000000000007", 16) - ecp160k1.Gx, _ = new(big.Int).SetString("3B4C382CE37AA192A4019E763036F4F5DD4D7EBB", 16) - ecp160k1.Gy, _ = new(big.Int).SetString("938CF935318FDCED6BC28286531733C3F03C4FEE", 16) - ecp160k1.BitSize = 160 -} - -func initS192() { - // See SEC 2 section 2.5.1 - ecp192k1 = new(BitCurve) - ecp192k1.P, _ = new(big.Int).SetString("FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFEE37", 16) - ecp192k1.N, _ = new(big.Int).SetString("FFFFFFFFFFFFFFFFFFFFFFFE26F2FC170F69466A74DEFD8D", 16) - ecp192k1.B, _ = new(big.Int).SetString("000000000000000000000000000000000000000000000003", 16) - ecp192k1.Gx, _ = new(big.Int).SetString("DB4FF10EC057E9AE26B07D0280B7F4341DA5D1B1EAE06C7D", 16) - ecp192k1.Gy, _ = new(big.Int).SetString("9B2F2F6D9C5628A7844163D015BE86344082AA88D95E2F9D", 16) - ecp192k1.BitSize = 192 -} - -func initS224() { - // See SEC 2 section 2.6.1 - ecp224k1 = new(BitCurve) - ecp224k1.P, _ = new(big.Int).SetString("FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFE56D", 16) - ecp224k1.N, _ = new(big.Int).SetString("010000000000000000000000000001DCE8D2EC6184CAF0A971769FB1F7", 16) - ecp224k1.B, _ = new(big.Int).SetString("00000000000000000000000000000000000000000000000000000005", 16) - ecp224k1.Gx, _ = new(big.Int).SetString("A1455B334DF099DF30FC28A169A467E9E47075A90F7E650EB6B7A45C", 16) - ecp224k1.Gy, _ = new(big.Int).SetString("7E089FED7FBA344282CAFBD6F7E319F7C0B0BD59E2CA4BDB556D61A5", 16) - ecp224k1.BitSize = 224 -} - -func initS256() { - // See SEC 2 section 2.7.1 - ecp256k1 = new(BitCurve) - ecp256k1.P, _ = new(big.Int).SetString("FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFFC2F", 16) - ecp256k1.N, _ = new(big.Int).SetString("FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141", 16) - ecp256k1.B, _ = new(big.Int).SetString("0000000000000000000000000000000000000000000000000000000000000007", 16) - ecp256k1.Gx, _ = new(big.Int).SetString("79BE667EF9DCBBAC55A06295CE870B07029BFCDB2DCE28D959F2815B16F81798", 16) - ecp256k1.Gy, _ = new(big.Int).SetString("483ADA7726A3C4655DA4FBFC0E1108A8FD17B448A68554199C47D08FFB10D4B8", 16) - ecp256k1.BitSize = 256 -} - -// S160 returns a BitCurve which implements secp160k1 (see SEC 2 section 2.4.1) -func S160() *BitCurve { - initonce.Do(initAll) - return ecp160k1 -} - -// S192 returns a BitCurve which implements secp192k1 (see SEC 2 section 2.5.1) -func S192() *BitCurve { - initonce.Do(initAll) - return ecp192k1 -} - -// S224 returns a BitCurve which implements secp224k1 (see SEC 2 section 2.6.1) -func S224() *BitCurve { - initonce.Do(initAll) - return ecp224k1 -} +var ( + initonce sync.Once + theCurve *BitCurve +) // S256 returns a BitCurve which implements secp256k1 (see SEC 2 section 2.7.1) func S256() *BitCurve { - initonce.Do(initAll) - return ecp256k1 + initonce.Do(func() { + // See SEC 2 section 2.7.1 + // curve parameters taken from: + // http://www.secg.org/collateral/sec2_final.pdf + theCurve = new(BitCurve) + theCurve.P, _ = new(big.Int).SetString("FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFFC2F", 16) + theCurve.N, _ = new(big.Int).SetString("FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141", 16) + theCurve.B, _ = new(big.Int).SetString("0000000000000000000000000000000000000000000000000000000000000007", 16) + theCurve.Gx, _ = new(big.Int).SetString("79BE667EF9DCBBAC55A06295CE870B07029BFCDB2DCE28D959F2815B16F81798", 16) + theCurve.Gy, _ = new(big.Int).SetString("483ADA7726A3C4655DA4FBFC0E1108A8FD17B448A68554199C47D08FFB10D4B8", 16) + theCurve.BitSize = 256 + }) + return theCurve } diff --git a/crypto/secp256k1/curve_test.go b/crypto/secp256k1/curve_test.go new file mode 100644 index 000000000..d915ee852 --- /dev/null +++ b/crypto/secp256k1/curve_test.go @@ -0,0 +1,39 @@ +// Copyright 2015 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. + +package secp256k1 + +import ( + "bytes" + "encoding/hex" + "math/big" + "testing" +) + +func TestReadBits(t *testing.T) { + check := func(input string) { + want, _ := hex.DecodeString(input) + int, _ := new(big.Int).SetString(input, 16) + buf := make([]byte, len(want)) + readBits(buf, int) + if !bytes.Equal(buf, want) { + t.Errorf("have: %x\nwant: %x", buf, want) + } + } + check("000000000000000000000000000000000000000000000000000000FEFCF3F8F0") + check("0000000000012345000000000000000000000000000000000000FEFCF3F8F0") + check("18F8F8F1000111000110011100222004330052300000000000000000FEFCF3F8F0") +} diff --git a/crypto/secp256k1/pubkey_scalar_mul.h b/crypto/secp256k1/pubkey_scalar_mul.h new file mode 100644 index 000000000..0511545ec --- /dev/null +++ b/crypto/secp256k1/pubkey_scalar_mul.h @@ -0,0 +1,56 @@ +// Copyright 2015 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. + +/** Multiply point by scalar in constant time. + * Returns: 1: multiplication was successful + * 0: scalar was invalid (zero or overflow) + * Args: ctx: pointer to a context object (cannot be NULL) + * Out: point: the multiplied point (usually secret) + * In: point: pointer to a 64-byte bytepublic point, + encoded as two 256bit big-endian numbers. + * scalar: a 32-byte scalar with which to multiply the point + */ +int secp256k1_pubkey_scalar_mul(const secp256k1_context* ctx, unsigned char *point, const unsigned char *scalar) { + int ret = 0; + int overflow = 0; + secp256k1_fe feX, feY; + secp256k1_gej res; + secp256k1_ge ge; + secp256k1_scalar s; + ARG_CHECK(point != NULL); + ARG_CHECK(scalar != NULL); + (void)ctx; + + secp256k1_fe_set_b32(&feX, point); + secp256k1_fe_set_b32(&feY, point+32); + secp256k1_ge_set_xy(&ge, &feX, &feY); + secp256k1_scalar_set_b32(&s, scalar, &overflow); + if (overflow || secp256k1_scalar_is_zero(&s)) { + ret = 0; + } else { + secp256k1_ecmult_const(&res, &ge, &s); + secp256k1_ge_set_gej(&ge, &res); + /* Note: can't use secp256k1_pubkey_save here because it is not constant time. */ + secp256k1_fe_normalize(&ge.x); + secp256k1_fe_normalize(&ge.y); + secp256k1_fe_get_b32(point, &ge.x); + secp256k1_fe_get_b32(point+32, &ge.y); + ret = 1; + } + secp256k1_scalar_clear(&s); + return ret; +} + diff --git a/crypto/secp256k1/secp256.go b/crypto/secp256k1/secp256.go index 41a5608a5..8dc248145 100644 --- a/crypto/secp256k1/secp256.go +++ b/crypto/secp256k1/secp256.go @@ -20,6 +20,7 @@ package secp256k1 /* #cgo CFLAGS: -I./libsecp256k1 +#cgo CFLAGS: -I./libsecp256k1/src/ #cgo darwin CFLAGS: -I/usr/local/include #cgo freebsd CFLAGS: -I/usr/local/include #cgo linux,arm CFLAGS: -I/usr/local/arm/include @@ -35,6 +36,7 @@ package secp256k1 #define NDEBUG #include "./libsecp256k1/src/secp256k1.c" #include "./libsecp256k1/src/modules/recovery/main_impl.h" +#include "pubkey_scalar_mul.h" typedef void (*callbackFunc) (const char* msg, void* data); extern void secp256k1GoPanicIllegal(const char* msg, void* data); @@ -44,6 +46,7 @@ import "C" import ( "errors" + "math/big" "unsafe" "github.com/ethereum/go-ethereum/crypto/randentropy" @@ -56,13 +59,16 @@ import ( > store private keys in buffer and shuffle (deters persistance on swap disc) > byte permutation (changing) > xor with chaning random block (to deter scanning memory for 0x63) (stream cipher?) - > on disk: store keys in wallets */ // holds ptr to secp256k1_context_struct (see secp256k1/include/secp256k1.h) -var context *C.secp256k1_context +var ( + context *C.secp256k1_context + N *big.Int +) func init() { + N, _ = new(big.Int).SetString("fffffffffffffffffffffffffffffffebaaedce6af48a03bbfd25e8cd0364141", 16) // around 20 ms on a modern CPU. context = C.secp256k1_context_create(3) // SECP256K1_START_SIGN | SECP256K1_START_VERIFY C.secp256k1_context_set_illegal_callback(context, C.callbackFunc(C.secp256k1GoPanicIllegal), nil) @@ -78,7 +84,6 @@ var ( func GenerateKeyPair() ([]byte, []byte) { var seckey []byte = randentropy.GetEntropyCSPRNG(32) var seckey_ptr *C.uchar = (*C.uchar)(unsafe.Pointer(&seckey[0])) - var pubkey64 []byte = make([]byte, 64) // secp256k1_pubkey var pubkey65 []byte = make([]byte, 65) // 65 byte uncompressed pubkey pubkey64_ptr := (*C.secp256k1_pubkey)(unsafe.Pointer(&pubkey64[0])) @@ -254,3 +259,16 @@ func checkSignature(sig []byte) error { } return nil } + +// reads num into buf as big-endian bytes. +func readBits(buf []byte, num *big.Int) { + const wordLen = int(unsafe.Sizeof(big.Word(0))) + i := len(buf) + for _, d := range num.Bits() { + for j := 0; j < wordLen && i > 0; j++ { + i-- + buf[i] = byte(d) + d >>= 8 + } + } +} diff --git a/crypto/secp256k1/secp256_test.go b/crypto/secp256k1/secp256_test.go index cb71ea5e7..fc6fc9b32 100644 --- a/crypto/secp256k1/secp256_test.go +++ b/crypto/secp256k1/secp256_test.go @@ -24,7 +24,7 @@ import ( "github.com/ethereum/go-ethereum/crypto/randentropy" ) -const TestCount = 10000 +const TestCount = 1000 func TestPrivkeyGenerate(t *testing.T) { _, seckey := GenerateKeyPair() @@ -86,10 +86,7 @@ func TestSignAndRecover(t *testing.T) { func TestRandomMessagesWithSameKey(t *testing.T) { pubkey, seckey := GenerateKeyPair() keys := func() ([]byte, []byte) { - // Sign function zeroes the privkey so we need a new one in each call - newkey := make([]byte, len(seckey)) - copy(newkey, seckey) - return pubkey, newkey + return pubkey, seckey } signAndRecoverWithRandomMessages(t, keys) } @@ -209,30 +206,32 @@ func compactSigCheck(t *testing.T, sig []byte) { } } -// godep go test -v -run=XXX -bench=BenchmarkSignRandomInputEachRound +// godep go test -v -run=XXX -bench=BenchmarkSign // add -benchtime=10s to benchmark longer for more accurate average -func BenchmarkSignRandomInputEachRound(b *testing.B) { + +// to avoid compiler optimizing the benchmarked function call +var err error + +func BenchmarkSign(b *testing.B) { for i := 0; i < b.N; i++ { - b.StopTimer() _, seckey := GenerateKeyPair() msg := randentropy.GetEntropyCSPRNG(32) b.StartTimer() - if _, err := Sign(msg, seckey); err != nil { - b.Fatal(err) - } + _, e := Sign(msg, seckey) + err = e + b.StopTimer() } } -//godep go test -v -run=XXX -bench=BenchmarkRecoverRandomInputEachRound -func BenchmarkRecoverRandomInputEachRound(b *testing.B) { +//godep go test -v -run=XXX -bench=BenchmarkECRec +func BenchmarkRecover(b *testing.B) { for i := 0; i < b.N; i++ { - b.StopTimer() _, seckey := GenerateKeyPair() msg := randentropy.GetEntropyCSPRNG(32) sig, _ := Sign(msg, seckey) b.StartTimer() - if _, err := RecoverPubkey(msg, sig); err != nil { - b.Fatal(err) - } + _, e := RecoverPubkey(msg, sig) + err = e + b.StopTimer() } } diff --git a/eth/backend.go b/eth/backend.go index 0369f6afd..91f02db72 100644 --- a/eth/backend.go +++ b/eth/backend.go @@ -191,7 +191,7 @@ func New(ctx *node.ServiceContext, config *Config) (*Ethereum, error) { shutdownChan: make(chan bool), chainDb: chainDb, dappDb: dappDb, - eventMux: &event.TypeMux{}, + eventMux: ctx.EventMux, accountManager: config.AccountManager, etherbase: config.Etherbase, netVersionId: config.NetworkId, diff --git a/p2p/discover/node.go b/p2p/discover/node.go index a14f29424..dd19df3a2 100644 --- a/p2p/discover/node.go +++ b/p2p/discover/node.go @@ -210,7 +210,7 @@ func PubkeyID(pub *ecdsa.PublicKey) NodeID { // Pubkey returns the public key represented by the node ID. // It returns an error if the ID is not a point on the curve. func (id NodeID) Pubkey() (*ecdsa.PublicKey, error) { - p := &ecdsa.PublicKey{Curve: crypto.S256(), X: new(big.Int), Y: new(big.Int)} + p := &ecdsa.PublicKey{Curve: secp256k1.S256(), X: new(big.Int), Y: new(big.Int)} half := len(id) / 2 p.X.SetBytes(id[:half]) p.Y.SetBytes(id[half:]) diff --git a/p2p/rlpx.go b/p2p/rlpx.go index aaa733854..8f429d6ec 100644 --- a/p2p/rlpx.go +++ b/p2p/rlpx.go @@ -277,7 +277,7 @@ func newInitiatorHandshake(remoteID discover.NodeID) (*encHandshake, error) { return nil, err } // generate random keypair to use for signing - randpriv, err := ecies.GenerateKey(rand.Reader, crypto.S256(), nil) + randpriv, err := ecies.GenerateKey(rand.Reader, secp256k1.S256(), nil) if err != nil { return nil, err } @@ -376,7 +376,7 @@ func decodeAuthMsg(prv *ecdsa.PrivateKey, token []byte, auth []byte) (*encHandsh var err error h := new(encHandshake) // generate random keypair for session - h.randomPrivKey, err = ecies.GenerateKey(rand.Reader, crypto.S256(), nil) + h.randomPrivKey, err = ecies.GenerateKey(rand.Reader, secp256k1.S256(), nil) if err != nil { return nil, err } diff --git a/p2p/rlpx_test.go b/p2p/rlpx_test.go index 900353f0e..7cc7548e2 100644 --- a/p2p/rlpx_test.go +++ b/p2p/rlpx_test.go @@ -93,6 +93,7 @@ func testEncHandshake(token []byte) error { go func() { r := result{side: "initiator"} defer func() { output <- r }() + defer fd0.Close() dest := &discover.Node{ID: discover.PubkeyID(&prv1.PublicKey)} r.id, r.err = c0.doEncHandshake(prv0, dest) @@ -107,6 +108,7 @@ func testEncHandshake(token []byte) error { go func() { r := result{side: "receiver"} defer func() { output <- r }() + defer fd1.Close() r.id, r.err = c1.doEncHandshake(prv1, nil) if r.err != nil { diff --git a/trie/arc.go b/trie/arc.go index 9da012e16..fc7a3259f 100644 --- a/trie/arc.go +++ b/trie/arc.go @@ -62,6 +62,18 @@ func newARC(c int) *arc { } } +// Clear clears the cache +func (a *arc) Clear() { + a.mutex.Lock() + defer a.mutex.Unlock() + a.p = 0 + a.t1 = list.New() + a.b1 = list.New() + a.t2 = list.New() + a.b2 = list.New() + a.cache = make(map[string]*entry, a.c) +} + // Put inserts a new key-value pair into the cache. // This optimizes future access to this entry (side effect). func (a *arc) Put(key hashNode, value node) bool { diff --git a/trie/errors.go b/trie/errors.go new file mode 100644 index 000000000..a0f58f28f --- /dev/null +++ b/trie/errors.go @@ -0,0 +1,41 @@ +// Copyright 2014 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. + +package trie + +import ( + "fmt" + + "github.com/ethereum/go-ethereum/common" +) + +// MissingNodeError is returned by the trie functions (TryGet, TryUpdate, TryDelete) +// in the case where a trie node is not present in the local database. Contains +// information necessary for retrieving the missing node through an ODR service. +// +// NodeHash is the hash of the missing node +// RootHash is the original root of the trie that contains the node +// KeyPrefix is the prefix that leads from the root to the missing node (hex encoded) +// KeySuffix (optional) contains the rest of the key we were looking for, gives a +// hint on which further nodes should also be retrieved (hex encoded) +type MissingNodeError struct { + RootHash, NodeHash common.Hash + KeyPrefix, KeySuffix []byte +} + +func (err *MissingNodeError) Error() string { + return fmt.Sprintf("Missing trie node %064x", err.NodeHash) +} diff --git a/trie/iterator.go b/trie/iterator.go index 38555fe08..5f205e081 100644 --- a/trie/iterator.go +++ b/trie/iterator.go @@ -16,7 +16,12 @@ package trie -import "bytes" +import ( + "bytes" + + "github.com/ethereum/go-ethereum/logger" + "github.com/ethereum/go-ethereum/logger/glog" +) type Iterator struct { trie *Trie @@ -100,7 +105,11 @@ func (self *Iterator) next(node interface{}, key []byte, isIterStart bool) []byt } case hashNode: - return self.next(self.trie.resolveHash(node), key, isIterStart) + rn, err := self.trie.resolveHash(node, nil, nil) + if err != nil && glog.V(logger.Error) { + glog.Errorf("Unhandled trie error: %v", err) + } + return self.next(rn, key, isIterStart) } return nil } @@ -127,7 +136,11 @@ func (self *Iterator) key(node interface{}) []byte { } } case hashNode: - return self.key(self.trie.resolveHash(node)) + rn, err := self.trie.resolveHash(node, nil, nil) + if err != nil && glog.V(logger.Error) { + glog.Errorf("Unhandled trie error: %v", err) + } + return self.key(rn) } return nil diff --git a/trie/proof.go b/trie/proof.go index a705c49db..2e88bb50b 100644 --- a/trie/proof.go +++ b/trie/proof.go @@ -7,6 +7,8 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/crypto/sha3" + "github.com/ethereum/go-ethereum/logger" + "github.com/ethereum/go-ethereum/logger/glog" "github.com/ethereum/go-ethereum/rlp" ) @@ -39,7 +41,14 @@ func (t *Trie) Prove(key []byte) []rlp.RawValue { case nil: return nil case hashNode: - tn = t.resolveHash(n) + var err error + tn, err = t.resolveHash(n, nil, nil) + if err != nil { + if glog.V(logger.Error) { + glog.Errorf("Unhandled trie error: %v", err) + } + return nil + } default: panic(fmt.Sprintf("%T: invalid node: %v", tn, tn)) } diff --git a/trie/secure_trie.go b/trie/secure_trie.go index 47d1934d0..caeef3c3a 100644 --- a/trie/secure_trie.go +++ b/trie/secure_trie.go @@ -21,6 +21,8 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/crypto/sha3" + "github.com/ethereum/go-ethereum/logger" + "github.com/ethereum/go-ethereum/logger/glog" ) var secureKeyPrefix = []byte("secure-key-") @@ -46,8 +48,8 @@ type SecureTrie struct { // NewSecure creates a trie with an existing root node from db. // // If root is the zero hash or the sha3 hash of an empty string, the -// trie is initially empty. Otherwise, New will panics if db is nil -// and returns ErrMissingRoot if the root node cannpt be found. +// trie is initially empty. Otherwise, New will panic if db is nil +// and returns MissingNodeError if the root node cannot be found. // Accessing the trie loads nodes from db on demand. func NewSecure(root common.Hash, db Database) (*SecureTrie, error) { if db == nil { @@ -63,7 +65,18 @@ func NewSecure(root common.Hash, db Database) (*SecureTrie, error) { // Get returns the value for key stored in the trie. // The value bytes must not be modified by the caller. func (t *SecureTrie) Get(key []byte) []byte { - return t.Trie.Get(t.hashKey(key)) + res, err := t.TryGet(key) + if err != nil && glog.V(logger.Error) { + glog.Errorf("Unhandled trie error: %v", err) + } + return res +} + +// TryGet returns the value for key stored in the trie. +// The value bytes must not be modified by the caller. +// If a node was not found in the database, a MissingNodeError is returned. +func (t *SecureTrie) TryGet(key []byte) ([]byte, error) { + return t.Trie.TryGet(t.hashKey(key)) } // Update associates key with value in the trie. Subsequent calls to @@ -73,14 +86,40 @@ func (t *SecureTrie) Get(key []byte) []byte { // The value bytes must not be modified by the caller while they are // stored in the trie. func (t *SecureTrie) Update(key, value []byte) { + if err := t.TryUpdate(key, value); err != nil && glog.V(logger.Error) { + glog.Errorf("Unhandled trie error: %v", err) + } +} + +// TryUpdate associates key with value in the trie. Subsequent calls to +// Get will return value. If value has length zero, any existing value +// is deleted from the trie and calls to Get will return nil. +// +// The value bytes must not be modified by the caller while they are +// stored in the trie. +// +// If a node was not found in the database, a MissingNodeError is returned. +func (t *SecureTrie) TryUpdate(key, value []byte) error { hk := t.hashKey(key) - t.Trie.Update(hk, value) + err := t.Trie.TryUpdate(hk, value) + if err != nil { + return err + } t.Trie.db.Put(t.secKey(hk), key) + return nil } // Delete removes any existing value for key from the trie. func (t *SecureTrie) Delete(key []byte) { - t.Trie.Delete(t.hashKey(key)) + if err := t.TryDelete(key); err != nil && glog.V(logger.Error) { + glog.Errorf("Unhandled trie error: %v", err) + } +} + +// TryDelete removes any existing value for key from the trie. +// If a node was not found in the database, a MissingNodeError is returned. +func (t *SecureTrie) TryDelete(key []byte) error { + return t.Trie.TryDelete(t.hashKey(key)) } // GetKey returns the sha3 preimage of a hashed key that was diff --git a/trie/trie.go b/trie/trie.go index a3a383fb5..717296e27 100644 --- a/trie/trie.go +++ b/trie/trie.go @@ -19,7 +19,6 @@ package trie import ( "bytes" - "errors" "fmt" "hash" @@ -44,7 +43,10 @@ var ( emptyState = crypto.Sha3Hash(nil) ) -var ErrMissingRoot = errors.New("missing root node") +// ClearGlobalCache clears the global trie cache +func ClearGlobalCache() { + globalCache.Clear() +} // Database must be implemented by backing stores for the trie. type Database interface { @@ -67,8 +69,9 @@ type DatabaseWriter interface { // // Trie is not safe for concurrent use. type Trie struct { - root node - db Database + root node + db Database + originalRoot common.Hash *hasher } @@ -76,16 +79,19 @@ type Trie struct { // // If root is the zero hash or the sha3 hash of an empty string, the // trie is initially empty and does not require a database. Otherwise, -// New will panics if db is nil or root does not exist in the -// database. Accessing the trie loads nodes from db on demand. +// New will panic if db is nil and returns a MissingNodeError if root does +// not exist in the database. Accessing the trie loads nodes from db on demand. func New(root common.Hash, db Database) (*Trie, error) { - trie := &Trie{db: db} + trie := &Trie{db: db, originalRoot: root} if (root != common.Hash{}) && root != emptyRoot { if db == nil { panic("trie.New: cannot use existing root without a database") } if v, _ := trie.db.Get(root[:]); len(v) == 0 { - return nil, ErrMissingRoot + return nil, &MissingNodeError{ + RootHash: root, + NodeHash: root, + } } trie.root = hashNode(root.Bytes()) } @@ -100,28 +106,44 @@ func (t *Trie) Iterator() *Iterator { // Get returns the value for key stored in the trie. // The value bytes must not be modified by the caller. func (t *Trie) Get(key []byte) []byte { + res, err := t.TryGet(key) + if err != nil && glog.V(logger.Error) { + glog.Errorf("Unhandled trie error: %v", err) + } + return res +} + +// TryGet returns the value for key stored in the trie. +// The value bytes must not be modified by the caller. +// If a node was not found in the database, a MissingNodeError is returned. +func (t *Trie) TryGet(key []byte) ([]byte, error) { key = compactHexDecode(key) + pos := 0 tn := t.root - for len(key) > 0 { + for pos < len(key) { switch n := tn.(type) { case shortNode: - if len(key) < len(n.Key) || !bytes.Equal(n.Key, key[:len(n.Key)]) { - return nil + if len(key)-pos < len(n.Key) || !bytes.Equal(n.Key, key[pos:pos+len(n.Key)]) { + return nil, nil } tn = n.Val - key = key[len(n.Key):] + pos += len(n.Key) case fullNode: - tn = n[key[0]] - key = key[1:] + tn = n[key[pos]] + pos++ case nil: - return nil + return nil, nil case hashNode: - tn = t.resolveHash(n) + var err error + tn, err = t.resolveHash(n, key[:pos], key[pos:]) + if err != nil { + return nil, err + } default: panic(fmt.Sprintf("%T: invalid node: %v", tn, tn)) } } - return tn.(valueNode) + return tn.(valueNode), nil } // Update associates key with value in the trie. Subsequent calls to @@ -131,17 +153,40 @@ func (t *Trie) Get(key []byte) []byte { // The value bytes must not be modified by the caller while they are // stored in the trie. func (t *Trie) Update(key, value []byte) { + if err := t.TryUpdate(key, value); err != nil && glog.V(logger.Error) { + glog.Errorf("Unhandled trie error: %v", err) + } +} + +// TryUpdate associates key with value in the trie. Subsequent calls to +// Get will return value. If value has length zero, any existing value +// is deleted from the trie and calls to Get will return nil. +// +// The value bytes must not be modified by the caller while they are +// stored in the trie. +// +// If a node was not found in the database, a MissingNodeError is returned. +func (t *Trie) TryUpdate(key, value []byte) error { k := compactHexDecode(key) if len(value) != 0 { - t.root = t.insert(t.root, k, valueNode(value)) + n, err := t.insert(t.root, nil, k, valueNode(value)) + if err != nil { + return err + } + t.root = n } else { - t.root = t.delete(t.root, k) + n, err := t.delete(t.root, nil, k) + if err != nil { + return err + } + t.root = n } + return nil } -func (t *Trie) insert(n node, key []byte, value node) node { +func (t *Trie) insert(n node, prefix, key []byte, value node) (node, error) { if len(key) == 0 { - return value + return value, nil } switch n := n.(type) { case shortNode: @@ -149,25 +194,40 @@ func (t *Trie) insert(n node, key []byte, value node) node { // If the whole key matches, keep this short node as is // and only update the value. if matchlen == len(n.Key) { - return shortNode{n.Key, t.insert(n.Val, key[matchlen:], value)} + nn, err := t.insert(n.Val, append(prefix, key[:matchlen]...), key[matchlen:], value) + if err != nil { + return nil, err + } + return shortNode{n.Key, nn}, nil } // Otherwise branch out at the index where they differ. var branch fullNode - branch[n.Key[matchlen]] = t.insert(nil, n.Key[matchlen+1:], n.Val) - branch[key[matchlen]] = t.insert(nil, key[matchlen+1:], value) + var err error + branch[n.Key[matchlen]], err = t.insert(nil, append(prefix, n.Key[:matchlen+1]...), n.Key[matchlen+1:], n.Val) + if err != nil { + return nil, err + } + branch[key[matchlen]], err = t.insert(nil, append(prefix, key[:matchlen+1]...), key[matchlen+1:], value) + if err != nil { + return nil, err + } // Replace this shortNode with the branch if it occurs at index 0. if matchlen == 0 { - return branch + return branch, nil } // Otherwise, replace it with a short node leading up to the branch. - return shortNode{key[:matchlen], branch} + return shortNode{key[:matchlen], branch}, nil case fullNode: - n[key[0]] = t.insert(n[key[0]], key[1:], value) - return n + nn, err := t.insert(n[key[0]], append(prefix, key[0]), key[1:], value) + if err != nil { + return nil, err + } + n[key[0]] = nn + return n, nil case nil: - return shortNode{key, value} + return shortNode{key, value}, nil case hashNode: // We've hit a part of the trie that isn't loaded yet. Load @@ -176,7 +236,11 @@ func (t *Trie) insert(n node, key []byte, value node) node { // // TODO: track whether insertion changed the value and keep // n as a hash node if it didn't. - return t.insert(t.resolveHash(n), key, value) + rn, err := t.resolveHash(n, prefix, key) + if err != nil { + return nil, err + } + return t.insert(rn, prefix, key, value) default: panic(fmt.Sprintf("%T: invalid node: %v", n, n)) @@ -185,28 +249,44 @@ func (t *Trie) insert(n node, key []byte, value node) node { // Delete removes any existing value for key from the trie. func (t *Trie) Delete(key []byte) { + if err := t.TryDelete(key); err != nil && glog.V(logger.Error) { + glog.Errorf("Unhandled trie error: %v", err) + } +} + +// TryDelete removes any existing value for key from the trie. +// If a node was not found in the database, a MissingNodeError is returned. +func (t *Trie) TryDelete(key []byte) error { k := compactHexDecode(key) - t.root = t.delete(t.root, k) + n, err := t.delete(t.root, nil, k) + if err != nil { + return err + } + t.root = n + return nil } // delete returns the new root of the trie with key deleted. // It reduces the trie to minimal form by simplifying // nodes on the way up after deleting recursively. -func (t *Trie) delete(n node, key []byte) node { +func (t *Trie) delete(n node, prefix, key []byte) (node, error) { switch n := n.(type) { case shortNode: matchlen := prefixLen(key, n.Key) if matchlen < len(n.Key) { - return n // don't replace n on mismatch + return n, nil // don't replace n on mismatch } if matchlen == len(key) { - return nil // remove n entirely for whole matches + return nil, nil // remove n entirely for whole matches } // The key is longer than n.Key. Remove the remaining suffix // from the subtrie. Child can never be nil here since the // subtrie must contain at least two other values with keys // longer than n.Key. - child := t.delete(n.Val, key[len(n.Key):]) + child, err := t.delete(n.Val, append(prefix, key[:len(n.Key)]...), key[len(n.Key):]) + if err != nil { + return nil, err + } switch child := child.(type) { case shortNode: // Deleting from the subtrie reduced it to another @@ -215,13 +295,17 @@ func (t *Trie) delete(n node, key []byte) node { // always creates a new slice) instead of append to // avoid modifying n.Key since it might be shared with // other nodes. - return shortNode{concat(n.Key, child.Key...), child.Val} + return shortNode{concat(n.Key, child.Key...), child.Val}, nil default: - return shortNode{n.Key, child} + return shortNode{n.Key, child}, nil } case fullNode: - n[key[0]] = t.delete(n[key[0]], key[1:]) + nn, err := t.delete(n[key[0]], append(prefix, key[0]), key[1:]) + if err != nil { + return nil, err + } + n[key[0]] = nn // Check how many non-nil entries are left after deleting and // reduce the full node to a short node if only one entry is // left. Since n must've contained at least two children @@ -250,21 +334,24 @@ func (t *Trie) delete(n node, key []byte) node { // shortNode{..., shortNode{...}}. Since the entry // might not be loaded yet, resolve it just for this // check. - cnode := t.resolve(n[pos]) + cnode, err := t.resolve(n[pos], prefix, []byte{byte(pos)}) + if err != nil { + return nil, err + } if cnode, ok := cnode.(shortNode); ok { k := append([]byte{byte(pos)}, cnode.Key...) - return shortNode{k, cnode.Val} + return shortNode{k, cnode.Val}, nil } } // Otherwise, n is replaced by a one-nibble short node // containing the child. - return shortNode{[]byte{byte(pos)}, n[pos]} + return shortNode{[]byte{byte(pos)}, n[pos]}, nil } // n still contains at least two values and cannot be reduced. - return n + return n, nil case nil: - return nil + return nil, nil case hashNode: // We've hit a part of the trie that isn't loaded yet. Load @@ -273,7 +360,11 @@ func (t *Trie) delete(n node, key []byte) node { // // TODO: track whether deletion actually hit a key and keep // n as a hash node if it didn't. - return t.delete(t.resolveHash(n), key) + rn, err := t.resolveHash(n, prefix, key) + if err != nil { + return nil, err + } + return t.delete(rn, prefix, key) default: panic(fmt.Sprintf("%T: invalid node: %v (%v)", n, n, key)) @@ -287,34 +378,31 @@ func concat(s1 []byte, s2 ...byte) []byte { return r } -func (t *Trie) resolve(n node) node { +func (t *Trie) resolve(n node, prefix, suffix []byte) (node, error) { if n, ok := n.(hashNode); ok { - return t.resolveHash(n) + return t.resolveHash(n, prefix, suffix) } - return n + return n, nil } -func (t *Trie) resolveHash(n hashNode) node { +func (t *Trie) resolveHash(n hashNode, prefix, suffix []byte) (node, error) { if v, ok := globalCache.Get(n); ok { - return v + return v, nil } enc, err := t.db.Get(n) if err != nil || enc == nil { - // TODO: This needs to be improved to properly distinguish errors. - // Disk I/O errors shouldn't produce nil (and cause a - // consensus failure or weird crash), but it is unclear how - // they could be handled because the entire stack above the trie isn't - // prepared to cope with missing state nodes. - if glog.V(logger.Error) { - glog.Errorf("Dangling hash node ref %x: %v", n, err) + return nil, &MissingNodeError{ + RootHash: t.originalRoot, + NodeHash: common.BytesToHash(n), + KeyPrefix: prefix, + KeySuffix: suffix, } - return nil } dec := mustDecodeNode(n, enc) if dec != nil { globalCache.Put(n, dec) } - return dec + return dec, nil } // Root returns the root hash of the trie. diff --git a/trie/trie_test.go b/trie/trie_test.go index c96861bed..35d043cdf 100644 --- a/trie/trie_test.go +++ b/trie/trie_test.go @@ -64,11 +64,84 @@ func TestMissingRoot(t *testing.T) { if trie != nil { t.Error("New returned non-nil trie for invalid root") } - if err != ErrMissingRoot { + if _, ok := err.(*MissingNodeError); !ok { t.Error("New returned wrong error: %v", err) } } +func TestMissingNode(t *testing.T) { + db, _ := ethdb.NewMemDatabase() + trie, _ := New(common.Hash{}, db) + updateString(trie, "120000", "qwerqwerqwerqwerqwerqwerqwerqwer") + updateString(trie, "123456", "asdfasdfasdfasdfasdfasdfasdfasdf") + root, _ := trie.Commit() + + ClearGlobalCache() + + trie, _ = New(root, db) + _, err := trie.TryGet([]byte("120000")) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + trie, _ = New(root, db) + _, err = trie.TryGet([]byte("120099")) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + trie, _ = New(root, db) + _, err = trie.TryGet([]byte("123456")) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + trie, _ = New(root, db) + err = trie.TryUpdate([]byte("120099"), []byte("zxcvzxcvzxcvzxcvzxcvzxcvzxcvzxcv")) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + trie, _ = New(root, db) + err = trie.TryDelete([]byte("123456")) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + db.Delete(common.FromHex("e1d943cc8f061a0c0b98162830b970395ac9315654824bf21b73b891365262f9")) + ClearGlobalCache() + + trie, _ = New(root, db) + _, err = trie.TryGet([]byte("120000")) + if _, ok := err.(*MissingNodeError); !ok { + t.Errorf("Wrong error: %v", err) + } + + trie, _ = New(root, db) + _, err = trie.TryGet([]byte("120099")) + if _, ok := err.(*MissingNodeError); !ok { + t.Errorf("Wrong error: %v", err) + } + + trie, _ = New(root, db) + _, err = trie.TryGet([]byte("123456")) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + trie, _ = New(root, db) + err = trie.TryUpdate([]byte("120099"), []byte("zxcv")) + if _, ok := err.(*MissingNodeError); !ok { + t.Errorf("Wrong error: %v", err) + } + + trie, _ = New(root, db) + err = trie.TryDelete([]byte("123456")) + if _, ok := err.(*MissingNodeError); !ok { + t.Errorf("Wrong error: %v", err) + } +} + func TestInsert(t *testing.T) { trie := newEmpty() diff --git a/whisper/message_test.go b/whisper/message_test.go index 6ff95efff..d70da40a4 100644 --- a/whisper/message_test.go +++ b/whisper/message_test.go @@ -23,6 +23,7 @@ import ( "time" "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/crypto/secp256k1" ) // Tests whether a message can be wrapped without any identity or encryption. @@ -72,8 +73,8 @@ func TestMessageCleartextSignRecover(t *testing.T) { if pubKey == nil { t.Fatalf("failed to recover public key") } - p1 := elliptic.Marshal(crypto.S256(), key.PublicKey.X, key.PublicKey.Y) - p2 := elliptic.Marshal(crypto.S256(), pubKey.X, pubKey.Y) + p1 := elliptic.Marshal(secp256k1.S256(), key.PublicKey.X, key.PublicKey.Y) + p2 := elliptic.Marshal(secp256k1.S256(), pubKey.X, pubKey.Y) if !bytes.Equal(p1, p2) { t.Fatalf("public key mismatch: have 0x%x, want 0x%x", p2, p1) } @@ -150,8 +151,8 @@ func TestMessageFullCrypto(t *testing.T) { if pubKey == nil { t.Fatalf("failed to recover public key") } - p1 := elliptic.Marshal(crypto.S256(), fromKey.PublicKey.X, fromKey.PublicKey.Y) - p2 := elliptic.Marshal(crypto.S256(), pubKey.X, pubKey.Y) + p1 := elliptic.Marshal(secp256k1.S256(), fromKey.PublicKey.X, fromKey.PublicKey.Y) + p2 := elliptic.Marshal(secp256k1.S256(), pubKey.X, pubKey.Y) if !bytes.Equal(p1, p2) { t.Fatalf("public key mismatch: have 0x%x, want 0x%x", p2, p1) } |