aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorTing-Wei Lan <tingwei.lan@cobinhood.com>2019-04-23 15:44:55 +0800
committerTing-Wei Lan <tingwei.lan@cobinhood.com>2019-05-14 11:04:15 +0800
commit521b542feaaad50b2f6aca8fdff5b8fcb7578593 (patch)
treeae43b805f46b67dd5a09adc36fb977c09cf604cd
parent7a237ad79f7d7e6ed1a875ee700805b6d9d3791b (diff)
downloaddexon-521b542feaaad50b2f6aca8fdff5b8fcb7578593.tar.gz
dexon-521b542feaaad50b2f6aca8fdff5b8fcb7578593.tar.zst
dexon-521b542feaaad50b2f6aca8fdff5b8fcb7578593.zip
core: vm: sqlvm: implement schema compiler and type checker
-rw-r--r--core/vm/sqlvm/ast/ast.go135
-rw-r--r--core/vm/sqlvm/ast/types.go182
-rw-r--r--core/vm/sqlvm/checkers/actions.go147
-rw-r--r--core/vm/sqlvm/checkers/checkers.go1496
-rw-r--r--core/vm/sqlvm/checkers/utils.go471
-rw-r--r--core/vm/sqlvm/cmd/ast-checker/main.go123
-rw-r--r--core/vm/sqlvm/errors/errors.go90
-rw-r--r--core/vm/sqlvm/schema/schema.go135
8 files changed, 2766 insertions, 13 deletions
diff --git a/core/vm/sqlvm/ast/ast.go b/core/vm/sqlvm/ast/ast.go
index 2215f4480..5868549a8 100644
--- a/core/vm/sqlvm/ast/ast.go
+++ b/core/vm/sqlvm/ast/ast.go
@@ -1,7 +1,10 @@
package ast
import (
+ "bytes"
"fmt"
+ "unicode"
+ "unicode/utf8"
"github.com/dexon-foundation/decimal"
@@ -66,6 +69,113 @@ func (n *NodeBase) SetToken(token []byte) {
n.Token = token
}
+func safeIdentifierStart(r rune) bool {
+ return (r >= 'A' && r <= 'Z') || (r >= 'a' && r <= 'z') || (r >= 0x80)
+}
+
+func safeIdentifierExtend(r rune) bool {
+ return (r >= '0' && r <= '9') || (r == '_')
+}
+
+func safeString(s []byte, quote byte, mustQuote bool) []byte {
+ o := bytes.Buffer{}
+ o.Grow(len(s) + 1)
+ o.WriteByte(quote)
+
+ if len(s) > 0 {
+ for r, i, size := rune(0), 0, 0; i < len(s); i += size {
+ r, size = utf8.DecodeRune(s[i:])
+ if r == utf8.RuneError {
+ switch size {
+ case 0:
+ panic("we should never pass an empty slice to DecodeRune")
+ case 1:
+ mustQuote = true
+ x := fmt.Sprintf(`\x%02x`, s[i])
+ o.WriteString(x)
+ continue
+ // The default case is deliberately omitted here. It is
+ // possible for a valid UTF-8 string to include the code
+ // point utf8.RuneError. If it is the case, the size must
+ // be larger than 1.
+ }
+ }
+ safeIdentifier :=
+ safeIdentifierStart(r) || (i > 0 && safeIdentifierExtend(r))
+ if !safeIdentifier {
+ mustQuote = true
+ }
+ switch r {
+ case '\\':
+ mustQuote = true
+ o.WriteString(`\\`)
+ case '\'':
+ mustQuote = true
+ o.WriteString(`\'`)
+ case '"':
+ mustQuote = true
+ o.WriteString(`\"`)
+ case '\b':
+ mustQuote = true
+ o.WriteString(`\b`)
+ case '\f':
+ mustQuote = true
+ o.WriteString(`\f`)
+ case '\n':
+ mustQuote = true
+ o.WriteString(`\n`)
+ case '\r':
+ mustQuote = true
+ o.WriteString(`\r`)
+ case '\t':
+ mustQuote = true
+ o.WriteString(`\t`)
+ case '\v':
+ mustQuote = true
+ o.WriteString(`\v`)
+ default:
+ if unicode.IsPrint(r) {
+ x := s[i : i+size]
+ o.Write(x)
+ } else {
+ mustQuote = true
+ if r > 0xffff {
+ x := fmt.Sprintf(`\U%08x`, r)
+ o.WriteString(x)
+ } else {
+ x := fmt.Sprintf(`\u%04x`, r)
+ o.WriteString(x)
+ }
+ }
+ }
+ }
+ } else {
+ mustQuote = true
+ }
+
+ if mustQuote {
+ o.WriteByte(quote)
+ return o.Bytes()
+ }
+ return o.Bytes()[1:]
+}
+
+// QuoteIdentifier quotes a string for safely using it as a SQL identifier.
+func QuoteIdentifier(s []byte) []byte {
+ return safeString(s, '"', true)
+}
+
+// QuoteIdentifierOptional is the same as QuoteIdentifier except that it does
+// not quote the string if the string itself can be used safely.
+func QuoteIdentifierOptional(s []byte) []byte {
+ return safeString(s, '"', false)
+}
+
+// QuoteString quotes a string for safely using it as a SQL string literal.
+func QuoteString(s []byte) []byte {
+ return safeString(s, '\'', true)
+}
+
// ---------------------------------------------------------------------------
// Identifiers
// ---------------------------------------------------------------------------
@@ -186,6 +296,31 @@ func (n *BoolValueNode) GetType() DataType {
return ComposeDataType(DataTypeMajorBool, DataTypeMinorDontCare)
}
+// AddressValueNode is an address constant.
+type AddressValueNode struct {
+ UntaggedExprNodeBase
+ V []byte
+}
+
+var _ Valuer = (*AddressValueNode)(nil)
+
+func (n *AddressValueNode) ˉValuer() {}
+
+// GetChildren returns a list of child nodes used for traversing.
+func (n *AddressValueNode) GetChildren() []Node {
+ return nil
+}
+
+// IsConstant returns whether a node is a constant.
+func (n *AddressValueNode) IsConstant() bool {
+ return true
+}
+
+// GetType returns the type of 'bool'.
+func (n *AddressValueNode) GetType() DataType {
+ return ComposeDataType(DataTypeMajorAddress, DataTypeMinorDontCare)
+}
+
// IntegerValueNode is an integer constant.
type IntegerValueNode struct {
TaggedExprNodeBase
diff --git a/core/vm/sqlvm/ast/types.go b/core/vm/sqlvm/ast/types.go
index a441b4b2c..fc403fcb3 100644
--- a/core/vm/sqlvm/ast/types.go
+++ b/core/vm/sqlvm/ast/types.go
@@ -59,6 +59,7 @@ const (
// Special data types which are commonly used.
const (
DataTypePending DataType = (DataType(DataTypeMajorPending) << 8) | DataType(DataTypeMinorDontCare)
+ DataTypeNull DataType = (DataType(DataTypeMajorSpecial) << 8) | DataType(DataTypeMinorSpecialNull)
DataTypeBad DataType = math.MaxUint16
)
@@ -189,6 +190,136 @@ func (dt DataType) Size() uint8 {
}
}
+// Valid checks whether a data type is set to a defined value.
+func (dt DataType) Valid() bool {
+ major, minor := DecomposeDataType(dt)
+ switch major {
+ case DataTypeMajorPending,
+ DataTypeMajorBool,
+ DataTypeMajorAddress,
+ DataTypeMajorDynamicBytes:
+ return true
+ case DataTypeMajorSpecial:
+ switch minor {
+ case DataTypeMinorSpecialNull:
+ return true
+ }
+ case DataTypeMajorInt,
+ DataTypeMajorUint,
+ DataTypeMajorFixedBytes:
+ if minor <= 0x1f {
+ return true
+ }
+ }
+ if (major.IsFixedRange() || major.IsUfixedRange()) && minor <= 80 {
+ return true
+ }
+ return false
+}
+
+// ValidColumn checks whether a data type is a valid column type.
+func (dt DataType) ValidColumn() bool {
+ major, minor := DecomposeDataType(dt)
+ switch major {
+ case DataTypeMajorBool,
+ DataTypeMajorAddress,
+ DataTypeMajorDynamicBytes:
+ return true
+ case DataTypeMajorInt,
+ DataTypeMajorUint,
+ DataTypeMajorFixedBytes:
+ if minor <= 0x1f {
+ return true
+ }
+ }
+ if (major.IsFixedRange() || major.IsUfixedRange()) && minor <= 80 {
+ return true
+ }
+ return false
+}
+
+// ValidExpr checks whether a data type is a valid type of an expression.
+func (dt DataType) ValidExpr() bool {
+ return dt.ValidColumn() || dt == DataTypeNull
+}
+
+// Equal checks whether two data types are equal and valid. If any of them is
+// invalid, false is returned.
+func (dt DataType) Equal(dt2 DataType) bool {
+ // Rename variables.
+ a := dt
+ b := dt2
+ // Process the common case.
+ if a == b {
+ return a.Valid()
+ }
+ // a ≠ b
+ aMajor, _ := DecomposeDataType(a)
+ bMajor, _ := DecomposeDataType(b)
+ if aMajor != bMajor {
+ return false
+ }
+ // a ≠ b, aMajor = bMajor ⇒ aMinor ≠ bMinor
+ switch aMajor {
+ case DataTypeMajorPending,
+ DataTypeMajorBool,
+ DataTypeMajorAddress,
+ DataTypeMajorDynamicBytes:
+ return true
+ default:
+ return false
+ }
+}
+
+func (dt DataType) String() string {
+ major, minor := DecomposeDataType(dt)
+ switch major {
+ case DataTypeMajorPending:
+ return "<PENDING>"
+ case DataTypeMajorSpecial:
+ switch minor {
+ case DataTypeMinorSpecialNull:
+ return "NULL"
+ }
+ case DataTypeMajorBool:
+ return "BOOL"
+ case DataTypeMajorAddress:
+ return "ADDRESS"
+ case DataTypeMajorInt:
+ if minor <= 0x1f {
+ size := (uint32(minor) + 1) * 8
+ return fmt.Sprintf("INT%d", size)
+ }
+ case DataTypeMajorUint:
+ if minor <= 0x1f {
+ size := (uint32(minor) + 1) * 8
+ return fmt.Sprintf("UINT%d", size)
+ }
+ case DataTypeMajorFixedBytes:
+ if minor <= 0x1f {
+ size := uint32(minor) + 1
+ return fmt.Sprintf("BYTES%d", size)
+ }
+ case DataTypeMajorDynamicBytes:
+ return "BYTES"
+ }
+ switch {
+ case major.IsFixedRange():
+ if minor <= 80 {
+ size := (uint32(major-DataTypeMajorFixed) + 1) * 8
+ fractionalDigits := uint32(minor)
+ return fmt.Sprintf("FIXED%dX%d", size, fractionalDigits)
+ }
+ case major.IsUfixedRange():
+ if minor <= 80 {
+ size := (uint32(major-DataTypeMajorUfixed) + 1) * 8
+ fractionalDigits := uint32(minor)
+ return fmt.Sprintf("UFIXED%dX%d", size, fractionalDigits)
+ }
+ }
+ return "<INVALID>"
+}
+
// GetNode constructs an AST node from a data type.
func (dt DataType) GetNode() TypeNode {
major, minor := DecomposeDataType(dt)
@@ -293,12 +424,39 @@ var decimalMinMaxMap = func() map[DataType]decimalMinMaxPair {
// GetMinMax returns min, max pair according to given data type.
func (dt DataType) GetMinMax() (decimal.Decimal, decimal.Decimal, bool) {
- var (
- pair decimalMinMaxPair
- ok bool
- )
- pair, ok = decimalMinMaxMap[dt]
- return pair.Min, pair.Max, ok
+ pair, ok := decimalMinMaxMap[dt]
+ if ok {
+ return pair.Min, pair.Max, true
+ }
+
+ // Compute the range of fixed and ufixed types on demand.
+ major, minor := DecomposeDataType(dt)
+ switch {
+ case major.IsFixedRange():
+ mapInt := ComposeDataType(
+ DataTypeMajorInt, DataTypeMinor(major-DataTypeMajorFixed))
+ pair, ok = decimalMinMaxMap[mapInt]
+ if !ok || minor > 80 {
+ return decimal.Zero, decimal.Zero, false
+ }
+ min := pair.Min.Shift(-int32(minor))
+ max := pair.Max.Shift(-int32(minor))
+ return min, max, true
+
+ case major.IsUfixedRange():
+ mapUint := ComposeDataType(
+ DataTypeMajorUint, DataTypeMinor(major-DataTypeMajorUfixed))
+ pair, ok = decimalMinMaxMap[mapUint]
+ if !ok || minor > 80 {
+ return decimal.Zero, decimal.Zero, false
+ }
+ min := pair.Min.Shift(-int32(minor))
+ max := pair.Max.Shift(-int32(minor))
+ return min, max, true
+
+ default:
+ return decimal.Zero, decimal.Zero, false
+ }
}
func decimalToBig(d decimal.Decimal) *big.Int {
@@ -318,13 +476,21 @@ func decimalEncode(size int, d decimal.Decimal) []byte {
if s > 0 {
bs := b.Bytes()
- copy(ret[size-len(bs):], bs)
+ if size >= len(bs) {
+ copy(ret[size-len(bs):], bs)
+ } else {
+ copy(ret, bs[len(bs)-size:])
+ }
return ret
}
b.Add(b, bigIntOne)
bs := b.Bytes()
- copy(ret[size-len(bs):], bs)
+ if size >= len(bs) {
+ copy(ret[size-len(bs):], bs)
+ } else {
+ copy(ret, bs[len(bs)-size:])
+ }
for idx := range ret {
ret[idx] = ^ret[idx]
}
diff --git a/core/vm/sqlvm/checkers/actions.go b/core/vm/sqlvm/checkers/actions.go
new file mode 100644
index 000000000..c15029d9a
--- /dev/null
+++ b/core/vm/sqlvm/checkers/actions.go
@@ -0,0 +1,147 @@
+package checkers
+
+import (
+ "fmt"
+
+ "github.com/dexon-foundation/dexon/core/vm/sqlvm/ast"
+ "github.com/dexon-foundation/dexon/core/vm/sqlvm/errors"
+ "github.com/dexon-foundation/dexon/core/vm/sqlvm/schema"
+)
+
+// CheckOptions stores boolean options for Check* functions.
+type CheckOptions uint32
+
+const (
+ // CheckWithSafeMath enables overflow and underflow checks during expression
+ // evaluation. An error will be thrown when the result is out of range.
+ CheckWithSafeMath CheckOptions = 1 << iota
+ // CheckWithSafeCast enables overflow and underflow checks during casting.
+ // An error will be thrown if the value does not fit in the target type.
+ CheckWithSafeCast
+ // CheckWithConstantOnly restricts the expression to be a constant. An error
+ // will be thrown if the expression cannot be folded into a constant.
+ CheckWithConstantOnly
+)
+
+// CheckCreate runs CREATE commands to generate a database schema. It modifies
+// AST in-place during evaluation of expressions.
+func CheckCreate(ss []ast.StmtNode, o CheckOptions) (schema.Schema, error) {
+ fn := "CheckCreate"
+ s := schema.Schema{}
+ c := newSchemaCache()
+ el := errors.ErrorList{}
+
+ for idx := range ss {
+ if ss[idx] == nil {
+ continue
+ }
+
+ switch n := ss[idx].(type) {
+ case *ast.CreateTableStmtNode:
+ checkCreateTableStmt(n, &s, o, c, &el)
+ case *ast.CreateIndexStmtNode:
+ checkCreateIndexStmt(n, &s, o, c, &el)
+ default:
+ el.Append(errors.Error{
+ Position: ss[idx].GetPosition(),
+ Length: ss[idx].GetLength(),
+ Category: errors.ErrorCategoryCommand,
+ Code: errors.ErrorCodeDisallowedCommand,
+ Severity: errors.ErrorSeverityError,
+ Prefix: fn,
+ Message: fmt.Sprintf(
+ "command %s is not allowed when creating a contract",
+ ast.QuoteIdentifier(ss[idx].GetVerb())),
+ }, nil)
+ }
+ }
+
+ if len(s) == 0 && len(el) == 0 {
+ el.Append(errors.Error{
+ Position: 0,
+ Length: 0,
+ Category: errors.ErrorCategoryCommand,
+ Code: errors.ErrorCodeNoCommand,
+ Severity: errors.ErrorSeverityError,
+ Prefix: fn,
+ Message: "creating a contract without a table is not allowed",
+ }, nil)
+ }
+ if len(el) != 0 {
+ return s, el
+ }
+ return s, nil
+}
+
+// CheckQuery checks and modifies SELECT commands with a given database schema.
+func CheckQuery(ss []ast.StmtNode, s schema.Schema, o CheckOptions) error {
+ fn := "CheckQuery"
+ c := newSchemaCache()
+ el := errors.ErrorList{}
+
+ for idx := range ss {
+ if ss[idx] == nil {
+ continue
+ }
+
+ switch n := ss[idx].(type) {
+ case *ast.SelectStmtNode:
+ checkSelectStmt(n, s, o, c, &el)
+ default:
+ el.Append(errors.Error{
+ Position: ss[idx].GetPosition(),
+ Length: ss[idx].GetLength(),
+ Category: errors.ErrorCategoryCommand,
+ Code: errors.ErrorCodeDisallowedCommand,
+ Severity: errors.ErrorSeverityError,
+ Prefix: fn,
+ Message: fmt.Sprintf(
+ "command %s is not allowed when calling query",
+ ast.QuoteIdentifier(ss[idx].GetVerb())),
+ }, nil)
+ }
+ }
+ if len(el) != 0 {
+ return el
+ }
+ return nil
+}
+
+// CheckExec checks and modifies UPDATE, DELETE, INSERT commands with a given
+// database schema.
+func CheckExec(ss []ast.StmtNode, s schema.Schema, o CheckOptions) error {
+ fn := "CheckExec"
+ c := newSchemaCache()
+ el := errors.ErrorList{}
+
+ for idx := range ss {
+ if ss[idx] == nil {
+ continue
+ }
+
+ switch n := ss[idx].(type) {
+ case *ast.UpdateStmtNode:
+ checkUpdateStmt(n, s, o, c, &el)
+ case *ast.DeleteStmtNode:
+ checkDeleteStmt(n, s, o, c, &el)
+ case *ast.InsertStmtNode:
+ checkInsertStmt(n, s, o, c, &el)
+ default:
+ el.Append(errors.Error{
+ Position: ss[idx].GetPosition(),
+ Length: ss[idx].GetLength(),
+ Category: errors.ErrorCategoryCommand,
+ Code: errors.ErrorCodeDisallowedCommand,
+ Severity: errors.ErrorSeverityError,
+ Prefix: fn,
+ Message: fmt.Sprintf(
+ "command %s is not allowed when calling exec",
+ ast.QuoteIdentifier(ss[idx].GetVerb())),
+ }, nil)
+ }
+ }
+ if len(el) != 0 {
+ return el
+ }
+ return nil
+}
diff --git a/core/vm/sqlvm/checkers/checkers.go b/core/vm/sqlvm/checkers/checkers.go
new file mode 100644
index 000000000..429d08a5c
--- /dev/null
+++ b/core/vm/sqlvm/checkers/checkers.go
@@ -0,0 +1,1496 @@
+package checkers
+
+import (
+ "fmt"
+ "sort"
+
+ "github.com/dexon-foundation/decimal"
+
+ "github.com/dexon-foundation/dexon/core/vm/sqlvm/ast"
+ "github.com/dexon-foundation/dexon/core/vm/sqlvm/errors"
+ "github.com/dexon-foundation/dexon/core/vm/sqlvm/schema"
+)
+
+// In addition to the convention mentioned in utils.go, we have these variable
+// names in this file:
+//
+// ftd -> foreign table descriptor
+// ftn -> foreign table name
+// fcd -> foreign column descriptor
+// fcn -> foreign column name
+// fid -> foreign index descriptor
+// fin -> foreign index name
+//
+// fmid -> first matching index descriptor
+// fmir -> first matching index reference
+// fmin -> first matching index name
+
+// findFirstMatchingIndex finds the first index in 'haystack' matching the
+// declaration of 'needle' with attributes specified in 'attrDontCare' ignored.
+// This function is considered as a part of the interface, so it have to work
+// deterministically.
+func findFirstMatchingIndex(haystack []schema.Index, needle schema.Index,
+ attrDontCare schema.IndexAttr) (schema.IndexRef, bool) {
+
+ compareAttr := func(a1, a2 schema.IndexAttr) bool {
+ a1 = a1.GetDeclaredFlags() | attrDontCare
+ a2 = a2.GetDeclaredFlags() | attrDontCare
+ return a1 == a2
+ }
+
+ compareColumns := func(c1, c2 []schema.ColumnRef) bool {
+ if len(c1) != len(c2) {
+ return false
+ }
+ for ci := range c1 {
+ if c1[ci] != c2[ci] {
+ return false
+ }
+ }
+ return true
+ }
+
+ for ii := range haystack {
+ if compareAttr(haystack[ii].Attr, needle.Attr) &&
+ compareColumns(haystack[ii].Columns, needle.Columns) {
+ ir := schema.IndexRef(ii)
+ return ir, true
+ }
+ }
+ return 0, false
+}
+
+func checkCreateTableStmt(n *ast.CreateTableStmtNode, s *schema.Schema,
+ o CheckOptions, c *schemaCache, el *errors.ErrorList) {
+
+ fn := "CheckCreateTableStmt"
+ hasError := false
+
+ if c.Begin() != 0 {
+ panic("schema cache must not have any open scope")
+ }
+ defer func() {
+ if hasError {
+ c.Rollback()
+ return
+ }
+ c.Commit()
+ }()
+
+ // Return early if there are too many tables. We cannot ignore this error
+ // because it will overflow schema.TableRef, which is used as a part of
+ // column key in schemaCache.
+ if len(*s) > schema.MaxTableRef {
+ el.Append(errors.Error{
+ Position: n.GetPosition(),
+ Length: n.GetLength(),
+ Category: errors.ErrorCategoryLimit,
+ Code: errors.ErrorCodeTooManyTables,
+ Severity: errors.ErrorSeverityError,
+ Prefix: fn,
+ Message: fmt.Sprintf("cannot have more than %d tables",
+ schema.MaxTableRef+1),
+ }, &hasError)
+ return
+ }
+
+ table := schema.Table{}
+ tr := schema.TableRef(len(*s))
+ td := schema.TableDescriptor{Table: tr}
+
+ if len(n.Table.Name) == 0 {
+ el.Append(errors.Error{
+ Position: n.Table.GetPosition(),
+ Length: n.Table.GetLength(),
+ Category: errors.ErrorCategorySemantic,
+ Code: errors.ErrorCodeEmptyTableName,
+ Severity: errors.ErrorSeverityError,
+ Prefix: fn,
+ Message: "cannot create a table with an empty name",
+ }, &hasError)
+ }
+
+ tn := n.Table.Name
+ if !c.AddTable(string(tn), td) {
+ el.Append(errors.Error{
+ Position: n.Table.GetPosition(),
+ Length: n.Table.GetLength(),
+ Category: errors.ErrorCategorySemantic,
+ Code: errors.ErrorCodeDuplicateTableName,
+ Severity: errors.ErrorSeverityError,
+ Prefix: fn,
+ Message: fmt.Sprintf("table %s already exists",
+ ast.QuoteIdentifier(tn)),
+ }, &hasError)
+ }
+ table.Name = n.Table.Name
+ table.Columns = make([]schema.Column, 0, len(n.Column))
+
+ // Handle the primary key index.
+ pk := []schema.ColumnRef{}
+ // Handle sequences.
+ seq := 0
+ // Handle indices for unique constraints.
+ type localIndex struct {
+ index schema.Index
+ node ast.Node
+ }
+ localIndices := []localIndex{}
+ // Handle indices for foreign key constraints.
+ type foreignNewIndex struct {
+ table schema.TableDescriptor
+ index schema.Index
+ node ast.Node
+ }
+ foreignNewIndices := []foreignNewIndex{}
+ type foreignExistingIndex struct {
+ index schema.IndexDescriptor
+ node ast.Node
+ }
+ foreignExistingIndices := []foreignExistingIndex{}
+
+ for ci := range n.Column {
+ if len(table.Columns) > schema.MaxColumnRef {
+ el.Append(errors.Error{
+ Position: n.Column[ci].GetPosition(),
+ Length: n.Column[ci].GetLength(),
+ Category: errors.ErrorCategoryLimit,
+ Code: errors.ErrorCodeTooManyColumns,
+ Severity: errors.ErrorSeverityError,
+ Prefix: fn,
+ Message: fmt.Sprintf("cannot have more than %d columns",
+ schema.MaxColumnRef+1),
+ }, &hasError)
+ return
+ }
+
+ column := schema.Column{}
+ ok := func() (ok bool) {
+ innerHasError := false
+ defer func() { ok = !innerHasError }()
+
+ // Block access to the outer hasError variable.
+ hasError := struct{}{}
+ // Suppress “declared and not used” error.
+ _ = hasError
+
+ c.Begin()
+ defer func() {
+ if innerHasError {
+ c.Rollback()
+ return
+ }
+ c.Commit()
+ }()
+
+ cr := schema.ColumnRef(len(table.Columns))
+ cd := schema.ColumnDescriptor{Table: tr, Column: cr}
+
+ if len(n.Column[ci].Column.Name) == 0 {
+ el.Append(errors.Error{
+ Position: n.Column[ci].Column.GetPosition(),
+ Length: n.Column[ci].Column.GetLength(),
+ Category: errors.ErrorCategorySemantic,
+ Code: errors.ErrorCodeEmptyColumnName,
+ Severity: errors.ErrorSeverityError,
+ Prefix: fn,
+ Message: "cannot declare a column with an empty name",
+ }, &innerHasError)
+ }
+
+ cn := n.Column[ci].Column.Name
+ if !c.AddColumn(string(cn), cd) {
+ el.Append(errors.Error{
+ Position: n.Column[ci].Column.GetPosition(),
+ Length: n.Column[ci].Column.GetLength(),
+ Category: errors.ErrorCategorySemantic,
+ Code: errors.ErrorCodeDuplicateColumnName,
+ Severity: errors.ErrorSeverityError,
+ Prefix: fn,
+ Message: fmt.Sprintf("column %s already exists",
+ ast.QuoteIdentifier(cn)),
+ }, &innerHasError)
+ } else {
+ column.Name = n.Column[ci].Column.Name
+ }
+
+ dt, code, message := n.Column[ci].DataType.GetType()
+ if code == errors.ErrorCodeNil {
+ if !dt.ValidColumn() {
+ el.Append(errors.Error{
+ Position: n.Column[ci].DataType.GetPosition(),
+ Length: n.Column[ci].DataType.GetLength(),
+ Category: errors.ErrorCategorySemantic,
+ Code: errors.ErrorCodeInvalidColumnDataType,
+ Severity: errors.ErrorSeverityError,
+ Prefix: fn,
+ Message: fmt.Sprintf(
+ "cannot declare a column with type %s", dt.String()),
+ }, &innerHasError)
+ }
+ } else {
+ el.Append(errors.Error{
+ Position: n.Column[ci].DataType.GetPosition(),
+ Length: n.Column[ci].DataType.GetLength(),
+ Category: errors.ErrorCategorySemantic,
+ Code: code,
+ Severity: errors.ErrorSeverityError,
+ Prefix: fn,
+ Message: message,
+ }, &innerHasError)
+ }
+ column.Type = dt
+
+ // Backup lengths of slices in case we have to rollback. We don't
+ // have to copy slice headers or data stored in underlying arrays
+ // because we always append data at the end.
+ defer func(LPK, SEQ, LLI, LFNI, LFEI int) {
+ if innerHasError {
+ pk = pk[:LPK]
+ seq = SEQ
+ localIndices = localIndices[:LLI]
+ foreignNewIndices = foreignNewIndices[:LFNI]
+ foreignExistingIndices = foreignExistingIndices[:LFEI]
+ }
+ }(
+ len(pk), seq, len(localIndices), len(foreignNewIndices),
+ len(foreignExistingIndices),
+ )
+
+ // cs -> constraint
+ // csi -> constraint index
+ for csi := range n.Column[ci].Constraint {
+ // Cases are sorted in the same order as internal/grammar.peg.
+ cs:
+ switch cs := n.Column[ci].Constraint[csi].(type) {
+ case *ast.PrimaryOptionNode:
+ pk = append(pk, cr)
+ column.Attr |= schema.ColumnAttrPrimaryKey
+
+ case *ast.NotNullOptionNode:
+ column.Attr |= schema.ColumnAttrNotNull
+
+ case *ast.UniqueOptionNode:
+ if (column.Attr & schema.ColumnAttrUnique) != 0 {
+ // Don't create duplicate indices on a column.
+ break cs
+ }
+ column.Attr |= schema.ColumnAttrUnique
+ indexName := fmt.Sprintf("%s_%s_unique",
+ table.Name, column.Name)
+ idx := schema.Index{
+ Name: []byte(indexName),
+ Attr: schema.IndexAttrUnique,
+ Columns: []schema.ColumnRef{cr},
+ }
+ localIndices = append(localIndices, localIndex{
+ index: idx,
+ node: cs,
+ })
+
+ case *ast.DefaultOptionNode:
+ if (column.Attr & schema.ColumnAttrHasDefault) != 0 {
+ el.Append(errors.Error{
+ Position: cs.GetPosition(),
+ Length: cs.GetLength(),
+ Category: errors.ErrorCategorySemantic,
+ Code: errors.ErrorCodeMultipleDefaultValues,
+ Severity: errors.ErrorSeverityError,
+ Prefix: fn,
+ Message: "cannot have multiple default values",
+ }, &innerHasError)
+ break cs
+ }
+ column.Attr |= schema.ColumnAttrHasDefault
+
+ value := checkExpr(cs.Value, *s, o|CheckWithConstantOnly,
+ c, el, 0, newTypeActionAssign(column.Type))
+ if value == nil {
+ innerHasError = true
+ break cs
+ }
+ cs.Value = value
+
+ switch v := cs.Value.(ast.Valuer).(type) {
+ case *ast.BoolValueNode:
+ sb := v.V.NullBool()
+ if !sb.Valid {
+ el.Append(errors.Error{
+ Position: cs.GetPosition(),
+ Length: cs.GetLength(),
+ Category: errors.ErrorCategorySemantic,
+ Code: errors.ErrorCodeNullDefaultValue,
+ Severity: errors.ErrorSeverityError,
+ Prefix: fn,
+ Message: "default value must not be NULL",
+ }, &innerHasError)
+ break cs
+ }
+ column.Default = sb.Bool
+
+ case *ast.AddressValueNode:
+ column.Default = v.V
+
+ case *ast.IntegerValueNode:
+ column.Default = v.V
+
+ case *ast.DecimalValueNode:
+ column.Default = v.V
+
+ case *ast.BytesValueNode:
+ column.Default = v.V
+
+ case *ast.NullValueNode:
+ el.Append(errors.Error{
+ Position: cs.GetPosition(),
+ Length: cs.GetLength(),
+ Category: errors.ErrorCategorySemantic,
+ Code: errors.ErrorCodeNullDefaultValue,
+ Severity: errors.ErrorSeverityError,
+ Prefix: fn,
+ Message: "default value must not be NULL",
+ }, &innerHasError)
+ break cs
+
+ default:
+ panic(fmt.Sprintf("unknown constant value type %T", n))
+ }
+
+ case *ast.ForeignOptionNode:
+ if len(column.ForeignKeys) > schema.MaxForeignKeys {
+ el.Append(errors.Error{
+ Position: cs.GetPosition(),
+ Length: cs.GetLength(),
+ Category: errors.ErrorCategoryLimit,
+ Code: errors.ErrorCodeTooManyForeignKeys,
+ Severity: errors.ErrorSeverityError,
+ Prefix: fn,
+ Message: fmt.Sprintf(
+ "cannot have more than %d foreign key "+
+ "constraints in a column",
+ schema.MaxForeignKeys+1),
+ }, &innerHasError)
+ break cs
+ }
+ column.Attr |= schema.ColumnAttrHasForeignKey
+ ftn := cs.Table.Name
+ ftd, found := c.FindTableInBase(string(ftn))
+ if !found {
+ el.Append(errors.Error{
+ Position: cs.Table.GetPosition(),
+ Length: cs.Table.GetLength(),
+ Category: errors.ErrorCategorySemantic,
+ Code: errors.ErrorCodeTableNotFound,
+ Severity: errors.ErrorSeverityError,
+ Prefix: fn,
+ Message: fmt.Sprintf(
+ "foreign table %s does not exist",
+ ast.QuoteIdentifier(ftn)),
+ }, &innerHasError)
+ break cs
+ }
+ fcn := cs.Column.Name
+ fcd, found := c.FindColumnInBase(ftd.Table, string(fcn))
+ if !found {
+ el.Append(errors.Error{
+ Position: cs.Column.GetPosition(),
+ Length: cs.Column.GetLength(),
+ Category: errors.ErrorCategorySemantic,
+ Code: errors.ErrorCodeColumnNotFound,
+ Severity: errors.ErrorSeverityError,
+ Prefix: fn,
+ Message: fmt.Sprintf(
+ "column %s does not exist in foreign table %s",
+ ast.QuoteIdentifier(fcn),
+ ast.QuoteIdentifier(ftn)),
+ }, &innerHasError)
+ break cs
+ }
+ foreignType := (*s)[fcd.Table].Columns[fcd.Column].Type
+ if !foreignType.Equal(column.Type) {
+ el.Append(errors.Error{
+ Position: n.Column[ci].DataType.GetPosition(),
+ Length: n.Column[ci].DataType.GetLength(),
+ Category: errors.ErrorCategorySemantic,
+ Code: errors.ErrorCodeForeignKeyDataTypeMismatch,
+ Severity: errors.ErrorSeverityError,
+ Prefix: fn,
+ Message: fmt.Sprintf(
+ "foreign column has type %s (%04x), but "+
+ "this column has type %s (%04x)",
+ foreignType.String(), uint16(foreignType),
+ column.Type.String(), uint16(column.Type)),
+ }, &innerHasError)
+ break cs
+ }
+
+ idx := schema.Index{
+ Attr: schema.IndexAttrReferenced,
+ Columns: []schema.ColumnRef{fcd.Column},
+ }
+ fmir, found := findFirstMatchingIndex(
+ (*s)[ftd.Table].Indices, idx, schema.IndexAttrUnique)
+ if found {
+ fmid := schema.IndexDescriptor{
+ Table: ftd.Table,
+ Index: fmir,
+ }
+ foreignExistingIndices = append(
+ foreignExistingIndices, foreignExistingIndex{
+ index: fmid,
+ node: cs,
+ })
+ } else {
+ if len(column.ForeignKeys) > 0 {
+ idx.Name = []byte(fmt.Sprintf("%s_%s_foreign_key_%d",
+ table.Name, column.Name, len(column.ForeignKeys)))
+ } else {
+ idx.Name = []byte(fmt.Sprintf("%s_%s_foreign_key",
+ table.Name, column.Name))
+ }
+ foreignNewIndices = append(
+ foreignNewIndices, foreignNewIndex{
+ table: ftd,
+ index: idx,
+ node: cs,
+ })
+ }
+ column.ForeignKeys = append(column.ForeignKeys, fcd)
+
+ case *ast.AutoIncrementOptionNode:
+ if (column.Attr & schema.ColumnAttrHasSequence) != 0 {
+ // Don't process AUTOINCREMENT twice.
+ break cs
+ }
+ // We set the flag regardless of the error because we
+ // don't want to produce duplicate errors.
+ column.Attr |= schema.ColumnAttrHasSequence
+ if seq > schema.MaxSequenceRef {
+ el.Append(errors.Error{
+ Position: cs.GetPosition(),
+ Length: cs.GetLength(),
+ Category: errors.ErrorCategoryLimit,
+ Code: errors.ErrorCodeTooManySequences,
+ Severity: errors.ErrorSeverityError,
+ Prefix: fn,
+ Message: fmt.Sprintf(
+ "cannot have more than %d sequences",
+ schema.MaxSequenceRef+1),
+ }, &innerHasError)
+ break cs
+ }
+ major, _ := ast.DecomposeDataType(column.Type)
+ switch major {
+ case ast.DataTypeMajorInt, ast.DataTypeMajorUint:
+ default:
+ el.Append(errors.Error{
+ Position: cs.GetPosition(),
+ Length: cs.GetLength(),
+ Category: errors.ErrorCategorySemantic,
+ Code: errors.ErrorCodeInvalidAutoIncrementDataType,
+ Prefix: fn,
+ Message: fmt.Sprintf(
+ "AUTOINCREMENT is only supported on "+
+ "INT and UINT types, but this column "+
+ "has type %s (%04x)",
+ column.Type.String(), uint16(column.Type)),
+ }, &innerHasError)
+ break cs
+ }
+ column.Sequence = schema.SequenceRef(seq)
+ seq++
+
+ default:
+ panic(fmt.Sprintf("unknown constraint type %T", c))
+ }
+ }
+
+ // The return value will be set by the first defer function.
+ return
+ }()
+
+ // If an error occurs in the function, stop here and continue
+ // processing the next column.
+ if !ok {
+ hasError = true
+ continue
+ }
+
+ // Commit the column.
+ table.Columns = append(table.Columns, column)
+ }
+
+ // Return early if there is any error.
+ if hasError {
+ return
+ }
+
+ mustAddIndex := func(name *[]byte, id schema.IndexDescriptor) {
+ for !c.AddIndex(string(*name), id, true) {
+ *name = append(*name, '_')
+ }
+ }
+
+ // Create the primary key index. This is the first index on the table, so
+ // it is not possible to exceed the limit on the number of indices.
+ ir := schema.IndexRef(len(table.Indices))
+ if len(pk) > 0 {
+ idx := schema.Index{
+ Name: []byte(fmt.Sprintf("%s_primary_key", table.Name)),
+ Attr: schema.IndexAttrUnique,
+ Columns: pk,
+ }
+ id := schema.IndexDescriptor{Table: tr, Index: ir}
+ mustAddIndex(&idx.Name, id)
+ table.Indices = append(table.Indices, idx)
+ }
+
+ // Create indices for the current table.
+ for ii := range localIndices {
+ if len(table.Indices) > schema.MaxIndexRef {
+ el.Append(errors.Error{
+ Position: localIndices[ii].node.GetPosition(),
+ Length: localIndices[ii].node.GetLength(),
+ Category: errors.ErrorCategoryLimit,
+ Code: errors.ErrorCodeTooManyIndices,
+ Severity: errors.ErrorSeverityError,
+ Prefix: fn,
+ Message: fmt.Sprintf("cannot have more than %d indices",
+ schema.MaxIndexRef+1),
+ }, &hasError)
+ return
+ }
+ idx := localIndices[ii].index
+ ir := schema.IndexRef(len(table.Indices))
+ id := schema.IndexDescriptor{Table: tr, Index: ir}
+ mustAddIndex(&idx.Name, id)
+ table.Indices = append(table.Indices, idx)
+ }
+
+ // Create indices for foreign tables.
+ for ii := range foreignNewIndices {
+ ftd := foreignNewIndices[ii].table
+ if len((*s)[ftd.Table].Indices) > schema.MaxIndexRef {
+ el.Append(errors.Error{
+ Position: foreignNewIndices[ii].node.GetPosition(),
+ Length: foreignNewIndices[ii].node.GetLength(),
+ Category: errors.ErrorCategoryLimit,
+ Code: errors.ErrorCodeTooManyIndices,
+ Severity: errors.ErrorSeverityError,
+ Prefix: fn,
+ Message: fmt.Sprintf(
+ "table %s already has %d indices",
+ ast.QuoteIdentifier((*s)[ftd.Table].Name),
+ schema.MaxIndexRef+1),
+ }, &hasError)
+ return
+ }
+ idx := foreignNewIndices[ii].index
+ ir := schema.IndexRef(len((*s)[ftd.Table].Indices))
+ id := schema.IndexDescriptor{Table: ftd.Table, Index: ir}
+ mustAddIndex(&idx.Name, id)
+ (*s)[ftd.Table].Indices = append((*s)[ftd.Table].Indices, idx)
+ defer func(tr schema.TableRef, length schema.IndexRef) {
+ if hasError {
+ (*s)[tr].Indices = (*s)[tr].Indices[:ir]
+ }
+ }(ftd.Table, ir)
+ }
+
+ // Mark existing indices as referenced.
+ for ii := range foreignExistingIndices {
+ fid := foreignExistingIndices[ii].index
+ (*s)[fid.Table].Indices[fid.Index].Attr |= schema.IndexAttrReferenced
+ }
+
+ // Finally, we can commit the table definition to the schema.
+ *s = append(*s, table)
+}
+
+func checkCreateIndexStmt(n *ast.CreateIndexStmtNode, s *schema.Schema,
+ o CheckOptions, c *schemaCache, el *errors.ErrorList) {
+
+ fn := "CheckCreateIndexStmt"
+ hasError := false
+
+ if c.Begin() != 0 {
+ panic("schema cache must not have any open scope")
+ }
+ defer func() {
+ if hasError {
+ c.Rollback()
+ return
+ }
+ c.Commit()
+ }()
+
+ tn := n.Table.Name
+ td, found := c.FindTableInBase(string(tn))
+ if !found {
+ el.Append(errors.Error{
+ Position: n.Table.GetPosition(),
+ Length: n.Table.GetLength(),
+ Category: errors.ErrorCategorySemantic,
+ Code: errors.ErrorCodeTableNotFound,
+ Severity: errors.ErrorSeverityError,
+ Prefix: fn,
+ Message: fmt.Sprintf(
+ "index table %s does not exist",
+ ast.QuoteIdentifier(tn)),
+ }, &hasError)
+ return
+ }
+
+ if len(n.Column) > schema.MaxColumnRef {
+ begin := n.Column[0].GetPosition()
+ last := len(n.Column) - 1
+ end := n.Column[last].GetPosition() + n.Column[last].GetLength()
+ el.Append(errors.Error{
+ Position: begin,
+ Length: end - begin,
+ Category: errors.ErrorCategoryLimit,
+ Code: errors.ErrorCodeTooManyColumns,
+ Severity: errors.ErrorSeverityError,
+ Prefix: fn,
+ Message: fmt.Sprintf(
+ "cannot create an index on more than %d columns",
+ schema.MaxColumnRef+1),
+ }, &hasError)
+ return
+ }
+
+ columnRefs := newColumnRefSlice(uint8(len(n.Column)))
+ for ci := range n.Column {
+ cn := n.Column[ci].Name
+ cd, found := c.FindColumnInBase(td.Table, string(cn))
+ if !found {
+ el.Append(errors.Error{
+ Position: n.Column[ci].GetPosition(),
+ Length: n.Column[ci].GetLength(),
+ Category: errors.ErrorCategorySemantic,
+ Code: errors.ErrorCodeColumnNotFound,
+ Severity: errors.ErrorSeverityError,
+ Prefix: fn,
+ Message: fmt.Sprintf(
+ "column %s does not exist in index table %s",
+ ast.QuoteIdentifier(cn),
+ ast.QuoteIdentifier(tn)),
+ }, &hasError)
+ continue
+ }
+ columnRefs.Append(cd.Column, uint8(ci))
+ }
+ if hasError {
+ return
+ }
+
+ sort.Stable(columnRefs)
+ for ci := 1; ci < len(n.Column); ci++ {
+ if columnRefs.columns[ci] == columnRefs.columns[ci-1] {
+ el.Append(errors.Error{
+ Position: n.Column[columnRefs.nodes[ci]].GetPosition(),
+ Length: n.Column[columnRefs.nodes[ci]].GetLength(),
+ Category: errors.ErrorCategorySemantic,
+ Code: errors.ErrorCodeDuplicateIndexColumn,
+ Severity: errors.ErrorSeverityError,
+ Prefix: fn,
+ Message: fmt.Sprintf(
+ "column %s already exists in the column list",
+ ast.QuoteIdentifier(n.Column[columnRefs.nodes[ci]].Name)),
+ }, &hasError)
+ return
+ }
+ }
+
+ index := schema.Index{}
+ index.Columns = columnRefs.columns
+
+ if len((*s)[td.Table].Indices) > schema.MaxIndexRef {
+ el.Append(errors.Error{
+ Position: n.GetPosition(),
+ Length: n.GetLength(),
+ Category: errors.ErrorCategoryLimit,
+ Code: errors.ErrorCodeTooManyIndices,
+ Severity: errors.ErrorSeverityError,
+ Prefix: fn,
+ Message: fmt.Sprintf(
+ "cannot have more than %d indices in table %s",
+ schema.MaxIndexRef+1,
+ ast.QuoteIdentifier(tn)),
+ }, &hasError)
+ return
+ }
+
+ ir := schema.IndexRef(len((*s)[td.Table].Indices))
+ id := schema.IndexDescriptor{Table: td.Table, Index: ir}
+ if n.Unique != nil {
+ index.Attr |= schema.IndexAttrUnique
+ }
+
+ if len(n.Index.Name) == 0 {
+ el.Append(errors.Error{
+ Position: n.Index.GetPosition(),
+ Length: n.Table.GetLength(),
+ Category: errors.ErrorCategorySemantic,
+ Code: errors.ErrorCodeEmptyIndexName,
+ Severity: errors.ErrorSeverityError,
+ Prefix: fn,
+ Message: "cannot create an index with an empty name",
+ }, &hasError)
+ return
+ }
+
+ // If there is an existing index that is automatically created, rename it
+ // instead of creating a new one.
+ rename := false
+ fmir, found := findFirstMatchingIndex((*s)[id.Table].Indices, index, 0)
+ if found {
+ fmid := schema.IndexDescriptor{Table: id.Table, Index: fmir}
+ fmin := (*s)[id.Table].Indices[fmir].Name
+ fminString := string(fmin)
+ fmidCache, auto, found := c.FindIndexInBase(fminString)
+ if !found {
+ panic(fmt.Sprintf("index %s exists in the schema, "+
+ "but it cannot be found in the schema cache",
+ ast.QuoteIdentifier(fmin)))
+ }
+ if fmidCache != fmid {
+ panic(fmt.Sprintf("index %s has descriptor %+v, "+
+ "but the schema cache records it as %+v",
+ ast.QuoteIdentifier(fmin), fmid, fmidCache))
+ }
+ if auto {
+ if !c.DeleteIndex(fminString) {
+ panic(fmt.Sprintf("unable to mark index %s for deletion",
+ ast.QuoteIdentifier(fmin)))
+ }
+ rename = true
+ id = fmid
+ ir = id.Index
+ }
+ }
+
+ in := n.Index.Name
+ if !c.AddIndex(string(in), id, false) {
+ el.Append(errors.Error{
+ Position: n.Index.GetPosition(),
+ Length: n.Index.GetLength(),
+ Category: errors.ErrorCategorySemantic,
+ Code: errors.ErrorCodeDuplicateIndexName,
+ Severity: errors.ErrorSeverityError,
+ Prefix: fn,
+ Message: fmt.Sprintf("index %s already exists",
+ ast.QuoteIdentifier(in)),
+ }, &hasError)
+ return
+ }
+
+ // Commit the change into the schema.
+ if rename {
+ (*s)[id.Table].Indices[id.Index].Name = n.Index.Name
+ } else {
+ index.Name = n.Index.Name
+ (*s)[id.Table].Indices = append((*s)[id.Table].Indices, index)
+ }
+}
+
+func checkSelectStmt(n *ast.SelectStmtNode, s schema.Schema,
+ o CheckOptions, c *schemaCache, el *errors.ErrorList) {
+}
+
+func checkUpdateStmt(n *ast.UpdateStmtNode, s schema.Schema,
+ o CheckOptions, c *schemaCache, el *errors.ErrorList) {
+}
+
+func checkDeleteStmt(n *ast.DeleteStmtNode, s schema.Schema,
+ o CheckOptions, c *schemaCache, el *errors.ErrorList) {
+}
+
+func checkInsertStmt(n *ast.InsertStmtNode, s schema.Schema,
+ o CheckOptions, c *schemaCache, el *errors.ErrorList) {
+}
+
+func checkExpr(n ast.ExprNode,
+ s schema.Schema, o CheckOptions, c *schemaCache, el *errors.ErrorList,
+ tr schema.TableRef, ta typeAction) ast.ExprNode {
+
+ switch n := n.(type) {
+ case *ast.IdentifierNode:
+ return checkVariable(n, s, o, c, el, tr, ta)
+
+ case *ast.BoolValueNode:
+ return checkBoolValue(n, o, el, ta)
+
+ case *ast.AddressValueNode:
+ return checkAddressValue(n, o, el, ta)
+
+ case *ast.IntegerValueNode:
+ return checkIntegerValue(n, o, el, ta)
+
+ case *ast.DecimalValueNode:
+ return checkDecimalValue(n, o, el, ta)
+
+ case *ast.BytesValueNode:
+ return checkBytesValue(n, o, el, ta)
+
+ case *ast.NullValueNode:
+ return checkNullValue(n, o, el, ta)
+
+ case *ast.PosOperatorNode:
+ return checkPosOperator(n, s, o, c, el, tr, ta)
+
+ case *ast.NegOperatorNode:
+ return n
+
+ case *ast.NotOperatorNode:
+ return n
+
+ case *ast.ParenOperatorNode:
+ return n
+
+ case *ast.AndOperatorNode:
+ return n
+
+ case *ast.OrOperatorNode:
+ return n
+
+ case *ast.GreaterOrEqualOperatorNode:
+ return n
+
+ case *ast.LessOrEqualOperatorNode:
+ return n
+
+ case *ast.NotEqualOperatorNode:
+ return n
+
+ case *ast.EqualOperatorNode:
+ return n
+
+ case *ast.GreaterOperatorNode:
+ return n
+
+ case *ast.LessOperatorNode:
+ return n
+
+ case *ast.ConcatOperatorNode:
+ return n
+
+ case *ast.AddOperatorNode:
+ return n
+
+ case *ast.SubOperatorNode:
+ return n
+
+ case *ast.MulOperatorNode:
+ return n
+
+ case *ast.DivOperatorNode:
+ return n
+
+ case *ast.ModOperatorNode:
+ return n
+
+ case *ast.IsOperatorNode:
+ return n
+
+ case *ast.LikeOperatorNode:
+ return n
+
+ case *ast.CastOperatorNode:
+ return n
+
+ case *ast.InOperatorNode:
+ return n
+
+ case *ast.FunctionOperatorNode:
+ return n
+
+ default:
+ panic(fmt.Sprintf("unknown expression type %T", n))
+ }
+}
+
+func checkVariable(n *ast.IdentifierNode,
+ s schema.Schema, o CheckOptions, c *schemaCache, el *errors.ErrorList,
+ tr schema.TableRef, ta typeAction) ast.ExprNode {
+
+ fn := "CheckVariable"
+
+ if (o & CheckWithConstantOnly) != 0 {
+ el.Append(errors.Error{
+ Position: n.GetPosition(),
+ Length: n.GetLength(),
+ Category: errors.ErrorCategorySemantic,
+ Code: errors.ErrorCodeNonConstantExpression,
+ Severity: errors.ErrorSeverityError,
+ Prefix: fn,
+ Message: fmt.Sprintf("%s is not a constant",
+ ast.QuoteIdentifier(n.Name)),
+ }, nil)
+ return nil
+ }
+
+ cn := string(n.Name)
+ cd, found := c.FindColumnInBase(tr, cn)
+ if !found {
+ el.Append(errors.Error{
+ Position: n.GetPosition(),
+ Length: n.GetLength(),
+ Category: errors.ErrorCategorySemantic,
+ Code: errors.ErrorCodeColumnNotFound,
+ Severity: errors.ErrorSeverityError,
+ Prefix: fn,
+ Message: fmt.Sprintf(
+ "cannot find column %s in table %s",
+ ast.QuoteIdentifier(n.Name),
+ ast.QuoteIdentifier(s[tr].Name)),
+ }, nil)
+ return nil
+ }
+
+ cr := cd.Column
+ dt := s[tr].Columns[cr].Type
+ switch a := ta.(type) {
+ case typeActionInferDefault:
+ case typeActionInferWithSize:
+ case typeActionAssign:
+ if !dt.Equal(a.dt) {
+ el.Append(errors.Error{
+ Position: n.GetPosition(),
+ Length: n.GetLength(),
+ Category: errors.ErrorCategorySemantic,
+ Code: errors.ErrorCodeTypeError,
+ Severity: errors.ErrorSeverityError,
+ Prefix: fn,
+ Message: fmt.Sprintf(
+ "expect %s (%04x), but %s (%04x) is given",
+ a.dt.String(), uint16(a.dt), dt.String(), uint16(dt)),
+ }, nil)
+ return nil
+ }
+ }
+
+ n.SetType(dt)
+ n.Desc = cd
+ return n
+}
+
+func checkBoolValue(n *ast.BoolValueNode,
+ o CheckOptions, el *errors.ErrorList, ta typeAction) ast.ExprNode {
+
+ fn := "CheckBoolValue"
+
+ switch a := ta.(type) {
+ case typeActionInferDefault:
+ case typeActionInferWithSize:
+ case typeActionAssign:
+ major, _ := ast.DecomposeDataType(a.dt)
+ if major != ast.DataTypeMajorBool {
+ el.Append(errors.Error{
+ Position: n.GetPosition(),
+ Length: n.GetLength(),
+ Category: errors.ErrorCategorySemantic,
+ Code: errors.ErrorCodeTypeError,
+ Severity: errors.ErrorSeverityError,
+ Prefix: fn,
+ Message: fmt.Sprintf(
+ "expect %s (%04x), but boolean value is given",
+ a.dt.String(), uint16(a.dt)),
+ }, nil)
+ return nil
+ }
+ }
+ return n
+}
+
+func checkAddressValue(n *ast.AddressValueNode,
+ o CheckOptions, el *errors.ErrorList, ta typeAction) ast.ExprNode {
+
+ fn := "CheckAddressValue"
+
+ switch a := ta.(type) {
+ case typeActionInferDefault:
+ case typeActionInferWithSize:
+ case typeActionAssign:
+ major, _ := ast.DecomposeDataType(a.dt)
+ if major != ast.DataTypeMajorAddress {
+ el.Append(errors.Error{
+ Position: n.GetPosition(),
+ Length: n.GetLength(),
+ Category: errors.ErrorCategorySemantic,
+ Code: errors.ErrorCodeTypeError,
+ Severity: errors.ErrorSeverityError,
+ Prefix: fn,
+ Message: fmt.Sprintf(
+ "expect %s (%04x), but address value is given",
+ a.dt.String(), uint16(a.dt)),
+ }, nil)
+ return nil
+ }
+ }
+ return n
+}
+
+func mustGetMinMax(dt ast.DataType) (decimal.Decimal, decimal.Decimal) {
+ min, max, ok := dt.GetMinMax()
+ if !ok {
+ panic(fmt.Sprintf("GetMinMax does not handle %v", dt))
+ }
+ return min, max
+}
+
+func mustDecimalEncode(dt ast.DataType, d decimal.Decimal) []byte {
+ b, ok := ast.DecimalEncode(dt, d)
+ if !ok {
+ panic(fmt.Sprintf("DecimalEncode does not handle %v", dt))
+ }
+ return b
+}
+
+func mustDecimalDecode(dt ast.DataType, b []byte) decimal.Decimal {
+ d, ok := ast.DecimalDecode(dt, b)
+ if !ok {
+ panic(fmt.Sprintf("DecimalDecode does not handle %v", dt))
+ }
+ return d
+}
+
+func cropDecimal(dt ast.DataType, d decimal.Decimal) decimal.Decimal {
+ b := mustDecimalEncode(dt, d)
+ return mustDecimalDecode(dt, b)
+}
+
+func elAppendConstantTooLongError(el *errors.ErrorList, n ast.Node,
+ fn string, v decimal.Decimal) {
+ el.Append(errors.Error{
+ Position: n.GetPosition(),
+ Length: n.GetLength(),
+ Category: errors.ErrorCategorySemantic,
+ Code: errors.ErrorCodeConstantTooLong,
+ Severity: errors.ErrorSeverityError,
+ Prefix: fn,
+ Message: fmt.Sprintf(
+ "constant expression %s has more than %d digits",
+ ast.QuoteString(n.GetToken()), MaxIntegerPartDigits),
+ }, nil)
+}
+
+func elAppendOverflowError(el *errors.ErrorList, n ast.Node,
+ fn string, dt ast.DataType, v, min, max decimal.Decimal) {
+ el.Append(errors.Error{
+ Position: n.GetPosition(),
+ Length: n.GetLength(),
+ Category: errors.ErrorCategorySemantic,
+ Code: errors.ErrorCodeOverflow,
+ Severity: errors.ErrorSeverityError,
+ Prefix: fn,
+ Message: fmt.Sprintf(
+ "number %s (%s) overflow %s (%04x)",
+ v.String(), ast.QuoteString(n.GetToken()),
+ dt.String(), uint16(dt)),
+ }, nil)
+ el.Append(errors.Error{
+ Position: n.GetPosition(),
+ Length: n.GetLength(),
+ Category: 0,
+ Code: 0,
+ Severity: errors.ErrorSeverityNote,
+ Prefix: fn,
+ Message: fmt.Sprintf(
+ "the range of %s is [%s, %s]",
+ dt.String(), min.String(), max.String()),
+ }, nil)
+}
+
+func elAppendOverflowWarning(el *errors.ErrorList, n ast.Node,
+ fn string, dt ast.DataType, from, to decimal.Decimal) {
+ el.Append(errors.Error{
+ Position: n.GetPosition(),
+ Length: n.GetLength(),
+ Category: 0,
+ Code: 0,
+ Severity: errors.ErrorSeverityWarning,
+ Prefix: fn,
+ Message: fmt.Sprintf(
+ "number %s (%s) overflow %s (%04x), converted to %s",
+ from.String(), ast.QuoteString(n.GetToken()),
+ dt.String(), uint16(dt), to.String()),
+ }, nil)
+}
+
+func checkIntegerValue(n *ast.IntegerValueNode,
+ o CheckOptions, el *errors.ErrorList, ta typeAction) ast.ExprNode {
+
+ fn := "CheckIntegerValue"
+
+ normalizeDecimal(&n.V)
+ if !safeDecimalRange(n.V) {
+ elAppendConstantTooLongError(el, n, fn, n.V)
+ return nil
+ }
+
+ infer := func(size int) (ast.DataType, bool) {
+ // The first case: assume V fits in the signed integer.
+ minor := ast.DataTypeMinor(size - 1)
+ dt := ast.ComposeDataType(ast.DataTypeMajorInt, minor)
+ min, max := mustGetMinMax(dt)
+ // Return if V < min. V must be negative so it must be signed.
+ if n.V.LessThan(min) {
+ if (o & CheckWithSafeMath) != 0 {
+ elAppendOverflowError(el, n, fn, dt, n.V, min, max)
+ return dt, false
+ }
+ cropped := cropDecimal(dt, n.V)
+ elAppendOverflowWarning(el, n, fn, dt, n.V, cropped)
+ normalizeDecimal(&cropped)
+ n.V = cropped
+ return dt, true
+ }
+ // We are done if V fits in the signed integer.
+ if n.V.LessThanOrEqual(max) {
+ return dt, true
+ }
+
+ // The second case: V is a non-negative integer, but it does not fit
+ // in the signed integer. Test whether the unsigned integer works.
+ dt = ast.ComposeDataType(ast.DataTypeMajorUint, minor)
+ min, max = mustGetMinMax(dt)
+ // Return if V > max. We don't have to test whether V < min because min
+ // is always zero and we already know V is non-negative.
+ if n.V.GreaterThan(max) {
+ if (o & CheckWithSafeMath) != 0 {
+ elAppendOverflowError(el, n, fn, dt, n.V, min, max)
+ return dt, false
+ }
+ cropped := cropDecimal(dt, n.V)
+ elAppendOverflowWarning(el, n, fn, dt, n.V, cropped)
+ normalizeDecimal(&cropped)
+ n.V = cropped
+ return dt, true
+ }
+ return dt, true
+ }
+
+ dt := ast.DataTypePending
+ switch a := ta.(type) {
+ case typeActionInferDefault:
+ // Default to int256 or uint256.
+ var ok bool
+ dt, ok = infer(256 / 8)
+ if !ok {
+ return nil
+ }
+
+ case typeActionInferWithSize:
+ var ok bool
+ dt, ok = infer(a.size)
+ if !ok {
+ return nil
+ }
+
+ case typeActionAssign:
+ dt = a.dt
+ major, _ := ast.DecomposeDataType(dt)
+ switch {
+ case major == ast.DataTypeMajorAddress:
+ if !n.IsAddress {
+ el.Append(errors.Error{
+ Position: n.GetPosition(),
+ Length: n.GetLength(),
+ Category: errors.ErrorCategorySemantic,
+ Code: errors.ErrorCodeInvalidAddressChecksum,
+ Severity: errors.ErrorSeverityError,
+ Prefix: fn,
+ Message: fmt.Sprintf(
+ "expect %s (%04x), but %s has invalid checksum",
+ dt.String(), uint16(dt), n.GetToken()),
+ }, nil)
+ return nil
+ }
+ // Redirect to checkAddressValue if it becomes an address.
+ addrNode := &ast.AddressValueNode{}
+ addrNode.SetPosition(addrNode.GetPosition())
+ addrNode.SetLength(addrNode.GetLength())
+ addrNode.SetToken(addrNode.GetToken())
+ addrNode.V = mustDecimalEncode(ast.ComposeDataType(
+ ast.DataTypeMajorUint, ast.DataTypeMinor(160/8-1)), n.V)
+ return checkAddressValue(addrNode, o, el, ta)
+
+ case major == ast.DataTypeMajorInt,
+ major == ast.DataTypeMajorUint,
+ major.IsFixedRange(),
+ major.IsUfixedRange():
+ min, max := mustGetMinMax(dt)
+ if n.V.LessThan(min) || n.V.GreaterThan(max) {
+ if (o & CheckWithSafeMath) != 0 {
+ elAppendOverflowError(el, n, fn, dt, n.V, min, max)
+ return nil
+ }
+ cropped := cropDecimal(dt, n.V)
+ elAppendOverflowWarning(el, n, fn, dt, n.V, cropped)
+ normalizeDecimal(&cropped)
+ n.V = cropped
+ }
+
+ default:
+ el.Append(errors.Error{
+ Position: n.GetPosition(),
+ Length: n.GetLength(),
+ Category: errors.ErrorCategorySemantic,
+ Code: errors.ErrorCodeTypeError,
+ Severity: errors.ErrorSeverityError,
+ Prefix: fn,
+ Message: fmt.Sprintf(
+ "expect %s (%04x), but number value is given",
+ dt.String(), uint16(dt)),
+ }, nil)
+ return nil
+ }
+ }
+
+ if dt != ast.DataTypePending {
+ n.SetType(dt)
+ }
+ return n
+}
+
+func checkDecimalValue(n *ast.DecimalValueNode,
+ o CheckOptions, el *errors.ErrorList, ta typeAction) ast.ExprNode {
+
+ fn := "CheckDecimalValue"
+
+ normalizeDecimal(&n.V)
+ if !safeDecimalRange(n.V) {
+ elAppendConstantTooLongError(el, n, fn, n.V)
+ return nil
+ }
+
+ // Redirect to checkIntegerValue if the value is an integer.
+ if intPart := n.V.Truncate(0); n.V.Equal(intPart) {
+ intNode := &ast.IntegerValueNode{}
+ intNode.SetPosition(n.GetPosition())
+ intNode.SetLength(n.GetLength())
+ intNode.SetToken(n.GetToken())
+ intNode.IsAddress = false
+ intNode.V = n.V
+ return checkIntegerValue(intNode, o, el, ta)
+ }
+
+ infer := func(size, fractionalDigits int) (ast.DataType, bool) {
+ major := ast.DataTypeMajorFixed + ast.DataTypeMajor(size-1)
+ minor := ast.DataTypeMinor(fractionalDigits)
+ dt := ast.ComposeDataType(major, minor)
+ min, max := mustGetMinMax(dt)
+ if n.V.LessThan(min) {
+ if (o & CheckWithSafeMath) != 0 {
+ elAppendOverflowError(el, n, fn, dt, n.V, min, max)
+ return dt, false
+ }
+ cropped := cropDecimal(dt, n.V)
+ elAppendOverflowWarning(el, n, fn, dt, n.V, cropped)
+ normalizeDecimal(&cropped)
+ n.V = cropped
+ return dt, false
+ }
+ if n.V.LessThanOrEqual(max) {
+ return dt, true
+ }
+
+ major = ast.DataTypeMajorUfixed + ast.DataTypeMajor(size-1)
+ minor = ast.DataTypeMinor(fractionalDigits)
+ dt = ast.ComposeDataType(major, minor)
+ min, max = mustGetMinMax(dt)
+ if n.V.GreaterThan(max) {
+ if (o & CheckWithSafeMath) != 0 {
+ elAppendOverflowError(el, n, fn, dt, n.V, min, max)
+ return dt, false
+ }
+ cropped := cropDecimal(dt, n.V)
+ elAppendOverflowWarning(el, n, fn, dt, n.V, cropped)
+ normalizeDecimal(&cropped)
+ n.V = cropped
+ return dt, true
+ }
+ return dt, true
+ }
+
+ // Now we are sure the number we are dealing has fractional part.
+ dt := ast.DataTypePending
+ switch a := ta.(type) {
+ case typeActionInferDefault:
+ // Default to fixed128x18 and ufixed128x18.
+ var ok bool
+ dt, ok = infer(128/8, 18)
+ if !ok {
+ return nil
+ }
+
+ case typeActionInferWithSize:
+ // It is unclear that what the size hint means for fixed-point numbers,
+ // so we just ignore it and do the same thing as the above case.
+ var ok bool
+ dt, ok = infer(128/8, 18)
+ if !ok {
+ return nil
+ }
+
+ case typeActionAssign:
+ dt = a.dt
+ major, _ := ast.DecomposeDataType(dt)
+ switch {
+ case major.IsFixedRange(),
+ major.IsUfixedRange():
+ min, max := mustGetMinMax(dt)
+ if n.V.LessThan(min) || n.V.GreaterThan(max) {
+ if (o & CheckWithSafeMath) != 0 {
+ elAppendOverflowError(el, n, fn, dt, n.V, min, max)
+ return nil
+ }
+ cropped := cropDecimal(dt, n.V)
+ elAppendOverflowWarning(el, n, fn, dt, n.V, cropped)
+ normalizeDecimal(&cropped)
+ n.V = cropped
+ }
+
+ case major == ast.DataTypeMajorInt,
+ major == ast.DataTypeMajorUint:
+ el.Append(errors.Error{
+ Position: n.GetPosition(),
+ Length: n.GetLength(),
+ Category: errors.ErrorCategorySemantic,
+ Code: errors.ErrorCodeTypeError,
+ Severity: errors.ErrorSeverityError,
+ Prefix: fn,
+ Message: fmt.Sprintf(
+ "expect %s (%04x), but the number %s has fractional part",
+ dt.String(), uint16(dt), n.V.String()),
+ }, nil)
+ return nil
+
+ default:
+ el.Append(errors.Error{
+ Position: n.GetPosition(),
+ Length: n.GetLength(),
+ Category: errors.ErrorCategorySemantic,
+ Code: errors.ErrorCodeTypeError,
+ Severity: errors.ErrorSeverityError,
+ Prefix: fn,
+ Message: fmt.Sprintf(
+ "expect %s (%04x), but number value is given",
+ dt.String(), uint16(dt)),
+ }, nil)
+ return nil
+ }
+ }
+
+ if dt != ast.DataTypePending {
+ n.SetType(dt)
+ _, minor := ast.DecomposeDataType(dt)
+ fractionalDigits := int32(minor)
+ n.V = n.V.Round(fractionalDigits)
+ normalizeDecimal(&n.V)
+ }
+ return n
+}
+
+func checkBytesValue(n *ast.BytesValueNode,
+ o CheckOptions, el *errors.ErrorList, ta typeAction) ast.ExprNode {
+
+ fn := "CheckBytesValue"
+
+ dt := ast.DataTypePending
+ switch a := ta.(type) {
+ case typeActionInferDefault:
+ // Default to bytes.
+ major := ast.DataTypeMajorDynamicBytes
+ minor := ast.DataTypeMinorDontCare
+ dt = ast.ComposeDataType(major, minor)
+
+ case typeActionInferWithSize:
+ // Still default to bytes. The size hint does not matter at all.
+ major := ast.DataTypeMajorDynamicBytes
+ minor := ast.DataTypeMinorDontCare
+ dt = ast.ComposeDataType(major, minor)
+
+ case typeActionAssign:
+ dt = a.dt
+ major, minor := ast.DecomposeDataType(dt)
+ switch major {
+ case ast.DataTypeMajorDynamicBytes:
+ // Do nothing because it is always valid.
+
+ case ast.DataTypeMajorFixedBytes:
+ sizeGiven := len(n.V)
+ sizeExpected := int(minor) + 1
+ if sizeGiven != sizeExpected {
+ el.Append(errors.Error{
+ Position: n.GetPosition(),
+ Length: n.GetLength(),
+ Category: errors.ErrorCategorySemantic,
+ Code: errors.ErrorCodeTypeError,
+ Severity: errors.ErrorSeverityError,
+ Prefix: fn,
+ Message: fmt.Sprintf(
+ "expect %s (%04x), but %s has %d bytes",
+ dt.String(), uint16(dt),
+ ast.QuoteString(n.V), sizeGiven),
+ }, nil)
+ return nil
+ }
+
+ default:
+ el.Append(errors.Error{
+ Position: n.GetPosition(),
+ Length: n.GetLength(),
+ Category: errors.ErrorCategorySemantic,
+ Code: errors.ErrorCodeTypeError,
+ Severity: errors.ErrorSeverityError,
+ Prefix: fn,
+ Message: fmt.Sprintf(
+ "expect %s (%04x), but bytes value is given",
+ dt.String(), uint16(dt)),
+ }, nil)
+ return nil
+ }
+ }
+
+ if dt != ast.DataTypePending {
+ n.SetType(dt)
+ }
+ return n
+}
+
+func checkNullValue(n *ast.NullValueNode,
+ o CheckOptions, el *errors.ErrorList, ta typeAction) ast.ExprNode {
+
+ dt := ast.DataTypePending
+ switch a := ta.(type) {
+ case typeActionInferDefault:
+ dt = ast.DataTypeNull
+ case typeActionInferWithSize:
+ dt = ast.DataTypeNull
+ case typeActionAssign:
+ dt = a.dt
+ }
+
+ if dt != ast.DataTypePending {
+ n.SetType(dt)
+ }
+ return n
+}
+
+func checkPosOperator(n *ast.PosOperatorNode,
+ s schema.Schema, o CheckOptions, c *schemaCache, el *errors.ErrorList,
+ tr schema.TableRef, ta typeAction) ast.ExprNode {
+
+ /*
+ fn := "CheckPosOperator"
+
+ switch a := ta.(type) {
+ case typeActionInferDefault:
+ case typeActionInferWithSize:
+ case typeActionAssign:
+ }
+ */
+ return n
+}
diff --git a/core/vm/sqlvm/checkers/utils.go b/core/vm/sqlvm/checkers/utils.go
new file mode 100644
index 000000000..4a8bbd96e
--- /dev/null
+++ b/core/vm/sqlvm/checkers/utils.go
@@ -0,0 +1,471 @@
+package checkers
+
+import (
+ "github.com/dexon-foundation/decimal"
+
+ "github.com/dexon-foundation/dexon/core/vm/sqlvm/ast"
+ "github.com/dexon-foundation/dexon/core/vm/sqlvm/schema"
+)
+
+// Variable name convention:
+//
+// fn -> function name
+// el -> error list
+//
+// td -> table descriptor
+// tr -> table reference
+// ti -> table index
+// tn -< table name
+//
+// cd -> column descriptor
+// cr -> column reference
+// ci -> column index
+// cn -> column name
+//
+// id -> index descriptor
+// ir -> index reference
+// ii -> index index
+// in -> index name
+
+const (
+ MaxIntegerPartDigits int32 = 200
+ MaxFractionalPartDigits int32 = 200
+)
+
+var (
+ MaxConstant = func() decimal.Decimal {
+ max := (decimal.New(1, MaxIntegerPartDigits).
+ Sub(decimal.New(1, -MaxFractionalPartDigits)))
+ normalizeDecimal(&max)
+ return max
+ }()
+ MinConstant = MaxConstant.Neg()
+)
+
+func normalizeDecimal(d *decimal.Decimal) {
+ if d.Exponent() != -MaxFractionalPartDigits {
+ *d = d.Rescale(-MaxFractionalPartDigits)
+ }
+}
+
+func safeDecimalRange(d decimal.Decimal) bool {
+ return d.GreaterThanOrEqual(MinConstant) && d.LessThanOrEqual(MaxConstant)
+}
+
+type schemaCache struct {
+ base schemaCacheBase
+ scopes []schemaCacheScope
+}
+
+type schemaCacheIndexValue struct {
+ id schema.IndexDescriptor
+ auto bool
+}
+
+type schemaCacheColumnKey struct {
+ tr schema.TableRef
+ n string
+}
+
+type schemaCacheBase struct {
+ table map[string]schema.TableDescriptor
+ index map[string]schemaCacheIndexValue
+ column map[schemaCacheColumnKey]schema.ColumnDescriptor
+}
+
+func (lower *schemaCacheBase) Merge(upper schemaCacheScope) {
+ // Process deletions.
+ for n := range upper.tableDeleted {
+ delete(lower.table, n)
+ }
+ for n := range upper.indexDeleted {
+ delete(lower.index, n)
+ }
+ for ck := range upper.columnDeleted {
+ delete(lower.column, ck)
+ }
+
+ // Process additions.
+ for n, td := range upper.table {
+ lower.table[n] = td
+ }
+ for n, iv := range upper.index {
+ lower.index[n] = iv
+ }
+ for ck, cd := range upper.column {
+ lower.column[ck] = cd
+ }
+}
+
+type schemaCacheScope struct {
+ table map[string]schema.TableDescriptor
+ tableDeleted map[string]struct{}
+ index map[string]schemaCacheIndexValue
+ indexDeleted map[string]struct{}
+ column map[schemaCacheColumnKey]schema.ColumnDescriptor
+ columnDeleted map[schemaCacheColumnKey]struct{}
+}
+
+func (lower *schemaCacheScope) Merge(upper schemaCacheScope) {
+ // Process deletions.
+ for n := range upper.tableDeleted {
+ delete(lower.table, n)
+ lower.tableDeleted[n] = struct{}{}
+ }
+ for n := range upper.indexDeleted {
+ delete(lower.index, n)
+ lower.indexDeleted[n] = struct{}{}
+ }
+ for ck := range upper.columnDeleted {
+ delete(lower.column, ck)
+ lower.columnDeleted[ck] = struct{}{}
+ }
+
+ // Process additions.
+ for n, td := range upper.table {
+ lower.table[n] = td
+ }
+ for n, iv := range upper.index {
+ lower.index[n] = iv
+ }
+ for ck, cd := range upper.column {
+ lower.column[ck] = cd
+ }
+}
+
+func newSchemaCache() *schemaCache {
+ return &schemaCache{
+ base: schemaCacheBase{
+ table: map[string]schema.TableDescriptor{},
+ index: map[string]schemaCacheIndexValue{},
+ column: map[schemaCacheColumnKey]schema.ColumnDescriptor{},
+ },
+ }
+}
+
+func (c *schemaCache) Begin() int {
+ position := len(c.scopes)
+ scope := schemaCacheScope{
+ table: map[string]schema.TableDescriptor{},
+ tableDeleted: map[string]struct{}{},
+ index: map[string]schemaCacheIndexValue{},
+ indexDeleted: map[string]struct{}{},
+ column: map[schemaCacheColumnKey]schema.ColumnDescriptor{},
+ columnDeleted: map[schemaCacheColumnKey]struct{}{},
+ }
+ c.scopes = append(c.scopes, scope)
+ return position
+}
+
+func (c *schemaCache) Rollback() {
+ if len(c.scopes) == 0 {
+ panic("there is no scope to rollback")
+ }
+ c.scopes = c.scopes[:len(c.scopes)-1]
+}
+
+func (c *schemaCache) RollbackTo(position int) {
+ for position <= len(c.scopes) {
+ c.Rollback()
+ }
+}
+
+func (c *schemaCache) Commit() {
+ if len(c.scopes) == 0 {
+ panic("there is no scope to commit")
+ }
+ if len(c.scopes) == 1 {
+ c.base.Merge(c.scopes[0])
+ } else {
+ src := len(c.scopes) - 1
+ dst := len(c.scopes) - 2
+ c.scopes[dst].Merge(c.scopes[src])
+ }
+ c.scopes = c.scopes[:len(c.scopes)-1]
+}
+
+func (c *schemaCache) CommitTo(position int) {
+ for position <= len(c.scopes) {
+ c.Commit()
+ }
+}
+
+func (c *schemaCache) FindTableInBase(n string) (
+ schema.TableDescriptor, bool) {
+
+ td, exists := c.base.table[n]
+ return td, exists
+}
+
+func (c *schemaCache) FindTableInScope(n string) (
+ schema.TableDescriptor, bool) {
+
+ for si := range c.scopes {
+ si = len(c.scopes) - si - 1
+ if td, exists := c.scopes[si].table[n]; exists {
+ return td, true
+ }
+ if _, exists := c.scopes[si].tableDeleted[n]; exists {
+ return schema.TableDescriptor{}, false
+ }
+ }
+ return c.FindTableInBase(n)
+}
+
+func (c *schemaCache) FindTableInBaseWithFallback(n string,
+ fallback schema.Schema) (schema.TableDescriptor, bool) {
+
+ if td, found := c.FindTableInBase(n); found {
+ return td, true
+ }
+ if fallback == nil {
+ return schema.TableDescriptor{}, false
+ }
+
+ s := fallback
+ for ti := range s {
+ if n == string(s[ti].Name) {
+ tr := schema.TableRef(ti)
+ td := schema.TableDescriptor{Table: tr}
+ c.base.table[n] = td
+ return td, true
+ }
+ }
+ return schema.TableDescriptor{}, false
+}
+
+func (c *schemaCache) FindIndexInBase(n string) (
+ schema.IndexDescriptor, bool, bool) {
+
+ iv, exists := c.base.index[n]
+ return iv.id, iv.auto, exists
+}
+
+func (c *schemaCache) FindIndexInScope(n string) (
+ schema.IndexDescriptor, bool, bool) {
+
+ for si := range c.scopes {
+ si = len(c.scopes) - si - 1
+ if iv, exists := c.scopes[si].index[n]; exists {
+ return iv.id, iv.auto, true
+ }
+ if _, exists := c.scopes[si].indexDeleted[n]; exists {
+ return schema.IndexDescriptor{}, false, false
+ }
+ }
+ return c.FindIndexInBase(n)
+}
+
+func (c *schemaCache) FindIndexInBaseWithFallback(n string,
+ fallback schema.Schema) (schema.IndexDescriptor, bool, bool) {
+
+ if id, auto, found := c.FindIndexInBase(n); found {
+ return id, auto, true
+ }
+ if fallback == nil {
+ return schema.IndexDescriptor{}, false, false
+ }
+
+ s := fallback
+ for ti := range s {
+ for ii := range s[ti].Indices {
+ if n == string(s[ti].Indices[ii].Name) {
+ tr := schema.TableRef(ti)
+ ir := schema.IndexRef(ii)
+ id := schema.IndexDescriptor{Table: tr, Index: ir}
+ iv := schemaCacheIndexValue{id: id, auto: false}
+ c.base.index[n] = iv
+ return id, false, true
+ }
+ }
+ }
+ return schema.IndexDescriptor{}, false, false
+}
+
+func (c *schemaCache) FindColumnInBase(tr schema.TableRef, n string) (
+ schema.ColumnDescriptor, bool) {
+
+ cd, exists := c.base.column[schemaCacheColumnKey{tr: tr, n: n}]
+ return cd, exists
+}
+
+func (c *schemaCache) FindColumnInScope(tr schema.TableRef, n string) (
+ schema.ColumnDescriptor, bool) {
+
+ ck := schemaCacheColumnKey{tr: tr, n: n}
+ for si := range c.scopes {
+ si = len(c.scopes) - si - 1
+ if cd, exists := c.scopes[si].column[ck]; exists {
+ return cd, true
+ }
+ if _, exists := c.scopes[si].columnDeleted[ck]; exists {
+ return schema.ColumnDescriptor{}, false
+ }
+ }
+ return c.FindColumnInBase(tr, n)
+}
+
+func (c *schemaCache) FindColumnInBaseWithFallback(tr schema.TableRef, n string,
+ fallback schema.Schema) (schema.ColumnDescriptor, bool) {
+
+ if cd, found := c.FindColumnInBase(tr, n); found {
+ return cd, true
+ }
+ if fallback == nil {
+ return schema.ColumnDescriptor{}, false
+ }
+
+ s := fallback
+ for ci := range s[tr].Columns {
+ if n == string(s[tr].Columns[ci].Name) {
+ cr := schema.ColumnRef(ci)
+ cd := schema.ColumnDescriptor{Table: tr, Column: cr}
+ ck := schemaCacheColumnKey{tr: tr, n: n}
+ c.base.column[ck] = cd
+ return cd, true
+ }
+ }
+ return schema.ColumnDescriptor{}, false
+}
+
+func (c *schemaCache) AddTable(n string,
+ td schema.TableDescriptor) bool {
+
+ top := len(c.scopes) - 1
+ if _, found := c.FindTableInScope(n); found {
+ return false
+ }
+
+ c.scopes[top].table[n] = td
+ return true
+}
+
+func (c *schemaCache) AddIndex(n string,
+ id schema.IndexDescriptor, auto bool) bool {
+
+ top := len(c.scopes) - 1
+ if _, _, found := c.FindIndexInScope(n); found {
+ return false
+ }
+
+ iv := schemaCacheIndexValue{id: id, auto: auto}
+ c.scopes[top].index[n] = iv
+ return true
+}
+
+func (c *schemaCache) AddColumn(n string,
+ cd schema.ColumnDescriptor) bool {
+
+ top := len(c.scopes) - 1
+ tr := cd.Table
+ if _, found := c.FindColumnInScope(tr, n); found {
+ return false
+ }
+
+ ck := schemaCacheColumnKey{tr: tr, n: n}
+ c.scopes[top].column[ck] = cd
+ return true
+}
+
+func (c *schemaCache) DeleteTable(n string) bool {
+ top := len(c.scopes) - 1
+ if _, found := c.FindTableInScope(n); !found {
+ return false
+ }
+
+ delete(c.scopes[top].table, n)
+ c.scopes[top].tableDeleted[n] = struct{}{}
+ return true
+}
+
+func (c *schemaCache) DeleteIndex(n string) bool {
+ top := len(c.scopes) - 1
+ if _, _, found := c.FindIndexInScope(n); !found {
+ return false
+ }
+
+ delete(c.scopes[top].index, n)
+ c.scopes[top].indexDeleted[n] = struct{}{}
+ return true
+}
+
+func (c *schemaCache) DeleteColumn(tr schema.TableRef, n string) bool {
+ top := len(c.scopes) - 1
+ if _, found := c.FindColumnInScope(tr, n); !found {
+ return false
+ }
+
+ ck := schemaCacheColumnKey{tr: tr, n: n}
+ delete(c.scopes[top].column, ck)
+ c.scopes[top].columnDeleted[ck] = struct{}{}
+ return true
+}
+
+type columnRefSlice struct {
+ columns []schema.ColumnRef
+ nodes []uint8
+}
+
+func newColumnRefSlice(c uint8) columnRefSlice {
+ return columnRefSlice{
+ columns: make([]schema.ColumnRef, 0, c),
+ nodes: make([]uint8, 0, c),
+ }
+}
+
+func (s *columnRefSlice) Append(c schema.ColumnRef, i uint8) {
+ s.columns = append(s.columns, c)
+ s.nodes = append(s.nodes, i)
+}
+
+func (s columnRefSlice) Len() int {
+ return len(s.columns)
+}
+
+func (s columnRefSlice) Less(i, j int) bool {
+ return s.columns[i] < s.columns[j]
+}
+
+func (s columnRefSlice) Swap(i, j int) {
+ s.columns[i], s.columns[j] = s.columns[j], s.columns[i]
+ s.nodes[i], s.nodes[j] = s.nodes[j], s.nodes[i]
+}
+
+//go-sumtype:decl typeAction
+type typeAction interface {
+ ˉtypeAction()
+}
+
+type typeActionInferDefault struct{}
+
+func newTypeActionInferDefaultSize() typeActionInferDefault {
+ return typeActionInferDefault{}
+}
+
+var _ typeAction = typeActionInferDefault{}
+
+func (typeActionInferDefault) ˉtypeAction() {}
+
+type typeActionInferWithSize struct {
+ size int
+}
+
+func newTypeActionInferWithSize(bytes int) typeActionInferWithSize {
+ return typeActionInferWithSize{size: bytes}
+}
+
+var _ typeAction = typeActionInferWithSize{}
+
+func (typeActionInferWithSize) ˉtypeAction() {}
+
+type typeActionAssign struct {
+ dt ast.DataType
+}
+
+func newTypeActionAssign(expected ast.DataType) typeActionAssign {
+ return typeActionAssign{dt: expected}
+}
+
+var _ typeAction = typeActionAssign{}
+
+func (typeActionAssign) ˉtypeAction() {}
diff --git a/core/vm/sqlvm/cmd/ast-checker/main.go b/core/vm/sqlvm/cmd/ast-checker/main.go
new file mode 100644
index 000000000..c02b58f0f
--- /dev/null
+++ b/core/vm/sqlvm/cmd/ast-checker/main.go
@@ -0,0 +1,123 @@
+package main
+
+import (
+ "bytes"
+ "encoding/hex"
+ "flag"
+ "fmt"
+ "os"
+
+ "github.com/dexon-foundation/dexon/core/vm/sqlvm/checkers"
+ "github.com/dexon-foundation/dexon/core/vm/sqlvm/parser"
+ "github.com/dexon-foundation/dexon/core/vm/sqlvm/schema"
+ "github.com/dexon-foundation/dexon/rlp"
+)
+
+func create(sql string, o checkers.CheckOptions) int {
+ n, parseErr := parser.Parse([]byte(sql))
+ if parseErr != nil {
+ fmt.Fprintf(os.Stderr, "Parse error:\n%+v\n", parseErr)
+ }
+ s, checkErr := checkers.CheckCreate(n, o)
+ if checkErr != nil {
+ fmt.Fprintf(os.Stderr, "Check error:\n%+v\n", checkErr)
+ }
+ b := bytes.Buffer{}
+ rlpErr := rlp.Encode(&b, s)
+ if rlpErr != nil {
+ fmt.Fprintf(os.Stderr, "RLP encode error: %v\n", rlpErr)
+ return 1
+ }
+ fmt.Println(hex.EncodeToString(b.Bytes()))
+ if parseErr != nil || checkErr != nil {
+ return 1
+ }
+ return 0
+}
+
+func decode(ss string) int {
+ b, hexErr := hex.DecodeString(ss)
+ if hexErr != nil {
+ fmt.Fprintf(os.Stderr, "Hex decode error: %v\n", hexErr)
+ return 1
+ }
+ s := schema.Schema{}
+ rlpErr := rlp.Decode(bytes.NewReader(b), &s)
+ if rlpErr != nil {
+ fmt.Fprintf(os.Stderr, "RLP decode error: %v\n", rlpErr)
+ return 1
+ }
+ s.SetupColumnOffset()
+ fmt.Print(s.String())
+ return 0
+}
+
+func query(ss, sql string, o checkers.CheckOptions) int {
+ fmt.Fprintln(os.Stderr, "Function not implemented")
+ return 1
+}
+
+func exec(ss, sql string, o checkers.CheckOptions) int {
+ fmt.Fprintln(os.Stderr, "Function not implemented")
+ return 1
+}
+
+func main() {
+ var noSafeMath bool
+ var noSafeCast bool
+ flag.BoolVar(&noSafeMath, "no-safe-math", false, "disable safe math")
+ flag.BoolVar(&noSafeCast, "no-safe-cast", false, "disable safe cast")
+
+ flag.Parse()
+
+ if flag.NArg() < 1 {
+ fmt.Fprintf(os.Stderr,
+ "Usage: %s <action> <arguments>\n"+
+ "Actions:\n"+
+ " create <SQL> returns schema\n"+
+ " decode <schema> returns SQL\n"+
+ " query <schema> <SQL> returns AST\n"+
+ " exec <schema> <SQL> returns AST\n",
+ os.Args[0])
+ os.Exit(1)
+ }
+
+ o := checkers.CheckWithSafeMath | checkers.CheckWithSafeCast
+ if noSafeMath {
+ o &= ^(checkers.CheckWithSafeMath)
+ }
+ if noSafeCast {
+ o &= ^(checkers.CheckWithSafeCast)
+ }
+
+ action := flag.Arg(0)
+ switch action {
+ case "create":
+ if flag.NArg() < 2 {
+ fmt.Fprintln(os.Stderr, "create needs 1 argument")
+ os.Exit(1)
+ }
+ os.Exit(create(flag.Arg(1), o))
+ case "decode":
+ if flag.NArg() < 2 {
+ fmt.Fprintln(os.Stderr, "decode needs 1 argument")
+ os.Exit(1)
+ }
+ os.Exit(decode(flag.Arg(1)))
+ case "query":
+ if flag.NArg() < 3 {
+ fmt.Fprintln(os.Stderr, "query needs 2 arguments")
+ os.Exit(1)
+ }
+ os.Exit(query(flag.Arg(1), flag.Arg(2), o))
+ case "exec":
+ if flag.NArg() < 3 {
+ fmt.Fprintln(os.Stderr, "exec needs 2 arguments")
+ os.Exit(1)
+ }
+ os.Exit(exec(flag.Arg(1), flag.Arg(2), o))
+ default:
+ fmt.Fprintf(os.Stderr, "Invalid action %s\n", action)
+ os.Exit(1)
+ }
+}
diff --git a/core/vm/sqlvm/errors/errors.go b/core/vm/sqlvm/errors/errors.go
index c2e9dbab2..dcc8e6a48 100644
--- a/core/vm/sqlvm/errors/errors.go
+++ b/core/vm/sqlvm/errors/errors.go
@@ -58,17 +58,44 @@ func (e Error) Error() string {
// ErrorList is a list of Error.
type ErrorList []Error
-func (e ErrorList) Error() string {
+func (el ErrorList) Error() string {
b := strings.Builder{}
- for i := range e {
+ for i, e := range el {
if i > 0 {
b.WriteByte('\n')
}
- b.WriteString(e[i].Error())
+ b.WriteString(e.Error())
}
return b.String()
}
+func (el ErrorList) HasError() bool {
+ for _, e := range el {
+ if e.Severity == ErrorSeverityError {
+ return true
+ }
+ }
+ return false
+}
+
+func (el ErrorList) GetFirstError() (Error, bool) {
+ for _, e := range el {
+ if e.Severity == ErrorSeverityError {
+ return e, true
+ }
+ }
+ return Error{}, false
+}
+
+func (el *ErrorList) Append(e Error, hasError *bool) {
+ *el = append(*el, e)
+ if hasError != nil {
+ if e.Severity == ErrorSeverityError {
+ *hasError = true
+ }
+ }
+}
+
// ErrorCategory is used to distinguish errors come from different phases.
type ErrorCategory uint16
@@ -77,6 +104,7 @@ const (
ErrorCategoryNil ErrorCategory = iota
ErrorCategoryLimit
ErrorCategoryGrammar
+ ErrorCategoryCommand
ErrorCategorySemantic
ErrorCategoryRuntime
)
@@ -84,6 +112,7 @@ const (
var errorCategoryMap = [...]string{
ErrorCategoryLimit: "limit",
ErrorCategoryGrammar: "grammar",
+ ErrorCategoryCommand: "command",
ErrorCategorySemantic: "semantic",
ErrorCategoryRuntime: "runtime",
}
@@ -98,8 +127,14 @@ type ErrorCode uint16
// Error code starts from 1. Zero value is invalid.
const (
ErrorCodeNil ErrorCode = iota
- ErrorCodeDepthLimitReached
ErrorCodeParser
+ ErrorCodeDepthLimitReached
+ ErrorCodeTooManyTables
+ ErrorCodeTooManyColumns
+ ErrorCodeTooManyIndices
+ ErrorCodeTooManyForeignKeys
+ ErrorCodeTooManySequences
+ ErrorCodeTooManySelectColumns
ErrorCodeInvalidIntegerSyntax
ErrorCodeInvalidNumberSyntax
ErrorCodeIntegerOutOfRange
@@ -115,6 +150,26 @@ const (
ErrorCodeInvalidUfixedSize
ErrorCodeInvalidFixedFractionalDigits
ErrorCodeInvalidUfixedFractionalDigits
+ ErrorCodeNoCommand
+ ErrorCodeDisallowedCommand
+ ErrorCodeEmptyTableName
+ ErrorCodeDuplicateTableName
+ ErrorCodeEmptyColumnName
+ ErrorCodeDuplicateColumnName
+ ErrorCodeTableNotFound
+ ErrorCodeColumnNotFound
+ ErrorCodeInvalidColumnDataType
+ ErrorCodeForeignKeyDataTypeMismatch
+ ErrorCodeNullDefaultValue
+ ErrorCodeMultipleDefaultValues
+ ErrorCodeInvalidAutoIncrementDataType
+ ErrorCodeEmptyIndexName
+ ErrorCodeDuplicateIndexName
+ ErrorCodeDuplicateIndexColumn
+ ErrorCodeTypeError
+ ErrorCodeConstantTooLong
+ ErrorCodeNonConstantExpression
+ ErrorCodeInvalidAddressChecksum
// Runtime Error
ErrorCodeInvalidOperandNum
@@ -132,6 +187,12 @@ const (
var errorCodeMap = [...]string{
ErrorCodeDepthLimitReached: "depth limit reached",
+ ErrorCodeTooManyTables: "too many tables",
+ ErrorCodeTooManyColumns: "too many columns",
+ ErrorCodeTooManyIndices: "too many indices",
+ ErrorCodeTooManyForeignKeys: "too many foreign keys",
+ ErrorCodeTooManySequences: "too many sequences",
+ ErrorCodeTooManySelectColumns: "too many select columns",
ErrorCodeParser: "parser error",
ErrorCodeInvalidIntegerSyntax: "invalid integer syntax",
ErrorCodeInvalidNumberSyntax: "invalid number syntax",
@@ -148,6 +209,27 @@ var errorCodeMap = [...]string{
ErrorCodeInvalidUfixedSize: "invalid ufixed size",
ErrorCodeInvalidFixedFractionalDigits: "invalid fixed fractional digits",
ErrorCodeInvalidUfixedFractionalDigits: "invalid ufixed fractional digits",
+ ErrorCodeNoCommand: "no command",
+ ErrorCodeDisallowedCommand: "disallowed command",
+ ErrorCodeEmptyTableName: "empty table name",
+ ErrorCodeDuplicateTableName: "duplicate table name",
+ ErrorCodeEmptyColumnName: "empty column name",
+ ErrorCodeDuplicateColumnName: "duplicate column name",
+ ErrorCodeTableNotFound: "table not found",
+ ErrorCodeColumnNotFound: "column not found",
+ ErrorCodeInvalidColumnDataType: "invalid column data type",
+ ErrorCodeForeignKeyDataTypeMismatch: "foreign key data type mismatch",
+ ErrorCodeNullDefaultValue: "null default value",
+ ErrorCodeMultipleDefaultValues: "multiple default values",
+ ErrorCodeInvalidAutoIncrementDataType: "invalid auto increment data type",
+ ErrorCodeEmptyIndexName: "empty index name",
+ ErrorCodeDuplicateIndexName: "duplicate index name",
+ ErrorCodeDuplicateIndexColumn: "duplicate index column",
+ ErrorCodeTypeError: "type error",
+ ErrorCodeConstantTooLong: "constant too long",
+ ErrorCodeNonConstantExpression: "non-constant expression",
+ ErrorCodeInvalidAddressChecksum: "invalid address checksum",
+
// Runtime Error
ErrorCodeInvalidOperandNum: "invalid operand number",
ErrorCodeInvalidDataType: "invalid data type",
diff --git a/core/vm/sqlvm/schema/schema.go b/core/vm/sqlvm/schema/schema.go
index 71122cc26..6d279aba0 100644
--- a/core/vm/sqlvm/schema/schema.go
+++ b/core/vm/sqlvm/schema/schema.go
@@ -2,7 +2,10 @@ package schema
import (
"errors"
+ "fmt"
"io"
+ "math"
+ "strings"
"github.com/dexon-foundation/decimal"
@@ -66,21 +69,39 @@ func (a ColumnAttr) GetDerivedFlags() ColumnAttr {
// FunctionRef defines the type for number of builtin function.
type FunctionRef uint16
+// MaxFunctionRef is the maximum value of FunctionRef.
+const MaxFunctionRef = math.MaxUint16
+
// TableRef defines the type for table index in Schema.
type TableRef uint8
+// MaxTableRef is the maximum value of TableRef.
+const MaxTableRef = math.MaxUint8
+
// ColumnRef defines the type for column index in Table.Columns.
type ColumnRef uint8
+// MaxColumnRef is the maximum value of ColumnRef.
+const MaxColumnRef = math.MaxUint8
+
// IndexRef defines the type for array index of Column.Indices.
type IndexRef uint8
+// MaxIndexRef is the maximum value of IndexRef.
+const MaxIndexRef = math.MaxUint8
+
// SequenceRef defines the type for sequence index in Table.
type SequenceRef uint8
+// MaxSequenceRef is the maximum value of SequenceRef.
+const MaxSequenceRef = math.MaxUint8
+
// SelectColumnRef defines the type for column index in SelectStmtNode.Column.
type SelectColumnRef uint16
+// MaxSelectColumnRef is the maximum value of SelectColumnRef.
+const MaxSelectColumnRef = math.MaxUint16
+
// IndexAttr defines bit flags for describing index attribute.
type IndexAttr uint16
@@ -116,6 +137,111 @@ func (s Schema) SetupColumnOffset() {
}
}
+func (s Schema) String() string {
+ b := strings.Builder{}
+ b.WriteString(fmt.Sprintf(
+ "-- DEXON SQLVM database schema dump (%d tables)\n", len(s)))
+ for _, t := range s {
+ b.WriteString("CREATE TABLE ")
+ b.Write(ast.QuoteIdentifierOptional(t.Name))
+ b.WriteString(" (\n")
+ for ci, c := range t.Columns {
+ b.WriteString(" ")
+ b.Write(ast.QuoteIdentifierOptional(c.Name))
+ b.WriteByte(' ')
+ b.WriteString(c.Type.String())
+ comments := []string{
+ fmt.Sprintf("slot %d", c.SlotOffset),
+ fmt.Sprintf("byte %d", c.ByteOffset),
+ }
+ if (c.Attr & ColumnAttrPrimaryKey) != 0 {
+ b.WriteString(" PRIMARY KEY")
+ }
+ if (c.Attr & ColumnAttrNotNull) != 0 {
+ b.WriteString(" NOT NULL")
+ }
+ if (c.Attr & ColumnAttrUnique) != 0 {
+ b.WriteString(" UNIQUE")
+ }
+ if (c.Attr & ColumnAttrHasDefault) != 0 {
+ b.WriteString(" DEFAULT ")
+ switch v := c.Default.(type) {
+ case nil:
+ b.WriteString("NULL")
+ case bool:
+ if v {
+ b.WriteString("TRUE")
+ } else {
+ b.WriteString("FALSE")
+ }
+ case []byte:
+ major, _ := ast.DecomposeDataType(c.Type)
+ if major == ast.DataTypeMajorAddress {
+ b.WriteString(common.BytesToAddress(v).String())
+ break
+ }
+ b.Write(ast.QuoteString(v))
+ case decimal.Decimal:
+ b.WriteString(v.String())
+ default:
+ b.WriteString("<?>")
+ }
+ }
+ if (c.Attr & ColumnAttrHasForeignKey) != 0 {
+ for _, fk := range c.ForeignKeys {
+ b.WriteString(" REFERENCES ")
+ b.Write(ast.QuoteIdentifierOptional(
+ s[fk.Table].Name))
+ b.WriteByte('(')
+ b.Write(ast.QuoteIdentifierOptional(
+ s[fk.Table].Columns[fk.Column].Name))
+ b.WriteByte(')')
+ }
+ }
+ if (c.Attr & ColumnAttrHasSequence) != 0 {
+ b.WriteString(" AUTOINCREMENT")
+ comments = append(comments,
+ fmt.Sprintf("sequence %d", c.Sequence))
+ }
+ if ci < len(t.Columns)-1 {
+ b.WriteByte(',')
+ }
+ b.WriteString(" -- ")
+ b.WriteString(strings.Join(comments, ", "))
+ b.WriteByte('\n')
+ }
+ b.WriteString(");\n")
+ for _, i := range t.Indices {
+ comments := []string{}
+ b.WriteString("CREATE")
+ if (i.Attr & IndexAttrUnique) != 0 {
+ b.WriteString(" UNIQUE")
+ }
+ if (i.Attr & IndexAttrReferenced) != 0 {
+ comments = append(comments, "referenced")
+ }
+ b.WriteString(" INDEX ")
+ b.Write(ast.QuoteIdentifierOptional(i.Name))
+ b.WriteString(" ON ")
+ b.Write(ast.QuoteIdentifierOptional(t.Name))
+ b.WriteByte('(')
+ for ci, c := range i.Columns {
+ b.Write(ast.QuoteIdentifierOptional(t.Columns[c].Name))
+ if ci < len(i.Columns)-1 {
+ b.WriteString(", ")
+ }
+ }
+ b.WriteString(");")
+ if len(comments) > 0 {
+ b.WriteString(" -- ")
+ b.WriteString(strings.Join(comments, ", "))
+ }
+ b.WriteByte('\n')
+ }
+ }
+ return b.String()
+}
+
// Table defiens sqlvm table struct.
type Table struct {
Name []byte
@@ -172,6 +298,10 @@ type column struct {
Rest interface{}
}
+// MaxForeignKeys is the maximum number of foreign key constraints which can be
+// defined on a column.
+const MaxForeignKeys = math.MaxUint8
+
// Column defines sqlvm index struct.
type Column struct {
column
@@ -240,6 +370,7 @@ func (c *Column) DecodeRLP(s *rlp.Stream) error {
// nil is converted to empty list by encoder, while empty list is
// converted to []interface{} by decoder.
// So we view this case as nil and skip it.
+ c.Default = nil
case []byte:
major, _ := ast.DecomposeDataType(c.Type)
switch major {
@@ -249,7 +380,9 @@ func (c *Column) DecodeRLP(s *rlp.Stream) error {
} else {
c.Default = false
}
- case ast.DataTypeMajorFixedBytes, ast.DataTypeMajorDynamicBytes:
+ case ast.DataTypeMajorAddress,
+ ast.DataTypeMajorFixedBytes,
+ ast.DataTypeMajorDynamicBytes:
c.Default = rest
default:
d, ok := ast.DecimalDecode(c.Type, rest)