diff options
Diffstat (limited to 'accounts/keystore')
-rw-r--r-- | accounts/keystore/account_cache.go (renamed from accounts/keystore/address_cache.go) | 59 | ||||
-rw-r--r-- | accounts/keystore/account_cache_test.go (renamed from accounts/keystore/address_cache_test.go) | 53 | ||||
-rw-r--r-- | accounts/keystore/keystore.go | 178 | ||||
-rw-r--r-- | accounts/keystore/keystore_test.go | 156 | ||||
-rw-r--r-- | accounts/keystore/keystore_wallet.go | 133 | ||||
-rw-r--r-- | accounts/keystore/watch.go | 4 | ||||
-rw-r--r-- | accounts/keystore/watch_fallback.go | 2 |
7 files changed, 519 insertions, 66 deletions
diff --git a/accounts/keystore/address_cache.go b/accounts/keystore/account_cache.go index eb3e3263b..cc8626afc 100644 --- a/accounts/keystore/address_cache.go +++ b/accounts/keystore/account_cache.go @@ -39,11 +39,11 @@ import ( // exist yet, the code will attempt to create a watcher at most this often. const minReloadInterval = 2 * time.Second -type accountsByFile []accounts.Account +type accountsByURL []accounts.Account -func (s accountsByFile) Len() int { return len(s) } -func (s accountsByFile) Less(i, j int) bool { return s[i].URL < s[j].URL } -func (s accountsByFile) Swap(i, j int) { s[i], s[j] = s[j], s[i] } +func (s accountsByURL) Len() int { return len(s) } +func (s accountsByURL) Less(i, j int) bool { return s[i].URL < s[j].URL } +func (s accountsByURL) Swap(i, j int) { s[i], s[j] = s[j], s[i] } // AmbiguousAddrError is returned when attempting to unlock // an address for which more than one file exists. @@ -63,26 +63,28 @@ func (err *AmbiguousAddrError) Error() string { return fmt.Sprintf("multiple keys match address (%s)", files) } -// addressCache is a live index of all accounts in the keystore. -type addressCache struct { +// accountCache is a live index of all accounts in the keystore. +type accountCache struct { keydir string watcher *watcher mu sync.Mutex - all accountsByFile + all accountsByURL byAddr map[common.Address][]accounts.Account throttle *time.Timer + notify chan struct{} } -func newAddrCache(keydir string) *addressCache { - ac := &addressCache{ +func newAccountCache(keydir string) (*accountCache, chan struct{}) { + ac := &accountCache{ keydir: keydir, byAddr: make(map[common.Address][]accounts.Account), + notify: make(chan struct{}, 1), } ac.watcher = newWatcher(ac) - return ac + return ac, ac.notify } -func (ac *addressCache) accounts() []accounts.Account { +func (ac *accountCache) accounts() []accounts.Account { ac.maybeReload() ac.mu.Lock() defer ac.mu.Unlock() @@ -91,14 +93,14 @@ func (ac *addressCache) accounts() []accounts.Account { return cpy } -func (ac *addressCache) hasAddress(addr common.Address) bool { +func (ac *accountCache) hasAddress(addr common.Address) bool { ac.maybeReload() ac.mu.Lock() defer ac.mu.Unlock() return len(ac.byAddr[addr]) > 0 } -func (ac *addressCache) add(newAccount accounts.Account) { +func (ac *accountCache) add(newAccount accounts.Account) { ac.mu.Lock() defer ac.mu.Unlock() @@ -111,18 +113,28 @@ func (ac *addressCache) add(newAccount accounts.Account) { copy(ac.all[i+1:], ac.all[i:]) ac.all[i] = newAccount ac.byAddr[newAccount.Address] = append(ac.byAddr[newAccount.Address], newAccount) + + select { + case ac.notify <- struct{}{}: + default: + } } // note: removed needs to be unique here (i.e. both File and Address must be set). -func (ac *addressCache) delete(removed accounts.Account) { +func (ac *accountCache) delete(removed accounts.Account) { ac.mu.Lock() defer ac.mu.Unlock() + ac.all = removeAccount(ac.all, removed) if ba := removeAccount(ac.byAddr[removed.Address], removed); len(ba) == 0 { delete(ac.byAddr, removed.Address) } else { ac.byAddr[removed.Address] = ba } + select { + case ac.notify <- struct{}{}: + default: + } } func removeAccount(slice []accounts.Account, elem accounts.Account) []accounts.Account { @@ -137,7 +149,7 @@ func removeAccount(slice []accounts.Account, elem accounts.Account) []accounts.A // find returns the cached account for address if there is a unique match. // The exact matching rules are explained by the documentation of accounts.Account. // Callers must hold ac.mu. -func (ac *addressCache) find(a accounts.Account) (accounts.Account, error) { +func (ac *accountCache) find(a accounts.Account) (accounts.Account, error) { // Limit search to address candidates if possible. matches := ac.all if (a.Address != common.Address{}) { @@ -169,9 +181,10 @@ func (ac *addressCache) find(a accounts.Account) (accounts.Account, error) { } } -func (ac *addressCache) maybeReload() { +func (ac *accountCache) maybeReload() { ac.mu.Lock() defer ac.mu.Unlock() + if ac.watcher.running { return // A watcher is running and will keep the cache up-to-date. } @@ -189,18 +202,22 @@ func (ac *addressCache) maybeReload() { ac.throttle.Reset(minReloadInterval) } -func (ac *addressCache) close() { +func (ac *accountCache) close() { ac.mu.Lock() ac.watcher.close() if ac.throttle != nil { ac.throttle.Stop() } + if ac.notify != nil { + close(ac.notify) + ac.notify = nil + } ac.mu.Unlock() } // reload caches addresses of existing accounts. // Callers must hold ac.mu. -func (ac *addressCache) reload() { +func (ac *accountCache) reload() { accounts, err := ac.scan() if err != nil && glog.V(logger.Debug) { glog.Errorf("can't load keys: %v", err) @@ -213,10 +230,14 @@ func (ac *addressCache) reload() { for _, a := range accounts { ac.byAddr[a.Address] = append(ac.byAddr[a.Address], a) } + select { + case ac.notify <- struct{}{}: + default: + } glog.V(logger.Debug).Infof("reloaded keys, cache has %d accounts", len(ac.all)) } -func (ac *addressCache) scan() ([]accounts.Account, error) { +func (ac *accountCache) scan() ([]accounts.Account, error) { files, err := ioutil.ReadDir(ac.keydir) if err != nil { return nil, err diff --git a/accounts/keystore/address_cache_test.go b/accounts/keystore/account_cache_test.go index 68af74338..ea6f7d011 100644 --- a/accounts/keystore/address_cache_test.go +++ b/accounts/keystore/account_cache_test.go @@ -53,11 +53,11 @@ var ( func TestWatchNewFile(t *testing.T) { t.Parallel() - dir, am := tmpKeyStore(t, false) + dir, ks := tmpKeyStore(t, false) defer os.RemoveAll(dir) // Ensure the watcher is started before adding any files. - am.Accounts() + ks.Accounts() time.Sleep(200 * time.Millisecond) // Move in the files. @@ -71,11 +71,17 @@ func TestWatchNewFile(t *testing.T) { } } - // am should see the accounts. + // ks should see the accounts. var list []accounts.Account for d := 200 * time.Millisecond; d < 5*time.Second; d *= 2 { - list = am.Accounts() + list = ks.Accounts() if reflect.DeepEqual(list, wantAccounts) { + // ks should have also received change notifications + select { + case <-ks.changes: + default: + t.Fatalf("wasn't notified of new accounts") + } return } time.Sleep(d) @@ -86,12 +92,12 @@ func TestWatchNewFile(t *testing.T) { func TestWatchNoDir(t *testing.T) { t.Parallel() - // Create am but not the directory that it watches. + // Create ks but not the directory that it watches. rand.Seed(time.Now().UnixNano()) dir := filepath.Join(os.TempDir(), fmt.Sprintf("eth-keystore-watch-test-%d-%d", os.Getpid(), rand.Int())) - am := NewKeyStore(dir, LightScryptN, LightScryptP) + ks := NewKeyStore(dir, LightScryptN, LightScryptP) - list := am.Accounts() + list := ks.Accounts() if len(list) > 0 { t.Error("initial account list not empty:", list) } @@ -105,12 +111,18 @@ func TestWatchNoDir(t *testing.T) { t.Fatal(err) } - // am should see the account. + // ks should see the account. wantAccounts := []accounts.Account{cachetestAccounts[0]} wantAccounts[0].URL = file for d := 200 * time.Millisecond; d < 8*time.Second; d *= 2 { - list = am.Accounts() + list = ks.Accounts() if reflect.DeepEqual(list, wantAccounts) { + // ks should have also received change notifications + select { + case <-ks.changes: + default: + t.Fatalf("wasn't notified of new accounts") + } return } time.Sleep(d) @@ -119,7 +131,7 @@ func TestWatchNoDir(t *testing.T) { } func TestCacheInitialReload(t *testing.T) { - cache := newAddrCache(cachetestDir) + cache, _ := newAccountCache(cachetestDir) accounts := cache.accounts() if !reflect.DeepEqual(accounts, cachetestAccounts) { t.Fatalf("got initial accounts: %swant %s", spew.Sdump(accounts), spew.Sdump(cachetestAccounts)) @@ -127,7 +139,7 @@ func TestCacheInitialReload(t *testing.T) { } func TestCacheAddDeleteOrder(t *testing.T) { - cache := newAddrCache("testdata/no-such-dir") + cache, notify := newAccountCache("testdata/no-such-dir") cache.watcher.running = true // prevent unexpected reloads accs := []accounts.Account{ @@ -163,14 +175,24 @@ func TestCacheAddDeleteOrder(t *testing.T) { for _, a := range accs { cache.add(a) } + select { + case <-notify: + default: + t.Fatalf("notifications didn't fire for adding new accounts") + } // Add some of them twice to check that they don't get reinserted. cache.add(accs[0]) cache.add(accs[2]) + select { + case <-notify: + t.Fatalf("notifications fired for adding existing accounts") + default: + } // Check that the account list is sorted by filename. wantAccounts := make([]accounts.Account, len(accs)) copy(wantAccounts, accs) - sort.Sort(accountsByFile(wantAccounts)) + sort.Sort(accountsByURL(wantAccounts)) list := cache.accounts() if !reflect.DeepEqual(list, wantAccounts) { t.Fatalf("got accounts: %s\nwant %s", spew.Sdump(accs), spew.Sdump(wantAccounts)) @@ -190,6 +212,11 @@ func TestCacheAddDeleteOrder(t *testing.T) { } cache.delete(accounts.Account{Address: common.HexToAddress("fd9bd350f08ee3c0c19b85a8e16114a11a60aa4e"), URL: "something"}) + select { + case <-notify: + default: + t.Fatalf("notifications didn't fire for deleting accounts") + } // Check content again after deletion. wantAccountsAfterDelete := []accounts.Account{ wantAccounts[1], @@ -212,7 +239,7 @@ func TestCacheAddDeleteOrder(t *testing.T) { func TestCacheFind(t *testing.T) { dir := filepath.Join("testdata", "dir") - cache := newAddrCache(dir) + cache, _ := newAccountCache(dir) cache.watcher.running = true // prevent unexpected reloads accs := []accounts.Account{ diff --git a/accounts/keystore/keystore.go b/accounts/keystore/keystore.go index d125f7d62..ce4e87ce9 100644 --- a/accounts/keystore/keystore.go +++ b/accounts/keystore/keystore.go @@ -37,23 +37,34 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/event" ) var ( - ErrNeedPasswordOrUnlock = accounts.NewAuthNeededError("password or unlock") - ErrNoMatch = errors.New("no key for given address or file") - ErrDecrypt = errors.New("could not decrypt key with given passphrase") + ErrLocked = accounts.NewAuthNeededError("password or unlock") + ErrNoMatch = errors.New("no key for given address or file") + ErrDecrypt = errors.New("could not decrypt key with given passphrase") ) -// BackendType can be used to query the account manager for encrypted keystores. -var BackendType = reflect.TypeOf(new(KeyStore)) +// KeyStoreType is the reflect type of a keystore backend. +var KeyStoreType = reflect.TypeOf(&KeyStore{}) + +// Maximum time between wallet refreshes (if filesystem notifications don't work). +const walletRefreshCycle = 3 * time.Second // KeyStore manages a key storage directory on disk. type KeyStore struct { - cache *addressCache - keyStore keyStore - mu sync.RWMutex - unlocked map[common.Address]*unlocked + storage keyStore // Storage backend, might be cleartext or encrypted + cache *accountCache // In-memory account cache over the filesystem storage + changes chan struct{} // Channel receiving change notifications from the cache + unlocked map[common.Address]*unlocked // Currently unlocked account (decrypted private keys) + + wallets []accounts.Wallet // Wallet wrappers around the individual key files + updateFeed event.Feed // Event feed to notify wallet additions/removals + updateScope event.SubscriptionScope // Subscription scope tracking current live listeners + updating bool // Whether the event notification loop is running + + mu sync.RWMutex } type unlocked struct { @@ -64,7 +75,7 @@ type unlocked struct { // NewKeyStore creates a keystore for the given directory. func NewKeyStore(keydir string, scryptN, scryptP int) *KeyStore { keydir, _ = filepath.Abs(keydir) - ks := &KeyStore{keyStore: &keyStorePassphrase{keydir, scryptN, scryptP}} + ks := &KeyStore{storage: &keyStorePassphrase{keydir, scryptN, scryptP}} ks.init(keydir) return ks } @@ -73,20 +84,136 @@ func NewKeyStore(keydir string, scryptN, scryptP int) *KeyStore { // Deprecated: Use NewKeyStore. func NewPlaintextKeyStore(keydir string) *KeyStore { keydir, _ = filepath.Abs(keydir) - ks := &KeyStore{keyStore: &keyStorePlain{keydir}} + ks := &KeyStore{storage: &keyStorePlain{keydir}} ks.init(keydir) return ks } func (ks *KeyStore) init(keydir string) { + // Lock the mutex since the account cache might call back with events + ks.mu.Lock() + defer ks.mu.Unlock() + + // Initialize the set of unlocked keys and the account cache ks.unlocked = make(map[common.Address]*unlocked) - ks.cache = newAddrCache(keydir) + ks.cache, ks.changes = newAccountCache(keydir) + // TODO: In order for this finalizer to work, there must be no references // to ks. addressCache doesn't keep a reference but unlocked keys do, // so the finalizer will not trigger until all timed unlocks have expired. runtime.SetFinalizer(ks, func(m *KeyStore) { m.cache.close() }) + // Create the initial list of wallets from the cache + accs := ks.cache.accounts() + ks.wallets = make([]accounts.Wallet, len(accs)) + for i := 0; i < len(accs); i++ { + ks.wallets[i] = &keystoreWallet{account: accs[i], keystore: ks} + } +} + +// Wallets implements accounts.Backend, returning all single-key wallets from the +// keystore directory. +func (ks *KeyStore) Wallets() []accounts.Wallet { + // Make sure the list of wallets is in sync with the account cache + ks.refreshWallets() + + ks.mu.RLock() + defer ks.mu.RUnlock() + + cpy := make([]accounts.Wallet, len(ks.wallets)) + copy(cpy, ks.wallets) + return cpy +} + +// refreshWallets retrieves the current account list and based on that does any +// necessary wallet refreshes. +func (ks *KeyStore) refreshWallets() { + // Retrieve the current list of accounts + accs := ks.cache.accounts() + + // Transform the current list of wallets into the new one + ks.mu.Lock() + + wallets := make([]accounts.Wallet, 0, len(accs)) + events := []accounts.WalletEvent{} + + for _, account := range accs { + // Drop wallets while they were in front of the next account + for len(ks.wallets) > 0 && ks.wallets[0].URL() < account.URL { + events = append(events, accounts.WalletEvent{Wallet: ks.wallets[0], Arrive: false}) + ks.wallets = ks.wallets[1:] + } + // If there are no more wallets or the account is before the next, wrap new wallet + if len(ks.wallets) == 0 || ks.wallets[0].URL() > account.URL { + wallet := &keystoreWallet{account: account, keystore: ks} + + events = append(events, accounts.WalletEvent{Wallet: wallet, Arrive: true}) + wallets = append(wallets, wallet) + continue + } + // If the account is the same as the first wallet, keep it + if ks.wallets[0].Accounts()[0] == account { + wallets = append(wallets, ks.wallets[0]) + ks.wallets = ks.wallets[1:] + continue + } + } + // Drop any leftover wallets and set the new batch + for _, wallet := range ks.wallets { + events = append(events, accounts.WalletEvent{Wallet: wallet, Arrive: false}) + } + ks.wallets = wallets + ks.mu.Unlock() + + // Fire all wallet events and return + for _, event := range events { + ks.updateFeed.Send(event) + } +} + +// Subscribe implements accounts.Backend, creating an async subscription to +// receive notifications on the addition or removal of keystore wallets. +func (ks *KeyStore) Subscribe(sink chan<- accounts.WalletEvent) event.Subscription { + // We need the mutex to reliably start/stop the update loop + ks.mu.Lock() + defer ks.mu.Unlock() + + // Subscribe the caller and track the subscriber count + sub := ks.updateScope.Track(ks.updateFeed.Subscribe(sink)) + + // Subscribers require an active notification loop, start it + if !ks.updating { + ks.updating = true + go ks.updater() + } + return sub +} + +// updater is responsible for maintaining an up-to-date list of wallets stored in +// the keystore, and for firing wallet addition/removal events. It listens for +// account change events from the underlying account cache, and also periodically +// forces a manual refresh (only triggers for systems where the filesystem notifier +// is not running). +func (ks *KeyStore) updater() { + for { + // Wait for an account update or a refresh timeout + select { + case <-ks.changes: + case <-time.After(walletRefreshCycle): + } + // Run the wallet refresher + ks.refreshWallets() + + // If all our subscribers left, stop the updater + ks.mu.Lock() + if ks.updateScope.Count() == 0 { + ks.updating = false + ks.mu.Unlock() + return + } + ks.mu.Unlock() + } } // HasAddress reports whether a key with the given address is present. @@ -118,6 +245,7 @@ func (ks *KeyStore) Delete(a accounts.Account, passphrase string) error { err = os.Remove(a.URL) if err == nil { ks.cache.delete(a) + ks.refreshWallets() } return err } @@ -131,7 +259,7 @@ func (ks *KeyStore) SignHash(a accounts.Account, hash []byte) ([]byte, error) { unlockedKey, found := ks.unlocked[a.Address] if !found { - return nil, ErrNeedPasswordOrUnlock + return nil, ErrLocked } // Sign the hash using plain ECDSA operations return crypto.Sign(hash, unlockedKey.PrivateKey) @@ -145,7 +273,7 @@ func (ks *KeyStore) SignTx(a accounts.Account, tx *types.Transaction, chainID *b unlockedKey, found := ks.unlocked[a.Address] if !found { - return nil, ErrNeedPasswordOrUnlock + return nil, ErrLocked } // Depending on the presence of the chain ID, sign with EIP155 or homestead if chainID != nil { @@ -221,10 +349,9 @@ func (ks *KeyStore) TimedUnlock(a accounts.Account, passphrase string, timeout t // it with a timeout would be confusing. zeroKey(key.PrivateKey) return nil - } else { - // Terminate the expire goroutine and replace it below. - close(u.abort) } + // Terminate the expire goroutine and replace it below. + close(u.abort) } if timeout > 0 { u = &unlocked{Key: key, abort: make(chan struct{})} @@ -250,7 +377,7 @@ func (ks *KeyStore) getDecryptedKey(a accounts.Account, auth string) (accounts.A if err != nil { return a, nil, err } - key, err := ks.keyStore.GetKey(a.Address, a.URL, auth) + key, err := ks.storage.GetKey(a.Address, a.URL, auth) return a, key, err } @@ -277,13 +404,14 @@ func (ks *KeyStore) expire(addr common.Address, u *unlocked, timeout time.Durati // NewAccount generates a new key and stores it into the key directory, // encrypting it with the passphrase. func (ks *KeyStore) NewAccount(passphrase string) (accounts.Account, error) { - _, account, err := storeNewKey(ks.keyStore, crand.Reader, passphrase) + _, account, err := storeNewKey(ks.storage, crand.Reader, passphrase) if err != nil { return accounts.Account{}, err } // Add the account to the cache immediately rather // than waiting for file system notifications to pick it up. ks.cache.add(account) + ks.refreshWallets() return account, nil } @@ -294,7 +422,7 @@ func (ks *KeyStore) Export(a accounts.Account, passphrase, newPassphrase string) return nil, err } var N, P int - if store, ok := ks.keyStore.(*keyStorePassphrase); ok { + if store, ok := ks.storage.(*keyStorePassphrase); ok { N, P = store.scryptN, store.scryptP } else { N, P = StandardScryptN, StandardScryptP @@ -325,11 +453,12 @@ func (ks *KeyStore) ImportECDSA(priv *ecdsa.PrivateKey, passphrase string) (acco } func (ks *KeyStore) importKey(key *Key, passphrase string) (accounts.Account, error) { - a := accounts.Account{Address: key.Address, URL: ks.keyStore.JoinPath(keyFileName(key.Address))} - if err := ks.keyStore.StoreKey(a.URL, key, passphrase); err != nil { + a := accounts.Account{Address: key.Address, URL: ks.storage.JoinPath(keyFileName(key.Address))} + if err := ks.storage.StoreKey(a.URL, key, passphrase); err != nil { return accounts.Account{}, err } ks.cache.add(a) + ks.refreshWallets() return a, nil } @@ -339,17 +468,18 @@ func (ks *KeyStore) Update(a accounts.Account, passphrase, newPassphrase string) if err != nil { return err } - return ks.keyStore.StoreKey(a.URL, key, newPassphrase) + return ks.storage.StoreKey(a.URL, key, newPassphrase) } // ImportPreSaleKey decrypts the given Ethereum presale wallet and stores // a key file in the key directory. The key file is encrypted with the same passphrase. func (ks *KeyStore) ImportPreSaleKey(keyJSON []byte, passphrase string) (accounts.Account, error) { - a, _, err := importPreSaleKey(ks.keyStore, keyJSON, passphrase) + a, _, err := importPreSaleKey(ks.storage, keyJSON, passphrase) if err != nil { return a, err } ks.cache.add(a) + ks.refreshWallets() return a, nil } diff --git a/accounts/keystore/keystore_test.go b/accounts/keystore/keystore_test.go index af2140c31..6b7170a2f 100644 --- a/accounts/keystore/keystore_test.go +++ b/accounts/keystore/keystore_test.go @@ -18,14 +18,17 @@ package keystore import ( "io/ioutil" + "math/rand" "os" "runtime" + "sort" "strings" "testing" "time" "github.com/ethereum/go-ethereum/accounts" "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/event" ) var testSigData = make([]byte, 32) @@ -122,8 +125,8 @@ func TestTimedUnlock(t *testing.T) { // Signing without passphrase fails because account is locked _, err = ks.SignHash(accounts.Account{Address: a1.Address}, testSigData) - if err != ErrNeedPasswordOrUnlock { - t.Fatal("Signing should've failed with ErrNeedPasswordOrUnlock before unlocking, got ", err) + if err != ErrLocked { + t.Fatal("Signing should've failed with ErrLocked before unlocking, got ", err) } // Signing with passphrase works @@ -140,8 +143,8 @@ func TestTimedUnlock(t *testing.T) { // Signing fails again after automatic locking time.Sleep(250 * time.Millisecond) _, err = ks.SignHash(accounts.Account{Address: a1.Address}, testSigData) - if err != ErrNeedPasswordOrUnlock { - t.Fatal("Signing should've failed with ErrNeedPasswordOrUnlock timeout expired, got ", err) + if err != ErrLocked { + t.Fatal("Signing should've failed with ErrLocked timeout expired, got ", err) } } @@ -180,8 +183,8 @@ func TestOverrideUnlock(t *testing.T) { // Signing fails again after automatic locking time.Sleep(250 * time.Millisecond) _, err = ks.SignHash(accounts.Account{Address: a1.Address}, testSigData) - if err != ErrNeedPasswordOrUnlock { - t.Fatal("Signing should've failed with ErrNeedPasswordOrUnlock timeout expired, got ", err) + if err != ErrLocked { + t.Fatal("Signing should've failed with ErrLocked timeout expired, got ", err) } } @@ -201,7 +204,7 @@ func TestSignRace(t *testing.T) { } end := time.Now().Add(500 * time.Millisecond) for time.Now().Before(end) { - if _, err := ks.SignHash(accounts.Account{Address: a1.Address}, testSigData); err == ErrNeedPasswordOrUnlock { + if _, err := ks.SignHash(accounts.Account{Address: a1.Address}, testSigData); err == ErrLocked { return } else if err != nil { t.Errorf("Sign error: %v", err) @@ -212,6 +215,145 @@ func TestSignRace(t *testing.T) { t.Errorf("Account did not lock within the timeout") } +// Tests that the wallet notifier loop starts and stops correctly based on the +// addition and removal of wallet event subscriptions. +func TestWalletNotifierLifecycle(t *testing.T) { + // Create a temporary kesytore to test with + dir, ks := tmpKeyStore(t, false) + defer os.RemoveAll(dir) + + // Ensure that the notification updater is not running yet + time.Sleep(250 * time.Millisecond) + ks.mu.RLock() + updating := ks.updating + ks.mu.RUnlock() + + if updating { + t.Errorf("wallet notifier running without subscribers") + } + // Subscribe to the wallet feed and ensure the updater boots up + updates := make(chan accounts.WalletEvent) + + subs := make([]event.Subscription, 2) + for i := 0; i < len(subs); i++ { + // Create a new subscription + subs[i] = ks.Subscribe(updates) + + // Ensure the notifier comes online + time.Sleep(250 * time.Millisecond) + ks.mu.RLock() + updating = ks.updating + ks.mu.RUnlock() + + if !updating { + t.Errorf("sub %d: wallet notifier not running after subscription", i) + } + } + // Unsubscribe and ensure the updater terminates eventually + for i := 0; i < len(subs); i++ { + // Close an existing subscription + subs[i].Unsubscribe() + + // Ensure the notifier shuts down at and only at the last close + for k := 0; k < int(walletRefreshCycle/(250*time.Millisecond))+2; k++ { + ks.mu.RLock() + updating = ks.updating + ks.mu.RUnlock() + + if i < len(subs)-1 && !updating { + t.Fatalf("sub %d: event notifier stopped prematurely", i) + } + if i == len(subs)-1 && !updating { + return + } + time.Sleep(250 * time.Millisecond) + } + } + t.Errorf("wallet notifier didn't terminate after unsubscribe") +} + +// Tests that wallet notifications and correctly fired when accounts are added +// or deleted from the keystore. +func TestWalletNotifications(t *testing.T) { + // Create a temporary kesytore to test with + dir, ks := tmpKeyStore(t, false) + defer os.RemoveAll(dir) + + // Subscribe to the wallet feed + updates := make(chan accounts.WalletEvent, 1) + sub := ks.Subscribe(updates) + defer sub.Unsubscribe() + + // Randomly add and remove account and make sure events and wallets are in sync + live := make(map[common.Address]accounts.Account) + for i := 0; i < 1024; i++ { + // Execute a creation or deletion and ensure event arrival + if create := len(live) == 0 || rand.Int()%4 > 0; create { + // Add a new account and ensure wallet notifications arrives + account, err := ks.NewAccount("") + if err != nil { + t.Fatalf("failed to create test account: %v", err) + } + select { + case event := <-updates: + if !event.Arrive { + t.Errorf("departure event on account creation") + } + if event.Wallet.Accounts()[0] != account { + t.Errorf("account mismatch on created wallet: have %v, want %v", event.Wallet.Accounts()[0], account) + } + default: + t.Errorf("wallet arrival event not fired on account creation") + } + live[account.Address] = account + } else { + // Select a random account to delete (crude, but works) + var account accounts.Account + for _, a := range live { + account = a + break + } + // Remove an account and ensure wallet notifiaction arrives + if err := ks.Delete(account, ""); err != nil { + t.Fatalf("failed to delete test account: %v", err) + } + select { + case event := <-updates: + if event.Arrive { + t.Errorf("arrival event on account deletion") + } + if event.Wallet.Accounts()[0] != account { + t.Errorf("account mismatch on deleted wallet: have %v, want %v", event.Wallet.Accounts()[0], account) + } + default: + t.Errorf("wallet departure event not fired on account creation") + } + delete(live, account.Address) + } + // Retrieve the list of wallets and ensure it matches with our required live set + liveList := make([]accounts.Account, 0, len(live)) + for _, account := range live { + liveList = append(liveList, account) + } + sort.Sort(accountsByURL(liveList)) + + wallets := ks.Wallets() + if len(liveList) != len(wallets) { + t.Errorf("wallet list doesn't match required accounts: have %v, want %v", wallets, liveList) + } else { + for j, wallet := range wallets { + if accs := wallet.Accounts(); len(accs) != 1 { + t.Errorf("wallet %d: contains invalid number of accounts: have %d, want 1", j, len(accs)) + } else if accs[0] != liveList[j] { + t.Errorf("wallet %d: account mismatch: have %v, want %v", j, accs[0], liveList[j]) + } + } + } + // Sleep a bit to avoid same-timestamp keyfiles + time.Sleep(10 * time.Millisecond) + } +} + func tmpKeyStore(t *testing.T, encrypted bool) (string, *KeyStore) { d, err := ioutil.TempDir("", "eth-keystore-test") if err != nil { diff --git a/accounts/keystore/keystore_wallet.go b/accounts/keystore/keystore_wallet.go new file mode 100644 index 000000000..d92926478 --- /dev/null +++ b/accounts/keystore/keystore_wallet.go @@ -0,0 +1,133 @@ +// Copyright 2017 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. + +package keystore + +import ( + "math/big" + + "github.com/ethereum/go-ethereum/accounts" + "github.com/ethereum/go-ethereum/core/types" +) + +// keystoreWallet implements the accounts.Wallet interface for the original +// keystore. +type keystoreWallet struct { + account accounts.Account // Single account contained in this wallet + keystore *KeyStore // Keystore where the account originates from +} + +// Type implements accounts.Wallet, returning the textual type of the wallet. +func (w *keystoreWallet) Type() string { + return "secret-storage" +} + +// URL implements accounts.Wallet, returning the URL of the account within. +func (w *keystoreWallet) URL() string { + return w.account.URL +} + +// Status implements accounts.Wallet, always returning "open", since there is no +// concept of open/close for plain keystore accounts. +func (w *keystoreWallet) Status() string { + return "Open" +} + +// Open implements accounts.Wallet, but is a noop for plain wallets since there +// is no connection or decryption step necessary to access the list of accounts. +func (w *keystoreWallet) Open(passphrase string) error { return nil } + +// Close implements accounts.Wallet, but is a noop for plain wallets since is no +// meaningful open operation. +func (w *keystoreWallet) Close() error { return nil } + +// Accounts implements accounts.Wallet, returning an account list consisting of +// a single account that the plain kestore wallet contains. +func (w *keystoreWallet) Accounts() []accounts.Account { + return []accounts.Account{w.account} +} + +// Contains implements accounts.Wallet, returning whether a particular account is +// or is not wrapped by this wallet instance. +func (w *keystoreWallet) Contains(account accounts.Account) bool { + return account.Address == w.account.Address && (account.URL == "" || account.URL == w.account.URL) +} + +// Derive implements accounts.Wallet, but is a noop for plain wallets since there +// is no notion of hierarchical account derivation for plain keystore accounts. +func (w *keystoreWallet) Derive(path string, pin bool) (accounts.Account, error) { + return accounts.Account{}, accounts.ErrNotSupported +} + +// SignHash implements accounts.Wallet, attempting to sign the given hash with +// the given account. If the wallet does not wrap this particular account, an +// error is returned to avoid account leakage (even though in theory we may be +// able to sign via our shared keystore backend). +func (w *keystoreWallet) SignHash(account accounts.Account, hash []byte) ([]byte, error) { + // Make sure the requested account is contained within + if account.Address != w.account.Address { + return nil, accounts.ErrUnknownAccount + } + if account.URL != "" && account.URL != w.account.URL { + return nil, accounts.ErrUnknownAccount + } + // Account seems valid, request the keystore to sign + return w.keystore.SignHash(account, hash) +} + +// SignTx implements accounts.Wallet, attempting to sign the given transaction +// with the given account. If the wallet does not wrap this particular account, +// an error is returned to avoid account leakage (even though in theory we may +// be able to sign via our shared keystore backend). +func (w *keystoreWallet) SignTx(account accounts.Account, tx *types.Transaction, chainID *big.Int) (*types.Transaction, error) { + // Make sure the requested account is contained within + if account.Address != w.account.Address { + return nil, accounts.ErrUnknownAccount + } + if account.URL != "" && account.URL != w.account.URL { + return nil, accounts.ErrUnknownAccount + } + // Account seems valid, request the keystore to sign + return w.keystore.SignTx(account, tx, chainID) +} + +// SignHashWithPassphrase implements accounts.Wallet, attempting to sign the +// given hash with the given account using passphrase as extra authentication. +func (w *keystoreWallet) SignHashWithPassphrase(account accounts.Account, passphrase string, hash []byte) ([]byte, error) { + // Make sure the requested account is contained within + if account.Address != w.account.Address { + return nil, accounts.ErrUnknownAccount + } + if account.URL != "" && account.URL != w.account.URL { + return nil, accounts.ErrUnknownAccount + } + // Account seems valid, request the keystore to sign + return w.keystore.SignHashWithPassphrase(account, passphrase, hash) +} + +// SignTxWithPassphrase implements accounts.Wallet, attempting to sign the given +// transaction with the given account using passphrase as extra authentication. +func (w *keystoreWallet) SignTxWithPassphrase(account accounts.Account, passphrase string, tx *types.Transaction, chainID *big.Int) (*types.Transaction, error) { + // Make sure the requested account is contained within + if account.Address != w.account.Address { + return nil, accounts.ErrUnknownAccount + } + if account.URL != "" && account.URL != w.account.URL { + return nil, accounts.ErrUnknownAccount + } + // Account seems valid, request the keystore to sign + return w.keystore.SignTxWithPassphrase(account, passphrase, tx, chainID) +} diff --git a/accounts/keystore/watch.go b/accounts/keystore/watch.go index 04a87b12e..0b4401255 100644 --- a/accounts/keystore/watch.go +++ b/accounts/keystore/watch.go @@ -27,14 +27,14 @@ import ( ) type watcher struct { - ac *addressCache + ac *accountCache starting bool running bool ev chan notify.EventInfo quit chan struct{} } -func newWatcher(ac *addressCache) *watcher { +func newWatcher(ac *accountCache) *watcher { return &watcher{ ac: ac, ev: make(chan notify.EventInfo, 10), diff --git a/accounts/keystore/watch_fallback.go b/accounts/keystore/watch_fallback.go index 6412f3b33..7c5e9cb2e 100644 --- a/accounts/keystore/watch_fallback.go +++ b/accounts/keystore/watch_fallback.go @@ -23,6 +23,6 @@ package keystore type watcher struct{ running bool } -func newWatcher(*addressCache) *watcher { return new(watcher) } +func newWatcher(*accountCache) *watcher { return new(watcher) } func (*watcher) start() {} func (*watcher) close() {} |