diff options
-rw-r--r-- | p2p/discover/node.go | 83 | ||||
-rw-r--r-- | p2p/discover/table.go | 3 |
2 files changed, 73 insertions, 13 deletions
diff --git a/p2p/discover/node.go b/p2p/discover/node.go index 6662a6cb7..d8a5cc351 100644 --- a/p2p/discover/node.go +++ b/p2p/discover/node.go @@ -1,8 +1,10 @@ package discover import ( + "bytes" "crypto/ecdsa" "crypto/elliptic" + "encoding/binary" "encoding/hex" "errors" "fmt" @@ -11,13 +13,16 @@ import ( "math/rand" "net" "net/url" + "os" "strconv" "strings" - "sync" "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/crypto/secp256k1" "github.com/ethereum/go-ethereum/rlp" + "github.com/syndtr/goleveldb/leveldb" + "github.com/syndtr/goleveldb/leveldb/opt" + "github.com/syndtr/goleveldb/leveldb/storage" ) const nodeIDBits = 512 @@ -308,23 +313,77 @@ func randomID(a NodeID, n int) (b NodeID) { // nodeDB stores all nodes we know about. type nodeDB struct { - mu sync.RWMutex - byID map[NodeID]*Node + ldb *leveldb.DB +} + +var dbVersionKey = []byte("pv") + +// Opens the backing LevelDB. If path is "", we use an in-memory database. +func newNodeDB(path string, version int64) (db *nodeDB, err error) { + db = new(nodeDB) + opts := new(opt.Options) + if path == "" { + db.ldb, err = leveldb.Open(storage.NewMemStorage(), opts) + } else { + db.ldb, err = openLDB(path, opts, version) + } + return db, err +} + +func openLDB(path string, opts *opt.Options, version int64) (*leveldb.DB, error) { + ldb, err := leveldb.OpenFile(path, opts) + if _, iscorrupted := err.(leveldb.ErrCorrupted); iscorrupted { + ldb, err = leveldb.RecoverFile(path, opts) + } + if err != nil { + return nil, err + } + // The nodes contained in the database correspond to a certain + // protocol version. Flush all nodes if the DB version doesn't match. + // There is no need to do this for memory databases because they + // won't ever be used with a different protocol version. + shouldVal := make([]byte, binary.MaxVarintLen64) + shouldVal = shouldVal[:binary.PutVarint(shouldVal, version)] + val, err := ldb.Get(dbVersionKey, nil) + if err == leveldb.ErrNotFound { + err = ldb.Put(dbVersionKey, shouldVal, nil) + } else if err == nil && !bytes.Equal(val, shouldVal) { + // Delete and start over. + ldb.Close() + if err = os.RemoveAll(path); err != nil { + return nil, err + } + return openLDB(path, opts, version) + } + if err != nil { + ldb.Close() + ldb = nil + } + return ldb, err } func (db *nodeDB) get(id NodeID) *Node { - db.mu.RLock() - defer db.mu.RUnlock() - return db.byID[id] + v, err := db.ldb.Get(id[:], nil) + if err != nil { + return nil + } + n := new(Node) + if err := rlp.DecodeBytes(v, n); err != nil { + return nil + } + return n } -func (db *nodeDB) add(id NodeID, addr *net.UDPAddr, tcpPort uint16) *Node { - db.mu.Lock() - defer db.mu.Unlock() - if db.byID == nil { - db.byID = make(map[NodeID]*Node) +func (db *nodeDB) update(n *Node) error { + v, err := rlp.EncodeToBytes(n) + if err != nil { + return err } + return db.ldb.Put(n.ID[:], v, nil) +} + +func (db *nodeDB) add(id NodeID, addr *net.UDPAddr, tcpPort uint16) *Node { n := &Node{ID: id, IP: addr.IP, DiscPort: addr.Port, TCPPort: int(tcpPort)} - db.byID[n.ID] = n + db.update(n) return n } diff --git a/p2p/discover/table.go b/p2p/discover/table.go index e2e846456..ba2f9b8ec 100644 --- a/p2p/discover/table.go +++ b/p2p/discover/table.go @@ -59,9 +59,10 @@ type bucket struct { } func newTable(t transport, ourID NodeID, ourAddr *net.UDPAddr) *Table { + db, _ := newNodeDB("", Version) tab := &Table{ net: t, - db: new(nodeDB), + db: db, self: newNode(ourID, ourAddr), bonding: make(map[NodeID]*bondproc), bondslots: make(chan struct{}, maxBondingPingPongs), |