diff options
-rw-r--r-- | core/vm/sqlvm/common/storage.go | 149 | ||||
-rw-r--r-- | core/vm/sqlvm/common/storage_test.go | 50 | ||||
-rw-r--r-- | core/vm/sqlvm/runtime/instructions_test.go | 3 |
3 files changed, 136 insertions, 66 deletions
diff --git a/core/vm/sqlvm/common/storage.go b/core/vm/sqlvm/common/storage.go index e046b0411..be8a074ad 100644 --- a/core/vm/sqlvm/common/storage.go +++ b/core/vm/sqlvm/common/storage.go @@ -534,43 +534,99 @@ func (s *Storage) IncSequence( return val } -// DecodePKHeader decodes primary key hash header to lastRowID and rowCount. -func (s *Storage) DecodePKHeader(header common.Hash) (lastRowID, rowCount uint64) { - lastRowID = binary.BigEndian.Uint64(header[:8]) - rowCount = binary.BigEndian.Uint64(header[8:16]) - return +func setBit(n byte, pos uint) byte { + n |= (1 << pos) + return n } -// EncodePKHeader encodes lastRowID and rowCount to primary key hash header. -func (s *Storage) EncodePKHeader(lastRowID, rowCount uint64) (header common.Hash) { - binary.BigEndian.PutUint64(header[:8], lastRowID) - binary.BigEndian.PutUint64(header[8:16], rowCount) - return +func hasBit(n byte, pos uint) bool { + val := n & (1 << pos) + return (val > 0) } -// UpdateHash updates hash to stateDB. -func (s *Storage) UpdateHash(m map[common.Hash]common.Hash, address common.Address) { - for key, val := range m { - s.SetState(address, key, val) +func getOffset(d common.Hash) (offset []uint64) { + for j, b := range d { + for i := 0; i < 8; i++ { + if hasBit(b, uint(i)) { + offset = append(offset, uint64(j*8+i)) + } + } } + return } -func setBit(n byte, pos uint) byte { - n |= (1 << pos) - return n +// RepeatPK returns primary IDs by table reference. +func (s *Storage) RepeatPK(address common.Address, tableRef schema.TableRef) []uint64 { + hash := s.GetPrimaryPathHash(tableRef) + bm := newBitMap(hash, address, s) + return bm.loadPK() } -func hasBit(n byte, pos uint) bool { - val := n & (1 << pos) - return (val > 0) +// IncreasePK increases the primary ID and return it. +func (s *Storage) IncreasePK( + address common.Address, + tableRef schema.TableRef, +) uint64 { + hash := s.GetPrimaryPathHash(tableRef) + bm := newBitMap(hash, address, s) + return bm.increasePK() } // SetPK sets IDs to primary bit map. -func (s *Storage) SetPK(address common.Address, tableRef schema.TableRef, IDs []uint64) { - hash := s.GetPrimaryPathHash(tableRef) - header := s.GetState(address, hash) - lastRowID, rowCount := s.DecodePKHeader(header) - slotHashToData := make(map[common.Hash]common.Hash) +func (s *Storage) SetPK(address common.Address, headerHash common.Hash, IDs []uint64) { + bm := newBitMap(headerHash, address, s) + bm.setPK(IDs) +} + +type bitMap struct { + storage *Storage + headerSlot common.Hash + headerData common.Hash + address common.Address + dirtySlot map[uint64]common.Hash +} + +func (bm *bitMap) decodeHeader() (lastRowID, rowCount uint64) { + lastRowID = binary.BigEndian.Uint64(bm.headerData[:8]) + rowCount = binary.BigEndian.Uint64(bm.headerData[8:16]) + return +} + +func (bm *bitMap) encodeHeader(lastRowID, rowCount uint64) { + binary.BigEndian.PutUint64(bm.headerData[:8], lastRowID) + binary.BigEndian.PutUint64(bm.headerData[8:16], rowCount) +} + +func (bm *bitMap) increasePK() uint64 { + lastRowID, rowCount := bm.decodeHeader() + lastRowID++ + rowCount++ + bm.encodeHeader(lastRowID, rowCount) + shift := lastRowID/256 + 1 + slot := bm.storage.ShiftHashUint64(bm.headerSlot, shift) + data := bm.storage.GetState(bm.address, slot) + byteShift := (lastRowID & 255) / 8 + data[byteShift] = setBit(data[byteShift], uint(lastRowID&7)) + bm.dirtySlot[shift] = data + bm.flushAll() + return lastRowID +} + +func (bm *bitMap) flushHeader() { + bm.storage.SetState(bm.address, bm.headerSlot, bm.headerData) +} + +func (bm *bitMap) flushAll() { + for k, v := range bm.dirtySlot { + slot := bm.storage.ShiftHashUint64(bm.headerSlot, k) + bm.storage.SetState(bm.address, slot, v) + } + bm.flushHeader() + bm.dirtySlot = make(map[uint64]common.Hash) +} + +func (bm *bitMap) setPK(IDs []uint64) { + lastRowID, rowCount := bm.decodeHeader() for _, id := range IDs { if lastRowID < id { lastRowID = id @@ -578,45 +634,30 @@ func (s *Storage) SetPK(address common.Address, tableRef schema.TableRef, IDs [] slotNum := id/256 + 1 byteLoc := (id & 255) / 8 bitLoc := uint(id & 7) - slotHash := s.ShiftHashUint64(hash, slotNum) - data, exist := slotHashToData[slotHash] + data, exist := bm.dirtySlot[slotNum] if !exist { - data = s.GetState(address, slotHash) + slotHash := bm.storage.ShiftHashUint64(bm.headerSlot, slotNum) + data = bm.storage.GetState(bm.address, slotHash) } if !hasBit(data[byteLoc], bitLoc) { rowCount++ data[byteLoc] = setBit(data[byteLoc], bitLoc) } - slotHashToData[slotHash] = data + bm.dirtySlot[slotNum] = data } - s.UpdateHash(slotHashToData, address) - header = s.EncodePKHeader(lastRowID, rowCount) - s.SetState(address, hash, header) + bm.encodeHeader(lastRowID, rowCount) + bm.flushAll() } -func getCountAndOffset(d common.Hash) (offset []uint64) { - for j, b := range d { - for i := 0; i < 8; i++ { - if hasBit(b, uint(i)) { - offset = append(offset, uint64(j*8+i)) - } - } - } - return -} - -// RepeatPK returns primary IDs by table reference. -func (s *Storage) RepeatPK(address common.Address, tableRef schema.TableRef) []uint64 { - hash := s.GetPrimaryPathHash(tableRef) - header := s.GetState(address, hash) - lastRowID, rowCount := s.DecodePKHeader(header) +func (bm *bitMap) loadPK() []uint64 { + lastRowID, rowCount := bm.decodeHeader() maxSlotNum := lastRowID/256 + 1 result := make([]uint64, rowCount) ptr := 0 for slotNum := uint64(0); slotNum < maxSlotNum; slotNum++ { - slotHash := s.ShiftHashUint64(hash, slotNum+1) - slotData := s.GetState(address, slotHash) - offsets := getCountAndOffset(slotData) + slotHash := bm.storage.ShiftHashUint64(bm.headerSlot, slotNum+1) + slotData := bm.storage.GetState(bm.address, slotHash) + offsets := getOffset(slotData) for i, o := range offsets { result[i+ptr] = o + slotNum*256 } @@ -624,3 +665,9 @@ func (s *Storage) RepeatPK(address common.Address, tableRef schema.TableRef) []u } return result } + +func newBitMap(headerSlot common.Hash, address common.Address, s *Storage) *bitMap { + headerData := s.GetState(address, headerSlot) + bm := bitMap{s, headerSlot, headerData, address, make(map[uint64]common.Hash)} + return &bm +} diff --git a/core/vm/sqlvm/common/storage_test.go b/core/vm/sqlvm/common/storage_test.go index d402868e3..b02705c11 100644 --- a/core/vm/sqlvm/common/storage_test.go +++ b/core/vm/sqlvm/common/storage_test.go @@ -208,23 +208,13 @@ func (s *StorageTestSuite) TestSequence() { func (s *StorageTestSuite) TestPKHeaderEncodeDecode() { lastRowID := uint64(5566) rowCount := uint64(6655) - newLastRowID, newRowCount := s.storage.DecodePKHeader(s.storage.EncodePKHeader(lastRowID, rowCount)) + bm := bitMap{} + bm.encodeHeader(lastRowID, rowCount) + newLastRowID, newRowCount := bm.decodeHeader() s.Require().Equal(lastRowID, newLastRowID) s.Require().Equal(rowCount, newRowCount) } -func (s *StorageTestSuite) TestUpdateHash() { - m := map[common.Hash]common.Hash{ - common.BytesToHash([]byte("hello world")): common.BytesToHash([]byte("hello SQLVM")), - common.BytesToHash([]byte("bye world")): common.BytesToHash([]byte("bye SQLVM")), - } - s.storage.UpdateHash(m, s.address) - for key, val := range m { - rVal := s.storage.GetState(s.address, key) - s.Require().Equal(val, rVal) - } -} - func (s *StorageTestSuite) TestRepeatPK() { type testCase struct { address common.Address @@ -249,12 +239,44 @@ func (s *StorageTestSuite) TestRepeatPK() { }, } for i, t := range testCases { - s.storage.SetPK(t.address, t.tableRef, t.expectIDs) + headerSlot := s.storage.GetPrimaryPathHash(t.tableRef) + s.storage.SetPK(t.address, headerSlot, t.expectIDs) IDs := s.storage.RepeatPK(t.address, t.tableRef) s.Require().Equalf(t.expectIDs, IDs, "testCase #%v\n", i) } } +func (s *StorageTestSuite) TestBitMapIncreasePK() { + type testCase struct { + tableRef schema.TableRef + IDs []uint64 + } + testCases := []testCase{ + { + tableRef: schema.TableRef(0), + IDs: []uint64{0, 1, 2}, + }, + { + tableRef: schema.TableRef(1), + IDs: []uint64{1234, 5566}, + }, + { + tableRef: schema.TableRef(2), + IDs: []uint64{0, 128, 256, 512, 1024}, + }, + } + for i, t := range testCases { + hash := s.storage.GetPrimaryPathHash(t.tableRef) + s.storage.SetPK(s.address, hash, t.IDs) + bm := newBitMap(hash, s.address, s.storage) + newID := bm.increasePK() + + t.IDs = append(t.IDs, newID) + IDs := s.storage.RepeatPK(s.address, t.tableRef) + s.Require().Equalf(t.IDs, IDs, "testCase #%v\n", i) + } +} + func TestStorage(t *testing.T) { suite.Run(t, new(StorageTestSuite)) } diff --git a/core/vm/sqlvm/runtime/instructions_test.go b/core/vm/sqlvm/runtime/instructions_test.go index 77ba0ff70..a00ac12dd 100644 --- a/core/vm/sqlvm/runtime/instructions_test.go +++ b/core/vm/sqlvm/runtime/instructions_test.go @@ -457,7 +457,8 @@ func (s opRepeatPKSuite) getTestCases(storage *common.Storage) []repeatPKTestCas }, } for _, t := range testCases { - storage.SetPK(t.address, t.tableRef, t.expectedIDs) + headerSlot := storage.GetPrimaryPathHash(t.tableRef) + storage.SetPK(t.address, headerSlot, t.expectedIDs) } return testCases } |