diff options
author | Wei-Ning Huang <w@byzantine-lab.io> | 2019-06-12 17:31:08 +0800 |
---|---|---|
committer | Wei-Ning Huang <w@byzantine-lab.io> | 2019-09-17 16:57:29 +0800 |
commit | ac088de6322fc16ebe75c2e5554be73754bf1fe2 (patch) | |
tree | 086b7827d46a4d07b834cd94be73beaabb77b734 /vendor/github.com/byzantine-lab/dexon-consensus/core/db | |
parent | 67d565f3f0e398e99bef96827f729e3e4b0edf31 (diff) | |
download | go-tangerine-ac088de6322fc16ebe75c2e5554be73754bf1fe2.tar.gz go-tangerine-ac088de6322fc16ebe75c2e5554be73754bf1fe2.tar.zst go-tangerine-ac088de6322fc16ebe75c2e5554be73754bf1fe2.zip |
Rebrand as tangerine-network/go-tangerine
Diffstat (limited to 'vendor/github.com/byzantine-lab/dexon-consensus/core/db')
3 files changed, 935 insertions, 0 deletions
diff --git a/vendor/github.com/byzantine-lab/dexon-consensus/core/db/interfaces.go b/vendor/github.com/byzantine-lab/dexon-consensus/core/db/interfaces.go new file mode 100644 index 000000000..1d15c68a0 --- /dev/null +++ b/vendor/github.com/byzantine-lab/dexon-consensus/core/db/interfaces.go @@ -0,0 +1,100 @@ +// Copyright 2018 The dexon-consensus Authors +// This file is part of the dexon-consensus library. +// +// The dexon-consensus library is free software: you can redistribute it +// and/or modify it under the terms of the GNU Lesser General Public License as +// published by the Free Software Foundation, either version 3 of the License, +// or (at your option) any later version. +// +// The dexon-consensus library is distributed in the hope that it will be +// useful, but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser +// General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the dexon-consensus library. If not, see +// <http://www.gnu.org/licenses/>. + +package db + +import ( + "errors" + "fmt" + + "github.com/byzantine-lab/dexon-consensus/common" + "github.com/byzantine-lab/dexon-consensus/core/crypto/dkg" + "github.com/byzantine-lab/dexon-consensus/core/types" +) + +var ( + // ErrBlockExists is the error when block eixsts. + ErrBlockExists = errors.New("block exists") + // ErrBlockDoesNotExist is the error when block does not eixst. + ErrBlockDoesNotExist = errors.New("block does not exist") + // ErrIterationFinished is the error to check if the iteration is finished. + ErrIterationFinished = errors.New("iteration finished") + // ErrEmptyPath is the error when the required path is empty. + ErrEmptyPath = fmt.Errorf("empty path") + // ErrClosed is the error when using DB after it's closed. + ErrClosed = fmt.Errorf("db closed") + // ErrNotImplemented is the error that some interface is not implemented. + ErrNotImplemented = fmt.Errorf("not implemented") + // ErrInvalidCompactionChainTipHeight means the newly updated height of + // the tip of compaction chain is invalid, usually means it's smaller than + // current cached one. + ErrInvalidCompactionChainTipHeight = fmt.Errorf( + "invalid compaction chain tip height") + // ErrDKGPrivateKeyExists raised when attempting to save DKG private key + // that already saved. + ErrDKGPrivateKeyExists = errors.New("dkg private key exists") + // ErrDKGPrivateKeyDoesNotExist raised when the DKG private key of the + // requested round does not exists. + ErrDKGPrivateKeyDoesNotExist = errors.New("dkg private key does not exists") + // ErrDKGProtocolExists raised when attempting to save DKG protocol + // that already saved. + ErrDKGProtocolExists = errors.New("dkg protocol exists") + // ErrDKGProtocolDoesNotExist raised when the DKG protocol of the + // requested round does not exists. + ErrDKGProtocolDoesNotExist = errors.New("dkg protocol does not exists") +) + +// Database is the interface for a Database. +type Database interface { + Reader + Writer + + // Close allows database implementation able to + // release resource when finishing. + Close() error +} + +// Reader defines the interface for reading blocks into DB. +type Reader interface { + HasBlock(hash common.Hash) bool + GetBlock(hash common.Hash) (types.Block, error) + GetAllBlocks() (BlockIterator, error) + + // GetCompactionChainTipInfo returns the block hash and finalization height + // of the tip block of compaction chain. Empty hash and zero height means + // the compaction chain is empty. + GetCompactionChainTipInfo() (common.Hash, uint64) + + // DKG Private Key related methods. + GetDKGPrivateKey(round, reset uint64) (dkg.PrivateKey, error) + GetDKGProtocol() (dkgProtocol DKGProtocolInfo, err error) +} + +// Writer defines the interface for writing blocks into DB. +type Writer interface { + UpdateBlock(block types.Block) error + PutBlock(block types.Block) error + PutCompactionChainTipInfo(common.Hash, uint64) error + PutDKGPrivateKey(round, reset uint64, pk dkg.PrivateKey) error + PutOrUpdateDKGProtocol(dkgProtocol DKGProtocolInfo) error +} + +// BlockIterator defines an iterator on blocks hold +// in a DB. +type BlockIterator interface { + NextBlock() (types.Block, error) +} diff --git a/vendor/github.com/byzantine-lab/dexon-consensus/core/db/level-db.go b/vendor/github.com/byzantine-lab/dexon-consensus/core/db/level-db.go new file mode 100644 index 000000000..9e3564b50 --- /dev/null +++ b/vendor/github.com/byzantine-lab/dexon-consensus/core/db/level-db.go @@ -0,0 +1,573 @@ +// Copyright 2018 The dexon-consensus Authors +// This file is part of the dexon-consensus library. +// +// The dexon-consensus library is free software: you can redistribute it +// and/or modify it under the terms of the GNU Lesser General Public License as +// published by the Free Software Foundation, either version 3 of the License, +// or (at your option) any later version. +// +// The dexon-consensus library is distributed in the hope that it will be +// useful, but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser +// General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the dexon-consensus library. If not, see +// <http://www.gnu.org/licenses/>. + +package db + +import ( + "encoding/binary" + "io" + + "github.com/syndtr/goleveldb/leveldb" + + "github.com/byzantine-lab/dexon-consensus/common" + "github.com/byzantine-lab/dexon-consensus/core/crypto/dkg" + "github.com/byzantine-lab/dexon-consensus/core/types" + "github.com/byzantine-lab/go-tangerine/rlp" +) + +var ( + blockKeyPrefix = []byte("b-") + compactionChainTipInfoKey = []byte("cc-tip") + dkgPrivateKeyKeyPrefix = []byte("dkg-prvs") + dkgProtocolInfoKeyPrefix = []byte("dkg-protocol-info") +) + +type compactionChainTipInfo struct { + Height uint64 `json:"height"` + Hash common.Hash `json:"hash"` +} + +// DKGProtocolInfo DKG protocol info. +type DKGProtocolInfo struct { + ID types.NodeID + Round uint64 + Threshold uint64 + IDMap NodeIDToDKGID + MpkMap NodeIDToPubShares + MasterPrivateShare dkg.PrivateKeyShares + IsMasterPrivateShareEmpty bool + PrvShares dkg.PrivateKeyShares + IsPrvSharesEmpty bool + PrvSharesReceived NodeID + NodeComplained NodeID + AntiComplaintReceived NodeIDToNodeIDs + Step uint64 + Reset uint64 +} + +type dkgPrivateKey struct { + PK dkg.PrivateKey + Reset uint64 +} + +// Equal compare with target DKGProtocolInfo. +func (info *DKGProtocolInfo) Equal(target *DKGProtocolInfo) bool { + if !info.ID.Equal(target.ID) || + info.Round != target.Round || + info.Threshold != target.Threshold || + info.IsMasterPrivateShareEmpty != target.IsMasterPrivateShareEmpty || + info.IsPrvSharesEmpty != target.IsPrvSharesEmpty || + info.Step != target.Step || + info.Reset != target.Reset || + !info.MasterPrivateShare.Equal(&target.MasterPrivateShare) || + !info.PrvShares.Equal(&target.PrvShares) { + return false + } + + if len(info.IDMap) != len(target.IDMap) { + return false + } + for k, v := range info.IDMap { + tV, exist := target.IDMap[k] + if !exist { + return false + } + + if !v.IsEqual(&tV) { + return false + } + } + + if len(info.MpkMap) != len(target.MpkMap) { + return false + } + for k, v := range info.MpkMap { + tV, exist := target.MpkMap[k] + if !exist { + return false + } + + if !v.Equal(tV) { + return false + } + } + + if len(info.PrvSharesReceived) != len(target.PrvSharesReceived) { + return false + } + for k := range info.PrvSharesReceived { + _, exist := target.PrvSharesReceived[k] + if !exist { + return false + } + } + + if len(info.NodeComplained) != len(target.NodeComplained) { + return false + } + for k := range info.NodeComplained { + _, exist := target.NodeComplained[k] + if !exist { + return false + } + } + + if len(info.AntiComplaintReceived) != len(target.AntiComplaintReceived) { + return false + } + for k, v := range info.AntiComplaintReceived { + tV, exist := target.AntiComplaintReceived[k] + if !exist { + return false + } + + if len(v) != len(tV) { + return false + } + for kk := range v { + _, exist := tV[kk] + if !exist { + return false + } + } + } + + return true +} + +// NodeIDToNodeIDs the map with NodeID to NodeIDs. +type NodeIDToNodeIDs map[types.NodeID]map[types.NodeID]struct{} + +// EncodeRLP implements rlp.Encoder +func (m NodeIDToNodeIDs) EncodeRLP(w io.Writer) error { + var allBytes [][][]byte + for k, v := range m { + kBytes, err := k.MarshalText() + if err != nil { + return err + } + allBytes = append(allBytes, [][]byte{kBytes}) + + var vBytes [][]byte + for subK := range v { + bytes, err := subK.MarshalText() + if err != nil { + return err + } + vBytes = append(vBytes, bytes) + } + allBytes = append(allBytes, vBytes) + } + + return rlp.Encode(w, allBytes) +} + +// DecodeRLP implements rlp.Encoder +func (m *NodeIDToNodeIDs) DecodeRLP(s *rlp.Stream) error { + *m = make(NodeIDToNodeIDs) + var dec [][][]byte + if err := s.Decode(&dec); err != nil { + return err + } + + for i := 0; i < len(dec); i += 2 { + key := types.NodeID{} + err := key.UnmarshalText(dec[i][0]) + if err != nil { + return err + } + + valueMap := map[types.NodeID]struct{}{} + for _, v := range dec[i+1] { + value := types.NodeID{} + err := value.UnmarshalText(v) + if err != nil { + return err + } + + valueMap[value] = struct{}{} + } + + (*m)[key] = valueMap + } + + return nil +} + +// NodeID the map with NodeID. +type NodeID map[types.NodeID]struct{} + +// EncodeRLP implements rlp.Encoder +func (m NodeID) EncodeRLP(w io.Writer) error { + var allBytes [][]byte + for k := range m { + kBytes, err := k.MarshalText() + if err != nil { + return err + } + allBytes = append(allBytes, kBytes) + } + + return rlp.Encode(w, allBytes) +} + +// DecodeRLP implements rlp.Encoder +func (m *NodeID) DecodeRLP(s *rlp.Stream) error { + *m = make(NodeID) + var dec [][]byte + if err := s.Decode(&dec); err != nil { + return err + } + + for i := 0; i < len(dec); i++ { + key := types.NodeID{} + err := key.UnmarshalText(dec[i]) + if err != nil { + return err + } + + (*m)[key] = struct{}{} + } + + return nil +} + +// NodeIDToPubShares the map with NodeID to PublicKeyShares. +type NodeIDToPubShares map[types.NodeID]*dkg.PublicKeyShares + +// EncodeRLP implements rlp.Encoder +func (m NodeIDToPubShares) EncodeRLP(w io.Writer) error { + var allBytes [][]byte + for k, v := range m { + kBytes, err := k.MarshalText() + if err != nil { + return err + } + allBytes = append(allBytes, kBytes) + + bytes, err := rlp.EncodeToBytes(v) + if err != nil { + return err + } + allBytes = append(allBytes, bytes) + } + + return rlp.Encode(w, allBytes) +} + +// DecodeRLP implements rlp.Encoder +func (m *NodeIDToPubShares) DecodeRLP(s *rlp.Stream) error { + *m = make(NodeIDToPubShares) + var dec [][]byte + if err := s.Decode(&dec); err != nil { + return err + } + + for i := 0; i < len(dec); i += 2 { + key := types.NodeID{} + err := key.UnmarshalText(dec[i]) + if err != nil { + return err + } + + value := dkg.PublicKeyShares{} + err = rlp.DecodeBytes(dec[i+1], &value) + if err != nil { + return err + } + + (*m)[key] = &value + } + + return nil +} + +// NodeIDToDKGID the map with NodeID to DKGID. +type NodeIDToDKGID map[types.NodeID]dkg.ID + +// EncodeRLP implements rlp.Encoder +func (m NodeIDToDKGID) EncodeRLP(w io.Writer) error { + var allBytes [][]byte + for k, v := range m { + kBytes, err := k.MarshalText() + if err != nil { + return err + } + allBytes = append(allBytes, kBytes) + allBytes = append(allBytes, v.GetLittleEndian()) + } + + return rlp.Encode(w, allBytes) +} + +// DecodeRLP implements rlp.Encoder +func (m *NodeIDToDKGID) DecodeRLP(s *rlp.Stream) error { + *m = make(NodeIDToDKGID) + var dec [][]byte + if err := s.Decode(&dec); err != nil { + return err + } + + for i := 0; i < len(dec); i += 2 { + key := types.NodeID{} + err := key.UnmarshalText(dec[i]) + if err != nil { + return err + } + + value := dkg.ID{} + err = value.SetLittleEndian(dec[i+1]) + if err != nil { + return err + } + + (*m)[key] = value + } + + return nil +} + +// LevelDBBackedDB is a leveldb backed DB implementation. +type LevelDBBackedDB struct { + db *leveldb.DB +} + +// NewLevelDBBackedDB initialize a leveldb-backed database. +func NewLevelDBBackedDB( + path string) (lvl *LevelDBBackedDB, err error) { + + dbInst, err := leveldb.OpenFile(path, nil) + if err != nil { + return + } + lvl = &LevelDBBackedDB{db: dbInst} + return +} + +// Close implement Closer interface, which would release allocated resource. +func (lvl *LevelDBBackedDB) Close() error { + return lvl.db.Close() +} + +// HasBlock implements the Reader.Has method. +func (lvl *LevelDBBackedDB) HasBlock(hash common.Hash) bool { + exists, err := lvl.internalHasBlock(lvl.getBlockKey(hash)) + if err != nil { + panic(err) + } + return exists +} + +func (lvl *LevelDBBackedDB) internalHasBlock(key []byte) (bool, error) { + return lvl.db.Has(key, nil) +} + +// GetBlock implements the Reader.GetBlock method. +func (lvl *LevelDBBackedDB) GetBlock( + hash common.Hash) (block types.Block, err error) { + queried, err := lvl.db.Get(lvl.getBlockKey(hash), nil) + if err != nil { + if err == leveldb.ErrNotFound { + err = ErrBlockDoesNotExist + } + return + } + err = rlp.DecodeBytes(queried, &block) + return +} + +// UpdateBlock implements the Writer.UpdateBlock method. +func (lvl *LevelDBBackedDB) UpdateBlock(block types.Block) (err error) { + // NOTE: we didn't handle changes of block hash (and it + // should not happen). + marshaled, err := rlp.EncodeToBytes(&block) + if err != nil { + return + } + blockKey := lvl.getBlockKey(block.Hash) + exists, err := lvl.internalHasBlock(blockKey) + if err != nil { + return + } + if !exists { + err = ErrBlockDoesNotExist + return + } + err = lvl.db.Put(blockKey, marshaled, nil) + return +} + +// PutBlock implements the Writer.PutBlock method. +func (lvl *LevelDBBackedDB) PutBlock(block types.Block) (err error) { + marshaled, err := rlp.EncodeToBytes(&block) + if err != nil { + return + } + blockKey := lvl.getBlockKey(block.Hash) + exists, err := lvl.internalHasBlock(blockKey) + if err != nil { + return + } + if exists { + err = ErrBlockExists + return + } + err = lvl.db.Put(blockKey, marshaled, nil) + return +} + +// GetAllBlocks implements Reader.GetAllBlocks method, which allows callers +// to retrieve all blocks in DB. +func (lvl *LevelDBBackedDB) GetAllBlocks() (BlockIterator, error) { + return nil, ErrNotImplemented +} + +// PutCompactionChainTipInfo saves tip of compaction chain into the database. +func (lvl *LevelDBBackedDB) PutCompactionChainTipInfo( + blockHash common.Hash, height uint64) error { + marshaled, err := rlp.EncodeToBytes(&compactionChainTipInfo{ + Hash: blockHash, + Height: height, + }) + if err != nil { + return err + } + // Check current cached tip info to make sure the one to be updated is + // valid. + info, err := lvl.internalGetCompactionChainTipInfo() + if err != nil { + return err + } + if info.Height+1 != height { + return ErrInvalidCompactionChainTipHeight + } + return lvl.db.Put(compactionChainTipInfoKey, marshaled, nil) +} + +func (lvl *LevelDBBackedDB) internalGetCompactionChainTipInfo() ( + info compactionChainTipInfo, err error) { + queried, err := lvl.db.Get(compactionChainTipInfoKey, nil) + if err != nil { + if err == leveldb.ErrNotFound { + err = nil + } + return + } + err = rlp.DecodeBytes(queried, &info) + return +} + +// GetCompactionChainTipInfo get the tip info of compaction chain into the +// database. +func (lvl *LevelDBBackedDB) GetCompactionChainTipInfo() ( + hash common.Hash, height uint64) { + info, err := lvl.internalGetCompactionChainTipInfo() + if err != nil { + panic(err) + } + hash, height = info.Hash, info.Height + return +} + +// GetDKGPrivateKey get DKG private key of one round. +func (lvl *LevelDBBackedDB) GetDKGPrivateKey(round, reset uint64) ( + prv dkg.PrivateKey, err error) { + queried, err := lvl.db.Get(lvl.getDKGPrivateKeyKey(round), nil) + if err != nil { + if err == leveldb.ErrNotFound { + err = ErrDKGPrivateKeyDoesNotExist + } + return + } + pk := dkgPrivateKey{} + err = rlp.DecodeBytes(queried, &pk) + if pk.Reset != reset { + err = ErrDKGPrivateKeyDoesNotExist + return + } + prv = pk.PK + return +} + +// PutDKGPrivateKey save DKG private key of one round. +func (lvl *LevelDBBackedDB) PutDKGPrivateKey( + round, reset uint64, prv dkg.PrivateKey) error { + // Check existence. + _, err := lvl.GetDKGPrivateKey(round, reset) + if err == nil { + return ErrDKGPrivateKeyExists + } + if err != ErrDKGPrivateKeyDoesNotExist { + return err + } + pk := &dkgPrivateKey{ + PK: prv, + Reset: reset, + } + marshaled, err := rlp.EncodeToBytes(&pk) + if err != nil { + return err + } + return lvl.db.Put( + lvl.getDKGPrivateKeyKey(round), marshaled, nil) +} + +// GetDKGProtocol get DKG protocol. +func (lvl *LevelDBBackedDB) GetDKGProtocol() ( + info DKGProtocolInfo, err error) { + queried, err := lvl.db.Get(lvl.getDKGProtocolInfoKey(), nil) + if err != nil { + if err == leveldb.ErrNotFound { + err = ErrDKGProtocolDoesNotExist + } + return + } + + err = rlp.DecodeBytes(queried, &info) + return +} + +// PutOrUpdateDKGProtocol save DKG protocol. +func (lvl *LevelDBBackedDB) PutOrUpdateDKGProtocol(info DKGProtocolInfo) error { + marshaled, err := rlp.EncodeToBytes(&info) + if err != nil { + return err + } + return lvl.db.Put(lvl.getDKGProtocolInfoKey(), marshaled, nil) +} + +func (lvl *LevelDBBackedDB) getBlockKey(hash common.Hash) (ret []byte) { + ret = make([]byte, len(blockKeyPrefix)+len(hash[:])) + copy(ret, blockKeyPrefix) + copy(ret[len(blockKeyPrefix):], hash[:]) + return +} + +func (lvl *LevelDBBackedDB) getDKGPrivateKeyKey( + round uint64) (ret []byte) { + ret = make([]byte, len(dkgPrivateKeyKeyPrefix)+8) + copy(ret, dkgPrivateKeyKeyPrefix) + binary.LittleEndian.PutUint64( + ret[len(dkgPrivateKeyKeyPrefix):], round) + return +} + +func (lvl *LevelDBBackedDB) getDKGProtocolInfoKey() (ret []byte) { + ret = make([]byte, len(dkgProtocolInfoKeyPrefix)+8) + copy(ret, dkgProtocolInfoKeyPrefix) + return +} diff --git a/vendor/github.com/byzantine-lab/dexon-consensus/core/db/memory.go b/vendor/github.com/byzantine-lab/dexon-consensus/core/db/memory.go new file mode 100644 index 000000000..2ad5cda9e --- /dev/null +++ b/vendor/github.com/byzantine-lab/dexon-consensus/core/db/memory.go @@ -0,0 +1,262 @@ +// Copyright 2018 The dexon-consensus Authors +// This file is part of the dexon-consensus library. +// +// The dexon-consensus library is free software: you can redistribute it +// and/or modify it under the terms of the GNU Lesser General Public License as +// published by the Free Software Foundation, either version 3 of the License, +// or (at your option) any later version. +// +// The dexon-consensus library is distributed in the hope that it will be +// useful, but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser +// General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the dexon-consensus library. If not, see +// <http://www.gnu.org/licenses/>. + +package db + +import ( + "encoding/json" + "io/ioutil" + "os" + "sync" + + "github.com/byzantine-lab/dexon-consensus/common" + "github.com/byzantine-lab/dexon-consensus/core/crypto/dkg" + "github.com/byzantine-lab/dexon-consensus/core/types" +) + +type blockSeqIterator struct { + idx int + db *MemBackedDB +} + +// NextBlock implemenets BlockIterator.NextBlock method. +func (seq *blockSeqIterator) NextBlock() (types.Block, error) { + curIdx := seq.idx + seq.idx++ + return seq.db.getBlockByIndex(curIdx) +} + +// MemBackedDB is a memory backed DB implementation. +type MemBackedDB struct { + blocksLock sync.RWMutex + blockHashSequence common.Hashes + blocksByHash map[common.Hash]*types.Block + compactionChainTipLock sync.RWMutex + compactionChainTipHash common.Hash + compactionChainTipHeight uint64 + dkgPrivateKeysLock sync.RWMutex + dkgPrivateKeys map[uint64]*dkgPrivateKey + dkgProtocolLock sync.RWMutex + dkgProtocolInfo *DKGProtocolInfo + persistantFilePath string +} + +// NewMemBackedDB initialize a memory-backed database. +func NewMemBackedDB(persistantFilePath ...string) ( + dbInst *MemBackedDB, err error) { + dbInst = &MemBackedDB{ + blockHashSequence: common.Hashes{}, + blocksByHash: make(map[common.Hash]*types.Block), + dkgPrivateKeys: make(map[uint64]*dkgPrivateKey), + } + if len(persistantFilePath) == 0 || len(persistantFilePath[0]) == 0 { + return + } + dbInst.persistantFilePath = persistantFilePath[0] + buf, err := ioutil.ReadFile(dbInst.persistantFilePath) + if err != nil { + if !os.IsNotExist(err) { + // Something unexpected happened. + return + } + // It's expected behavior that file doesn't exists, we should not + // report error on it. + err = nil + return + } + + // Init this instance by file content, it's a temporary way + // to export those private field for JSON encoding. + toLoad := struct { + Sequence common.Hashes + ByHash map[common.Hash]*types.Block + }{} + err = json.Unmarshal(buf, &toLoad) + if err != nil { + return + } + dbInst.blockHashSequence = toLoad.Sequence + dbInst.blocksByHash = toLoad.ByHash + return +} + +// HasBlock returns wheter or not the DB has a block identified with the hash. +func (m *MemBackedDB) HasBlock(hash common.Hash) bool { + m.blocksLock.RLock() + defer m.blocksLock.RUnlock() + + _, ok := m.blocksByHash[hash] + return ok +} + +// GetBlock returns a block given a hash. +func (m *MemBackedDB) GetBlock(hash common.Hash) (types.Block, error) { + m.blocksLock.RLock() + defer m.blocksLock.RUnlock() + + return m.internalGetBlock(hash) +} + +func (m *MemBackedDB) internalGetBlock(hash common.Hash) (types.Block, error) { + b, ok := m.blocksByHash[hash] + if !ok { + return types.Block{}, ErrBlockDoesNotExist + } + return *b, nil +} + +// PutBlock inserts a new block into the database. +func (m *MemBackedDB) PutBlock(block types.Block) error { + if m.HasBlock(block.Hash) { + return ErrBlockExists + } + + m.blocksLock.Lock() + defer m.blocksLock.Unlock() + + m.blockHashSequence = append(m.blockHashSequence, block.Hash) + m.blocksByHash[block.Hash] = &block + return nil +} + +// UpdateBlock updates a block in the database. +func (m *MemBackedDB) UpdateBlock(block types.Block) error { + if !m.HasBlock(block.Hash) { + return ErrBlockDoesNotExist + } + + m.blocksLock.Lock() + defer m.blocksLock.Unlock() + + m.blocksByHash[block.Hash] = &block + return nil +} + +// PutCompactionChainTipInfo saves tip of compaction chain into the database. +func (m *MemBackedDB) PutCompactionChainTipInfo( + blockHash common.Hash, height uint64) error { + m.compactionChainTipLock.Lock() + defer m.compactionChainTipLock.Unlock() + if m.compactionChainTipHeight+1 != height { + return ErrInvalidCompactionChainTipHeight + } + m.compactionChainTipHeight = height + m.compactionChainTipHash = blockHash + return nil +} + +// GetCompactionChainTipInfo get the tip info of compaction chain into the +// database. +func (m *MemBackedDB) GetCompactionChainTipInfo() ( + hash common.Hash, height uint64) { + m.compactionChainTipLock.RLock() + defer m.compactionChainTipLock.RUnlock() + return m.compactionChainTipHash, m.compactionChainTipHeight +} + +// GetDKGPrivateKey get DKG private key of one round. +func (m *MemBackedDB) GetDKGPrivateKey(round, reset uint64) ( + dkg.PrivateKey, error) { + m.dkgPrivateKeysLock.RLock() + defer m.dkgPrivateKeysLock.RUnlock() + if prv, exists := m.dkgPrivateKeys[round]; exists && prv.Reset == reset { + return prv.PK, nil + } + return dkg.PrivateKey{}, ErrDKGPrivateKeyDoesNotExist +} + +// PutDKGPrivateKey save DKG private key of one round. +func (m *MemBackedDB) PutDKGPrivateKey( + round, reset uint64, prv dkg.PrivateKey) error { + m.dkgPrivateKeysLock.Lock() + defer m.dkgPrivateKeysLock.Unlock() + if prv, exists := m.dkgPrivateKeys[round]; exists && prv.Reset == reset { + return ErrDKGPrivateKeyExists + } + m.dkgPrivateKeys[round] = &dkgPrivateKey{ + PK: prv, + Reset: reset, + } + return nil +} + +// GetDKGProtocol get DKG protocol. +func (m *MemBackedDB) GetDKGProtocol() ( + DKGProtocolInfo, error) { + m.dkgProtocolLock.RLock() + defer m.dkgProtocolLock.RUnlock() + if m.dkgProtocolInfo == nil { + return DKGProtocolInfo{}, ErrDKGProtocolDoesNotExist + } + + return *m.dkgProtocolInfo, nil +} + +// PutOrUpdateDKGProtocol save DKG protocol. +func (m *MemBackedDB) PutOrUpdateDKGProtocol(dkgProtocol DKGProtocolInfo) error { + m.dkgProtocolLock.Lock() + defer m.dkgProtocolLock.Unlock() + m.dkgProtocolInfo = &dkgProtocol + return nil +} + +// Close implement Closer interface, which would release allocated resource. +func (m *MemBackedDB) Close() (err error) { + // Save internal state to a pretty-print json file. It's a temporary way + // to dump private file via JSON encoding. + if len(m.persistantFilePath) == 0 { + return + } + + m.blocksLock.RLock() + defer m.blocksLock.RUnlock() + + toDump := struct { + Sequence common.Hashes + ByHash map[common.Hash]*types.Block + }{ + Sequence: m.blockHashSequence, + ByHash: m.blocksByHash, + } + + // Dump to JSON with 2-space indent. + buf, err := json.Marshal(&toDump) + if err != nil { + return + } + + err = ioutil.WriteFile(m.persistantFilePath, buf, 0644) + return +} + +func (m *MemBackedDB) getBlockByIndex(idx int) (types.Block, error) { + m.blocksLock.RLock() + defer m.blocksLock.RUnlock() + + if idx >= len(m.blockHashSequence) { + return types.Block{}, ErrIterationFinished + } + + hash := m.blockHashSequence[idx] + return m.internalGetBlock(hash) +} + +// GetAllBlocks implement Reader.GetAllBlocks method, which allows caller +// to retrieve all blocks in DB. +func (m *MemBackedDB) GetAllBlocks() (BlockIterator, error) { + return &blockSeqIterator{db: m}, nil +} |