aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJhih-Ming Huang <jm.huang@cobinhood.com>2019-02-14 15:21:33 +0800
committerJhih-Ming Huang <jm.huang@cobinhood.com>2019-05-06 10:44:05 +0800
commit996310cbd484b5ff1ea76068578314d71973770f (patch)
tree04850adab7717a7ad4aa6ff1b3464b90e7c5502f
parent4a93916c7450ae197376e65500fef4fab11bf220 (diff)
downloaddexon-996310cbd484b5ff1ea76068578314d71973770f.tar.gz
dexon-996310cbd484b5ff1ea76068578314d71973770f.tar.zst
dexon-996310cbd484b5ff1ea76068578314d71973770f.zip
core: vm: sqlvm: runtime: implement opRepeatPK
-rw-r--r--core/vm/sqlvm/common/storage.go103
-rw-r--r--core/vm/sqlvm/common/storage_test.go173
-rw-r--r--core/vm/sqlvm/runtime/instructions.go36
-rw-r--r--core/vm/sqlvm/runtime/instructions_test.go103
4 files changed, 330 insertions, 85 deletions
diff --git a/core/vm/sqlvm/common/storage.go b/core/vm/sqlvm/common/storage.go
index 0244877a2..e046b0411 100644
--- a/core/vm/sqlvm/common/storage.go
+++ b/core/vm/sqlvm/common/storage.go
@@ -1,6 +1,7 @@
package common
import (
+ "encoding/binary"
"math/big"
"github.com/dexon-foundation/decimal"
@@ -146,6 +147,17 @@ func (s *Storage) GetReverseIndexPathHash(
return s.hashPathKey(key)
}
+// GetPrimaryPathHash returns primary rlp encoded hash.
+func (s *Storage) GetPrimaryPathHash(tableRef schema.TableRef) (h common.Hash) {
+ // PathKey(["tables", "{table_name}", "primary"])
+ key := [][]byte{
+ pathCompTables,
+ tableRefToBytes(tableRef),
+ pathCompPrimary,
+ }
+ return s.hashPathKey(key)
+}
+
// getSequencePathHash return the hash address of a sequence.
func (s *Storage) getSequencePathHash(
tableRef schema.TableRef, seqIdx uint8,
@@ -521,3 +533,94 @@ func (s *Storage) IncSequence(
s.SetState(contract, seqPath, common.BytesToHash(uint64ToBytes(val+inc)))
return val
}
+
+// DecodePKHeader decodes primary key hash header to lastRowID and rowCount.
+func (s *Storage) DecodePKHeader(header common.Hash) (lastRowID, rowCount uint64) {
+ lastRowID = binary.BigEndian.Uint64(header[:8])
+ rowCount = binary.BigEndian.Uint64(header[8:16])
+ return
+}
+
+// EncodePKHeader encodes lastRowID and rowCount to primary key hash header.
+func (s *Storage) EncodePKHeader(lastRowID, rowCount uint64) (header common.Hash) {
+ binary.BigEndian.PutUint64(header[:8], lastRowID)
+ binary.BigEndian.PutUint64(header[8:16], rowCount)
+ return
+}
+
+// UpdateHash updates hash to stateDB.
+func (s *Storage) UpdateHash(m map[common.Hash]common.Hash, address common.Address) {
+ for key, val := range m {
+ s.SetState(address, key, val)
+ }
+}
+
+func setBit(n byte, pos uint) byte {
+ n |= (1 << pos)
+ return n
+}
+
+func hasBit(n byte, pos uint) bool {
+ val := n & (1 << pos)
+ return (val > 0)
+}
+
+// SetPK sets IDs to primary bit map.
+func (s *Storage) SetPK(address common.Address, tableRef schema.TableRef, IDs []uint64) {
+ hash := s.GetPrimaryPathHash(tableRef)
+ header := s.GetState(address, hash)
+ lastRowID, rowCount := s.DecodePKHeader(header)
+ slotHashToData := make(map[common.Hash]common.Hash)
+ for _, id := range IDs {
+ if lastRowID < id {
+ lastRowID = id
+ }
+ slotNum := id/256 + 1
+ byteLoc := (id & 255) / 8
+ bitLoc := uint(id & 7)
+ slotHash := s.ShiftHashUint64(hash, slotNum)
+ data, exist := slotHashToData[slotHash]
+ if !exist {
+ data = s.GetState(address, slotHash)
+ }
+ if !hasBit(data[byteLoc], bitLoc) {
+ rowCount++
+ data[byteLoc] = setBit(data[byteLoc], bitLoc)
+ }
+ slotHashToData[slotHash] = data
+ }
+ s.UpdateHash(slotHashToData, address)
+ header = s.EncodePKHeader(lastRowID, rowCount)
+ s.SetState(address, hash, header)
+}
+
+func getCountAndOffset(d common.Hash) (offset []uint64) {
+ for j, b := range d {
+ for i := 0; i < 8; i++ {
+ if hasBit(b, uint(i)) {
+ offset = append(offset, uint64(j*8+i))
+ }
+ }
+ }
+ return
+}
+
+// RepeatPK returns primary IDs by table reference.
+func (s *Storage) RepeatPK(address common.Address, tableRef schema.TableRef) []uint64 {
+ hash := s.GetPrimaryPathHash(tableRef)
+ header := s.GetState(address, hash)
+ lastRowID, rowCount := s.DecodePKHeader(header)
+ maxSlotNum := lastRowID/256 + 1
+ result := make([]uint64, rowCount)
+ ptr := 0
+ for slotNum := uint64(0); slotNum < maxSlotNum; slotNum++ {
+ slotHash := s.ShiftHashUint64(hash, slotNum+1)
+ slotData := s.GetState(address, slotHash)
+ offsets := getCountAndOffset(slotData)
+ for i, o := range offsets {
+ result[i+ptr] = o + slotNum*256
+ }
+ ptr += len(offsets)
+ }
+ return result
+}
diff --git a/core/vm/sqlvm/common/storage_test.go b/core/vm/sqlvm/common/storage_test.go
index 44aff12f7..d402868e3 100644
--- a/core/vm/sqlvm/common/storage_test.go
+++ b/core/vm/sqlvm/common/storage_test.go
@@ -17,7 +17,18 @@ import (
"github.com/dexon-foundation/dexon/rlp"
)
-type StorageTestSuite struct{ suite.Suite }
+type StorageTestSuite struct {
+ suite.Suite
+ storage *Storage
+ address common.Address
+}
+
+func (s *StorageTestSuite) SetupTest() {
+ db := ethdb.NewMemDatabase()
+ state, _ := state.New(common.Hash{}, state.NewDatabase(db))
+ s.storage = NewStorage(state)
+ s.address = common.BytesToAddress([]byte("5566"))
+}
func (s *StorageTestSuite) TestUint64ToBytes() {
testcases := []uint64{1, 65535, math.MaxUint64}
@@ -38,8 +49,7 @@ func (s *StorageTestSuite) TestGetRowAddress() {
hw := sha3.NewLegacyKeccak256()
rlp.Encode(hw, key)
bytes := hw.Sum(nil)
- storage := &Storage{}
- result := storage.GetRowPathHash(table, id)
+ result := s.storage.GetRowPathHash(table, id)
s.Require().Equal(bytes, result[:])
}
@@ -50,9 +60,6 @@ type decodeTestCase struct {
}
func (s *StorageTestSuite) TestDecodeDByte() {
- db := ethdb.NewMemDatabase()
- state, _ := state.New(common.Hash{}, state.NewDatabase(db))
- storage := NewStorage(state)
address := common.BytesToAddress([]byte("123"))
head := common.HexToHash("0x5566")
testcase := []decodeTestCase{
@@ -77,10 +84,10 @@ func (s *StorageTestSuite) TestDecodeDByte() {
result: []byte(""),
},
}
- SetDataToStorage(head, storage, address, testcase)
+ SetDataToStorage(head, s.storage, address, testcase)
for i, t := range testcase {
- slot := storage.ShiftHashUint64(head, uint64(i))
- result := storage.DecodeDByteBySlot(address, slot)
+ slot := s.storage.ShiftHashUint64(head, uint64(i))
+ result := s.storage.DecodeDByteBySlot(address, slot)
s.Require().Truef(bytes.Equal(result, t.result), fmt.Sprintf("name %v", t.name))
}
}
@@ -111,29 +118,21 @@ func SetDataToStorage(head common.Hash, storage *Storage, addr common.Address,
}
func (s *StorageTestSuite) TestOwner() {
- db := ethdb.NewMemDatabase()
- state, _ := state.New(common.Hash{}, state.NewDatabase(db))
- storage := NewStorage(state)
-
contractA := common.BytesToAddress([]byte("I'm sad."))
ownerA := common.BytesToAddress([]byte{5, 5, 6, 6})
contractB := common.BytesToAddress([]byte{9, 5, 2, 7})
ownerB := common.BytesToAddress([]byte("Tong Pak-Fu"))
- storage.StoreOwner(contractA, ownerA)
- storage.StoreOwner(contractB, ownerB)
- s.Require().Equal(ownerA, storage.LoadOwner(contractA))
- s.Require().Equal(ownerB, storage.LoadOwner(contractB))
+ s.storage.StoreOwner(contractA, ownerA)
+ s.storage.StoreOwner(contractB, ownerB)
+ s.Require().Equal(ownerA, s.storage.LoadOwner(contractA))
+ s.Require().Equal(ownerB, s.storage.LoadOwner(contractB))
- storage.StoreOwner(contractA, ownerB)
- s.Require().Equal(ownerB, storage.LoadOwner(contractA))
+ s.storage.StoreOwner(contractA, ownerB)
+ s.Require().Equal(ownerB, s.storage.LoadOwner(contractA))
}
func (s *StorageTestSuite) TestTableWriter() {
- db := ethdb.NewMemDatabase()
- state, _ := state.New(common.Hash{}, state.NewDatabase(db))
- storage := NewStorage(state)
-
table1 := schema.TableRef(0)
table2 := schema.TableRef(1)
contractA := common.BytesToAddress([]byte("A"))
@@ -145,69 +144,115 @@ func (s *StorageTestSuite) TestTableWriter() {
}
// Genesis.
- s.Require().Len(storage.LoadTableWriters(contractA, table1), 0)
- s.Require().Len(storage.LoadTableWriters(contractB, table1), 0)
+ s.Require().Len(s.storage.LoadTableWriters(contractA, table1), 0)
+ s.Require().Len(s.storage.LoadTableWriters(contractB, table1), 0)
// Check writer list.
- storage.InsertTableWriter(contractA, table1, addrs[0])
- storage.InsertTableWriter(contractA, table1, addrs[1])
- storage.InsertTableWriter(contractA, table1, addrs[2])
- storage.InsertTableWriter(contractB, table2, addrs[0])
- s.Require().Equal(addrs, storage.LoadTableWriters(contractA, table1))
- s.Require().Len(storage.LoadTableWriters(contractA, table2), 0)
- s.Require().Len(storage.LoadTableWriters(contractB, table1), 0)
+ s.storage.InsertTableWriter(contractA, table1, addrs[0])
+ s.storage.InsertTableWriter(contractA, table1, addrs[1])
+ s.storage.InsertTableWriter(contractA, table1, addrs[2])
+ s.storage.InsertTableWriter(contractB, table2, addrs[0])
+ s.Require().Equal(addrs, s.storage.LoadTableWriters(contractA, table1))
+ s.Require().Len(s.storage.LoadTableWriters(contractA, table2), 0)
+ s.Require().Len(s.storage.LoadTableWriters(contractB, table1), 0)
s.Require().Equal([]common.Address{addrs[0]},
- storage.LoadTableWriters(contractB, table2))
+ s.storage.LoadTableWriters(contractB, table2))
// Insert duplicate.
- storage.InsertTableWriter(contractA, table1, addrs[0])
- storage.InsertTableWriter(contractA, table1, addrs[1])
- storage.InsertTableWriter(contractA, table1, addrs[2])
- s.Require().Equal(addrs, storage.LoadTableWriters(contractA, table1))
+ s.storage.InsertTableWriter(contractA, table1, addrs[0])
+ s.storage.InsertTableWriter(contractA, table1, addrs[1])
+ s.storage.InsertTableWriter(contractA, table1, addrs[2])
+ s.Require().Equal(addrs, s.storage.LoadTableWriters(contractA, table1))
// Delete some writer.
- storage.DeleteTableWriter(contractA, table1, addrs[0])
- storage.DeleteTableWriter(contractA, table2, addrs[0])
- storage.DeleteTableWriter(contractB, table2, addrs[0])
+ s.storage.DeleteTableWriter(contractA, table1, addrs[0])
+ s.storage.DeleteTableWriter(contractA, table2, addrs[0])
+ s.storage.DeleteTableWriter(contractB, table2, addrs[0])
s.Require().Equal([]common.Address{addrs[2], addrs[1]},
- storage.LoadTableWriters(contractA, table1))
- s.Require().Len(storage.LoadTableWriters(contractA, table2), 0)
- s.Require().Len(storage.LoadTableWriters(contractB, table1), 0)
- s.Require().Len(storage.LoadTableWriters(contractB, table2), 0)
+ s.storage.LoadTableWriters(contractA, table1))
+ s.Require().Len(s.storage.LoadTableWriters(contractA, table2), 0)
+ s.Require().Len(s.storage.LoadTableWriters(contractB, table1), 0)
+ s.Require().Len(s.storage.LoadTableWriters(contractB, table2), 0)
// Delete again.
- storage.DeleteTableWriter(contractA, table1, addrs[2])
+ s.storage.DeleteTableWriter(contractA, table1, addrs[2])
s.Require().Equal([]common.Address{addrs[1]},
- storage.LoadTableWriters(contractA, table1))
+ s.storage.LoadTableWriters(contractA, table1))
// Check writer.
- s.Require().False(storage.IsTableWriter(contractA, table1, addrs[0]))
- s.Require().True(storage.IsTableWriter(contractA, table1, addrs[1]))
- s.Require().False(storage.IsTableWriter(contractA, table1, addrs[2]))
- s.Require().False(storage.IsTableWriter(contractA, table2, addrs[0]))
- s.Require().False(storage.IsTableWriter(contractB, table2, addrs[0]))
+ s.Require().False(s.storage.IsTableWriter(contractA, table1, addrs[0]))
+ s.Require().True(s.storage.IsTableWriter(contractA, table1, addrs[1]))
+ s.Require().False(s.storage.IsTableWriter(contractA, table1, addrs[2]))
+ s.Require().False(s.storage.IsTableWriter(contractA, table2, addrs[0]))
+ s.Require().False(s.storage.IsTableWriter(contractB, table2, addrs[0]))
}
func (s *StorageTestSuite) TestSequence() {
- db := ethdb.NewMemDatabase()
- state, _ := state.New(common.Hash{}, state.NewDatabase(db))
- storage := NewStorage(state)
-
table1 := schema.TableRef(0)
table2 := schema.TableRef(1)
contract := common.BytesToAddress([]byte("A"))
- s.Require().Equal(uint64(0), storage.IncSequence(contract, table1, 0, 2))
- s.Require().Equal(uint64(2), storage.IncSequence(contract, table1, 0, 1))
- s.Require().Equal(uint64(3), storage.IncSequence(contract, table1, 0, 1))
+ s.Require().Equal(uint64(0), s.storage.IncSequence(contract, table1, 0, 2))
+ s.Require().Equal(uint64(2), s.storage.IncSequence(contract, table1, 0, 1))
+ s.Require().Equal(uint64(3), s.storage.IncSequence(contract, table1, 0, 1))
// Repeat on another sequence.
- s.Require().Equal(uint64(0), storage.IncSequence(contract, table1, 1, 1))
- s.Require().Equal(uint64(1), storage.IncSequence(contract, table1, 1, 2))
- s.Require().Equal(uint64(3), storage.IncSequence(contract, table1, 1, 3))
+ s.Require().Equal(uint64(0), s.storage.IncSequence(contract, table1, 1, 1))
+ s.Require().Equal(uint64(1), s.storage.IncSequence(contract, table1, 1, 2))
+ s.Require().Equal(uint64(3), s.storage.IncSequence(contract, table1, 1, 3))
// Repeat on another table.
- s.Require().Equal(uint64(0), storage.IncSequence(contract, table2, 0, 3))
- s.Require().Equal(uint64(3), storage.IncSequence(contract, table2, 0, 4))
- s.Require().Equal(uint64(7), storage.IncSequence(contract, table2, 0, 5))
+ s.Require().Equal(uint64(0), s.storage.IncSequence(contract, table2, 0, 3))
+ s.Require().Equal(uint64(3), s.storage.IncSequence(contract, table2, 0, 4))
+ s.Require().Equal(uint64(7), s.storage.IncSequence(contract, table2, 0, 5))
+}
+
+func (s *StorageTestSuite) TestPKHeaderEncodeDecode() {
+ lastRowID := uint64(5566)
+ rowCount := uint64(6655)
+ newLastRowID, newRowCount := s.storage.DecodePKHeader(s.storage.EncodePKHeader(lastRowID, rowCount))
+ s.Require().Equal(lastRowID, newLastRowID)
+ s.Require().Equal(rowCount, newRowCount)
+}
+
+func (s *StorageTestSuite) TestUpdateHash() {
+ m := map[common.Hash]common.Hash{
+ common.BytesToHash([]byte("hello world")): common.BytesToHash([]byte("hello SQLVM")),
+ common.BytesToHash([]byte("bye world")): common.BytesToHash([]byte("bye SQLVM")),
+ }
+ s.storage.UpdateHash(m, s.address)
+ for key, val := range m {
+ rVal := s.storage.GetState(s.address, key)
+ s.Require().Equal(val, rVal)
+ }
+}
+
+func (s *StorageTestSuite) TestRepeatPK() {
+ type testCase struct {
+ address common.Address
+ tableRef schema.TableRef
+ expectIDs []uint64
+ }
+ testCases := []testCase{
+ {
+ address: common.BytesToAddress([]byte("0")),
+ tableRef: schema.TableRef(0),
+ expectIDs: []uint64{0, 1, 2},
+ },
+ {
+ address: common.BytesToAddress([]byte("1")),
+ tableRef: schema.TableRef(1),
+ expectIDs: []uint64{1234, 5566},
+ },
+ {
+ address: common.BytesToAddress([]byte("2")),
+ tableRef: schema.TableRef(2),
+ expectIDs: []uint64{0, 128, 256, 512, 1024},
+ },
+ }
+ for i, t := range testCases {
+ s.storage.SetPK(t.address, t.tableRef, t.expectIDs)
+ IDs := s.storage.RepeatPK(t.address, t.tableRef)
+ s.Require().Equalf(t.expectIDs, IDs, "testCase #%v\n", i)
+ }
}
func TestStorage(t *testing.T) {
diff --git a/core/vm/sqlvm/runtime/instructions.go b/core/vm/sqlvm/runtime/instructions.go
index d61bacdde..6fa72d61a 100644
--- a/core/vm/sqlvm/runtime/instructions.go
+++ b/core/vm/sqlvm/runtime/instructions.go
@@ -4,6 +4,7 @@ import (
"bytes"
"errors"
"fmt"
+ "math/big"
"regexp"
"sort"
"strings"
@@ -98,6 +99,14 @@ func (op *Operand) toUint64() (result []uint64, err error) {
return
}
+func (op *Operand) toTableRef() (schema.TableRef, error) {
+ t, err := op.toUint8()
+ if err != nil {
+ return 0, err
+ }
+ return schema.TableRef(t[0]), nil
+}
+
func (op *Operand) toUint8() ([]uint8, error) {
result := make([]uint8, len(op.Data))
for i, tuple := range op.Data {
@@ -1881,3 +1890,30 @@ func opFunc(ctx *common.Context, ops, registers []*Operand, output uint) (err er
registers[output] = result
return
}
+
+func uint64ToOperands(numbers []uint64) (*Operand, error) {
+ result := &Operand{
+ Meta: []ast.DataType{ast.ComposeDataType(ast.DataTypeMajorUint, 7)},
+ Data: []Tuple{},
+ }
+ result.Data = make([]Tuple, len(numbers))
+ for i, n := range numbers {
+ result.Data[i] = []*Raw{
+ {
+ Value: decimal.NewFromBigInt(new(big.Int).SetUint64(n), 0),
+ Bytes: nil,
+ },
+ }
+ }
+ return result, nil
+}
+
+func opRepeatPK(ctx *common.Context, input []*Operand, registers []*Operand, output int) (err error) {
+ tableRef, err := input[0].toTableRef()
+ if err != nil {
+ return err
+ }
+ IDs := ctx.Storage.RepeatPK(ctx.Contract.Address(), tableRef)
+ registers[output], err = uint64ToOperands(IDs)
+ return
+}
diff --git a/core/vm/sqlvm/runtime/instructions_test.go b/core/vm/sqlvm/runtime/instructions_test.go
index a399ad557..77ba0ff70 100644
--- a/core/vm/sqlvm/runtime/instructions_test.go
+++ b/core/vm/sqlvm/runtime/instructions_test.go
@@ -189,15 +189,6 @@ func hexToBytes(s string) []byte {
return b
}
-type decodeTestCase struct {
- dt ast.DataType
- expectData *Raw
- expectSlotHash dexCommon.Hash
- shift uint64
- inputBytes []byte
- dBytes []byte
-}
-
type opLoadTestCase struct {
title string
outputIdx uint
@@ -205,7 +196,7 @@ type opLoadTestCase struct {
expectedErr error
ids []uint64
fields []uint8
- tableIdx int8
+ tableRef schema.TableRef
}
func (s *opLoadSuite) SetupTest() {
@@ -236,7 +227,7 @@ func (s *opLoadSuite) getOpLoadTestCases(raws []*raw) []opLoadTestCase {
expectedErr: nil,
ids: nil,
fields: nil,
- tableIdx: 0,
+ tableRef: 0,
},
{
title: "NOT_EXIST_TABLE",
@@ -245,7 +236,7 @@ func (s *opLoadSuite) getOpLoadTestCases(raws []*raw) []opLoadTestCase {
expectedErr: errors.ErrorCodeIndexOutOfRange,
ids: nil,
fields: nil,
- tableIdx: 13,
+ tableRef: 13,
},
{
title: "OK_CASE",
@@ -254,7 +245,7 @@ func (s *opLoadSuite) getOpLoadTestCases(raws []*raw) []opLoadTestCase {
expectedErr: nil,
ids: []uint64{123456, 654321},
fields: s.getOKCaseFields(raws),
- tableIdx: 1,
+ tableRef: 1,
},
}
return testCases
@@ -283,6 +274,15 @@ func (s *opLoadSuite) getOKCaseFields(raws []*raw) []uint8 {
return rValue
}
+type decodeTestCase struct {
+ dt ast.DataType
+ expectData *Raw
+ expectSlotHash dexCommon.Hash
+ shift uint64
+ inputBytes []byte
+ dBytes []byte
+}
+
func (s *opLoadSuite) getDecodeTestCases(headHash dexCommon.Hash,
address dexCommon.Address, storage *common.Storage) []decodeTestCase {
@@ -304,9 +304,9 @@ func (s *opLoadSuite) getDecodeTestCases(headHash dexCommon.Hash,
return testCases
}
-func (s *opLoadSuite) newRegisters(tableIdx int8, ids []uint64, fields []uint8) []*Operand {
+func (s *opLoadSuite) newRegisters(tableRef schema.TableRef, ids []uint64, fields []uint8) []*Operand {
o := make([]*Operand, 4)
- o[1] = newTableNameOperand(tableIdx)
+ o[1] = newTableRefOperand(tableRef)
o[2] = newIDsOperand(ids)
o[3] = newFieldsOperand(fields)
return o
@@ -323,8 +323,8 @@ func newInput(nums []int) []*Operand {
return o
}
-func newTableNameOperand(tableIdx int8) *Operand {
- if tableIdx < 0 {
+func newTableRefOperand(tableRef schema.TableRef) *Operand {
+ if tableRef < 0 {
return nil
}
o := &Operand{
@@ -334,7 +334,7 @@ func newTableNameOperand(tableIdx int8) *Operand {
Data: []Tuple{
[]*Raw{
{
- Value: decimal.New(int64(tableIdx), 0),
+ Value: decimal.New(int64(tableRef), 0),
},
},
},
@@ -404,7 +404,7 @@ func (s *opLoadSuite) TestOpLoad() {
testCases := s.getOpLoadTestCases(s.raws)
for _, t := range testCases {
input := newInput([]int{1, 2, 3})
- reg := s.newRegisters(t.tableIdx, t.ids, t.fields)
+ reg := s.newRegisters(t.tableRef, t.ids, t.fields)
loadRegister(input, reg)
err := opLoad(s.ctx, input, reg, t.outputIdx)
@@ -417,11 +417,72 @@ func (s *opLoadSuite) TestOpLoad() {
}
}
-func TestOpLoad(t *testing.T) {
- suite.Run(t, new(opLoadSuite))
+type opRepeatPKSuite struct{ suite.Suite }
+
+type repeatPKTestCase struct {
+ tableRef schema.TableRef
+ address dexCommon.Address
+ title string
+ expectedErr error
+ expectedIDs []uint64
+}
+
+func (s opRepeatPKSuite) newInput(tableRef schema.TableRef) []*Operand {
+ o := make([]*Operand, 1)
+ o[0] = newTableRefOperand(tableRef)
+ return o
+}
+
+func (s opRepeatPKSuite) newRegisters(tableRef schema.TableRef) []*Operand {
+ o := make([]*Operand, 2)
+ o[1] = newTableRefOperand(tableRef)
+ return o
+}
+
+func (s opRepeatPKSuite) getTestCases(storage *common.Storage) []repeatPKTestCase {
+ testCases := []repeatPKTestCase{
+ {
+ tableRef: 0,
+ address: dexCommon.BytesToAddress([]byte("0")),
+ title: "no IDs",
+ expectedErr: nil,
+ expectedIDs: []uint64{},
+ },
+ {
+ tableRef: 1,
+ address: dexCommon.BytesToAddress([]byte("1")),
+ title: "ok case",
+ expectedErr: nil,
+ expectedIDs: []uint64{1, 2, 3, 4},
+ },
+ }
+ for _, t := range testCases {
+ storage.SetPK(t.address, t.tableRef, t.expectedIDs)
+ }
+ return testCases
+}
+
+func (s opRepeatPKSuite) TestRepeatPK() {
+ ctx := &common.Context{}
+ ctx.Storage = newStorage()
+ testCases := s.getTestCases(ctx.Storage)
+ for _, t := range testCases {
+ address := t.address
+ ctx.Contract = vm.NewContract(vm.AccountRef(address),
+ vm.AccountRef(address), new(big.Int), uint64(0))
+ reg := s.newRegisters(t.tableRef)
+ input := newInput([]int{1})
+ loadRegister(input, reg)
+ err := opRepeatPK(ctx, input, reg, 0)
+ s.Require().Equalf(t.expectedErr, err, "testcase: [%v]", t.title)
+ result, _ := reg[0].toUint64()
+ s.Require().Equalf(t.expectedIDs, result, "testcase: [%v]", t.title)
+ }
}
func TestInstructions(t *testing.T) {
+ suite.Run(t, new(opLoadSuite))
+ suite.Run(t, new(opRepeatPKSuite))
suite.Run(t, new(instructionSuite))
}