diff options
-rw-r--r-- | core/vm/sqlvm/common/storage.go | 103 | ||||
-rw-r--r-- | core/vm/sqlvm/common/storage_test.go | 173 | ||||
-rw-r--r-- | core/vm/sqlvm/runtime/instructions.go | 36 | ||||
-rw-r--r-- | core/vm/sqlvm/runtime/instructions_test.go | 103 |
4 files changed, 330 insertions, 85 deletions
diff --git a/core/vm/sqlvm/common/storage.go b/core/vm/sqlvm/common/storage.go index 0244877a2..e046b0411 100644 --- a/core/vm/sqlvm/common/storage.go +++ b/core/vm/sqlvm/common/storage.go @@ -1,6 +1,7 @@ package common import ( + "encoding/binary" "math/big" "github.com/dexon-foundation/decimal" @@ -146,6 +147,17 @@ func (s *Storage) GetReverseIndexPathHash( return s.hashPathKey(key) } +// GetPrimaryPathHash returns primary rlp encoded hash. +func (s *Storage) GetPrimaryPathHash(tableRef schema.TableRef) (h common.Hash) { + // PathKey(["tables", "{table_name}", "primary"]) + key := [][]byte{ + pathCompTables, + tableRefToBytes(tableRef), + pathCompPrimary, + } + return s.hashPathKey(key) +} + // getSequencePathHash return the hash address of a sequence. func (s *Storage) getSequencePathHash( tableRef schema.TableRef, seqIdx uint8, @@ -521,3 +533,94 @@ func (s *Storage) IncSequence( s.SetState(contract, seqPath, common.BytesToHash(uint64ToBytes(val+inc))) return val } + +// DecodePKHeader decodes primary key hash header to lastRowID and rowCount. +func (s *Storage) DecodePKHeader(header common.Hash) (lastRowID, rowCount uint64) { + lastRowID = binary.BigEndian.Uint64(header[:8]) + rowCount = binary.BigEndian.Uint64(header[8:16]) + return +} + +// EncodePKHeader encodes lastRowID and rowCount to primary key hash header. +func (s *Storage) EncodePKHeader(lastRowID, rowCount uint64) (header common.Hash) { + binary.BigEndian.PutUint64(header[:8], lastRowID) + binary.BigEndian.PutUint64(header[8:16], rowCount) + return +} + +// UpdateHash updates hash to stateDB. +func (s *Storage) UpdateHash(m map[common.Hash]common.Hash, address common.Address) { + for key, val := range m { + s.SetState(address, key, val) + } +} + +func setBit(n byte, pos uint) byte { + n |= (1 << pos) + return n +} + +func hasBit(n byte, pos uint) bool { + val := n & (1 << pos) + return (val > 0) +} + +// SetPK sets IDs to primary bit map. +func (s *Storage) SetPK(address common.Address, tableRef schema.TableRef, IDs []uint64) { + hash := s.GetPrimaryPathHash(tableRef) + header := s.GetState(address, hash) + lastRowID, rowCount := s.DecodePKHeader(header) + slotHashToData := make(map[common.Hash]common.Hash) + for _, id := range IDs { + if lastRowID < id { + lastRowID = id + } + slotNum := id/256 + 1 + byteLoc := (id & 255) / 8 + bitLoc := uint(id & 7) + slotHash := s.ShiftHashUint64(hash, slotNum) + data, exist := slotHashToData[slotHash] + if !exist { + data = s.GetState(address, slotHash) + } + if !hasBit(data[byteLoc], bitLoc) { + rowCount++ + data[byteLoc] = setBit(data[byteLoc], bitLoc) + } + slotHashToData[slotHash] = data + } + s.UpdateHash(slotHashToData, address) + header = s.EncodePKHeader(lastRowID, rowCount) + s.SetState(address, hash, header) +} + +func getCountAndOffset(d common.Hash) (offset []uint64) { + for j, b := range d { + for i := 0; i < 8; i++ { + if hasBit(b, uint(i)) { + offset = append(offset, uint64(j*8+i)) + } + } + } + return +} + +// RepeatPK returns primary IDs by table reference. +func (s *Storage) RepeatPK(address common.Address, tableRef schema.TableRef) []uint64 { + hash := s.GetPrimaryPathHash(tableRef) + header := s.GetState(address, hash) + lastRowID, rowCount := s.DecodePKHeader(header) + maxSlotNum := lastRowID/256 + 1 + result := make([]uint64, rowCount) + ptr := 0 + for slotNum := uint64(0); slotNum < maxSlotNum; slotNum++ { + slotHash := s.ShiftHashUint64(hash, slotNum+1) + slotData := s.GetState(address, slotHash) + offsets := getCountAndOffset(slotData) + for i, o := range offsets { + result[i+ptr] = o + slotNum*256 + } + ptr += len(offsets) + } + return result +} diff --git a/core/vm/sqlvm/common/storage_test.go b/core/vm/sqlvm/common/storage_test.go index 44aff12f7..d402868e3 100644 --- a/core/vm/sqlvm/common/storage_test.go +++ b/core/vm/sqlvm/common/storage_test.go @@ -17,7 +17,18 @@ import ( "github.com/dexon-foundation/dexon/rlp" ) -type StorageTestSuite struct{ suite.Suite } +type StorageTestSuite struct { + suite.Suite + storage *Storage + address common.Address +} + +func (s *StorageTestSuite) SetupTest() { + db := ethdb.NewMemDatabase() + state, _ := state.New(common.Hash{}, state.NewDatabase(db)) + s.storage = NewStorage(state) + s.address = common.BytesToAddress([]byte("5566")) +} func (s *StorageTestSuite) TestUint64ToBytes() { testcases := []uint64{1, 65535, math.MaxUint64} @@ -38,8 +49,7 @@ func (s *StorageTestSuite) TestGetRowAddress() { hw := sha3.NewLegacyKeccak256() rlp.Encode(hw, key) bytes := hw.Sum(nil) - storage := &Storage{} - result := storage.GetRowPathHash(table, id) + result := s.storage.GetRowPathHash(table, id) s.Require().Equal(bytes, result[:]) } @@ -50,9 +60,6 @@ type decodeTestCase struct { } func (s *StorageTestSuite) TestDecodeDByte() { - db := ethdb.NewMemDatabase() - state, _ := state.New(common.Hash{}, state.NewDatabase(db)) - storage := NewStorage(state) address := common.BytesToAddress([]byte("123")) head := common.HexToHash("0x5566") testcase := []decodeTestCase{ @@ -77,10 +84,10 @@ func (s *StorageTestSuite) TestDecodeDByte() { result: []byte(""), }, } - SetDataToStorage(head, storage, address, testcase) + SetDataToStorage(head, s.storage, address, testcase) for i, t := range testcase { - slot := storage.ShiftHashUint64(head, uint64(i)) - result := storage.DecodeDByteBySlot(address, slot) + slot := s.storage.ShiftHashUint64(head, uint64(i)) + result := s.storage.DecodeDByteBySlot(address, slot) s.Require().Truef(bytes.Equal(result, t.result), fmt.Sprintf("name %v", t.name)) } } @@ -111,29 +118,21 @@ func SetDataToStorage(head common.Hash, storage *Storage, addr common.Address, } func (s *StorageTestSuite) TestOwner() { - db := ethdb.NewMemDatabase() - state, _ := state.New(common.Hash{}, state.NewDatabase(db)) - storage := NewStorage(state) - contractA := common.BytesToAddress([]byte("I'm sad.")) ownerA := common.BytesToAddress([]byte{5, 5, 6, 6}) contractB := common.BytesToAddress([]byte{9, 5, 2, 7}) ownerB := common.BytesToAddress([]byte("Tong Pak-Fu")) - storage.StoreOwner(contractA, ownerA) - storage.StoreOwner(contractB, ownerB) - s.Require().Equal(ownerA, storage.LoadOwner(contractA)) - s.Require().Equal(ownerB, storage.LoadOwner(contractB)) + s.storage.StoreOwner(contractA, ownerA) + s.storage.StoreOwner(contractB, ownerB) + s.Require().Equal(ownerA, s.storage.LoadOwner(contractA)) + s.Require().Equal(ownerB, s.storage.LoadOwner(contractB)) - storage.StoreOwner(contractA, ownerB) - s.Require().Equal(ownerB, storage.LoadOwner(contractA)) + s.storage.StoreOwner(contractA, ownerB) + s.Require().Equal(ownerB, s.storage.LoadOwner(contractA)) } func (s *StorageTestSuite) TestTableWriter() { - db := ethdb.NewMemDatabase() - state, _ := state.New(common.Hash{}, state.NewDatabase(db)) - storage := NewStorage(state) - table1 := schema.TableRef(0) table2 := schema.TableRef(1) contractA := common.BytesToAddress([]byte("A")) @@ -145,69 +144,115 @@ func (s *StorageTestSuite) TestTableWriter() { } // Genesis. - s.Require().Len(storage.LoadTableWriters(contractA, table1), 0) - s.Require().Len(storage.LoadTableWriters(contractB, table1), 0) + s.Require().Len(s.storage.LoadTableWriters(contractA, table1), 0) + s.Require().Len(s.storage.LoadTableWriters(contractB, table1), 0) // Check writer list. - storage.InsertTableWriter(contractA, table1, addrs[0]) - storage.InsertTableWriter(contractA, table1, addrs[1]) - storage.InsertTableWriter(contractA, table1, addrs[2]) - storage.InsertTableWriter(contractB, table2, addrs[0]) - s.Require().Equal(addrs, storage.LoadTableWriters(contractA, table1)) - s.Require().Len(storage.LoadTableWriters(contractA, table2), 0) - s.Require().Len(storage.LoadTableWriters(contractB, table1), 0) + s.storage.InsertTableWriter(contractA, table1, addrs[0]) + s.storage.InsertTableWriter(contractA, table1, addrs[1]) + s.storage.InsertTableWriter(contractA, table1, addrs[2]) + s.storage.InsertTableWriter(contractB, table2, addrs[0]) + s.Require().Equal(addrs, s.storage.LoadTableWriters(contractA, table1)) + s.Require().Len(s.storage.LoadTableWriters(contractA, table2), 0) + s.Require().Len(s.storage.LoadTableWriters(contractB, table1), 0) s.Require().Equal([]common.Address{addrs[0]}, - storage.LoadTableWriters(contractB, table2)) + s.storage.LoadTableWriters(contractB, table2)) // Insert duplicate. - storage.InsertTableWriter(contractA, table1, addrs[0]) - storage.InsertTableWriter(contractA, table1, addrs[1]) - storage.InsertTableWriter(contractA, table1, addrs[2]) - s.Require().Equal(addrs, storage.LoadTableWriters(contractA, table1)) + s.storage.InsertTableWriter(contractA, table1, addrs[0]) + s.storage.InsertTableWriter(contractA, table1, addrs[1]) + s.storage.InsertTableWriter(contractA, table1, addrs[2]) + s.Require().Equal(addrs, s.storage.LoadTableWriters(contractA, table1)) // Delete some writer. - storage.DeleteTableWriter(contractA, table1, addrs[0]) - storage.DeleteTableWriter(contractA, table2, addrs[0]) - storage.DeleteTableWriter(contractB, table2, addrs[0]) + s.storage.DeleteTableWriter(contractA, table1, addrs[0]) + s.storage.DeleteTableWriter(contractA, table2, addrs[0]) + s.storage.DeleteTableWriter(contractB, table2, addrs[0]) s.Require().Equal([]common.Address{addrs[2], addrs[1]}, - storage.LoadTableWriters(contractA, table1)) - s.Require().Len(storage.LoadTableWriters(contractA, table2), 0) - s.Require().Len(storage.LoadTableWriters(contractB, table1), 0) - s.Require().Len(storage.LoadTableWriters(contractB, table2), 0) + s.storage.LoadTableWriters(contractA, table1)) + s.Require().Len(s.storage.LoadTableWriters(contractA, table2), 0) + s.Require().Len(s.storage.LoadTableWriters(contractB, table1), 0) + s.Require().Len(s.storage.LoadTableWriters(contractB, table2), 0) // Delete again. - storage.DeleteTableWriter(contractA, table1, addrs[2]) + s.storage.DeleteTableWriter(contractA, table1, addrs[2]) s.Require().Equal([]common.Address{addrs[1]}, - storage.LoadTableWriters(contractA, table1)) + s.storage.LoadTableWriters(contractA, table1)) // Check writer. - s.Require().False(storage.IsTableWriter(contractA, table1, addrs[0])) - s.Require().True(storage.IsTableWriter(contractA, table1, addrs[1])) - s.Require().False(storage.IsTableWriter(contractA, table1, addrs[2])) - s.Require().False(storage.IsTableWriter(contractA, table2, addrs[0])) - s.Require().False(storage.IsTableWriter(contractB, table2, addrs[0])) + s.Require().False(s.storage.IsTableWriter(contractA, table1, addrs[0])) + s.Require().True(s.storage.IsTableWriter(contractA, table1, addrs[1])) + s.Require().False(s.storage.IsTableWriter(contractA, table1, addrs[2])) + s.Require().False(s.storage.IsTableWriter(contractA, table2, addrs[0])) + s.Require().False(s.storage.IsTableWriter(contractB, table2, addrs[0])) } func (s *StorageTestSuite) TestSequence() { - db := ethdb.NewMemDatabase() - state, _ := state.New(common.Hash{}, state.NewDatabase(db)) - storage := NewStorage(state) - table1 := schema.TableRef(0) table2 := schema.TableRef(1) contract := common.BytesToAddress([]byte("A")) - s.Require().Equal(uint64(0), storage.IncSequence(contract, table1, 0, 2)) - s.Require().Equal(uint64(2), storage.IncSequence(contract, table1, 0, 1)) - s.Require().Equal(uint64(3), storage.IncSequence(contract, table1, 0, 1)) + s.Require().Equal(uint64(0), s.storage.IncSequence(contract, table1, 0, 2)) + s.Require().Equal(uint64(2), s.storage.IncSequence(contract, table1, 0, 1)) + s.Require().Equal(uint64(3), s.storage.IncSequence(contract, table1, 0, 1)) // Repeat on another sequence. - s.Require().Equal(uint64(0), storage.IncSequence(contract, table1, 1, 1)) - s.Require().Equal(uint64(1), storage.IncSequence(contract, table1, 1, 2)) - s.Require().Equal(uint64(3), storage.IncSequence(contract, table1, 1, 3)) + s.Require().Equal(uint64(0), s.storage.IncSequence(contract, table1, 1, 1)) + s.Require().Equal(uint64(1), s.storage.IncSequence(contract, table1, 1, 2)) + s.Require().Equal(uint64(3), s.storage.IncSequence(contract, table1, 1, 3)) // Repeat on another table. - s.Require().Equal(uint64(0), storage.IncSequence(contract, table2, 0, 3)) - s.Require().Equal(uint64(3), storage.IncSequence(contract, table2, 0, 4)) - s.Require().Equal(uint64(7), storage.IncSequence(contract, table2, 0, 5)) + s.Require().Equal(uint64(0), s.storage.IncSequence(contract, table2, 0, 3)) + s.Require().Equal(uint64(3), s.storage.IncSequence(contract, table2, 0, 4)) + s.Require().Equal(uint64(7), s.storage.IncSequence(contract, table2, 0, 5)) +} + +func (s *StorageTestSuite) TestPKHeaderEncodeDecode() { + lastRowID := uint64(5566) + rowCount := uint64(6655) + newLastRowID, newRowCount := s.storage.DecodePKHeader(s.storage.EncodePKHeader(lastRowID, rowCount)) + s.Require().Equal(lastRowID, newLastRowID) + s.Require().Equal(rowCount, newRowCount) +} + +func (s *StorageTestSuite) TestUpdateHash() { + m := map[common.Hash]common.Hash{ + common.BytesToHash([]byte("hello world")): common.BytesToHash([]byte("hello SQLVM")), + common.BytesToHash([]byte("bye world")): common.BytesToHash([]byte("bye SQLVM")), + } + s.storage.UpdateHash(m, s.address) + for key, val := range m { + rVal := s.storage.GetState(s.address, key) + s.Require().Equal(val, rVal) + } +} + +func (s *StorageTestSuite) TestRepeatPK() { + type testCase struct { + address common.Address + tableRef schema.TableRef + expectIDs []uint64 + } + testCases := []testCase{ + { + address: common.BytesToAddress([]byte("0")), + tableRef: schema.TableRef(0), + expectIDs: []uint64{0, 1, 2}, + }, + { + address: common.BytesToAddress([]byte("1")), + tableRef: schema.TableRef(1), + expectIDs: []uint64{1234, 5566}, + }, + { + address: common.BytesToAddress([]byte("2")), + tableRef: schema.TableRef(2), + expectIDs: []uint64{0, 128, 256, 512, 1024}, + }, + } + for i, t := range testCases { + s.storage.SetPK(t.address, t.tableRef, t.expectIDs) + IDs := s.storage.RepeatPK(t.address, t.tableRef) + s.Require().Equalf(t.expectIDs, IDs, "testCase #%v\n", i) + } } func TestStorage(t *testing.T) { diff --git a/core/vm/sqlvm/runtime/instructions.go b/core/vm/sqlvm/runtime/instructions.go index d61bacdde..6fa72d61a 100644 --- a/core/vm/sqlvm/runtime/instructions.go +++ b/core/vm/sqlvm/runtime/instructions.go @@ -4,6 +4,7 @@ import ( "bytes" "errors" "fmt" + "math/big" "regexp" "sort" "strings" @@ -98,6 +99,14 @@ func (op *Operand) toUint64() (result []uint64, err error) { return } +func (op *Operand) toTableRef() (schema.TableRef, error) { + t, err := op.toUint8() + if err != nil { + return 0, err + } + return schema.TableRef(t[0]), nil +} + func (op *Operand) toUint8() ([]uint8, error) { result := make([]uint8, len(op.Data)) for i, tuple := range op.Data { @@ -1881,3 +1890,30 @@ func opFunc(ctx *common.Context, ops, registers []*Operand, output uint) (err er registers[output] = result return } + +func uint64ToOperands(numbers []uint64) (*Operand, error) { + result := &Operand{ + Meta: []ast.DataType{ast.ComposeDataType(ast.DataTypeMajorUint, 7)}, + Data: []Tuple{}, + } + result.Data = make([]Tuple, len(numbers)) + for i, n := range numbers { + result.Data[i] = []*Raw{ + { + Value: decimal.NewFromBigInt(new(big.Int).SetUint64(n), 0), + Bytes: nil, + }, + } + } + return result, nil +} + +func opRepeatPK(ctx *common.Context, input []*Operand, registers []*Operand, output int) (err error) { + tableRef, err := input[0].toTableRef() + if err != nil { + return err + } + IDs := ctx.Storage.RepeatPK(ctx.Contract.Address(), tableRef) + registers[output], err = uint64ToOperands(IDs) + return +} diff --git a/core/vm/sqlvm/runtime/instructions_test.go b/core/vm/sqlvm/runtime/instructions_test.go index a399ad557..77ba0ff70 100644 --- a/core/vm/sqlvm/runtime/instructions_test.go +++ b/core/vm/sqlvm/runtime/instructions_test.go @@ -189,15 +189,6 @@ func hexToBytes(s string) []byte { return b } -type decodeTestCase struct { - dt ast.DataType - expectData *Raw - expectSlotHash dexCommon.Hash - shift uint64 - inputBytes []byte - dBytes []byte -} - type opLoadTestCase struct { title string outputIdx uint @@ -205,7 +196,7 @@ type opLoadTestCase struct { expectedErr error ids []uint64 fields []uint8 - tableIdx int8 + tableRef schema.TableRef } func (s *opLoadSuite) SetupTest() { @@ -236,7 +227,7 @@ func (s *opLoadSuite) getOpLoadTestCases(raws []*raw) []opLoadTestCase { expectedErr: nil, ids: nil, fields: nil, - tableIdx: 0, + tableRef: 0, }, { title: "NOT_EXIST_TABLE", @@ -245,7 +236,7 @@ func (s *opLoadSuite) getOpLoadTestCases(raws []*raw) []opLoadTestCase { expectedErr: errors.ErrorCodeIndexOutOfRange, ids: nil, fields: nil, - tableIdx: 13, + tableRef: 13, }, { title: "OK_CASE", @@ -254,7 +245,7 @@ func (s *opLoadSuite) getOpLoadTestCases(raws []*raw) []opLoadTestCase { expectedErr: nil, ids: []uint64{123456, 654321}, fields: s.getOKCaseFields(raws), - tableIdx: 1, + tableRef: 1, }, } return testCases @@ -283,6 +274,15 @@ func (s *opLoadSuite) getOKCaseFields(raws []*raw) []uint8 { return rValue } +type decodeTestCase struct { + dt ast.DataType + expectData *Raw + expectSlotHash dexCommon.Hash + shift uint64 + inputBytes []byte + dBytes []byte +} + func (s *opLoadSuite) getDecodeTestCases(headHash dexCommon.Hash, address dexCommon.Address, storage *common.Storage) []decodeTestCase { @@ -304,9 +304,9 @@ func (s *opLoadSuite) getDecodeTestCases(headHash dexCommon.Hash, return testCases } -func (s *opLoadSuite) newRegisters(tableIdx int8, ids []uint64, fields []uint8) []*Operand { +func (s *opLoadSuite) newRegisters(tableRef schema.TableRef, ids []uint64, fields []uint8) []*Operand { o := make([]*Operand, 4) - o[1] = newTableNameOperand(tableIdx) + o[1] = newTableRefOperand(tableRef) o[2] = newIDsOperand(ids) o[3] = newFieldsOperand(fields) return o @@ -323,8 +323,8 @@ func newInput(nums []int) []*Operand { return o } -func newTableNameOperand(tableIdx int8) *Operand { - if tableIdx < 0 { +func newTableRefOperand(tableRef schema.TableRef) *Operand { + if tableRef < 0 { return nil } o := &Operand{ @@ -334,7 +334,7 @@ func newTableNameOperand(tableIdx int8) *Operand { Data: []Tuple{ []*Raw{ { - Value: decimal.New(int64(tableIdx), 0), + Value: decimal.New(int64(tableRef), 0), }, }, }, @@ -404,7 +404,7 @@ func (s *opLoadSuite) TestOpLoad() { testCases := s.getOpLoadTestCases(s.raws) for _, t := range testCases { input := newInput([]int{1, 2, 3}) - reg := s.newRegisters(t.tableIdx, t.ids, t.fields) + reg := s.newRegisters(t.tableRef, t.ids, t.fields) loadRegister(input, reg) err := opLoad(s.ctx, input, reg, t.outputIdx) @@ -417,11 +417,72 @@ func (s *opLoadSuite) TestOpLoad() { } } -func TestOpLoad(t *testing.T) { - suite.Run(t, new(opLoadSuite)) +type opRepeatPKSuite struct{ suite.Suite } + +type repeatPKTestCase struct { + tableRef schema.TableRef + address dexCommon.Address + title string + expectedErr error + expectedIDs []uint64 +} + +func (s opRepeatPKSuite) newInput(tableRef schema.TableRef) []*Operand { + o := make([]*Operand, 1) + o[0] = newTableRefOperand(tableRef) + return o +} + +func (s opRepeatPKSuite) newRegisters(tableRef schema.TableRef) []*Operand { + o := make([]*Operand, 2) + o[1] = newTableRefOperand(tableRef) + return o +} + +func (s opRepeatPKSuite) getTestCases(storage *common.Storage) []repeatPKTestCase { + testCases := []repeatPKTestCase{ + { + tableRef: 0, + address: dexCommon.BytesToAddress([]byte("0")), + title: "no IDs", + expectedErr: nil, + expectedIDs: []uint64{}, + }, + { + tableRef: 1, + address: dexCommon.BytesToAddress([]byte("1")), + title: "ok case", + expectedErr: nil, + expectedIDs: []uint64{1, 2, 3, 4}, + }, + } + for _, t := range testCases { + storage.SetPK(t.address, t.tableRef, t.expectedIDs) + } + return testCases +} + +func (s opRepeatPKSuite) TestRepeatPK() { + ctx := &common.Context{} + ctx.Storage = newStorage() + testCases := s.getTestCases(ctx.Storage) + for _, t := range testCases { + address := t.address + ctx.Contract = vm.NewContract(vm.AccountRef(address), + vm.AccountRef(address), new(big.Int), uint64(0)) + reg := s.newRegisters(t.tableRef) + input := newInput([]int{1}) + loadRegister(input, reg) + err := opRepeatPK(ctx, input, reg, 0) + s.Require().Equalf(t.expectedErr, err, "testcase: [%v]", t.title) + result, _ := reg[0].toUint64() + s.Require().Equalf(t.expectedIDs, result, "testcase: [%v]", t.title) + } } func TestInstructions(t *testing.T) { + suite.Run(t, new(opLoadSuite)) + suite.Run(t, new(opRepeatPKSuite)) suite.Run(t, new(instructionSuite)) } |