diff options
author | yenlin.lai <yenlin.lai@cobinhood.com> | 2019-04-08 17:08:47 +0800 |
---|---|---|
committer | Jhih-Ming Huang <jm.huang@cobinhood.com> | 2019-05-06 10:44:04 +0800 |
commit | b99b6272d7c0858f66867cc1ac056686bbf33fe0 (patch) | |
tree | ab2b00f124a4abdea1b5d9b0da3f1ef59ad63579 | |
parent | 9b8e66235752d5334a10023f9e00218904d746e8 (diff) | |
download | dexon-b99b6272d7c0858f66867cc1ac056686bbf33fe0.tar.gz dexon-b99b6272d7c0858f66867cc1ac056686bbf33fe0.tar.zst dexon-b99b6272d7c0858f66867cc1ac056686bbf33fe0.zip |
sqlvm: common: add some shared methods on Storage struct
Add methods for ACL control and index meta loading. These methods will
be used outside runtime, so put them on Storage.
-rw-r--r-- | core/vm/sqlvm/common/storage.go | 427 | ||||
-rw-r--r-- | core/vm/sqlvm/common/storage_test.go | 120 | ||||
-rw-r--r-- | core/vm/sqlvm/runtime/instructions.go | 2 | ||||
-rw-r--r-- | core/vm/sqlvm/runtime/instructions_test.go | 4 |
4 files changed, 538 insertions, 15 deletions
diff --git a/core/vm/sqlvm/common/storage.go b/core/vm/sqlvm/common/storage.go index 22ef85885..e977ee3ed 100644 --- a/core/vm/sqlvm/common/storage.go +++ b/core/vm/sqlvm/common/storage.go @@ -14,6 +14,17 @@ import ( "github.com/dexon-foundation/dexon/rlp" ) +// Constants for path keys. +var ( + pathCompTables = []byte("tables") + pathCompPrimary = []byte("primary") + pathCompIndices = []byte("indices") + pathCompSequence = []byte("sequence") + pathCompOwner = []byte("owner") + pathCompWriters = []byte("writers") + pathCompReverseIndices = []byte("reverse_indices") +) + // Storage holds SQLVM required data and method. type Storage struct { state.StateDB @@ -26,7 +37,8 @@ func NewStorage(state *state.StateDB) Storage { return s } -func convertIDtoBytes(id uint64) []byte { +// TODO(yenlin): Do we really need to use ast encode/decode here? +func uint64ToBytes(id uint64) []byte { bigIntID := new(big.Int).SetUint64(id) decimalID := decimal.NewFromBigInt(bigIntID, 0) dt := ast.ComposeDataType(ast.DataTypeMajorUint, 7) @@ -34,14 +46,23 @@ func convertIDtoBytes(id uint64) []byte { return byteID } -// GetPrimaryKeyHash return primary key hash. -func (s Storage) GetPrimaryKeyHash(tableName []byte, id uint64) (h common.Hash) { - key := [][]byte{ - []byte("tables"), - tableName, - []byte("primary"), - convertIDtoBytes(id), - } +func bytesToUint64(b []byte) uint64 { + dt := ast.ComposeDataType(ast.DataTypeMajorUint, 7) + d, _ := ast.DecimalDecode(dt, b) + // TODO(yenlin): Not yet a convenient way to extract uint64 from decimal... + bigInt := d.Rescale(0).Coefficient() + return bigInt.Uint64() +} + +func hashToAddress(hash common.Hash) common.Address { + return common.BytesToAddress(hash.Bytes()) +} + +func addressToHash(addr common.Address) common.Hash { + return common.BytesToHash(addr.Bytes()) +} + +func (s Storage) hashPathKey(key [][]byte) (h common.Hash) { hw := sha3.NewLegacyKeccak256() rlp.Encode(hw, key) // length of common.Hash is 256bit, @@ -50,6 +71,106 @@ func (s Storage) GetPrimaryKeyHash(tableName []byte, id uint64) (h common.Hash) return } +// GetRowPathHash return primary key hash which points to row data. +func (s Storage) GetRowPathHash(tableName []byte, rowID uint64) common.Hash { + // PathKey(["tables", "{table_name}", "primary", uint64({row_id})]) + key := [][]byte{ + pathCompTables, + tableName, + pathCompPrimary, + uint64ToBytes(rowID), + } + return s.hashPathKey(key) +} + +// GetIndexValuesPathHash return the hash address to IndexValues structure +// which contains all possible values. +func (s Storage) GetIndexValuesPathHash( + tableName, indexName []byte, +) common.Hash { + // PathKey(["tables", "{table_name}", "indices", "{index_name}"]) + key := [][]byte{ + pathCompTables, + tableName, + pathCompIndices, + indexName, + } + return s.hashPathKey(key) +} + +// GetIndexEntryPathHash return the hash address to IndexEntry structure for a +// given value. +func (s Storage) GetIndexEntryPathHash( + tableName, indexName []byte, + values ...[]byte, +) common.Hash { + // PathKey(["tables", "{table_name}", "indices", "{index_name}", field_1, field_2, field_3, ...]) + key := make([][]byte, 0, 4+len(values)) + key = append(key, pathCompTables, tableName, pathCompIndices, indexName) + key = append(key, values...) + return s.hashPathKey(key) +} + +// GetReverseIndexPathHash return the hash address to IndexRev structure for a +// row in a table. +func (s Storage) GetReverseIndexPathHash( + tableName []byte, + rowID uint64, +) common.Hash { + // PathKey(["tables", "{table_name}", "reverse_indices", "{RowID}"]) + key := [][]byte{ + pathCompTables, + tableName, + pathCompReverseIndices, + uint64ToBytes(rowID), + } + return s.hashPathKey(key) +} + +// getSequencePathHash return the hash address of a sequence. +func (s Storage) getSequencePathHash( + tableName []byte, seqIdx uint8, +) common.Hash { + // PathKey(["tables", "{table_name}", "sequence", uint8(sequence_idx)]) + key := [][]byte{ + pathCompTables, + tableName, + pathCompSequence, + {seqIdx}, // TODO(yenlin): use some other encode method on uint8? + } + return s.hashPathKey(key) +} + +func (s Storage) getOwnerPathHash() common.Hash { + // PathKey(["owner"]) + key := [][]byte{pathCompOwner} + return s.hashPathKey(key) +} + +func (s Storage) getTableWritersPathHash(tableName []byte) common.Hash { + // PathKey(["tables", "{table_name}", "writers"]) + key := [][]byte{ + pathCompTables, + tableName, + pathCompWriters, + } + return s.hashPathKey(key) +} + +func (s Storage) getTableWriterRevIdxPathHash( + tableName []byte, + account common.Address, +) common.Hash { + // PathKey(["tables", "{table_name}", "writers", "{addr}"]) + key := [][]byte{ + pathCompTables, + tableName, + pathCompWriters, + account.Bytes(), + } + return s.hashPathKey(key) +} + // ShiftHashUint64 shift hash in uint64. func (s Storage) ShiftHashUint64(hash common.Hash, shift uint64) common.Hash { bigIntOffset := new(big.Int) @@ -64,6 +185,25 @@ func (s Storage) ShiftHashBigInt(hash common.Hash, shift *big.Int) common.Hash { return common.BytesToHash(head.Bytes()) } +// ShiftHashListEntry shift hash from the head of a list to the hash of +// idx-th entry. +func (s Storage) ShiftHashListEntry( + base common.Hash, + headerSize uint64, + entrySize uint64, + idx uint64, +) common.Hash { + // TODO(yenlin): tuning when headerSize+entrySize*idx do not overflow. + shift := new(big.Int) + operand := new(big.Int) + shift.SetUint64(entrySize) + operand.SetUint64(idx) + shift.Mul(shift, operand) + operand.SetUint64(headerSize) + shift.Add(shift, operand) + return s.ShiftHashBigInt(base, shift) +} + func getDByteSize(data common.Hash) uint64 { bytes := data.Bytes() lastByte := bytes[len(bytes)-1] @@ -91,3 +231,272 @@ func (s Storage) DecodeDByteBySlot(address common.Address, slot common.Hash) []b } return rVal[:length] } + +// SQLVM metadata structure operations. + +// IndexValues contain addresses to all possible values of an index. +type IndexValues struct { + // Header. + Length uint64 + // 3 unused uint64 fields here. + // Contents. + ValueHashes []common.Hash +} + +// IndexEntry contain row ids of a given value in an index. +type IndexEntry struct { + // Header. + Length uint64 + IndexToValuesOffset uint64 + ForeignKeyRefCount uint64 + // 1 unused uint64 field here. + // Contents. + RowIDs []uint64 +} + +// LoadIndexValues load IndexValues struct of a given index. +func (s Storage) LoadIndexValues( + contract common.Address, + tableName, indexName []byte, + onlyHeader bool, +) *IndexValues { + ret := &IndexValues{} + slot := s.GetIndexValuesPathHash(tableName, indexName) + data := s.GetState(contract, slot) + ret.Length = bytesToUint64(data[:8]) + if onlyHeader { + return ret + } + // Load all ValueHashes. + ret.ValueHashes = make([]common.Hash, ret.Length) + for i := uint64(0); i < ret.Length; i++ { + slot = s.ShiftHashUint64(slot, 1) + ret.ValueHashes[i] = s.GetState(contract, slot) + } + return ret +} + +// LoadIndexEntry load IndexEntry struct of a given value key on an index. +func (s Storage) LoadIndexEntry( + contract common.Address, + tableName, indexName []byte, + onlyHeader bool, + values ...[]byte, +) *IndexEntry { + ret := &IndexEntry{} + slot := s.GetIndexEntryPathHash(tableName, indexName, values...) + data := s.GetState(contract, slot) + ret.Length = bytesToUint64(data[:8]) + ret.IndexToValuesOffset = bytesToUint64(data[8:16]) + ret.ForeignKeyRefCount = bytesToUint64(data[16:24]) + + if onlyHeader { + return ret + } + // Load all RowIDs. + ret.RowIDs = make([]uint64, 0, ret.Length) + remain := ret.Length + for remain > 0 { + bound := remain + if bound > 4 { + bound = 4 + } + slot = s.ShiftHashUint64(slot, 1) + data := s.GetState(contract, slot).Bytes() + for i := uint64(0); i < bound; i++ { + ret.RowIDs = append(ret.RowIDs, bytesToUint64(data[:8])) + data = data[8:] + } + remain -= bound + } + return ret +} + +// LoadOwner load the owner of a SQLVM contract from storage. +func (s *Storage) LoadOwner(contract common.Address) common.Address { + return hashToAddress(s.GetState(contract, s.getOwnerPathHash())) +} + +// StoreOwner save the owner of a SQLVM contract to storage. +func (s *Storage) StoreOwner(contract, newOwner common.Address) { + s.SetState(contract, s.getOwnerPathHash(), addressToHash(newOwner)) +} + +type tableWriters struct { + Length uint64 + // 3 unused uint64 in slot 1. + Writers []common.Address // Each address consumes one slot, right aligned. +} + +type tableWriterRevIdx struct { + IndexToValuesOffset uint64 + // 3 unused uint64 in the slot. +} + +func (c *tableWriterRevIdx) Valid() bool { + return c.IndexToValuesOffset != 0 +} + +func (s Storage) loadTableWriterRevIdx( + contract common.Address, + path common.Hash, +) *tableWriterRevIdx { + ret := &tableWriterRevIdx{} + data := s.GetState(contract, path) + ret.IndexToValuesOffset = bytesToUint64(data[:8]) + return ret +} + +func (s Storage) storeTableWriterRevIdx( + contract common.Address, + path common.Hash, + rev *tableWriterRevIdx, +) { + var data common.Hash // One slot. + copy(data[:8], uint64ToBytes(rev.IndexToValuesOffset)) + s.SetState(contract, path, data) +} + +func (s Storage) loadTableWriters( + contract common.Address, + pathHash common.Hash, + onlyHeader bool, +) *tableWriters { + ret := &tableWriters{} + header := s.GetState(contract, pathHash) + ret.Length = bytesToUint64(header[:8]) + if onlyHeader { + return ret + } + ret.Writers = make([]common.Address, ret.Length) + for i := uint64(0); i < ret.Length; i++ { + ret.Writers[i] = s.loadSingleTableWriter(contract, pathHash, i) + } + return ret +} + +func (s Storage) storeTableWritersHeader( + contract common.Address, + pathHash common.Hash, + w *tableWriters, +) { + var header common.Hash + copy(header[:8], uint64ToBytes(w.Length)) + s.SetState(contract, pathHash, header) +} + +func (s Storage) shiftTableWriterList( + base common.Hash, + idx uint64, +) common.Hash { + return s.ShiftHashListEntry(base, 1, 1, idx) +} + +func (s Storage) loadSingleTableWriter( + contract common.Address, + writersPathHash common.Hash, + idx uint64, +) common.Address { + slot := s.shiftTableWriterList(writersPathHash, idx) + acc := s.GetState(contract, slot) + return hashToAddress(acc) +} + +func (s Storage) storeSingleTableWriter( + contract common.Address, + writersPathHash common.Hash, + idx uint64, + acc common.Address, +) { + slot := s.shiftTableWriterList(writersPathHash, idx) + s.SetState(contract, slot, addressToHash(acc)) +} + +// IsTableWriter check if an account is writer to the table. +func (s Storage) IsTableWriter( + contract common.Address, + tableName []byte, + account common.Address, +) bool { + path := s.getTableWriterRevIdxPathHash(tableName, account) + rev := s.loadTableWriterRevIdx(contract, path) + return rev.Valid() +} + +// LoadTableWriters load writers of a table. +func (s Storage) LoadTableWriters( + contract common.Address, + tableName []byte, +) (ret []common.Address) { + path := s.getTableWritersPathHash(tableName) + writers := s.loadTableWriters(contract, path, false) + return writers.Writers +} + +// InsertTableWriter insert an account into writer list of the table. +func (s Storage) InsertTableWriter( + contract common.Address, + tableName []byte, + account common.Address, +) { + revPath := s.getTableWriterRevIdxPathHash(tableName, account) + rev := s.loadTableWriterRevIdx(contract, revPath) + if rev.Valid() { + return + } + path := s.getTableWritersPathHash(tableName) + writers := s.loadTableWriters(contract, path, true) + // Store modification. + s.storeSingleTableWriter(contract, path, writers.Length, account) + writers.Length++ + s.storeTableWritersHeader(contract, path, writers) + // Notice: IndexToValuesOffset starts from 1. + s.storeTableWriterRevIdx(contract, revPath, &tableWriterRevIdx{ + IndexToValuesOffset: writers.Length, + }) +} + +// DeleteTableWriter delete an account from writer list of the table. +func (s Storage) DeleteTableWriter( + contract common.Address, + tableName []byte, + account common.Address, +) { + revPath := s.getTableWriterRevIdxPathHash(tableName, account) + rev := s.loadTableWriterRevIdx(contract, revPath) + if !rev.Valid() { + return + } + path := s.getTableWritersPathHash(tableName) + writers := s.loadTableWriters(contract, path, true) + + // Store modification. + if rev.IndexToValuesOffset != writers.Length { + // Move last to deleted slot. + lastAcc := s.loadSingleTableWriter(contract, path, writers.Length-1) + s.storeSingleTableWriter(contract, path, rev.IndexToValuesOffset-1, + lastAcc) + s.storeTableWriterRevIdx(contract, s.getTableWriterRevIdxPathHash( + tableName, lastAcc), rev) + } + // Delete last. + writers.Length-- + s.storeTableWritersHeader(contract, path, writers) + s.storeSingleTableWriter(contract, path, writers.Length, common.Address{}) + s.storeTableWriterRevIdx(contract, revPath, &tableWriterRevIdx{}) +} + +// IncSequence increment value of sequence by inc and return the old value. +func (s Storage) IncSequence( + contract common.Address, + tableName []byte, + seqIdx uint8, + inc uint64, +) uint64 { + seqPath := s.getSequencePathHash(tableName, seqIdx) + slot := s.GetState(contract, seqPath) + val := bytesToUint64(slot.Bytes()) + // TODO(yenlin): Check overflow? + s.SetState(contract, seqPath, common.BytesToHash(uint64ToBytes(val+inc))) + return val +} diff --git a/core/vm/sqlvm/common/storage_test.go b/core/vm/sqlvm/common/storage_test.go index 3a7633496..625d89158 100644 --- a/core/vm/sqlvm/common/storage_test.go +++ b/core/vm/sqlvm/common/storage_test.go @@ -3,6 +3,7 @@ package common import ( "bytes" "fmt" + "math" "testing" "github.com/stretchr/testify/suite" @@ -17,20 +18,27 @@ import ( type StorageTestSuite struct{ suite.Suite } -func (s *StorageTestSuite) TestGetPrimaryKeyHash() { +func (s *StorageTestSuite) TestUint64ToBytes() { + testcases := []uint64{1, 65535, math.MaxUint64} + for _, i := range testcases { + s.Require().Equal(i, bytesToUint64(uint64ToBytes(i))) + } +} + +func (s *StorageTestSuite) TestGetRowAddress() { id := uint64(555666) table := []byte("TABLE_A") key := [][]byte{ []byte("tables"), table, []byte("primary"), - convertIDtoBytes(id), + uint64ToBytes(id), } hw := sha3.NewLegacyKeccak256() rlp.Encode(hw, key) bytes := hw.Sum(nil) storage := Storage{} - result := storage.GetPrimaryKeyHash(table, id) + result := storage.GetRowPathHash(table, id) s.Require().Equal(bytes, result[:]) } @@ -102,6 +110,112 @@ func SetDataToStateDB(head common.Hash, storage Storage, addr common.Address, storage.Commit(false) } +func (s *StorageTestSuite) TestOwner() { + db := ethdb.NewMemDatabase() + state, _ := state.New(common.Hash{}, state.NewDatabase(db)) + storage := NewStorage(state) + + contractA := common.BytesToAddress([]byte("I'm sad.")) + ownerA := common.BytesToAddress([]byte{5, 5, 6, 6}) + contractB := common.BytesToAddress([]byte{9, 5, 2, 7}) + ownerB := common.BytesToAddress([]byte("Tong Pak-Fu")) + + storage.StoreOwner(contractA, ownerA) + storage.StoreOwner(contractB, ownerB) + storage.Commit(false) + s.Require().Equal(ownerA, storage.LoadOwner(contractA)) + s.Require().Equal(ownerB, storage.LoadOwner(contractB)) + + storage.StoreOwner(contractA, ownerB) + storage.Commit(false) + s.Require().Equal(ownerB, storage.LoadOwner(contractA)) +} + +func (s *StorageTestSuite) TestTableWriter() { + db := ethdb.NewMemDatabase() + state, _ := state.New(common.Hash{}, state.NewDatabase(db)) + storage := NewStorage(state) + + table1 := []byte("table1") + table2 := []byte("table2") + contractA := common.BytesToAddress([]byte("A")) + contractB := common.BytesToAddress([]byte("B")) + addrs := []common.Address{ + common.BytesToAddress([]byte("addr1")), + common.BytesToAddress([]byte("addr2")), + common.BytesToAddress([]byte("addr3")), + } + + // Genesis. + s.Require().Len(storage.LoadTableWriters(contractA, table1), 0) + s.Require().Len(storage.LoadTableWriters(contractB, table1), 0) + + // Check writer list. + storage.InsertTableWriter(contractA, table1, addrs[0]) + storage.InsertTableWriter(contractA, table1, addrs[1]) + storage.InsertTableWriter(contractA, table1, addrs[2]) + storage.InsertTableWriter(contractB, table2, addrs[0]) + storage.Commit(false) + s.Require().Equal(addrs, storage.LoadTableWriters(contractA, table1)) + s.Require().Len(storage.LoadTableWriters(contractA, table2), 0) + s.Require().Len(storage.LoadTableWriters(contractB, table1), 0) + s.Require().Equal([]common.Address{addrs[0]}, + storage.LoadTableWriters(contractB, table2)) + + // Insert duplicate. + storage.InsertTableWriter(contractA, table1, addrs[0]) + storage.InsertTableWriter(contractA, table1, addrs[1]) + storage.InsertTableWriter(contractA, table1, addrs[2]) + storage.Commit(false) + s.Require().Equal(addrs, storage.LoadTableWriters(contractA, table1)) + + // Delete some writer. + storage.DeleteTableWriter(contractA, table1, addrs[0]) + storage.DeleteTableWriter(contractA, table2, addrs[0]) + storage.DeleteTableWriter(contractB, table2, addrs[0]) + storage.Commit(false) + s.Require().Equal([]common.Address{addrs[2], addrs[1]}, + storage.LoadTableWriters(contractA, table1)) + s.Require().Len(storage.LoadTableWriters(contractA, table2), 0) + s.Require().Len(storage.LoadTableWriters(contractB, table1), 0) + s.Require().Len(storage.LoadTableWriters(contractB, table2), 0) + + // Delete again. + storage.DeleteTableWriter(contractA, table1, addrs[2]) + storage.Commit(false) + s.Require().Equal([]common.Address{addrs[1]}, + storage.LoadTableWriters(contractA, table1)) + + // Check writer. + s.Require().False(storage.IsTableWriter(contractA, table1, addrs[0])) + s.Require().True(storage.IsTableWriter(contractA, table1, addrs[1])) + s.Require().False(storage.IsTableWriter(contractA, table1, addrs[2])) + s.Require().False(storage.IsTableWriter(contractA, table2, addrs[0])) + s.Require().False(storage.IsTableWriter(contractB, table2, addrs[0])) +} + +func (s *StorageTestSuite) TestSequence() { + db := ethdb.NewMemDatabase() + state, _ := state.New(common.Hash{}, state.NewDatabase(db)) + storage := NewStorage(state) + + table1 := []byte("table1") + table2 := []byte("table2") + contract := common.BytesToAddress([]byte("A")) + + s.Require().Equal(uint64(0), storage.IncSequence(contract, table1, 0, 2)) + s.Require().Equal(uint64(2), storage.IncSequence(contract, table1, 0, 1)) + s.Require().Equal(uint64(3), storage.IncSequence(contract, table1, 0, 1)) + // Repeat on another sequence. + s.Require().Equal(uint64(0), storage.IncSequence(contract, table1, 1, 1)) + s.Require().Equal(uint64(1), storage.IncSequence(contract, table1, 1, 2)) + s.Require().Equal(uint64(3), storage.IncSequence(contract, table1, 1, 3)) + // Repeat on another table. + s.Require().Equal(uint64(0), storage.IncSequence(contract, table2, 0, 3)) + s.Require().Equal(uint64(3), storage.IncSequence(contract, table2, 0, 4)) + s.Require().Equal(uint64(7), storage.IncSequence(contract, table2, 0, 5)) +} + func TestStorage(t *testing.T) { suite.Run(t, new(StorageTestSuite)) } diff --git a/core/vm/sqlvm/runtime/instructions.go b/core/vm/sqlvm/runtime/instructions.go index b6b37052e..92d5753f5 100644 --- a/core/vm/sqlvm/runtime/instructions.go +++ b/core/vm/sqlvm/runtime/instructions.go @@ -127,7 +127,7 @@ func opLoad(ctx *common.Context, input []*Operand, registers []*Operand, output } for i, id := range ids { slotDataCache := make(map[dexCommon.Hash]dexCommon.Hash) - head := ctx.Storage.GetPrimaryKeyHash(table.Name, id) + head := ctx.Storage.GetRowPathHash(table.Name, id) for j := range fields { col := table.Columns[int(fields[j])] byteOffset := col.ByteOffset diff --git a/core/vm/sqlvm/runtime/instructions_test.go b/core/vm/sqlvm/runtime/instructions_test.go index 576a7783a..fe9f15994 100644 --- a/core/vm/sqlvm/runtime/instructions_test.go +++ b/core/vm/sqlvm/runtime/instructions_test.go @@ -213,7 +213,7 @@ type opLoadTestCase struct { func (s *opLoadSuite) SetupTest() { s.ctx = &common.Context{} s.ctx.Storage = s.newStorage() - s.headHash = s.ctx.Storage.GetPrimaryKeyHash([]byte("Table_B"), uint64(123456)) + s.headHash = s.ctx.Storage.GetRowPathHash([]byte("Table_B"), uint64(123456)) s.address = dexCommon.HexToAddress("0x6655") s.ctx.Storage.CreateAccount(s.address) s.ctx.Contract = vm.NewContract(vm.AccountRef(s.address), @@ -224,7 +224,7 @@ func (s *opLoadSuite) SetupTest() { } func (s *opLoadSuite) setColData(tableName string, id uint64) { - h := s.ctx.Storage.GetPrimaryKeyHash([]byte(tableName), id) + h := s.ctx.Storage.GetRowPathHash([]byte(tableName), id) setSlotDataInStateDB(h, s.address, s.ctx.Storage) } |