diff options
author | Ting-Wei Lan <tingwei.lan@cobinhood.com> | 2019-04-23 15:44:55 +0800 |
---|---|---|
committer | Ting-Wei Lan <tingwei.lan@cobinhood.com> | 2019-05-14 11:04:15 +0800 |
commit | 521b542feaaad50b2f6aca8fdff5b8fcb7578593 (patch) | |
tree | ae43b805f46b67dd5a09adc36fb977c09cf604cd | |
parent | 7a237ad79f7d7e6ed1a875ee700805b6d9d3791b (diff) | |
download | dexon-521b542feaaad50b2f6aca8fdff5b8fcb7578593.tar.gz dexon-521b542feaaad50b2f6aca8fdff5b8fcb7578593.tar.zst dexon-521b542feaaad50b2f6aca8fdff5b8fcb7578593.zip |
core: vm: sqlvm: implement schema compiler and type checker
-rw-r--r-- | core/vm/sqlvm/ast/ast.go | 135 | ||||
-rw-r--r-- | core/vm/sqlvm/ast/types.go | 182 | ||||
-rw-r--r-- | core/vm/sqlvm/checkers/actions.go | 147 | ||||
-rw-r--r-- | core/vm/sqlvm/checkers/checkers.go | 1496 | ||||
-rw-r--r-- | core/vm/sqlvm/checkers/utils.go | 471 | ||||
-rw-r--r-- | core/vm/sqlvm/cmd/ast-checker/main.go | 123 | ||||
-rw-r--r-- | core/vm/sqlvm/errors/errors.go | 90 | ||||
-rw-r--r-- | core/vm/sqlvm/schema/schema.go | 135 |
8 files changed, 2766 insertions, 13 deletions
diff --git a/core/vm/sqlvm/ast/ast.go b/core/vm/sqlvm/ast/ast.go index 2215f4480..5868549a8 100644 --- a/core/vm/sqlvm/ast/ast.go +++ b/core/vm/sqlvm/ast/ast.go @@ -1,7 +1,10 @@ package ast import ( + "bytes" "fmt" + "unicode" + "unicode/utf8" "github.com/dexon-foundation/decimal" @@ -66,6 +69,113 @@ func (n *NodeBase) SetToken(token []byte) { n.Token = token } +func safeIdentifierStart(r rune) bool { + return (r >= 'A' && r <= 'Z') || (r >= 'a' && r <= 'z') || (r >= 0x80) +} + +func safeIdentifierExtend(r rune) bool { + return (r >= '0' && r <= '9') || (r == '_') +} + +func safeString(s []byte, quote byte, mustQuote bool) []byte { + o := bytes.Buffer{} + o.Grow(len(s) + 1) + o.WriteByte(quote) + + if len(s) > 0 { + for r, i, size := rune(0), 0, 0; i < len(s); i += size { + r, size = utf8.DecodeRune(s[i:]) + if r == utf8.RuneError { + switch size { + case 0: + panic("we should never pass an empty slice to DecodeRune") + case 1: + mustQuote = true + x := fmt.Sprintf(`\x%02x`, s[i]) + o.WriteString(x) + continue + // The default case is deliberately omitted here. It is + // possible for a valid UTF-8 string to include the code + // point utf8.RuneError. If it is the case, the size must + // be larger than 1. + } + } + safeIdentifier := + safeIdentifierStart(r) || (i > 0 && safeIdentifierExtend(r)) + if !safeIdentifier { + mustQuote = true + } + switch r { + case '\\': + mustQuote = true + o.WriteString(`\\`) + case '\'': + mustQuote = true + o.WriteString(`\'`) + case '"': + mustQuote = true + o.WriteString(`\"`) + case '\b': + mustQuote = true + o.WriteString(`\b`) + case '\f': + mustQuote = true + o.WriteString(`\f`) + case '\n': + mustQuote = true + o.WriteString(`\n`) + case '\r': + mustQuote = true + o.WriteString(`\r`) + case '\t': + mustQuote = true + o.WriteString(`\t`) + case '\v': + mustQuote = true + o.WriteString(`\v`) + default: + if unicode.IsPrint(r) { + x := s[i : i+size] + o.Write(x) + } else { + mustQuote = true + if r > 0xffff { + x := fmt.Sprintf(`\U%08x`, r) + o.WriteString(x) + } else { + x := fmt.Sprintf(`\u%04x`, r) + o.WriteString(x) + } + } + } + } + } else { + mustQuote = true + } + + if mustQuote { + o.WriteByte(quote) + return o.Bytes() + } + return o.Bytes()[1:] +} + +// QuoteIdentifier quotes a string for safely using it as a SQL identifier. +func QuoteIdentifier(s []byte) []byte { + return safeString(s, '"', true) +} + +// QuoteIdentifierOptional is the same as QuoteIdentifier except that it does +// not quote the string if the string itself can be used safely. +func QuoteIdentifierOptional(s []byte) []byte { + return safeString(s, '"', false) +} + +// QuoteString quotes a string for safely using it as a SQL string literal. +func QuoteString(s []byte) []byte { + return safeString(s, '\'', true) +} + // --------------------------------------------------------------------------- // Identifiers // --------------------------------------------------------------------------- @@ -186,6 +296,31 @@ func (n *BoolValueNode) GetType() DataType { return ComposeDataType(DataTypeMajorBool, DataTypeMinorDontCare) } +// AddressValueNode is an address constant. +type AddressValueNode struct { + UntaggedExprNodeBase + V []byte +} + +var _ Valuer = (*AddressValueNode)(nil) + +func (n *AddressValueNode) ˉValuer() {} + +// GetChildren returns a list of child nodes used for traversing. +func (n *AddressValueNode) GetChildren() []Node { + return nil +} + +// IsConstant returns whether a node is a constant. +func (n *AddressValueNode) IsConstant() bool { + return true +} + +// GetType returns the type of 'bool'. +func (n *AddressValueNode) GetType() DataType { + return ComposeDataType(DataTypeMajorAddress, DataTypeMinorDontCare) +} + // IntegerValueNode is an integer constant. type IntegerValueNode struct { TaggedExprNodeBase diff --git a/core/vm/sqlvm/ast/types.go b/core/vm/sqlvm/ast/types.go index a441b4b2c..fc403fcb3 100644 --- a/core/vm/sqlvm/ast/types.go +++ b/core/vm/sqlvm/ast/types.go @@ -59,6 +59,7 @@ const ( // Special data types which are commonly used. const ( DataTypePending DataType = (DataType(DataTypeMajorPending) << 8) | DataType(DataTypeMinorDontCare) + DataTypeNull DataType = (DataType(DataTypeMajorSpecial) << 8) | DataType(DataTypeMinorSpecialNull) DataTypeBad DataType = math.MaxUint16 ) @@ -189,6 +190,136 @@ func (dt DataType) Size() uint8 { } } +// Valid checks whether a data type is set to a defined value. +func (dt DataType) Valid() bool { + major, minor := DecomposeDataType(dt) + switch major { + case DataTypeMajorPending, + DataTypeMajorBool, + DataTypeMajorAddress, + DataTypeMajorDynamicBytes: + return true + case DataTypeMajorSpecial: + switch minor { + case DataTypeMinorSpecialNull: + return true + } + case DataTypeMajorInt, + DataTypeMajorUint, + DataTypeMajorFixedBytes: + if minor <= 0x1f { + return true + } + } + if (major.IsFixedRange() || major.IsUfixedRange()) && minor <= 80 { + return true + } + return false +} + +// ValidColumn checks whether a data type is a valid column type. +func (dt DataType) ValidColumn() bool { + major, minor := DecomposeDataType(dt) + switch major { + case DataTypeMajorBool, + DataTypeMajorAddress, + DataTypeMajorDynamicBytes: + return true + case DataTypeMajorInt, + DataTypeMajorUint, + DataTypeMajorFixedBytes: + if minor <= 0x1f { + return true + } + } + if (major.IsFixedRange() || major.IsUfixedRange()) && minor <= 80 { + return true + } + return false +} + +// ValidExpr checks whether a data type is a valid type of an expression. +func (dt DataType) ValidExpr() bool { + return dt.ValidColumn() || dt == DataTypeNull +} + +// Equal checks whether two data types are equal and valid. If any of them is +// invalid, false is returned. +func (dt DataType) Equal(dt2 DataType) bool { + // Rename variables. + a := dt + b := dt2 + // Process the common case. + if a == b { + return a.Valid() + } + // a ≠ b + aMajor, _ := DecomposeDataType(a) + bMajor, _ := DecomposeDataType(b) + if aMajor != bMajor { + return false + } + // a ≠ b, aMajor = bMajor ⇒ aMinor ≠ bMinor + switch aMajor { + case DataTypeMajorPending, + DataTypeMajorBool, + DataTypeMajorAddress, + DataTypeMajorDynamicBytes: + return true + default: + return false + } +} + +func (dt DataType) String() string { + major, minor := DecomposeDataType(dt) + switch major { + case DataTypeMajorPending: + return "<PENDING>" + case DataTypeMajorSpecial: + switch minor { + case DataTypeMinorSpecialNull: + return "NULL" + } + case DataTypeMajorBool: + return "BOOL" + case DataTypeMajorAddress: + return "ADDRESS" + case DataTypeMajorInt: + if minor <= 0x1f { + size := (uint32(minor) + 1) * 8 + return fmt.Sprintf("INT%d", size) + } + case DataTypeMajorUint: + if minor <= 0x1f { + size := (uint32(minor) + 1) * 8 + return fmt.Sprintf("UINT%d", size) + } + case DataTypeMajorFixedBytes: + if minor <= 0x1f { + size := uint32(minor) + 1 + return fmt.Sprintf("BYTES%d", size) + } + case DataTypeMajorDynamicBytes: + return "BYTES" + } + switch { + case major.IsFixedRange(): + if minor <= 80 { + size := (uint32(major-DataTypeMajorFixed) + 1) * 8 + fractionalDigits := uint32(minor) + return fmt.Sprintf("FIXED%dX%d", size, fractionalDigits) + } + case major.IsUfixedRange(): + if minor <= 80 { + size := (uint32(major-DataTypeMajorUfixed) + 1) * 8 + fractionalDigits := uint32(minor) + return fmt.Sprintf("UFIXED%dX%d", size, fractionalDigits) + } + } + return "<INVALID>" +} + // GetNode constructs an AST node from a data type. func (dt DataType) GetNode() TypeNode { major, minor := DecomposeDataType(dt) @@ -293,12 +424,39 @@ var decimalMinMaxMap = func() map[DataType]decimalMinMaxPair { // GetMinMax returns min, max pair according to given data type. func (dt DataType) GetMinMax() (decimal.Decimal, decimal.Decimal, bool) { - var ( - pair decimalMinMaxPair - ok bool - ) - pair, ok = decimalMinMaxMap[dt] - return pair.Min, pair.Max, ok + pair, ok := decimalMinMaxMap[dt] + if ok { + return pair.Min, pair.Max, true + } + + // Compute the range of fixed and ufixed types on demand. + major, minor := DecomposeDataType(dt) + switch { + case major.IsFixedRange(): + mapInt := ComposeDataType( + DataTypeMajorInt, DataTypeMinor(major-DataTypeMajorFixed)) + pair, ok = decimalMinMaxMap[mapInt] + if !ok || minor > 80 { + return decimal.Zero, decimal.Zero, false + } + min := pair.Min.Shift(-int32(minor)) + max := pair.Max.Shift(-int32(minor)) + return min, max, true + + case major.IsUfixedRange(): + mapUint := ComposeDataType( + DataTypeMajorUint, DataTypeMinor(major-DataTypeMajorUfixed)) + pair, ok = decimalMinMaxMap[mapUint] + if !ok || minor > 80 { + return decimal.Zero, decimal.Zero, false + } + min := pair.Min.Shift(-int32(minor)) + max := pair.Max.Shift(-int32(minor)) + return min, max, true + + default: + return decimal.Zero, decimal.Zero, false + } } func decimalToBig(d decimal.Decimal) *big.Int { @@ -318,13 +476,21 @@ func decimalEncode(size int, d decimal.Decimal) []byte { if s > 0 { bs := b.Bytes() - copy(ret[size-len(bs):], bs) + if size >= len(bs) { + copy(ret[size-len(bs):], bs) + } else { + copy(ret, bs[len(bs)-size:]) + } return ret } b.Add(b, bigIntOne) bs := b.Bytes() - copy(ret[size-len(bs):], bs) + if size >= len(bs) { + copy(ret[size-len(bs):], bs) + } else { + copy(ret, bs[len(bs)-size:]) + } for idx := range ret { ret[idx] = ^ret[idx] } diff --git a/core/vm/sqlvm/checkers/actions.go b/core/vm/sqlvm/checkers/actions.go new file mode 100644 index 000000000..c15029d9a --- /dev/null +++ b/core/vm/sqlvm/checkers/actions.go @@ -0,0 +1,147 @@ +package checkers + +import ( + "fmt" + + "github.com/dexon-foundation/dexon/core/vm/sqlvm/ast" + "github.com/dexon-foundation/dexon/core/vm/sqlvm/errors" + "github.com/dexon-foundation/dexon/core/vm/sqlvm/schema" +) + +// CheckOptions stores boolean options for Check* functions. +type CheckOptions uint32 + +const ( + // CheckWithSafeMath enables overflow and underflow checks during expression + // evaluation. An error will be thrown when the result is out of range. + CheckWithSafeMath CheckOptions = 1 << iota + // CheckWithSafeCast enables overflow and underflow checks during casting. + // An error will be thrown if the value does not fit in the target type. + CheckWithSafeCast + // CheckWithConstantOnly restricts the expression to be a constant. An error + // will be thrown if the expression cannot be folded into a constant. + CheckWithConstantOnly +) + +// CheckCreate runs CREATE commands to generate a database schema. It modifies +// AST in-place during evaluation of expressions. +func CheckCreate(ss []ast.StmtNode, o CheckOptions) (schema.Schema, error) { + fn := "CheckCreate" + s := schema.Schema{} + c := newSchemaCache() + el := errors.ErrorList{} + + for idx := range ss { + if ss[idx] == nil { + continue + } + + switch n := ss[idx].(type) { + case *ast.CreateTableStmtNode: + checkCreateTableStmt(n, &s, o, c, &el) + case *ast.CreateIndexStmtNode: + checkCreateIndexStmt(n, &s, o, c, &el) + default: + el.Append(errors.Error{ + Position: ss[idx].GetPosition(), + Length: ss[idx].GetLength(), + Category: errors.ErrorCategoryCommand, + Code: errors.ErrorCodeDisallowedCommand, + Severity: errors.ErrorSeverityError, + Prefix: fn, + Message: fmt.Sprintf( + "command %s is not allowed when creating a contract", + ast.QuoteIdentifier(ss[idx].GetVerb())), + }, nil) + } + } + + if len(s) == 0 && len(el) == 0 { + el.Append(errors.Error{ + Position: 0, + Length: 0, + Category: errors.ErrorCategoryCommand, + Code: errors.ErrorCodeNoCommand, + Severity: errors.ErrorSeverityError, + Prefix: fn, + Message: "creating a contract without a table is not allowed", + }, nil) + } + if len(el) != 0 { + return s, el + } + return s, nil +} + +// CheckQuery checks and modifies SELECT commands with a given database schema. +func CheckQuery(ss []ast.StmtNode, s schema.Schema, o CheckOptions) error { + fn := "CheckQuery" + c := newSchemaCache() + el := errors.ErrorList{} + + for idx := range ss { + if ss[idx] == nil { + continue + } + + switch n := ss[idx].(type) { + case *ast.SelectStmtNode: + checkSelectStmt(n, s, o, c, &el) + default: + el.Append(errors.Error{ + Position: ss[idx].GetPosition(), + Length: ss[idx].GetLength(), + Category: errors.ErrorCategoryCommand, + Code: errors.ErrorCodeDisallowedCommand, + Severity: errors.ErrorSeverityError, + Prefix: fn, + Message: fmt.Sprintf( + "command %s is not allowed when calling query", + ast.QuoteIdentifier(ss[idx].GetVerb())), + }, nil) + } + } + if len(el) != 0 { + return el + } + return nil +} + +// CheckExec checks and modifies UPDATE, DELETE, INSERT commands with a given +// database schema. +func CheckExec(ss []ast.StmtNode, s schema.Schema, o CheckOptions) error { + fn := "CheckExec" + c := newSchemaCache() + el := errors.ErrorList{} + + for idx := range ss { + if ss[idx] == nil { + continue + } + + switch n := ss[idx].(type) { + case *ast.UpdateStmtNode: + checkUpdateStmt(n, s, o, c, &el) + case *ast.DeleteStmtNode: + checkDeleteStmt(n, s, o, c, &el) + case *ast.InsertStmtNode: + checkInsertStmt(n, s, o, c, &el) + default: + el.Append(errors.Error{ + Position: ss[idx].GetPosition(), + Length: ss[idx].GetLength(), + Category: errors.ErrorCategoryCommand, + Code: errors.ErrorCodeDisallowedCommand, + Severity: errors.ErrorSeverityError, + Prefix: fn, + Message: fmt.Sprintf( + "command %s is not allowed when calling exec", + ast.QuoteIdentifier(ss[idx].GetVerb())), + }, nil) + } + } + if len(el) != 0 { + return el + } + return nil +} diff --git a/core/vm/sqlvm/checkers/checkers.go b/core/vm/sqlvm/checkers/checkers.go new file mode 100644 index 000000000..429d08a5c --- /dev/null +++ b/core/vm/sqlvm/checkers/checkers.go @@ -0,0 +1,1496 @@ +package checkers + +import ( + "fmt" + "sort" + + "github.com/dexon-foundation/decimal" + + "github.com/dexon-foundation/dexon/core/vm/sqlvm/ast" + "github.com/dexon-foundation/dexon/core/vm/sqlvm/errors" + "github.com/dexon-foundation/dexon/core/vm/sqlvm/schema" +) + +// In addition to the convention mentioned in utils.go, we have these variable +// names in this file: +// +// ftd -> foreign table descriptor +// ftn -> foreign table name +// fcd -> foreign column descriptor +// fcn -> foreign column name +// fid -> foreign index descriptor +// fin -> foreign index name +// +// fmid -> first matching index descriptor +// fmir -> first matching index reference +// fmin -> first matching index name + +// findFirstMatchingIndex finds the first index in 'haystack' matching the +// declaration of 'needle' with attributes specified in 'attrDontCare' ignored. +// This function is considered as a part of the interface, so it have to work +// deterministically. +func findFirstMatchingIndex(haystack []schema.Index, needle schema.Index, + attrDontCare schema.IndexAttr) (schema.IndexRef, bool) { + + compareAttr := func(a1, a2 schema.IndexAttr) bool { + a1 = a1.GetDeclaredFlags() | attrDontCare + a2 = a2.GetDeclaredFlags() | attrDontCare + return a1 == a2 + } + + compareColumns := func(c1, c2 []schema.ColumnRef) bool { + if len(c1) != len(c2) { + return false + } + for ci := range c1 { + if c1[ci] != c2[ci] { + return false + } + } + return true + } + + for ii := range haystack { + if compareAttr(haystack[ii].Attr, needle.Attr) && + compareColumns(haystack[ii].Columns, needle.Columns) { + ir := schema.IndexRef(ii) + return ir, true + } + } + return 0, false +} + +func checkCreateTableStmt(n *ast.CreateTableStmtNode, s *schema.Schema, + o CheckOptions, c *schemaCache, el *errors.ErrorList) { + + fn := "CheckCreateTableStmt" + hasError := false + + if c.Begin() != 0 { + panic("schema cache must not have any open scope") + } + defer func() { + if hasError { + c.Rollback() + return + } + c.Commit() + }() + + // Return early if there are too many tables. We cannot ignore this error + // because it will overflow schema.TableRef, which is used as a part of + // column key in schemaCache. + if len(*s) > schema.MaxTableRef { + el.Append(errors.Error{ + Position: n.GetPosition(), + Length: n.GetLength(), + Category: errors.ErrorCategoryLimit, + Code: errors.ErrorCodeTooManyTables, + Severity: errors.ErrorSeverityError, + Prefix: fn, + Message: fmt.Sprintf("cannot have more than %d tables", + schema.MaxTableRef+1), + }, &hasError) + return + } + + table := schema.Table{} + tr := schema.TableRef(len(*s)) + td := schema.TableDescriptor{Table: tr} + + if len(n.Table.Name) == 0 { + el.Append(errors.Error{ + Position: n.Table.GetPosition(), + Length: n.Table.GetLength(), + Category: errors.ErrorCategorySemantic, + Code: errors.ErrorCodeEmptyTableName, + Severity: errors.ErrorSeverityError, + Prefix: fn, + Message: "cannot create a table with an empty name", + }, &hasError) + } + + tn := n.Table.Name + if !c.AddTable(string(tn), td) { + el.Append(errors.Error{ + Position: n.Table.GetPosition(), + Length: n.Table.GetLength(), + Category: errors.ErrorCategorySemantic, + Code: errors.ErrorCodeDuplicateTableName, + Severity: errors.ErrorSeverityError, + Prefix: fn, + Message: fmt.Sprintf("table %s already exists", + ast.QuoteIdentifier(tn)), + }, &hasError) + } + table.Name = n.Table.Name + table.Columns = make([]schema.Column, 0, len(n.Column)) + + // Handle the primary key index. + pk := []schema.ColumnRef{} + // Handle sequences. + seq := 0 + // Handle indices for unique constraints. + type localIndex struct { + index schema.Index + node ast.Node + } + localIndices := []localIndex{} + // Handle indices for foreign key constraints. + type foreignNewIndex struct { + table schema.TableDescriptor + index schema.Index + node ast.Node + } + foreignNewIndices := []foreignNewIndex{} + type foreignExistingIndex struct { + index schema.IndexDescriptor + node ast.Node + } + foreignExistingIndices := []foreignExistingIndex{} + + for ci := range n.Column { + if len(table.Columns) > schema.MaxColumnRef { + el.Append(errors.Error{ + Position: n.Column[ci].GetPosition(), + Length: n.Column[ci].GetLength(), + Category: errors.ErrorCategoryLimit, + Code: errors.ErrorCodeTooManyColumns, + Severity: errors.ErrorSeverityError, + Prefix: fn, + Message: fmt.Sprintf("cannot have more than %d columns", + schema.MaxColumnRef+1), + }, &hasError) + return + } + + column := schema.Column{} + ok := func() (ok bool) { + innerHasError := false + defer func() { ok = !innerHasError }() + + // Block access to the outer hasError variable. + hasError := struct{}{} + // Suppress “declared and not used” error. + _ = hasError + + c.Begin() + defer func() { + if innerHasError { + c.Rollback() + return + } + c.Commit() + }() + + cr := schema.ColumnRef(len(table.Columns)) + cd := schema.ColumnDescriptor{Table: tr, Column: cr} + + if len(n.Column[ci].Column.Name) == 0 { + el.Append(errors.Error{ + Position: n.Column[ci].Column.GetPosition(), + Length: n.Column[ci].Column.GetLength(), + Category: errors.ErrorCategorySemantic, + Code: errors.ErrorCodeEmptyColumnName, + Severity: errors.ErrorSeverityError, + Prefix: fn, + Message: "cannot declare a column with an empty name", + }, &innerHasError) + } + + cn := n.Column[ci].Column.Name + if !c.AddColumn(string(cn), cd) { + el.Append(errors.Error{ + Position: n.Column[ci].Column.GetPosition(), + Length: n.Column[ci].Column.GetLength(), + Category: errors.ErrorCategorySemantic, + Code: errors.ErrorCodeDuplicateColumnName, + Severity: errors.ErrorSeverityError, + Prefix: fn, + Message: fmt.Sprintf("column %s already exists", + ast.QuoteIdentifier(cn)), + }, &innerHasError) + } else { + column.Name = n.Column[ci].Column.Name + } + + dt, code, message := n.Column[ci].DataType.GetType() + if code == errors.ErrorCodeNil { + if !dt.ValidColumn() { + el.Append(errors.Error{ + Position: n.Column[ci].DataType.GetPosition(), + Length: n.Column[ci].DataType.GetLength(), + Category: errors.ErrorCategorySemantic, + Code: errors.ErrorCodeInvalidColumnDataType, + Severity: errors.ErrorSeverityError, + Prefix: fn, + Message: fmt.Sprintf( + "cannot declare a column with type %s", dt.String()), + }, &innerHasError) + } + } else { + el.Append(errors.Error{ + Position: n.Column[ci].DataType.GetPosition(), + Length: n.Column[ci].DataType.GetLength(), + Category: errors.ErrorCategorySemantic, + Code: code, + Severity: errors.ErrorSeverityError, + Prefix: fn, + Message: message, + }, &innerHasError) + } + column.Type = dt + + // Backup lengths of slices in case we have to rollback. We don't + // have to copy slice headers or data stored in underlying arrays + // because we always append data at the end. + defer func(LPK, SEQ, LLI, LFNI, LFEI int) { + if innerHasError { + pk = pk[:LPK] + seq = SEQ + localIndices = localIndices[:LLI] + foreignNewIndices = foreignNewIndices[:LFNI] + foreignExistingIndices = foreignExistingIndices[:LFEI] + } + }( + len(pk), seq, len(localIndices), len(foreignNewIndices), + len(foreignExistingIndices), + ) + + // cs -> constraint + // csi -> constraint index + for csi := range n.Column[ci].Constraint { + // Cases are sorted in the same order as internal/grammar.peg. + cs: + switch cs := n.Column[ci].Constraint[csi].(type) { + case *ast.PrimaryOptionNode: + pk = append(pk, cr) + column.Attr |= schema.ColumnAttrPrimaryKey + + case *ast.NotNullOptionNode: + column.Attr |= schema.ColumnAttrNotNull + + case *ast.UniqueOptionNode: + if (column.Attr & schema.ColumnAttrUnique) != 0 { + // Don't create duplicate indices on a column. + break cs + } + column.Attr |= schema.ColumnAttrUnique + indexName := fmt.Sprintf("%s_%s_unique", + table.Name, column.Name) + idx := schema.Index{ + Name: []byte(indexName), + Attr: schema.IndexAttrUnique, + Columns: []schema.ColumnRef{cr}, + } + localIndices = append(localIndices, localIndex{ + index: idx, + node: cs, + }) + + case *ast.DefaultOptionNode: + if (column.Attr & schema.ColumnAttrHasDefault) != 0 { + el.Append(errors.Error{ + Position: cs.GetPosition(), + Length: cs.GetLength(), + Category: errors.ErrorCategorySemantic, + Code: errors.ErrorCodeMultipleDefaultValues, + Severity: errors.ErrorSeverityError, + Prefix: fn, + Message: "cannot have multiple default values", + }, &innerHasError) + break cs + } + column.Attr |= schema.ColumnAttrHasDefault + + value := checkExpr(cs.Value, *s, o|CheckWithConstantOnly, + c, el, 0, newTypeActionAssign(column.Type)) + if value == nil { + innerHasError = true + break cs + } + cs.Value = value + + switch v := cs.Value.(ast.Valuer).(type) { + case *ast.BoolValueNode: + sb := v.V.NullBool() + if !sb.Valid { + el.Append(errors.Error{ + Position: cs.GetPosition(), + Length: cs.GetLength(), + Category: errors.ErrorCategorySemantic, + Code: errors.ErrorCodeNullDefaultValue, + Severity: errors.ErrorSeverityError, + Prefix: fn, + Message: "default value must not be NULL", + }, &innerHasError) + break cs + } + column.Default = sb.Bool + + case *ast.AddressValueNode: + column.Default = v.V + + case *ast.IntegerValueNode: + column.Default = v.V + + case *ast.DecimalValueNode: + column.Default = v.V + + case *ast.BytesValueNode: + column.Default = v.V + + case *ast.NullValueNode: + el.Append(errors.Error{ + Position: cs.GetPosition(), + Length: cs.GetLength(), + Category: errors.ErrorCategorySemantic, + Code: errors.ErrorCodeNullDefaultValue, + Severity: errors.ErrorSeverityError, + Prefix: fn, + Message: "default value must not be NULL", + }, &innerHasError) + break cs + + default: + panic(fmt.Sprintf("unknown constant value type %T", n)) + } + + case *ast.ForeignOptionNode: + if len(column.ForeignKeys) > schema.MaxForeignKeys { + el.Append(errors.Error{ + Position: cs.GetPosition(), + Length: cs.GetLength(), + Category: errors.ErrorCategoryLimit, + Code: errors.ErrorCodeTooManyForeignKeys, + Severity: errors.ErrorSeverityError, + Prefix: fn, + Message: fmt.Sprintf( + "cannot have more than %d foreign key "+ + "constraints in a column", + schema.MaxForeignKeys+1), + }, &innerHasError) + break cs + } + column.Attr |= schema.ColumnAttrHasForeignKey + ftn := cs.Table.Name + ftd, found := c.FindTableInBase(string(ftn)) + if !found { + el.Append(errors.Error{ + Position: cs.Table.GetPosition(), + Length: cs.Table.GetLength(), + Category: errors.ErrorCategorySemantic, + Code: errors.ErrorCodeTableNotFound, + Severity: errors.ErrorSeverityError, + Prefix: fn, + Message: fmt.Sprintf( + "foreign table %s does not exist", + ast.QuoteIdentifier(ftn)), + }, &innerHasError) + break cs + } + fcn := cs.Column.Name + fcd, found := c.FindColumnInBase(ftd.Table, string(fcn)) + if !found { + el.Append(errors.Error{ + Position: cs.Column.GetPosition(), + Length: cs.Column.GetLength(), + Category: errors.ErrorCategorySemantic, + Code: errors.ErrorCodeColumnNotFound, + Severity: errors.ErrorSeverityError, + Prefix: fn, + Message: fmt.Sprintf( + "column %s does not exist in foreign table %s", + ast.QuoteIdentifier(fcn), + ast.QuoteIdentifier(ftn)), + }, &innerHasError) + break cs + } + foreignType := (*s)[fcd.Table].Columns[fcd.Column].Type + if !foreignType.Equal(column.Type) { + el.Append(errors.Error{ + Position: n.Column[ci].DataType.GetPosition(), + Length: n.Column[ci].DataType.GetLength(), + Category: errors.ErrorCategorySemantic, + Code: errors.ErrorCodeForeignKeyDataTypeMismatch, + Severity: errors.ErrorSeverityError, + Prefix: fn, + Message: fmt.Sprintf( + "foreign column has type %s (%04x), but "+ + "this column has type %s (%04x)", + foreignType.String(), uint16(foreignType), + column.Type.String(), uint16(column.Type)), + }, &innerHasError) + break cs + } + + idx := schema.Index{ + Attr: schema.IndexAttrReferenced, + Columns: []schema.ColumnRef{fcd.Column}, + } + fmir, found := findFirstMatchingIndex( + (*s)[ftd.Table].Indices, idx, schema.IndexAttrUnique) + if found { + fmid := schema.IndexDescriptor{ + Table: ftd.Table, + Index: fmir, + } + foreignExistingIndices = append( + foreignExistingIndices, foreignExistingIndex{ + index: fmid, + node: cs, + }) + } else { + if len(column.ForeignKeys) > 0 { + idx.Name = []byte(fmt.Sprintf("%s_%s_foreign_key_%d", + table.Name, column.Name, len(column.ForeignKeys))) + } else { + idx.Name = []byte(fmt.Sprintf("%s_%s_foreign_key", + table.Name, column.Name)) + } + foreignNewIndices = append( + foreignNewIndices, foreignNewIndex{ + table: ftd, + index: idx, + node: cs, + }) + } + column.ForeignKeys = append(column.ForeignKeys, fcd) + + case *ast.AutoIncrementOptionNode: + if (column.Attr & schema.ColumnAttrHasSequence) != 0 { + // Don't process AUTOINCREMENT twice. + break cs + } + // We set the flag regardless of the error because we + // don't want to produce duplicate errors. + column.Attr |= schema.ColumnAttrHasSequence + if seq > schema.MaxSequenceRef { + el.Append(errors.Error{ + Position: cs.GetPosition(), + Length: cs.GetLength(), + Category: errors.ErrorCategoryLimit, + Code: errors.ErrorCodeTooManySequences, + Severity: errors.ErrorSeverityError, + Prefix: fn, + Message: fmt.Sprintf( + "cannot have more than %d sequences", + schema.MaxSequenceRef+1), + }, &innerHasError) + break cs + } + major, _ := ast.DecomposeDataType(column.Type) + switch major { + case ast.DataTypeMajorInt, ast.DataTypeMajorUint: + default: + el.Append(errors.Error{ + Position: cs.GetPosition(), + Length: cs.GetLength(), + Category: errors.ErrorCategorySemantic, + Code: errors.ErrorCodeInvalidAutoIncrementDataType, + Prefix: fn, + Message: fmt.Sprintf( + "AUTOINCREMENT is only supported on "+ + "INT and UINT types, but this column "+ + "has type %s (%04x)", + column.Type.String(), uint16(column.Type)), + }, &innerHasError) + break cs + } + column.Sequence = schema.SequenceRef(seq) + seq++ + + default: + panic(fmt.Sprintf("unknown constraint type %T", c)) + } + } + + // The return value will be set by the first defer function. + return + }() + + // If an error occurs in the function, stop here and continue + // processing the next column. + if !ok { + hasError = true + continue + } + + // Commit the column. + table.Columns = append(table.Columns, column) + } + + // Return early if there is any error. + if hasError { + return + } + + mustAddIndex := func(name *[]byte, id schema.IndexDescriptor) { + for !c.AddIndex(string(*name), id, true) { + *name = append(*name, '_') + } + } + + // Create the primary key index. This is the first index on the table, so + // it is not possible to exceed the limit on the number of indices. + ir := schema.IndexRef(len(table.Indices)) + if len(pk) > 0 { + idx := schema.Index{ + Name: []byte(fmt.Sprintf("%s_primary_key", table.Name)), + Attr: schema.IndexAttrUnique, + Columns: pk, + } + id := schema.IndexDescriptor{Table: tr, Index: ir} + mustAddIndex(&idx.Name, id) + table.Indices = append(table.Indices, idx) + } + + // Create indices for the current table. + for ii := range localIndices { + if len(table.Indices) > schema.MaxIndexRef { + el.Append(errors.Error{ + Position: localIndices[ii].node.GetPosition(), + Length: localIndices[ii].node.GetLength(), + Category: errors.ErrorCategoryLimit, + Code: errors.ErrorCodeTooManyIndices, + Severity: errors.ErrorSeverityError, + Prefix: fn, + Message: fmt.Sprintf("cannot have more than %d indices", + schema.MaxIndexRef+1), + }, &hasError) + return + } + idx := localIndices[ii].index + ir := schema.IndexRef(len(table.Indices)) + id := schema.IndexDescriptor{Table: tr, Index: ir} + mustAddIndex(&idx.Name, id) + table.Indices = append(table.Indices, idx) + } + + // Create indices for foreign tables. + for ii := range foreignNewIndices { + ftd := foreignNewIndices[ii].table + if len((*s)[ftd.Table].Indices) > schema.MaxIndexRef { + el.Append(errors.Error{ + Position: foreignNewIndices[ii].node.GetPosition(), + Length: foreignNewIndices[ii].node.GetLength(), + Category: errors.ErrorCategoryLimit, + Code: errors.ErrorCodeTooManyIndices, + Severity: errors.ErrorSeverityError, + Prefix: fn, + Message: fmt.Sprintf( + "table %s already has %d indices", + ast.QuoteIdentifier((*s)[ftd.Table].Name), + schema.MaxIndexRef+1), + }, &hasError) + return + } + idx := foreignNewIndices[ii].index + ir := schema.IndexRef(len((*s)[ftd.Table].Indices)) + id := schema.IndexDescriptor{Table: ftd.Table, Index: ir} + mustAddIndex(&idx.Name, id) + (*s)[ftd.Table].Indices = append((*s)[ftd.Table].Indices, idx) + defer func(tr schema.TableRef, length schema.IndexRef) { + if hasError { + (*s)[tr].Indices = (*s)[tr].Indices[:ir] + } + }(ftd.Table, ir) + } + + // Mark existing indices as referenced. + for ii := range foreignExistingIndices { + fid := foreignExistingIndices[ii].index + (*s)[fid.Table].Indices[fid.Index].Attr |= schema.IndexAttrReferenced + } + + // Finally, we can commit the table definition to the schema. + *s = append(*s, table) +} + +func checkCreateIndexStmt(n *ast.CreateIndexStmtNode, s *schema.Schema, + o CheckOptions, c *schemaCache, el *errors.ErrorList) { + + fn := "CheckCreateIndexStmt" + hasError := false + + if c.Begin() != 0 { + panic("schema cache must not have any open scope") + } + defer func() { + if hasError { + c.Rollback() + return + } + c.Commit() + }() + + tn := n.Table.Name + td, found := c.FindTableInBase(string(tn)) + if !found { + el.Append(errors.Error{ + Position: n.Table.GetPosition(), + Length: n.Table.GetLength(), + Category: errors.ErrorCategorySemantic, + Code: errors.ErrorCodeTableNotFound, + Severity: errors.ErrorSeverityError, + Prefix: fn, + Message: fmt.Sprintf( + "index table %s does not exist", + ast.QuoteIdentifier(tn)), + }, &hasError) + return + } + + if len(n.Column) > schema.MaxColumnRef { + begin := n.Column[0].GetPosition() + last := len(n.Column) - 1 + end := n.Column[last].GetPosition() + n.Column[last].GetLength() + el.Append(errors.Error{ + Position: begin, + Length: end - begin, + Category: errors.ErrorCategoryLimit, + Code: errors.ErrorCodeTooManyColumns, + Severity: errors.ErrorSeverityError, + Prefix: fn, + Message: fmt.Sprintf( + "cannot create an index on more than %d columns", + schema.MaxColumnRef+1), + }, &hasError) + return + } + + columnRefs := newColumnRefSlice(uint8(len(n.Column))) + for ci := range n.Column { + cn := n.Column[ci].Name + cd, found := c.FindColumnInBase(td.Table, string(cn)) + if !found { + el.Append(errors.Error{ + Position: n.Column[ci].GetPosition(), + Length: n.Column[ci].GetLength(), + Category: errors.ErrorCategorySemantic, + Code: errors.ErrorCodeColumnNotFound, + Severity: errors.ErrorSeverityError, + Prefix: fn, + Message: fmt.Sprintf( + "column %s does not exist in index table %s", + ast.QuoteIdentifier(cn), + ast.QuoteIdentifier(tn)), + }, &hasError) + continue + } + columnRefs.Append(cd.Column, uint8(ci)) + } + if hasError { + return + } + + sort.Stable(columnRefs) + for ci := 1; ci < len(n.Column); ci++ { + if columnRefs.columns[ci] == columnRefs.columns[ci-1] { + el.Append(errors.Error{ + Position: n.Column[columnRefs.nodes[ci]].GetPosition(), + Length: n.Column[columnRefs.nodes[ci]].GetLength(), + Category: errors.ErrorCategorySemantic, + Code: errors.ErrorCodeDuplicateIndexColumn, + Severity: errors.ErrorSeverityError, + Prefix: fn, + Message: fmt.Sprintf( + "column %s already exists in the column list", + ast.QuoteIdentifier(n.Column[columnRefs.nodes[ci]].Name)), + }, &hasError) + return + } + } + + index := schema.Index{} + index.Columns = columnRefs.columns + + if len((*s)[td.Table].Indices) > schema.MaxIndexRef { + el.Append(errors.Error{ + Position: n.GetPosition(), + Length: n.GetLength(), + Category: errors.ErrorCategoryLimit, + Code: errors.ErrorCodeTooManyIndices, + Severity: errors.ErrorSeverityError, + Prefix: fn, + Message: fmt.Sprintf( + "cannot have more than %d indices in table %s", + schema.MaxIndexRef+1, + ast.QuoteIdentifier(tn)), + }, &hasError) + return + } + + ir := schema.IndexRef(len((*s)[td.Table].Indices)) + id := schema.IndexDescriptor{Table: td.Table, Index: ir} + if n.Unique != nil { + index.Attr |= schema.IndexAttrUnique + } + + if len(n.Index.Name) == 0 { + el.Append(errors.Error{ + Position: n.Index.GetPosition(), + Length: n.Table.GetLength(), + Category: errors.ErrorCategorySemantic, + Code: errors.ErrorCodeEmptyIndexName, + Severity: errors.ErrorSeverityError, + Prefix: fn, + Message: "cannot create an index with an empty name", + }, &hasError) + return + } + + // If there is an existing index that is automatically created, rename it + // instead of creating a new one. + rename := false + fmir, found := findFirstMatchingIndex((*s)[id.Table].Indices, index, 0) + if found { + fmid := schema.IndexDescriptor{Table: id.Table, Index: fmir} + fmin := (*s)[id.Table].Indices[fmir].Name + fminString := string(fmin) + fmidCache, auto, found := c.FindIndexInBase(fminString) + if !found { + panic(fmt.Sprintf("index %s exists in the schema, "+ + "but it cannot be found in the schema cache", + ast.QuoteIdentifier(fmin))) + } + if fmidCache != fmid { + panic(fmt.Sprintf("index %s has descriptor %+v, "+ + "but the schema cache records it as %+v", + ast.QuoteIdentifier(fmin), fmid, fmidCache)) + } + if auto { + if !c.DeleteIndex(fminString) { + panic(fmt.Sprintf("unable to mark index %s for deletion", + ast.QuoteIdentifier(fmin))) + } + rename = true + id = fmid + ir = id.Index + } + } + + in := n.Index.Name + if !c.AddIndex(string(in), id, false) { + el.Append(errors.Error{ + Position: n.Index.GetPosition(), + Length: n.Index.GetLength(), + Category: errors.ErrorCategorySemantic, + Code: errors.ErrorCodeDuplicateIndexName, + Severity: errors.ErrorSeverityError, + Prefix: fn, + Message: fmt.Sprintf("index %s already exists", + ast.QuoteIdentifier(in)), + }, &hasError) + return + } + + // Commit the change into the schema. + if rename { + (*s)[id.Table].Indices[id.Index].Name = n.Index.Name + } else { + index.Name = n.Index.Name + (*s)[id.Table].Indices = append((*s)[id.Table].Indices, index) + } +} + +func checkSelectStmt(n *ast.SelectStmtNode, s schema.Schema, + o CheckOptions, c *schemaCache, el *errors.ErrorList) { +} + +func checkUpdateStmt(n *ast.UpdateStmtNode, s schema.Schema, + o CheckOptions, c *schemaCache, el *errors.ErrorList) { +} + +func checkDeleteStmt(n *ast.DeleteStmtNode, s schema.Schema, + o CheckOptions, c *schemaCache, el *errors.ErrorList) { +} + +func checkInsertStmt(n *ast.InsertStmtNode, s schema.Schema, + o CheckOptions, c *schemaCache, el *errors.ErrorList) { +} + +func checkExpr(n ast.ExprNode, + s schema.Schema, o CheckOptions, c *schemaCache, el *errors.ErrorList, + tr schema.TableRef, ta typeAction) ast.ExprNode { + + switch n := n.(type) { + case *ast.IdentifierNode: + return checkVariable(n, s, o, c, el, tr, ta) + + case *ast.BoolValueNode: + return checkBoolValue(n, o, el, ta) + + case *ast.AddressValueNode: + return checkAddressValue(n, o, el, ta) + + case *ast.IntegerValueNode: + return checkIntegerValue(n, o, el, ta) + + case *ast.DecimalValueNode: + return checkDecimalValue(n, o, el, ta) + + case *ast.BytesValueNode: + return checkBytesValue(n, o, el, ta) + + case *ast.NullValueNode: + return checkNullValue(n, o, el, ta) + + case *ast.PosOperatorNode: + return checkPosOperator(n, s, o, c, el, tr, ta) + + case *ast.NegOperatorNode: + return n + + case *ast.NotOperatorNode: + return n + + case *ast.ParenOperatorNode: + return n + + case *ast.AndOperatorNode: + return n + + case *ast.OrOperatorNode: + return n + + case *ast.GreaterOrEqualOperatorNode: + return n + + case *ast.LessOrEqualOperatorNode: + return n + + case *ast.NotEqualOperatorNode: + return n + + case *ast.EqualOperatorNode: + return n + + case *ast.GreaterOperatorNode: + return n + + case *ast.LessOperatorNode: + return n + + case *ast.ConcatOperatorNode: + return n + + case *ast.AddOperatorNode: + return n + + case *ast.SubOperatorNode: + return n + + case *ast.MulOperatorNode: + return n + + case *ast.DivOperatorNode: + return n + + case *ast.ModOperatorNode: + return n + + case *ast.IsOperatorNode: + return n + + case *ast.LikeOperatorNode: + return n + + case *ast.CastOperatorNode: + return n + + case *ast.InOperatorNode: + return n + + case *ast.FunctionOperatorNode: + return n + + default: + panic(fmt.Sprintf("unknown expression type %T", n)) + } +} + +func checkVariable(n *ast.IdentifierNode, + s schema.Schema, o CheckOptions, c *schemaCache, el *errors.ErrorList, + tr schema.TableRef, ta typeAction) ast.ExprNode { + + fn := "CheckVariable" + + if (o & CheckWithConstantOnly) != 0 { + el.Append(errors.Error{ + Position: n.GetPosition(), + Length: n.GetLength(), + Category: errors.ErrorCategorySemantic, + Code: errors.ErrorCodeNonConstantExpression, + Severity: errors.ErrorSeverityError, + Prefix: fn, + Message: fmt.Sprintf("%s is not a constant", + ast.QuoteIdentifier(n.Name)), + }, nil) + return nil + } + + cn := string(n.Name) + cd, found := c.FindColumnInBase(tr, cn) + if !found { + el.Append(errors.Error{ + Position: n.GetPosition(), + Length: n.GetLength(), + Category: errors.ErrorCategorySemantic, + Code: errors.ErrorCodeColumnNotFound, + Severity: errors.ErrorSeverityError, + Prefix: fn, + Message: fmt.Sprintf( + "cannot find column %s in table %s", + ast.QuoteIdentifier(n.Name), + ast.QuoteIdentifier(s[tr].Name)), + }, nil) + return nil + } + + cr := cd.Column + dt := s[tr].Columns[cr].Type + switch a := ta.(type) { + case typeActionInferDefault: + case typeActionInferWithSize: + case typeActionAssign: + if !dt.Equal(a.dt) { + el.Append(errors.Error{ + Position: n.GetPosition(), + Length: n.GetLength(), + Category: errors.ErrorCategorySemantic, + Code: errors.ErrorCodeTypeError, + Severity: errors.ErrorSeverityError, + Prefix: fn, + Message: fmt.Sprintf( + "expect %s (%04x), but %s (%04x) is given", + a.dt.String(), uint16(a.dt), dt.String(), uint16(dt)), + }, nil) + return nil + } + } + + n.SetType(dt) + n.Desc = cd + return n +} + +func checkBoolValue(n *ast.BoolValueNode, + o CheckOptions, el *errors.ErrorList, ta typeAction) ast.ExprNode { + + fn := "CheckBoolValue" + + switch a := ta.(type) { + case typeActionInferDefault: + case typeActionInferWithSize: + case typeActionAssign: + major, _ := ast.DecomposeDataType(a.dt) + if major != ast.DataTypeMajorBool { + el.Append(errors.Error{ + Position: n.GetPosition(), + Length: n.GetLength(), + Category: errors.ErrorCategorySemantic, + Code: errors.ErrorCodeTypeError, + Severity: errors.ErrorSeverityError, + Prefix: fn, + Message: fmt.Sprintf( + "expect %s (%04x), but boolean value is given", + a.dt.String(), uint16(a.dt)), + }, nil) + return nil + } + } + return n +} + +func checkAddressValue(n *ast.AddressValueNode, + o CheckOptions, el *errors.ErrorList, ta typeAction) ast.ExprNode { + + fn := "CheckAddressValue" + + switch a := ta.(type) { + case typeActionInferDefault: + case typeActionInferWithSize: + case typeActionAssign: + major, _ := ast.DecomposeDataType(a.dt) + if major != ast.DataTypeMajorAddress { + el.Append(errors.Error{ + Position: n.GetPosition(), + Length: n.GetLength(), + Category: errors.ErrorCategorySemantic, + Code: errors.ErrorCodeTypeError, + Severity: errors.ErrorSeverityError, + Prefix: fn, + Message: fmt.Sprintf( + "expect %s (%04x), but address value is given", + a.dt.String(), uint16(a.dt)), + }, nil) + return nil + } + } + return n +} + +func mustGetMinMax(dt ast.DataType) (decimal.Decimal, decimal.Decimal) { + min, max, ok := dt.GetMinMax() + if !ok { + panic(fmt.Sprintf("GetMinMax does not handle %v", dt)) + } + return min, max +} + +func mustDecimalEncode(dt ast.DataType, d decimal.Decimal) []byte { + b, ok := ast.DecimalEncode(dt, d) + if !ok { + panic(fmt.Sprintf("DecimalEncode does not handle %v", dt)) + } + return b +} + +func mustDecimalDecode(dt ast.DataType, b []byte) decimal.Decimal { + d, ok := ast.DecimalDecode(dt, b) + if !ok { + panic(fmt.Sprintf("DecimalDecode does not handle %v", dt)) + } + return d +} + +func cropDecimal(dt ast.DataType, d decimal.Decimal) decimal.Decimal { + b := mustDecimalEncode(dt, d) + return mustDecimalDecode(dt, b) +} + +func elAppendConstantTooLongError(el *errors.ErrorList, n ast.Node, + fn string, v decimal.Decimal) { + el.Append(errors.Error{ + Position: n.GetPosition(), + Length: n.GetLength(), + Category: errors.ErrorCategorySemantic, + Code: errors.ErrorCodeConstantTooLong, + Severity: errors.ErrorSeverityError, + Prefix: fn, + Message: fmt.Sprintf( + "constant expression %s has more than %d digits", + ast.QuoteString(n.GetToken()), MaxIntegerPartDigits), + }, nil) +} + +func elAppendOverflowError(el *errors.ErrorList, n ast.Node, + fn string, dt ast.DataType, v, min, max decimal.Decimal) { + el.Append(errors.Error{ + Position: n.GetPosition(), + Length: n.GetLength(), + Category: errors.ErrorCategorySemantic, + Code: errors.ErrorCodeOverflow, + Severity: errors.ErrorSeverityError, + Prefix: fn, + Message: fmt.Sprintf( + "number %s (%s) overflow %s (%04x)", + v.String(), ast.QuoteString(n.GetToken()), + dt.String(), uint16(dt)), + }, nil) + el.Append(errors.Error{ + Position: n.GetPosition(), + Length: n.GetLength(), + Category: 0, + Code: 0, + Severity: errors.ErrorSeverityNote, + Prefix: fn, + Message: fmt.Sprintf( + "the range of %s is [%s, %s]", + dt.String(), min.String(), max.String()), + }, nil) +} + +func elAppendOverflowWarning(el *errors.ErrorList, n ast.Node, + fn string, dt ast.DataType, from, to decimal.Decimal) { + el.Append(errors.Error{ + Position: n.GetPosition(), + Length: n.GetLength(), + Category: 0, + Code: 0, + Severity: errors.ErrorSeverityWarning, + Prefix: fn, + Message: fmt.Sprintf( + "number %s (%s) overflow %s (%04x), converted to %s", + from.String(), ast.QuoteString(n.GetToken()), + dt.String(), uint16(dt), to.String()), + }, nil) +} + +func checkIntegerValue(n *ast.IntegerValueNode, + o CheckOptions, el *errors.ErrorList, ta typeAction) ast.ExprNode { + + fn := "CheckIntegerValue" + + normalizeDecimal(&n.V) + if !safeDecimalRange(n.V) { + elAppendConstantTooLongError(el, n, fn, n.V) + return nil + } + + infer := func(size int) (ast.DataType, bool) { + // The first case: assume V fits in the signed integer. + minor := ast.DataTypeMinor(size - 1) + dt := ast.ComposeDataType(ast.DataTypeMajorInt, minor) + min, max := mustGetMinMax(dt) + // Return if V < min. V must be negative so it must be signed. + if n.V.LessThan(min) { + if (o & CheckWithSafeMath) != 0 { + elAppendOverflowError(el, n, fn, dt, n.V, min, max) + return dt, false + } + cropped := cropDecimal(dt, n.V) + elAppendOverflowWarning(el, n, fn, dt, n.V, cropped) + normalizeDecimal(&cropped) + n.V = cropped + return dt, true + } + // We are done if V fits in the signed integer. + if n.V.LessThanOrEqual(max) { + return dt, true + } + + // The second case: V is a non-negative integer, but it does not fit + // in the signed integer. Test whether the unsigned integer works. + dt = ast.ComposeDataType(ast.DataTypeMajorUint, minor) + min, max = mustGetMinMax(dt) + // Return if V > max. We don't have to test whether V < min because min + // is always zero and we already know V is non-negative. + if n.V.GreaterThan(max) { + if (o & CheckWithSafeMath) != 0 { + elAppendOverflowError(el, n, fn, dt, n.V, min, max) + return dt, false + } + cropped := cropDecimal(dt, n.V) + elAppendOverflowWarning(el, n, fn, dt, n.V, cropped) + normalizeDecimal(&cropped) + n.V = cropped + return dt, true + } + return dt, true + } + + dt := ast.DataTypePending + switch a := ta.(type) { + case typeActionInferDefault: + // Default to int256 or uint256. + var ok bool + dt, ok = infer(256 / 8) + if !ok { + return nil + } + + case typeActionInferWithSize: + var ok bool + dt, ok = infer(a.size) + if !ok { + return nil + } + + case typeActionAssign: + dt = a.dt + major, _ := ast.DecomposeDataType(dt) + switch { + case major == ast.DataTypeMajorAddress: + if !n.IsAddress { + el.Append(errors.Error{ + Position: n.GetPosition(), + Length: n.GetLength(), + Category: errors.ErrorCategorySemantic, + Code: errors.ErrorCodeInvalidAddressChecksum, + Severity: errors.ErrorSeverityError, + Prefix: fn, + Message: fmt.Sprintf( + "expect %s (%04x), but %s has invalid checksum", + dt.String(), uint16(dt), n.GetToken()), + }, nil) + return nil + } + // Redirect to checkAddressValue if it becomes an address. + addrNode := &ast.AddressValueNode{} + addrNode.SetPosition(addrNode.GetPosition()) + addrNode.SetLength(addrNode.GetLength()) + addrNode.SetToken(addrNode.GetToken()) + addrNode.V = mustDecimalEncode(ast.ComposeDataType( + ast.DataTypeMajorUint, ast.DataTypeMinor(160/8-1)), n.V) + return checkAddressValue(addrNode, o, el, ta) + + case major == ast.DataTypeMajorInt, + major == ast.DataTypeMajorUint, + major.IsFixedRange(), + major.IsUfixedRange(): + min, max := mustGetMinMax(dt) + if n.V.LessThan(min) || n.V.GreaterThan(max) { + if (o & CheckWithSafeMath) != 0 { + elAppendOverflowError(el, n, fn, dt, n.V, min, max) + return nil + } + cropped := cropDecimal(dt, n.V) + elAppendOverflowWarning(el, n, fn, dt, n.V, cropped) + normalizeDecimal(&cropped) + n.V = cropped + } + + default: + el.Append(errors.Error{ + Position: n.GetPosition(), + Length: n.GetLength(), + Category: errors.ErrorCategorySemantic, + Code: errors.ErrorCodeTypeError, + Severity: errors.ErrorSeverityError, + Prefix: fn, + Message: fmt.Sprintf( + "expect %s (%04x), but number value is given", + dt.String(), uint16(dt)), + }, nil) + return nil + } + } + + if dt != ast.DataTypePending { + n.SetType(dt) + } + return n +} + +func checkDecimalValue(n *ast.DecimalValueNode, + o CheckOptions, el *errors.ErrorList, ta typeAction) ast.ExprNode { + + fn := "CheckDecimalValue" + + normalizeDecimal(&n.V) + if !safeDecimalRange(n.V) { + elAppendConstantTooLongError(el, n, fn, n.V) + return nil + } + + // Redirect to checkIntegerValue if the value is an integer. + if intPart := n.V.Truncate(0); n.V.Equal(intPart) { + intNode := &ast.IntegerValueNode{} + intNode.SetPosition(n.GetPosition()) + intNode.SetLength(n.GetLength()) + intNode.SetToken(n.GetToken()) + intNode.IsAddress = false + intNode.V = n.V + return checkIntegerValue(intNode, o, el, ta) + } + + infer := func(size, fractionalDigits int) (ast.DataType, bool) { + major := ast.DataTypeMajorFixed + ast.DataTypeMajor(size-1) + minor := ast.DataTypeMinor(fractionalDigits) + dt := ast.ComposeDataType(major, minor) + min, max := mustGetMinMax(dt) + if n.V.LessThan(min) { + if (o & CheckWithSafeMath) != 0 { + elAppendOverflowError(el, n, fn, dt, n.V, min, max) + return dt, false + } + cropped := cropDecimal(dt, n.V) + elAppendOverflowWarning(el, n, fn, dt, n.V, cropped) + normalizeDecimal(&cropped) + n.V = cropped + return dt, false + } + if n.V.LessThanOrEqual(max) { + return dt, true + } + + major = ast.DataTypeMajorUfixed + ast.DataTypeMajor(size-1) + minor = ast.DataTypeMinor(fractionalDigits) + dt = ast.ComposeDataType(major, minor) + min, max = mustGetMinMax(dt) + if n.V.GreaterThan(max) { + if (o & CheckWithSafeMath) != 0 { + elAppendOverflowError(el, n, fn, dt, n.V, min, max) + return dt, false + } + cropped := cropDecimal(dt, n.V) + elAppendOverflowWarning(el, n, fn, dt, n.V, cropped) + normalizeDecimal(&cropped) + n.V = cropped + return dt, true + } + return dt, true + } + + // Now we are sure the number we are dealing has fractional part. + dt := ast.DataTypePending + switch a := ta.(type) { + case typeActionInferDefault: + // Default to fixed128x18 and ufixed128x18. + var ok bool + dt, ok = infer(128/8, 18) + if !ok { + return nil + } + + case typeActionInferWithSize: + // It is unclear that what the size hint means for fixed-point numbers, + // so we just ignore it and do the same thing as the above case. + var ok bool + dt, ok = infer(128/8, 18) + if !ok { + return nil + } + + case typeActionAssign: + dt = a.dt + major, _ := ast.DecomposeDataType(dt) + switch { + case major.IsFixedRange(), + major.IsUfixedRange(): + min, max := mustGetMinMax(dt) + if n.V.LessThan(min) || n.V.GreaterThan(max) { + if (o & CheckWithSafeMath) != 0 { + elAppendOverflowError(el, n, fn, dt, n.V, min, max) + return nil + } + cropped := cropDecimal(dt, n.V) + elAppendOverflowWarning(el, n, fn, dt, n.V, cropped) + normalizeDecimal(&cropped) + n.V = cropped + } + + case major == ast.DataTypeMajorInt, + major == ast.DataTypeMajorUint: + el.Append(errors.Error{ + Position: n.GetPosition(), + Length: n.GetLength(), + Category: errors.ErrorCategorySemantic, + Code: errors.ErrorCodeTypeError, + Severity: errors.ErrorSeverityError, + Prefix: fn, + Message: fmt.Sprintf( + "expect %s (%04x), but the number %s has fractional part", + dt.String(), uint16(dt), n.V.String()), + }, nil) + return nil + + default: + el.Append(errors.Error{ + Position: n.GetPosition(), + Length: n.GetLength(), + Category: errors.ErrorCategorySemantic, + Code: errors.ErrorCodeTypeError, + Severity: errors.ErrorSeverityError, + Prefix: fn, + Message: fmt.Sprintf( + "expect %s (%04x), but number value is given", + dt.String(), uint16(dt)), + }, nil) + return nil + } + } + + if dt != ast.DataTypePending { + n.SetType(dt) + _, minor := ast.DecomposeDataType(dt) + fractionalDigits := int32(minor) + n.V = n.V.Round(fractionalDigits) + normalizeDecimal(&n.V) + } + return n +} + +func checkBytesValue(n *ast.BytesValueNode, + o CheckOptions, el *errors.ErrorList, ta typeAction) ast.ExprNode { + + fn := "CheckBytesValue" + + dt := ast.DataTypePending + switch a := ta.(type) { + case typeActionInferDefault: + // Default to bytes. + major := ast.DataTypeMajorDynamicBytes + minor := ast.DataTypeMinorDontCare + dt = ast.ComposeDataType(major, minor) + + case typeActionInferWithSize: + // Still default to bytes. The size hint does not matter at all. + major := ast.DataTypeMajorDynamicBytes + minor := ast.DataTypeMinorDontCare + dt = ast.ComposeDataType(major, minor) + + case typeActionAssign: + dt = a.dt + major, minor := ast.DecomposeDataType(dt) + switch major { + case ast.DataTypeMajorDynamicBytes: + // Do nothing because it is always valid. + + case ast.DataTypeMajorFixedBytes: + sizeGiven := len(n.V) + sizeExpected := int(minor) + 1 + if sizeGiven != sizeExpected { + el.Append(errors.Error{ + Position: n.GetPosition(), + Length: n.GetLength(), + Category: errors.ErrorCategorySemantic, + Code: errors.ErrorCodeTypeError, + Severity: errors.ErrorSeverityError, + Prefix: fn, + Message: fmt.Sprintf( + "expect %s (%04x), but %s has %d bytes", + dt.String(), uint16(dt), + ast.QuoteString(n.V), sizeGiven), + }, nil) + return nil + } + + default: + el.Append(errors.Error{ + Position: n.GetPosition(), + Length: n.GetLength(), + Category: errors.ErrorCategorySemantic, + Code: errors.ErrorCodeTypeError, + Severity: errors.ErrorSeverityError, + Prefix: fn, + Message: fmt.Sprintf( + "expect %s (%04x), but bytes value is given", + dt.String(), uint16(dt)), + }, nil) + return nil + } + } + + if dt != ast.DataTypePending { + n.SetType(dt) + } + return n +} + +func checkNullValue(n *ast.NullValueNode, + o CheckOptions, el *errors.ErrorList, ta typeAction) ast.ExprNode { + + dt := ast.DataTypePending + switch a := ta.(type) { + case typeActionInferDefault: + dt = ast.DataTypeNull + case typeActionInferWithSize: + dt = ast.DataTypeNull + case typeActionAssign: + dt = a.dt + } + + if dt != ast.DataTypePending { + n.SetType(dt) + } + return n +} + +func checkPosOperator(n *ast.PosOperatorNode, + s schema.Schema, o CheckOptions, c *schemaCache, el *errors.ErrorList, + tr schema.TableRef, ta typeAction) ast.ExprNode { + + /* + fn := "CheckPosOperator" + + switch a := ta.(type) { + case typeActionInferDefault: + case typeActionInferWithSize: + case typeActionAssign: + } + */ + return n +} diff --git a/core/vm/sqlvm/checkers/utils.go b/core/vm/sqlvm/checkers/utils.go new file mode 100644 index 000000000..4a8bbd96e --- /dev/null +++ b/core/vm/sqlvm/checkers/utils.go @@ -0,0 +1,471 @@ +package checkers + +import ( + "github.com/dexon-foundation/decimal" + + "github.com/dexon-foundation/dexon/core/vm/sqlvm/ast" + "github.com/dexon-foundation/dexon/core/vm/sqlvm/schema" +) + +// Variable name convention: +// +// fn -> function name +// el -> error list +// +// td -> table descriptor +// tr -> table reference +// ti -> table index +// tn -< table name +// +// cd -> column descriptor +// cr -> column reference +// ci -> column index +// cn -> column name +// +// id -> index descriptor +// ir -> index reference +// ii -> index index +// in -> index name + +const ( + MaxIntegerPartDigits int32 = 200 + MaxFractionalPartDigits int32 = 200 +) + +var ( + MaxConstant = func() decimal.Decimal { + max := (decimal.New(1, MaxIntegerPartDigits). + Sub(decimal.New(1, -MaxFractionalPartDigits))) + normalizeDecimal(&max) + return max + }() + MinConstant = MaxConstant.Neg() +) + +func normalizeDecimal(d *decimal.Decimal) { + if d.Exponent() != -MaxFractionalPartDigits { + *d = d.Rescale(-MaxFractionalPartDigits) + } +} + +func safeDecimalRange(d decimal.Decimal) bool { + return d.GreaterThanOrEqual(MinConstant) && d.LessThanOrEqual(MaxConstant) +} + +type schemaCache struct { + base schemaCacheBase + scopes []schemaCacheScope +} + +type schemaCacheIndexValue struct { + id schema.IndexDescriptor + auto bool +} + +type schemaCacheColumnKey struct { + tr schema.TableRef + n string +} + +type schemaCacheBase struct { + table map[string]schema.TableDescriptor + index map[string]schemaCacheIndexValue + column map[schemaCacheColumnKey]schema.ColumnDescriptor +} + +func (lower *schemaCacheBase) Merge(upper schemaCacheScope) { + // Process deletions. + for n := range upper.tableDeleted { + delete(lower.table, n) + } + for n := range upper.indexDeleted { + delete(lower.index, n) + } + for ck := range upper.columnDeleted { + delete(lower.column, ck) + } + + // Process additions. + for n, td := range upper.table { + lower.table[n] = td + } + for n, iv := range upper.index { + lower.index[n] = iv + } + for ck, cd := range upper.column { + lower.column[ck] = cd + } +} + +type schemaCacheScope struct { + table map[string]schema.TableDescriptor + tableDeleted map[string]struct{} + index map[string]schemaCacheIndexValue + indexDeleted map[string]struct{} + column map[schemaCacheColumnKey]schema.ColumnDescriptor + columnDeleted map[schemaCacheColumnKey]struct{} +} + +func (lower *schemaCacheScope) Merge(upper schemaCacheScope) { + // Process deletions. + for n := range upper.tableDeleted { + delete(lower.table, n) + lower.tableDeleted[n] = struct{}{} + } + for n := range upper.indexDeleted { + delete(lower.index, n) + lower.indexDeleted[n] = struct{}{} + } + for ck := range upper.columnDeleted { + delete(lower.column, ck) + lower.columnDeleted[ck] = struct{}{} + } + + // Process additions. + for n, td := range upper.table { + lower.table[n] = td + } + for n, iv := range upper.index { + lower.index[n] = iv + } + for ck, cd := range upper.column { + lower.column[ck] = cd + } +} + +func newSchemaCache() *schemaCache { + return &schemaCache{ + base: schemaCacheBase{ + table: map[string]schema.TableDescriptor{}, + index: map[string]schemaCacheIndexValue{}, + column: map[schemaCacheColumnKey]schema.ColumnDescriptor{}, + }, + } +} + +func (c *schemaCache) Begin() int { + position := len(c.scopes) + scope := schemaCacheScope{ + table: map[string]schema.TableDescriptor{}, + tableDeleted: map[string]struct{}{}, + index: map[string]schemaCacheIndexValue{}, + indexDeleted: map[string]struct{}{}, + column: map[schemaCacheColumnKey]schema.ColumnDescriptor{}, + columnDeleted: map[schemaCacheColumnKey]struct{}{}, + } + c.scopes = append(c.scopes, scope) + return position +} + +func (c *schemaCache) Rollback() { + if len(c.scopes) == 0 { + panic("there is no scope to rollback") + } + c.scopes = c.scopes[:len(c.scopes)-1] +} + +func (c *schemaCache) RollbackTo(position int) { + for position <= len(c.scopes) { + c.Rollback() + } +} + +func (c *schemaCache) Commit() { + if len(c.scopes) == 0 { + panic("there is no scope to commit") + } + if len(c.scopes) == 1 { + c.base.Merge(c.scopes[0]) + } else { + src := len(c.scopes) - 1 + dst := len(c.scopes) - 2 + c.scopes[dst].Merge(c.scopes[src]) + } + c.scopes = c.scopes[:len(c.scopes)-1] +} + +func (c *schemaCache) CommitTo(position int) { + for position <= len(c.scopes) { + c.Commit() + } +} + +func (c *schemaCache) FindTableInBase(n string) ( + schema.TableDescriptor, bool) { + + td, exists := c.base.table[n] + return td, exists +} + +func (c *schemaCache) FindTableInScope(n string) ( + schema.TableDescriptor, bool) { + + for si := range c.scopes { + si = len(c.scopes) - si - 1 + if td, exists := c.scopes[si].table[n]; exists { + return td, true + } + if _, exists := c.scopes[si].tableDeleted[n]; exists { + return schema.TableDescriptor{}, false + } + } + return c.FindTableInBase(n) +} + +func (c *schemaCache) FindTableInBaseWithFallback(n string, + fallback schema.Schema) (schema.TableDescriptor, bool) { + + if td, found := c.FindTableInBase(n); found { + return td, true + } + if fallback == nil { + return schema.TableDescriptor{}, false + } + + s := fallback + for ti := range s { + if n == string(s[ti].Name) { + tr := schema.TableRef(ti) + td := schema.TableDescriptor{Table: tr} + c.base.table[n] = td + return td, true + } + } + return schema.TableDescriptor{}, false +} + +func (c *schemaCache) FindIndexInBase(n string) ( + schema.IndexDescriptor, bool, bool) { + + iv, exists := c.base.index[n] + return iv.id, iv.auto, exists +} + +func (c *schemaCache) FindIndexInScope(n string) ( + schema.IndexDescriptor, bool, bool) { + + for si := range c.scopes { + si = len(c.scopes) - si - 1 + if iv, exists := c.scopes[si].index[n]; exists { + return iv.id, iv.auto, true + } + if _, exists := c.scopes[si].indexDeleted[n]; exists { + return schema.IndexDescriptor{}, false, false + } + } + return c.FindIndexInBase(n) +} + +func (c *schemaCache) FindIndexInBaseWithFallback(n string, + fallback schema.Schema) (schema.IndexDescriptor, bool, bool) { + + if id, auto, found := c.FindIndexInBase(n); found { + return id, auto, true + } + if fallback == nil { + return schema.IndexDescriptor{}, false, false + } + + s := fallback + for ti := range s { + for ii := range s[ti].Indices { + if n == string(s[ti].Indices[ii].Name) { + tr := schema.TableRef(ti) + ir := schema.IndexRef(ii) + id := schema.IndexDescriptor{Table: tr, Index: ir} + iv := schemaCacheIndexValue{id: id, auto: false} + c.base.index[n] = iv + return id, false, true + } + } + } + return schema.IndexDescriptor{}, false, false +} + +func (c *schemaCache) FindColumnInBase(tr schema.TableRef, n string) ( + schema.ColumnDescriptor, bool) { + + cd, exists := c.base.column[schemaCacheColumnKey{tr: tr, n: n}] + return cd, exists +} + +func (c *schemaCache) FindColumnInScope(tr schema.TableRef, n string) ( + schema.ColumnDescriptor, bool) { + + ck := schemaCacheColumnKey{tr: tr, n: n} + for si := range c.scopes { + si = len(c.scopes) - si - 1 + if cd, exists := c.scopes[si].column[ck]; exists { + return cd, true + } + if _, exists := c.scopes[si].columnDeleted[ck]; exists { + return schema.ColumnDescriptor{}, false + } + } + return c.FindColumnInBase(tr, n) +} + +func (c *schemaCache) FindColumnInBaseWithFallback(tr schema.TableRef, n string, + fallback schema.Schema) (schema.ColumnDescriptor, bool) { + + if cd, found := c.FindColumnInBase(tr, n); found { + return cd, true + } + if fallback == nil { + return schema.ColumnDescriptor{}, false + } + + s := fallback + for ci := range s[tr].Columns { + if n == string(s[tr].Columns[ci].Name) { + cr := schema.ColumnRef(ci) + cd := schema.ColumnDescriptor{Table: tr, Column: cr} + ck := schemaCacheColumnKey{tr: tr, n: n} + c.base.column[ck] = cd + return cd, true + } + } + return schema.ColumnDescriptor{}, false +} + +func (c *schemaCache) AddTable(n string, + td schema.TableDescriptor) bool { + + top := len(c.scopes) - 1 + if _, found := c.FindTableInScope(n); found { + return false + } + + c.scopes[top].table[n] = td + return true +} + +func (c *schemaCache) AddIndex(n string, + id schema.IndexDescriptor, auto bool) bool { + + top := len(c.scopes) - 1 + if _, _, found := c.FindIndexInScope(n); found { + return false + } + + iv := schemaCacheIndexValue{id: id, auto: auto} + c.scopes[top].index[n] = iv + return true +} + +func (c *schemaCache) AddColumn(n string, + cd schema.ColumnDescriptor) bool { + + top := len(c.scopes) - 1 + tr := cd.Table + if _, found := c.FindColumnInScope(tr, n); found { + return false + } + + ck := schemaCacheColumnKey{tr: tr, n: n} + c.scopes[top].column[ck] = cd + return true +} + +func (c *schemaCache) DeleteTable(n string) bool { + top := len(c.scopes) - 1 + if _, found := c.FindTableInScope(n); !found { + return false + } + + delete(c.scopes[top].table, n) + c.scopes[top].tableDeleted[n] = struct{}{} + return true +} + +func (c *schemaCache) DeleteIndex(n string) bool { + top := len(c.scopes) - 1 + if _, _, found := c.FindIndexInScope(n); !found { + return false + } + + delete(c.scopes[top].index, n) + c.scopes[top].indexDeleted[n] = struct{}{} + return true +} + +func (c *schemaCache) DeleteColumn(tr schema.TableRef, n string) bool { + top := len(c.scopes) - 1 + if _, found := c.FindColumnInScope(tr, n); !found { + return false + } + + ck := schemaCacheColumnKey{tr: tr, n: n} + delete(c.scopes[top].column, ck) + c.scopes[top].columnDeleted[ck] = struct{}{} + return true +} + +type columnRefSlice struct { + columns []schema.ColumnRef + nodes []uint8 +} + +func newColumnRefSlice(c uint8) columnRefSlice { + return columnRefSlice{ + columns: make([]schema.ColumnRef, 0, c), + nodes: make([]uint8, 0, c), + } +} + +func (s *columnRefSlice) Append(c schema.ColumnRef, i uint8) { + s.columns = append(s.columns, c) + s.nodes = append(s.nodes, i) +} + +func (s columnRefSlice) Len() int { + return len(s.columns) +} + +func (s columnRefSlice) Less(i, j int) bool { + return s.columns[i] < s.columns[j] +} + +func (s columnRefSlice) Swap(i, j int) { + s.columns[i], s.columns[j] = s.columns[j], s.columns[i] + s.nodes[i], s.nodes[j] = s.nodes[j], s.nodes[i] +} + +//go-sumtype:decl typeAction +type typeAction interface { + ˉtypeAction() +} + +type typeActionInferDefault struct{} + +func newTypeActionInferDefaultSize() typeActionInferDefault { + return typeActionInferDefault{} +} + +var _ typeAction = typeActionInferDefault{} + +func (typeActionInferDefault) ˉtypeAction() {} + +type typeActionInferWithSize struct { + size int +} + +func newTypeActionInferWithSize(bytes int) typeActionInferWithSize { + return typeActionInferWithSize{size: bytes} +} + +var _ typeAction = typeActionInferWithSize{} + +func (typeActionInferWithSize) ˉtypeAction() {} + +type typeActionAssign struct { + dt ast.DataType +} + +func newTypeActionAssign(expected ast.DataType) typeActionAssign { + return typeActionAssign{dt: expected} +} + +var _ typeAction = typeActionAssign{} + +func (typeActionAssign) ˉtypeAction() {} diff --git a/core/vm/sqlvm/cmd/ast-checker/main.go b/core/vm/sqlvm/cmd/ast-checker/main.go new file mode 100644 index 000000000..c02b58f0f --- /dev/null +++ b/core/vm/sqlvm/cmd/ast-checker/main.go @@ -0,0 +1,123 @@ +package main + +import ( + "bytes" + "encoding/hex" + "flag" + "fmt" + "os" + + "github.com/dexon-foundation/dexon/core/vm/sqlvm/checkers" + "github.com/dexon-foundation/dexon/core/vm/sqlvm/parser" + "github.com/dexon-foundation/dexon/core/vm/sqlvm/schema" + "github.com/dexon-foundation/dexon/rlp" +) + +func create(sql string, o checkers.CheckOptions) int { + n, parseErr := parser.Parse([]byte(sql)) + if parseErr != nil { + fmt.Fprintf(os.Stderr, "Parse error:\n%+v\n", parseErr) + } + s, checkErr := checkers.CheckCreate(n, o) + if checkErr != nil { + fmt.Fprintf(os.Stderr, "Check error:\n%+v\n", checkErr) + } + b := bytes.Buffer{} + rlpErr := rlp.Encode(&b, s) + if rlpErr != nil { + fmt.Fprintf(os.Stderr, "RLP encode error: %v\n", rlpErr) + return 1 + } + fmt.Println(hex.EncodeToString(b.Bytes())) + if parseErr != nil || checkErr != nil { + return 1 + } + return 0 +} + +func decode(ss string) int { + b, hexErr := hex.DecodeString(ss) + if hexErr != nil { + fmt.Fprintf(os.Stderr, "Hex decode error: %v\n", hexErr) + return 1 + } + s := schema.Schema{} + rlpErr := rlp.Decode(bytes.NewReader(b), &s) + if rlpErr != nil { + fmt.Fprintf(os.Stderr, "RLP decode error: %v\n", rlpErr) + return 1 + } + s.SetupColumnOffset() + fmt.Print(s.String()) + return 0 +} + +func query(ss, sql string, o checkers.CheckOptions) int { + fmt.Fprintln(os.Stderr, "Function not implemented") + return 1 +} + +func exec(ss, sql string, o checkers.CheckOptions) int { + fmt.Fprintln(os.Stderr, "Function not implemented") + return 1 +} + +func main() { + var noSafeMath bool + var noSafeCast bool + flag.BoolVar(&noSafeMath, "no-safe-math", false, "disable safe math") + flag.BoolVar(&noSafeCast, "no-safe-cast", false, "disable safe cast") + + flag.Parse() + + if flag.NArg() < 1 { + fmt.Fprintf(os.Stderr, + "Usage: %s <action> <arguments>\n"+ + "Actions:\n"+ + " create <SQL> returns schema\n"+ + " decode <schema> returns SQL\n"+ + " query <schema> <SQL> returns AST\n"+ + " exec <schema> <SQL> returns AST\n", + os.Args[0]) + os.Exit(1) + } + + o := checkers.CheckWithSafeMath | checkers.CheckWithSafeCast + if noSafeMath { + o &= ^(checkers.CheckWithSafeMath) + } + if noSafeCast { + o &= ^(checkers.CheckWithSafeCast) + } + + action := flag.Arg(0) + switch action { + case "create": + if flag.NArg() < 2 { + fmt.Fprintln(os.Stderr, "create needs 1 argument") + os.Exit(1) + } + os.Exit(create(flag.Arg(1), o)) + case "decode": + if flag.NArg() < 2 { + fmt.Fprintln(os.Stderr, "decode needs 1 argument") + os.Exit(1) + } + os.Exit(decode(flag.Arg(1))) + case "query": + if flag.NArg() < 3 { + fmt.Fprintln(os.Stderr, "query needs 2 arguments") + os.Exit(1) + } + os.Exit(query(flag.Arg(1), flag.Arg(2), o)) + case "exec": + if flag.NArg() < 3 { + fmt.Fprintln(os.Stderr, "exec needs 2 arguments") + os.Exit(1) + } + os.Exit(exec(flag.Arg(1), flag.Arg(2), o)) + default: + fmt.Fprintf(os.Stderr, "Invalid action %s\n", action) + os.Exit(1) + } +} diff --git a/core/vm/sqlvm/errors/errors.go b/core/vm/sqlvm/errors/errors.go index c2e9dbab2..dcc8e6a48 100644 --- a/core/vm/sqlvm/errors/errors.go +++ b/core/vm/sqlvm/errors/errors.go @@ -58,17 +58,44 @@ func (e Error) Error() string { // ErrorList is a list of Error. type ErrorList []Error -func (e ErrorList) Error() string { +func (el ErrorList) Error() string { b := strings.Builder{} - for i := range e { + for i, e := range el { if i > 0 { b.WriteByte('\n') } - b.WriteString(e[i].Error()) + b.WriteString(e.Error()) } return b.String() } +func (el ErrorList) HasError() bool { + for _, e := range el { + if e.Severity == ErrorSeverityError { + return true + } + } + return false +} + +func (el ErrorList) GetFirstError() (Error, bool) { + for _, e := range el { + if e.Severity == ErrorSeverityError { + return e, true + } + } + return Error{}, false +} + +func (el *ErrorList) Append(e Error, hasError *bool) { + *el = append(*el, e) + if hasError != nil { + if e.Severity == ErrorSeverityError { + *hasError = true + } + } +} + // ErrorCategory is used to distinguish errors come from different phases. type ErrorCategory uint16 @@ -77,6 +104,7 @@ const ( ErrorCategoryNil ErrorCategory = iota ErrorCategoryLimit ErrorCategoryGrammar + ErrorCategoryCommand ErrorCategorySemantic ErrorCategoryRuntime ) @@ -84,6 +112,7 @@ const ( var errorCategoryMap = [...]string{ ErrorCategoryLimit: "limit", ErrorCategoryGrammar: "grammar", + ErrorCategoryCommand: "command", ErrorCategorySemantic: "semantic", ErrorCategoryRuntime: "runtime", } @@ -98,8 +127,14 @@ type ErrorCode uint16 // Error code starts from 1. Zero value is invalid. const ( ErrorCodeNil ErrorCode = iota - ErrorCodeDepthLimitReached ErrorCodeParser + ErrorCodeDepthLimitReached + ErrorCodeTooManyTables + ErrorCodeTooManyColumns + ErrorCodeTooManyIndices + ErrorCodeTooManyForeignKeys + ErrorCodeTooManySequences + ErrorCodeTooManySelectColumns ErrorCodeInvalidIntegerSyntax ErrorCodeInvalidNumberSyntax ErrorCodeIntegerOutOfRange @@ -115,6 +150,26 @@ const ( ErrorCodeInvalidUfixedSize ErrorCodeInvalidFixedFractionalDigits ErrorCodeInvalidUfixedFractionalDigits + ErrorCodeNoCommand + ErrorCodeDisallowedCommand + ErrorCodeEmptyTableName + ErrorCodeDuplicateTableName + ErrorCodeEmptyColumnName + ErrorCodeDuplicateColumnName + ErrorCodeTableNotFound + ErrorCodeColumnNotFound + ErrorCodeInvalidColumnDataType + ErrorCodeForeignKeyDataTypeMismatch + ErrorCodeNullDefaultValue + ErrorCodeMultipleDefaultValues + ErrorCodeInvalidAutoIncrementDataType + ErrorCodeEmptyIndexName + ErrorCodeDuplicateIndexName + ErrorCodeDuplicateIndexColumn + ErrorCodeTypeError + ErrorCodeConstantTooLong + ErrorCodeNonConstantExpression + ErrorCodeInvalidAddressChecksum // Runtime Error ErrorCodeInvalidOperandNum @@ -132,6 +187,12 @@ const ( var errorCodeMap = [...]string{ ErrorCodeDepthLimitReached: "depth limit reached", + ErrorCodeTooManyTables: "too many tables", + ErrorCodeTooManyColumns: "too many columns", + ErrorCodeTooManyIndices: "too many indices", + ErrorCodeTooManyForeignKeys: "too many foreign keys", + ErrorCodeTooManySequences: "too many sequences", + ErrorCodeTooManySelectColumns: "too many select columns", ErrorCodeParser: "parser error", ErrorCodeInvalidIntegerSyntax: "invalid integer syntax", ErrorCodeInvalidNumberSyntax: "invalid number syntax", @@ -148,6 +209,27 @@ var errorCodeMap = [...]string{ ErrorCodeInvalidUfixedSize: "invalid ufixed size", ErrorCodeInvalidFixedFractionalDigits: "invalid fixed fractional digits", ErrorCodeInvalidUfixedFractionalDigits: "invalid ufixed fractional digits", + ErrorCodeNoCommand: "no command", + ErrorCodeDisallowedCommand: "disallowed command", + ErrorCodeEmptyTableName: "empty table name", + ErrorCodeDuplicateTableName: "duplicate table name", + ErrorCodeEmptyColumnName: "empty column name", + ErrorCodeDuplicateColumnName: "duplicate column name", + ErrorCodeTableNotFound: "table not found", + ErrorCodeColumnNotFound: "column not found", + ErrorCodeInvalidColumnDataType: "invalid column data type", + ErrorCodeForeignKeyDataTypeMismatch: "foreign key data type mismatch", + ErrorCodeNullDefaultValue: "null default value", + ErrorCodeMultipleDefaultValues: "multiple default values", + ErrorCodeInvalidAutoIncrementDataType: "invalid auto increment data type", + ErrorCodeEmptyIndexName: "empty index name", + ErrorCodeDuplicateIndexName: "duplicate index name", + ErrorCodeDuplicateIndexColumn: "duplicate index column", + ErrorCodeTypeError: "type error", + ErrorCodeConstantTooLong: "constant too long", + ErrorCodeNonConstantExpression: "non-constant expression", + ErrorCodeInvalidAddressChecksum: "invalid address checksum", + // Runtime Error ErrorCodeInvalidOperandNum: "invalid operand number", ErrorCodeInvalidDataType: "invalid data type", diff --git a/core/vm/sqlvm/schema/schema.go b/core/vm/sqlvm/schema/schema.go index 71122cc26..6d279aba0 100644 --- a/core/vm/sqlvm/schema/schema.go +++ b/core/vm/sqlvm/schema/schema.go @@ -2,7 +2,10 @@ package schema import ( "errors" + "fmt" "io" + "math" + "strings" "github.com/dexon-foundation/decimal" @@ -66,21 +69,39 @@ func (a ColumnAttr) GetDerivedFlags() ColumnAttr { // FunctionRef defines the type for number of builtin function. type FunctionRef uint16 +// MaxFunctionRef is the maximum value of FunctionRef. +const MaxFunctionRef = math.MaxUint16 + // TableRef defines the type for table index in Schema. type TableRef uint8 +// MaxTableRef is the maximum value of TableRef. +const MaxTableRef = math.MaxUint8 + // ColumnRef defines the type for column index in Table.Columns. type ColumnRef uint8 +// MaxColumnRef is the maximum value of ColumnRef. +const MaxColumnRef = math.MaxUint8 + // IndexRef defines the type for array index of Column.Indices. type IndexRef uint8 +// MaxIndexRef is the maximum value of IndexRef. +const MaxIndexRef = math.MaxUint8 + // SequenceRef defines the type for sequence index in Table. type SequenceRef uint8 +// MaxSequenceRef is the maximum value of SequenceRef. +const MaxSequenceRef = math.MaxUint8 + // SelectColumnRef defines the type for column index in SelectStmtNode.Column. type SelectColumnRef uint16 +// MaxSelectColumnRef is the maximum value of SelectColumnRef. +const MaxSelectColumnRef = math.MaxUint16 + // IndexAttr defines bit flags for describing index attribute. type IndexAttr uint16 @@ -116,6 +137,111 @@ func (s Schema) SetupColumnOffset() { } } +func (s Schema) String() string { + b := strings.Builder{} + b.WriteString(fmt.Sprintf( + "-- DEXON SQLVM database schema dump (%d tables)\n", len(s))) + for _, t := range s { + b.WriteString("CREATE TABLE ") + b.Write(ast.QuoteIdentifierOptional(t.Name)) + b.WriteString(" (\n") + for ci, c := range t.Columns { + b.WriteString(" ") + b.Write(ast.QuoteIdentifierOptional(c.Name)) + b.WriteByte(' ') + b.WriteString(c.Type.String()) + comments := []string{ + fmt.Sprintf("slot %d", c.SlotOffset), + fmt.Sprintf("byte %d", c.ByteOffset), + } + if (c.Attr & ColumnAttrPrimaryKey) != 0 { + b.WriteString(" PRIMARY KEY") + } + if (c.Attr & ColumnAttrNotNull) != 0 { + b.WriteString(" NOT NULL") + } + if (c.Attr & ColumnAttrUnique) != 0 { + b.WriteString(" UNIQUE") + } + if (c.Attr & ColumnAttrHasDefault) != 0 { + b.WriteString(" DEFAULT ") + switch v := c.Default.(type) { + case nil: + b.WriteString("NULL") + case bool: + if v { + b.WriteString("TRUE") + } else { + b.WriteString("FALSE") + } + case []byte: + major, _ := ast.DecomposeDataType(c.Type) + if major == ast.DataTypeMajorAddress { + b.WriteString(common.BytesToAddress(v).String()) + break + } + b.Write(ast.QuoteString(v)) + case decimal.Decimal: + b.WriteString(v.String()) + default: + b.WriteString("<?>") + } + } + if (c.Attr & ColumnAttrHasForeignKey) != 0 { + for _, fk := range c.ForeignKeys { + b.WriteString(" REFERENCES ") + b.Write(ast.QuoteIdentifierOptional( + s[fk.Table].Name)) + b.WriteByte('(') + b.Write(ast.QuoteIdentifierOptional( + s[fk.Table].Columns[fk.Column].Name)) + b.WriteByte(')') + } + } + if (c.Attr & ColumnAttrHasSequence) != 0 { + b.WriteString(" AUTOINCREMENT") + comments = append(comments, + fmt.Sprintf("sequence %d", c.Sequence)) + } + if ci < len(t.Columns)-1 { + b.WriteByte(',') + } + b.WriteString(" -- ") + b.WriteString(strings.Join(comments, ", ")) + b.WriteByte('\n') + } + b.WriteString(");\n") + for _, i := range t.Indices { + comments := []string{} + b.WriteString("CREATE") + if (i.Attr & IndexAttrUnique) != 0 { + b.WriteString(" UNIQUE") + } + if (i.Attr & IndexAttrReferenced) != 0 { + comments = append(comments, "referenced") + } + b.WriteString(" INDEX ") + b.Write(ast.QuoteIdentifierOptional(i.Name)) + b.WriteString(" ON ") + b.Write(ast.QuoteIdentifierOptional(t.Name)) + b.WriteByte('(') + for ci, c := range i.Columns { + b.Write(ast.QuoteIdentifierOptional(t.Columns[c].Name)) + if ci < len(i.Columns)-1 { + b.WriteString(", ") + } + } + b.WriteString(");") + if len(comments) > 0 { + b.WriteString(" -- ") + b.WriteString(strings.Join(comments, ", ")) + } + b.WriteByte('\n') + } + } + return b.String() +} + // Table defiens sqlvm table struct. type Table struct { Name []byte @@ -172,6 +298,10 @@ type column struct { Rest interface{} } +// MaxForeignKeys is the maximum number of foreign key constraints which can be +// defined on a column. +const MaxForeignKeys = math.MaxUint8 + // Column defines sqlvm index struct. type Column struct { column @@ -240,6 +370,7 @@ func (c *Column) DecodeRLP(s *rlp.Stream) error { // nil is converted to empty list by encoder, while empty list is // converted to []interface{} by decoder. // So we view this case as nil and skip it. + c.Default = nil case []byte: major, _ := ast.DecomposeDataType(c.Type) switch major { @@ -249,7 +380,9 @@ func (c *Column) DecodeRLP(s *rlp.Stream) error { } else { c.Default = false } - case ast.DataTypeMajorFixedBytes, ast.DataTypeMajorDynamicBytes: + case ast.DataTypeMajorAddress, + ast.DataTypeMajorFixedBytes, + ast.DataTypeMajorDynamicBytes: c.Default = rest default: d, ok := ast.DecimalDecode(c.Type, rest) |