aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--core/vm/sqlvm/common/storage.go149
-rw-r--r--core/vm/sqlvm/common/storage_test.go50
-rw-r--r--core/vm/sqlvm/runtime/instructions_test.go3
3 files changed, 136 insertions, 66 deletions
diff --git a/core/vm/sqlvm/common/storage.go b/core/vm/sqlvm/common/storage.go
index e046b0411..be8a074ad 100644
--- a/core/vm/sqlvm/common/storage.go
+++ b/core/vm/sqlvm/common/storage.go
@@ -534,43 +534,99 @@ func (s *Storage) IncSequence(
return val
}
-// DecodePKHeader decodes primary key hash header to lastRowID and rowCount.
-func (s *Storage) DecodePKHeader(header common.Hash) (lastRowID, rowCount uint64) {
- lastRowID = binary.BigEndian.Uint64(header[:8])
- rowCount = binary.BigEndian.Uint64(header[8:16])
- return
+func setBit(n byte, pos uint) byte {
+ n |= (1 << pos)
+ return n
}
-// EncodePKHeader encodes lastRowID and rowCount to primary key hash header.
-func (s *Storage) EncodePKHeader(lastRowID, rowCount uint64) (header common.Hash) {
- binary.BigEndian.PutUint64(header[:8], lastRowID)
- binary.BigEndian.PutUint64(header[8:16], rowCount)
- return
+func hasBit(n byte, pos uint) bool {
+ val := n & (1 << pos)
+ return (val > 0)
}
-// UpdateHash updates hash to stateDB.
-func (s *Storage) UpdateHash(m map[common.Hash]common.Hash, address common.Address) {
- for key, val := range m {
- s.SetState(address, key, val)
+func getOffset(d common.Hash) (offset []uint64) {
+ for j, b := range d {
+ for i := 0; i < 8; i++ {
+ if hasBit(b, uint(i)) {
+ offset = append(offset, uint64(j*8+i))
+ }
+ }
}
+ return
}
-func setBit(n byte, pos uint) byte {
- n |= (1 << pos)
- return n
+// RepeatPK returns primary IDs by table reference.
+func (s *Storage) RepeatPK(address common.Address, tableRef schema.TableRef) []uint64 {
+ hash := s.GetPrimaryPathHash(tableRef)
+ bm := newBitMap(hash, address, s)
+ return bm.loadPK()
}
-func hasBit(n byte, pos uint) bool {
- val := n & (1 << pos)
- return (val > 0)
+// IncreasePK increases the primary ID and return it.
+func (s *Storage) IncreasePK(
+ address common.Address,
+ tableRef schema.TableRef,
+) uint64 {
+ hash := s.GetPrimaryPathHash(tableRef)
+ bm := newBitMap(hash, address, s)
+ return bm.increasePK()
}
// SetPK sets IDs to primary bit map.
-func (s *Storage) SetPK(address common.Address, tableRef schema.TableRef, IDs []uint64) {
- hash := s.GetPrimaryPathHash(tableRef)
- header := s.GetState(address, hash)
- lastRowID, rowCount := s.DecodePKHeader(header)
- slotHashToData := make(map[common.Hash]common.Hash)
+func (s *Storage) SetPK(address common.Address, headerHash common.Hash, IDs []uint64) {
+ bm := newBitMap(headerHash, address, s)
+ bm.setPK(IDs)
+}
+
+type bitMap struct {
+ storage *Storage
+ headerSlot common.Hash
+ headerData common.Hash
+ address common.Address
+ dirtySlot map[uint64]common.Hash
+}
+
+func (bm *bitMap) decodeHeader() (lastRowID, rowCount uint64) {
+ lastRowID = binary.BigEndian.Uint64(bm.headerData[:8])
+ rowCount = binary.BigEndian.Uint64(bm.headerData[8:16])
+ return
+}
+
+func (bm *bitMap) encodeHeader(lastRowID, rowCount uint64) {
+ binary.BigEndian.PutUint64(bm.headerData[:8], lastRowID)
+ binary.BigEndian.PutUint64(bm.headerData[8:16], rowCount)
+}
+
+func (bm *bitMap) increasePK() uint64 {
+ lastRowID, rowCount := bm.decodeHeader()
+ lastRowID++
+ rowCount++
+ bm.encodeHeader(lastRowID, rowCount)
+ shift := lastRowID/256 + 1
+ slot := bm.storage.ShiftHashUint64(bm.headerSlot, shift)
+ data := bm.storage.GetState(bm.address, slot)
+ byteShift := (lastRowID & 255) / 8
+ data[byteShift] = setBit(data[byteShift], uint(lastRowID&7))
+ bm.dirtySlot[shift] = data
+ bm.flushAll()
+ return lastRowID
+}
+
+func (bm *bitMap) flushHeader() {
+ bm.storage.SetState(bm.address, bm.headerSlot, bm.headerData)
+}
+
+func (bm *bitMap) flushAll() {
+ for k, v := range bm.dirtySlot {
+ slot := bm.storage.ShiftHashUint64(bm.headerSlot, k)
+ bm.storage.SetState(bm.address, slot, v)
+ }
+ bm.flushHeader()
+ bm.dirtySlot = make(map[uint64]common.Hash)
+}
+
+func (bm *bitMap) setPK(IDs []uint64) {
+ lastRowID, rowCount := bm.decodeHeader()
for _, id := range IDs {
if lastRowID < id {
lastRowID = id
@@ -578,45 +634,30 @@ func (s *Storage) SetPK(address common.Address, tableRef schema.TableRef, IDs []
slotNum := id/256 + 1
byteLoc := (id & 255) / 8
bitLoc := uint(id & 7)
- slotHash := s.ShiftHashUint64(hash, slotNum)
- data, exist := slotHashToData[slotHash]
+ data, exist := bm.dirtySlot[slotNum]
if !exist {
- data = s.GetState(address, slotHash)
+ slotHash := bm.storage.ShiftHashUint64(bm.headerSlot, slotNum)
+ data = bm.storage.GetState(bm.address, slotHash)
}
if !hasBit(data[byteLoc], bitLoc) {
rowCount++
data[byteLoc] = setBit(data[byteLoc], bitLoc)
}
- slotHashToData[slotHash] = data
+ bm.dirtySlot[slotNum] = data
}
- s.UpdateHash(slotHashToData, address)
- header = s.EncodePKHeader(lastRowID, rowCount)
- s.SetState(address, hash, header)
+ bm.encodeHeader(lastRowID, rowCount)
+ bm.flushAll()
}
-func getCountAndOffset(d common.Hash) (offset []uint64) {
- for j, b := range d {
- for i := 0; i < 8; i++ {
- if hasBit(b, uint(i)) {
- offset = append(offset, uint64(j*8+i))
- }
- }
- }
- return
-}
-
-// RepeatPK returns primary IDs by table reference.
-func (s *Storage) RepeatPK(address common.Address, tableRef schema.TableRef) []uint64 {
- hash := s.GetPrimaryPathHash(tableRef)
- header := s.GetState(address, hash)
- lastRowID, rowCount := s.DecodePKHeader(header)
+func (bm *bitMap) loadPK() []uint64 {
+ lastRowID, rowCount := bm.decodeHeader()
maxSlotNum := lastRowID/256 + 1
result := make([]uint64, rowCount)
ptr := 0
for slotNum := uint64(0); slotNum < maxSlotNum; slotNum++ {
- slotHash := s.ShiftHashUint64(hash, slotNum+1)
- slotData := s.GetState(address, slotHash)
- offsets := getCountAndOffset(slotData)
+ slotHash := bm.storage.ShiftHashUint64(bm.headerSlot, slotNum+1)
+ slotData := bm.storage.GetState(bm.address, slotHash)
+ offsets := getOffset(slotData)
for i, o := range offsets {
result[i+ptr] = o + slotNum*256
}
@@ -624,3 +665,9 @@ func (s *Storage) RepeatPK(address common.Address, tableRef schema.TableRef) []u
}
return result
}
+
+func newBitMap(headerSlot common.Hash, address common.Address, s *Storage) *bitMap {
+ headerData := s.GetState(address, headerSlot)
+ bm := bitMap{s, headerSlot, headerData, address, make(map[uint64]common.Hash)}
+ return &bm
+}
diff --git a/core/vm/sqlvm/common/storage_test.go b/core/vm/sqlvm/common/storage_test.go
index d402868e3..b02705c11 100644
--- a/core/vm/sqlvm/common/storage_test.go
+++ b/core/vm/sqlvm/common/storage_test.go
@@ -208,23 +208,13 @@ func (s *StorageTestSuite) TestSequence() {
func (s *StorageTestSuite) TestPKHeaderEncodeDecode() {
lastRowID := uint64(5566)
rowCount := uint64(6655)
- newLastRowID, newRowCount := s.storage.DecodePKHeader(s.storage.EncodePKHeader(lastRowID, rowCount))
+ bm := bitMap{}
+ bm.encodeHeader(lastRowID, rowCount)
+ newLastRowID, newRowCount := bm.decodeHeader()
s.Require().Equal(lastRowID, newLastRowID)
s.Require().Equal(rowCount, newRowCount)
}
-func (s *StorageTestSuite) TestUpdateHash() {
- m := map[common.Hash]common.Hash{
- common.BytesToHash([]byte("hello world")): common.BytesToHash([]byte("hello SQLVM")),
- common.BytesToHash([]byte("bye world")): common.BytesToHash([]byte("bye SQLVM")),
- }
- s.storage.UpdateHash(m, s.address)
- for key, val := range m {
- rVal := s.storage.GetState(s.address, key)
- s.Require().Equal(val, rVal)
- }
-}
-
func (s *StorageTestSuite) TestRepeatPK() {
type testCase struct {
address common.Address
@@ -249,12 +239,44 @@ func (s *StorageTestSuite) TestRepeatPK() {
},
}
for i, t := range testCases {
- s.storage.SetPK(t.address, t.tableRef, t.expectIDs)
+ headerSlot := s.storage.GetPrimaryPathHash(t.tableRef)
+ s.storage.SetPK(t.address, headerSlot, t.expectIDs)
IDs := s.storage.RepeatPK(t.address, t.tableRef)
s.Require().Equalf(t.expectIDs, IDs, "testCase #%v\n", i)
}
}
+func (s *StorageTestSuite) TestBitMapIncreasePK() {
+ type testCase struct {
+ tableRef schema.TableRef
+ IDs []uint64
+ }
+ testCases := []testCase{
+ {
+ tableRef: schema.TableRef(0),
+ IDs: []uint64{0, 1, 2},
+ },
+ {
+ tableRef: schema.TableRef(1),
+ IDs: []uint64{1234, 5566},
+ },
+ {
+ tableRef: schema.TableRef(2),
+ IDs: []uint64{0, 128, 256, 512, 1024},
+ },
+ }
+ for i, t := range testCases {
+ hash := s.storage.GetPrimaryPathHash(t.tableRef)
+ s.storage.SetPK(s.address, hash, t.IDs)
+ bm := newBitMap(hash, s.address, s.storage)
+ newID := bm.increasePK()
+
+ t.IDs = append(t.IDs, newID)
+ IDs := s.storage.RepeatPK(s.address, t.tableRef)
+ s.Require().Equalf(t.IDs, IDs, "testCase #%v\n", i)
+ }
+}
+
func TestStorage(t *testing.T) {
suite.Run(t, new(StorageTestSuite))
}
diff --git a/core/vm/sqlvm/runtime/instructions_test.go b/core/vm/sqlvm/runtime/instructions_test.go
index 77ba0ff70..a00ac12dd 100644
--- a/core/vm/sqlvm/runtime/instructions_test.go
+++ b/core/vm/sqlvm/runtime/instructions_test.go
@@ -457,7 +457,8 @@ func (s opRepeatPKSuite) getTestCases(storage *common.Storage) []repeatPKTestCas
},
}
for _, t := range testCases {
- storage.SetPK(t.address, t.tableRef, t.expectedIDs)
+ headerSlot := storage.GetPrimaryPathHash(t.tableRef)
+ storage.SetPK(t.address, headerSlot, t.expectedIDs)
}
return testCases
}